Skip to content

Commit 9ebf10a

Browse files
authored
Allow register_kernel() to take dispatcher name as input (#7796)
1 parent f3c89cc commit 9ebf10a

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

test/test_transforms_v2_refactored.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,3 +2181,36 @@ def test_unsupported_types(self, dispatcher, make_input):
21812181

21822182
with pytest.raises(TypeError, match=re.escape(str(type(input)))):
21832183
dispatcher(input)
2184+
2185+
2186+
class TestRegisterKernel:
2187+
@pytest.mark.parametrize("dispatcher", (F.resize, "resize"))
2188+
def test_register_kernel(self, dispatcher):
2189+
class CustomDatapoint(datapoints.Datapoint):
2190+
pass
2191+
2192+
kernel_was_called = False
2193+
2194+
@F.register_kernel(dispatcher, CustomDatapoint)
2195+
def new_resize(dp, *args, **kwargs):
2196+
nonlocal kernel_was_called
2197+
kernel_was_called = True
2198+
return dp
2199+
2200+
t = transforms.Resize(size=(224, 224), antialias=True)
2201+
2202+
my_dp = CustomDatapoint(torch.rand(3, 10, 10))
2203+
out = t(my_dp)
2204+
assert out is my_dp
2205+
assert kernel_was_called
2206+
2207+
# Sanity check to make sure we didn't override the kernel of other types
2208+
t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
2209+
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
2210+
2211+
def test_bad_disaptcher_name(self):
2212+
class CustomDatapoint(datapoints.Datapoint):
2213+
pass
2214+
2215+
with pytest.raises(ValueError, match="Could not find dispatcher with name"):
2216+
F.register_kernel("bad_name", CustomDatapoint)

torchvision/transforms/v2/functional/_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,18 @@ def decorator(kernel):
3737
return decorator
3838

3939

40+
def _name_to_dispatcher(name):
41+
import torchvision.transforms.v2.functional # noqa
42+
43+
try:
44+
return getattr(torchvision.transforms.v2.functional, name)
45+
except AttributeError:
46+
raise ValueError(f"Could not find dispatcher with name '{name}'.") from None
47+
48+
4049
def register_kernel(dispatcher, datapoint_cls):
50+
if isinstance(dispatcher, str):
51+
dispatcher = _name_to_dispatcher(name=dispatcher)
4152
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
4253

4354

0 commit comments

Comments
 (0)