Skip to content

Commit d2d448c

Browse files
authored
add tests for the output types of prototype functional dispatchers (#7118)
1 parent 01d138d commit d2d448c

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

test/prototype_transforms_dispatcher_infos.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,15 @@ def xfail_jit_list_of_ints(name, *, reason=None):
112112
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
113113
)
114114

115+
multi_crop_skips = [
116+
TestMark(
117+
("TestDispatchers", test_name),
118+
pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
119+
)
120+
for test_name in ["test_simple_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
121+
]
122+
multi_crop_skips.append(skip_dispatch_datapoint)
123+
115124

116125
def fill_sequence_needs_broadcast(args_kwargs):
117126
(image_loader, *_), kwargs = args_kwargs
@@ -404,7 +413,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
404413
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
405414
test_marks=[
406415
xfail_jit_python_scalar_arg("size"),
407-
skip_dispatch_datapoint,
416+
*multi_crop_skips,
408417
],
409418
),
410419
DispatcherInfo(
@@ -415,7 +424,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
415424
},
416425
test_marks=[
417426
xfail_jit_python_scalar_arg("size"),
418-
skip_dispatch_datapoint,
427+
*multi_crop_skips,
419428
],
420429
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
421430
),

test/test_prototype_transforms_functional.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,16 @@ def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
362362

363363
spy.assert_called_once()
364364

365+
@image_sample_inputs
366+
def test_simple_tensor_output_type(self, info, args_kwargs):
367+
(image_datapoint, *other_args), kwargs = args_kwargs.load()
368+
image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)
369+
370+
output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)
371+
372+
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
373+
assert type(output) is torch.Tensor
374+
365375
@make_info_args_kwargs_parametrization(
366376
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
367377
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
@@ -381,6 +391,22 @@ def test_dispatch_pil(self, info, args_kwargs, spy_on):
381391

382392
spy.assert_called_once()
383393

394+
@make_info_args_kwargs_parametrization(
395+
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
396+
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
397+
)
398+
def test_pil_output_type(self, info, args_kwargs):
399+
(image_datapoint, *other_args), kwargs = args_kwargs.load()
400+
401+
if image_datapoint.ndim > 3:
402+
pytest.skip("Input is batched")
403+
404+
image_pil = F.to_image_pil(image_datapoint)
405+
406+
output = info.dispatcher(image_pil, *other_args, **kwargs)
407+
408+
assert isinstance(output, PIL.Image.Image)
409+
384410
@make_info_args_kwargs_parametrization(
385411
DISPATCHER_INFOS,
386412
args_kwargs_fn=lambda info: info.sample_inputs(),
@@ -397,6 +423,17 @@ def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
397423

398424
spy.assert_called_once()
399425

426+
@make_info_args_kwargs_parametrization(
427+
DISPATCHER_INFOS,
428+
args_kwargs_fn=lambda info: info.sample_inputs(),
429+
)
430+
def test_datapoint_output_type(self, info, args_kwargs):
431+
(datapoint, *other_args), kwargs = args_kwargs.load()
432+
433+
output = info.dispatcher(datapoint, *other_args, **kwargs)
434+
435+
assert isinstance(output, type(datapoint))
436+
400437
@pytest.mark.parametrize(
401438
("dispatcher_info", "datapoint_type", "kernel_info"),
402439
[

0 commit comments

Comments
 (0)