diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index a14d5eaf007..11a4c35ae18 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -2,12 +2,12 @@ import dataclasses from collections import defaultdict + from typing import Callable, Dict, List, Optional, Sequence, Type import pytest import torchvision.prototype.transforms.functional as F -from prototype_common_utils import BoundingBoxLoader -from prototype_transforms_kernel_infos import KERNEL_INFOS, KernelInfo, Skip +from prototype_transforms_kernel_infos import KERNEL_INFOS, TestMark from torchvision.prototype import features __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] @@ -24,35 +24,27 @@ def __post_init__(self): self.kernel_name = self.kernel_name or self.kernel.__name__ -def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"): - return Skip( - "test_scripted_smoke", - condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)), - reason=reason, - ) - - -def skip_integer_size_jit(name="size"): - return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.") - - @dataclasses.dataclass class DispatcherInfo: dispatcher: Callable kernels: Dict[Type, Callable] - kernel_infos: Dict[Type, KernelInfo] = dataclasses.field(default=None) pil_kernel_info: Optional[PILKernelInfo] = None method_name: str = dataclasses.field(default=None) - skips: Sequence[Skip] = dataclasses.field(default_factory=list) - _skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False) + test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list) + _test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False) def __post_init__(self): self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()} self.method_name = self.method_name or self.dispatcher.__name__ - skips_map = defaultdict(list) - for skip in self.skips: - skips_map[skip.test_name].append(skip) - self._skips_map = dict(skips_map) + test_marks_map = defaultdict(list) + for test_mark in self.test_marks: + test_marks_map[test_mark.test_id].append(test_mark) + self._test_marks_map = dict(test_marks_map) + + def get_marks(self, test_id, args_kwargs): + return [ + test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs) + ] def sample_inputs(self, *feature_types, filter_metadata=True): for feature_type in feature_types or self.kernels.keys(): @@ -70,17 +62,27 @@ def sample_inputs(self, *feature_types, filter_metadata=True): yield args_kwargs - def maybe_skip(self, *, test_name, args_kwargs, device): - skips = self._skips_map.get(test_name) - if not skips: - return - for skip in skips: - if skip.condition(args_kwargs, device): - pytest.skip(skip.reason) +def xfail_python_scalar_arg_jit(name, *, reason=None): + reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting" + return TestMark( + ("TestDispatchers", "test_scripted_smoke"), + pytest.mark.xfail(reason=reason), + condition=lambda args_kwargs: isinstance(args_kwargs.kwargs[name], (int, float)), + ) + +def xfail_integer_size_jit(name="size"): + return xfail_python_scalar_arg_jit(name, reason=f"Integer `{name}` is not supported when scripting.") -def fill_sequence_needs_broadcast(args_kwargs, device): + +skip_dispatch_feature = TestMark( + ("TestDispatchers", "test_dispatch_feature"), + pytest.mark.skip(reason="Dispatcher doesn't support arbitrary feature dispatch."), +) + + +def fill_sequence_needs_broadcast(args_kwargs): (image_loader, *_), kwargs = args_kwargs try: fill = kwargs["fill"] @@ -93,15 +95,12 @@ def fill_sequence_needs_broadcast(args_kwargs, device): return image_loader.num_channels > 1 -skip_dispatch_pil_if_fill_sequence_needs_broadcast = Skip( - "test_dispatch_pil", +xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark( + ("TestDispatchers", "test_dispatch_pil"), + pytest.mark.xfail( + reason="PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger." + ), condition=fill_sequence_needs_broadcast, - reason="PIL kernel doesn't support sequences of length 1 if the number of channels is larger.", -) - -skip_dispatch_feature = Skip( - "test_dispatch_feature", - reason="Dispatcher doesn't support arbitrary feature dispatch.", ) @@ -123,8 +122,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device): features.Mask: F.resize_mask, }, pil_kernel_info=PILKernelInfo(F.resize_image_pil), - skips=[ - skip_integer_size_jit(), + test_marks=[ + xfail_integer_size_jit(), ], ), DispatcherInfo( @@ -135,9 +134,9 @@ def fill_sequence_needs_broadcast(args_kwargs, device): features.Mask: F.affine_mask, }, pil_kernel_info=PILKernelInfo(F.affine_image_pil), - skips=[ - skip_dispatch_pil_if_fill_sequence_needs_broadcast, - skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"), + test_marks=[ + xfail_dispatch_pil_if_fill_sequence_needs_broadcast, + xfail_python_scalar_arg_jit("shear"), ], ), DispatcherInfo( @@ -166,16 +165,6 @@ def fill_sequence_needs_broadcast(args_kwargs, device): features.Mask: F.crop_mask, }, pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"), - skips=[ - Skip( - "test_dispatch_feature", - condition=lambda args_kwargs, device: isinstance(args_kwargs.args[0], BoundingBoxLoader), - reason=( - "F.crop expects 4 coordinates as input, but bounding box sample inputs only generate two " - "since that is sufficient for the kernel." - ), - ) - ], ), DispatcherInfo( F.resized_crop, @@ -193,10 +182,20 @@ def fill_sequence_needs_broadcast(args_kwargs, device): features.BoundingBox: F.pad_bounding_box, features.Mask: F.pad_mask, }, - skips=[ - skip_dispatch_pil_if_fill_sequence_needs_broadcast, - ], pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), + test_marks=[ + TestMark( + ("TestDispatchers", "test_dispatch_pil"), + pytest.mark.xfail( + reason=( + "PIL kernel doesn't support sequences of length 1 for argument `fill` and " + "`padding_mode='constant'`, if the number of color channels is larger." + ) + ), + condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs) + and args_kwargs.kwargs.get("padding_mode", "constant") == "constant", + ) + ], ), DispatcherInfo( F.perspective, @@ -205,10 +204,10 @@ def fill_sequence_needs_broadcast(args_kwargs, device): features.BoundingBox: F.perspective_bounding_box, features.Mask: F.perspective_mask, }, - skips=[ - skip_dispatch_pil_if_fill_sequence_needs_broadcast, - ], pil_kernel_info=PILKernelInfo(F.perspective_image_pil), + test_marks=[ + xfail_dispatch_pil_if_fill_sequence_needs_broadcast, + ], ), DispatcherInfo( F.elastic, @@ -227,8 +226,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device): features.Mask: F.center_crop_mask, }, pil_kernel_info=PILKernelInfo(F.center_crop_image_pil), - skips=[ - skip_integer_size_jit("output_size"), + test_marks=[ + xfail_integer_size_jit("output_size"), ], ), DispatcherInfo( @@ -237,9 +236,9 @@ def fill_sequence_needs_broadcast(args_kwargs, device): features.Image: F.gaussian_blur_image_tensor, }, pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil), - skips=[ - skip_python_scalar_arg_jit("kernel_size"), - skip_python_scalar_arg_jit("sigma"), + test_marks=[ + xfail_python_scalar_arg_jit("kernel_size"), + xfail_python_scalar_arg_jit("sigma"), ], ), DispatcherInfo( @@ -290,7 +289,7 @@ def fill_sequence_needs_broadcast(args_kwargs, device): features.Image: F.erase_image_tensor, }, pil_kernel_info=PILKernelInfo(F.erase_image_pil), - skips=[ + test_marks=[ skip_dispatch_feature, ], ), @@ -335,8 +334,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device): features.Image: F.five_crop_image_tensor, }, pil_kernel_info=PILKernelInfo(F.five_crop_image_pil), - skips=[ - skip_integer_size_jit(), + test_marks=[ + xfail_integer_size_jit(), skip_dispatch_feature, ], ), @@ -345,18 +344,18 @@ def fill_sequence_needs_broadcast(args_kwargs, device): kernels={ features.Image: F.ten_crop_image_tensor, }, - pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil), - skips=[ - skip_integer_size_jit(), + test_marks=[ + xfail_integer_size_jit(), skip_dispatch_feature, ], + pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil), ), DispatcherInfo( F.normalize, kernels={ features.Image: F.normalize_image_tensor, }, - skips=[ + test_marks=[ skip_dispatch_feature, ], ), diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index a047a2d576b..2e02989b458 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -3,13 +3,15 @@ import itertools import math from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple import numpy as np import pytest import torch.testing import torchvision.ops import torchvision.prototype.transforms.functional as F + +from _pytest.mark.structures import MarkDecorator from datasets_utils import combinations_grid from prototype_common_utils import ArgsKwargs, make_bounding_box_loaders, make_image_loaders, make_mask_loaders from torchvision.prototype import features @@ -18,11 +20,14 @@ __all__ = ["KernelInfo", "KERNEL_INFOS"] +TestID = Tuple[Optional[str], str] + + @dataclasses.dataclass -class Skip: - test_name: str - reason: str - condition: Callable[[ArgsKwargs, str], bool] = lambda args_kwargs, device: True +class TestMark: + test_id: TestID + mark: MarkDecorator + condition: Callable[[ArgsKwargs], bool] = lambda args_kwargs: True @dataclasses.dataclass @@ -44,26 +49,22 @@ class KernelInfo: reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None # Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - skips: Sequence[Skip] = dataclasses.field(default_factory=list) - _skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False) + test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list) + _test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False) def __post_init__(self): self.kernel_name = self.kernel_name or self.kernel.__name__ self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn - skips_map = defaultdict(list) - for skip in self.skips: - skips_map[skip.test_name].append(skip) - self._skips_map = dict(skips_map) + test_marks_map = defaultdict(list) + for test_mark in self.test_marks: + test_marks_map[test_mark.test_id].append(test_mark) + self._test_marks_map = dict(test_marks_map) - def maybe_skip(self, *, test_name, args_kwargs, device): - skips = self._skips_map.get(test_name) - if not skips: - return - - for skip in skips: - if skip.condition(args_kwargs, device): - pytest.skip(skip.reason) + def get_marks(self, test_id, args_kwargs): + return [ + test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs) + ] DEFAULT_IMAGE_CLOSENESS_KWARGS = dict( @@ -87,16 +88,27 @@ def wrapper(image_tensor, *other_args, **kwargs): return wrapper -def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"): - return Skip( - "test_scripted_vs_eager", - condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)), - reason=reason, +def mark_framework_limitation(test_id, reason): + # The purpose of this function is to have a single entry point for skip marks that are only there, because the test + # framework cannot handle the kernel in general or a specific parameter combination. + # As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is + # still justified. + # We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus, + # we are wasting CI resources for no reason for most of the time. + return TestMark(test_id, pytest.mark.skip(reason=reason)) + + +def xfail_python_scalar_arg_jit(name, *, reason=None): + reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting" + return TestMark( + ("TestKernels", "test_scripted_vs_eager"), + pytest.mark.xfail(reason=reason), + condition=lambda args_kwargs: isinstance(args_kwargs.kwargs[name], (int, float)), ) -def skip_integer_size_jit(name="size"): - return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.") +def xfail_integer_size_jit(name="size"): + return xfail_python_scalar_arg_jit(name, reason=f"Integer `{name}` is not supported when scripting.") KERNEL_INFOS = [] @@ -151,8 +163,7 @@ def sample_inputs_horizontal_flip_mask(): def _get_resize_sizes(image_size): height, width = image_size length = max(image_size) - # FIXME: enable me when the kernels are fixed - # yield length + yield length yield [length] yield (length,) new_height = int(height * 0.75) @@ -236,15 +247,15 @@ def reference_inputs_resize_mask(): reference_fn=reference_resize_image_tensor, reference_inputs_fn=reference_inputs_resize_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[ - skip_integer_size_jit(), + test_marks=[ + xfail_integer_size_jit(), ], ), KernelInfo( F.resize_bounding_box, sample_inputs_fn=sample_inputs_resize_bounding_box, - skips=[ - skip_integer_size_jit(), + test_marks=[ + xfail_integer_size_jit(), ], ), KernelInfo( @@ -253,8 +264,8 @@ def reference_inputs_resize_mask(): reference_fn=reference_resize_mask, reference_inputs_fn=reference_inputs_resize_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[ - skip_integer_size_jit(), + test_marks=[ + xfail_integer_size_jit(), ], ), ] @@ -436,16 +447,6 @@ def reference_inputs_resize_mask(): yield ArgsKwargs(mask_loader, **affine_kwargs) -# FIXME: @datumbox, remove this as soon as you have fixed the behavior in https://github.com/pytorch/vision/pull/6636 -def skip_scalar_shears(*test_names): - for test_name in test_names: - yield Skip( - test_name, - condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["shear"], (int, float)), - reason="The kernel is broken for a scalar `shear`", - ) - - KERNEL_INFOS.extend( [ KernelInfo( @@ -454,7 +455,7 @@ def skip_scalar_shears(*test_names): reference_fn=pil_reference_wrapper(F.affine_image_pil), reference_inputs_fn=reference_inputs_affine_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")], + test_marks=[xfail_python_scalar_arg_jit("shear")], ), KernelInfo( F.affine_bounding_box, @@ -462,13 +463,8 @@ def skip_scalar_shears(*test_names): reference_fn=reference_affine_bounding_box, reference_inputs_fn=reference_inputs_affine_bounding_box, closeness_kwargs=dict(atol=1, rtol=0), - skips=[ - skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"), - *skip_scalar_shears( - "test_batched_vs_single", - "test_no_inplace", - "test_dtype_and_device_consistency", - ), + test_marks=[ + xfail_python_scalar_arg_jit("shear"), ], ), KernelInfo( @@ -477,7 +473,7 @@ def skip_scalar_shears(*test_names): reference_fn=reference_affine_mask, reference_inputs_fn=reference_inputs_resize_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")], + test_marks=[xfail_python_scalar_arg_jit("shear")], ), ] ) @@ -1093,15 +1089,15 @@ def reference_inputs_center_crop_mask(): reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_inputs_fn=reference_inputs_center_crop_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[ - skip_integer_size_jit("output_size"), + test_marks=[ + xfail_integer_size_jit("output_size"), ], ), KernelInfo( F.center_crop_bounding_box, sample_inputs_fn=sample_inputs_center_crop_bounding_box, - skips=[ - skip_integer_size_jit("output_size"), + test_marks=[ + xfail_integer_size_jit("output_size"), ], ), KernelInfo( @@ -1110,8 +1106,8 @@ def reference_inputs_center_crop_mask(): reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_inputs_fn=reference_inputs_center_crop_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[ - skip_integer_size_jit("output_size"), + test_marks=[ + xfail_integer_size_jit("output_size"), ], ), ] @@ -1138,9 +1134,9 @@ def sample_inputs_gaussian_blur_image_tensor(): F.gaussian_blur_image_tensor, sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - skips=[ - skip_python_scalar_arg_jit("kernel_size"), - skip_python_scalar_arg_jit("sigma"), + test_marks=[ + xfail_python_scalar_arg_jit("kernel_size"), + xfail_python_scalar_arg_jit("sigma"), ], ) ) @@ -1551,9 +1547,9 @@ def reference_inputs_ten_crop_image_tensor(): sample_inputs_fn=sample_inputs_five_crop_image_tensor, reference_fn=pil_reference_wrapper(F.five_crop_image_pil), reference_inputs_fn=reference_inputs_five_crop_image_tensor, - skips=[ - skip_integer_size_jit(), - Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."), + test_marks=[ + xfail_integer_size_jit(), + mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."), ], closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, ), @@ -1562,9 +1558,9 @@ def reference_inputs_ten_crop_image_tensor(): sample_inputs_fn=sample_inputs_ten_crop_image_tensor, reference_fn=pil_reference_wrapper(F.ten_crop_image_pil), reference_inputs_fn=reference_inputs_ten_crop_image_tensor, - skips=[ - skip_integer_size_jit(), - Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."), + test_marks=[ + xfail_integer_size_jit(), + mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."), ], closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, ), diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 143a5cd2228..a6523045c2d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1,3 +1,4 @@ +import functools import math import os @@ -26,33 +27,60 @@ def script(fn): raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error -@pytest.fixture(autouse=True) -def maybe_skip(request): - # In case the test uses no parametrization or fixtures, the `callspec` attribute does not exist - try: - callspec = request.node.callspec - except AttributeError: - return +def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, name_fn=lambda info: str(info)): + if condition is None: - try: - info = callspec.params["info"] - args_kwargs = callspec.params["args_kwargs"] - except KeyError: - return + def condition(info): + return True - info.maybe_skip( - test_name=request.node.originalname, args_kwargs=args_kwargs, device=callspec.params.get("device", "cpu") - ) + def decorator(test_fn): + parts = test_fn.__qualname__.split(".") + if len(parts) == 1: + test_class_name = None + test_function_name = parts[0] + elif len(parts) == 2: + test_class_name, test_function_name = parts + else: + raise pytest.UsageError("Unable to parse the test class and test name from test function") + test_id = (test_class_name, test_function_name) + + argnames = ("info", "args_kwargs") + argvalues = [] + for info in infos: + if not condition(info): + continue + + args_kwargs = list(args_kwargs_fn(info)) + name = name_fn(info) + idx_field_len = len(str(len(args_kwargs))) + + for idx, args_kwargs_ in enumerate(args_kwargs): + argvalues.append( + pytest.param( + info, + args_kwargs_, + marks=info.get_marks(test_id, args_kwargs_), + id=f"{name}-{idx:0{idx_field_len}}", + ) + ) + + return pytest.mark.parametrize(argnames, argvalues)(test_fn) + + return decorator class TestKernels: - sample_inputs = pytest.mark.parametrize( - ("info", "args_kwargs"), - [ - pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}") - for info in KERNEL_INFOS - for idx, args_kwargs in enumerate(info.sample_inputs_fn()) - ], + make_kernel_args_kwargs_parametrization = functools.partial( + make_args_kwargs_parametrization, name_fn=lambda info: info.kernel_name + ) + sample_inputs = kernel_sample_inputs = make_kernel_args_kwargs_parametrization( + KERNEL_INFOS, + args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(), + ) + reference_inputs = make_kernel_args_kwargs_parametrization( + KERNEL_INFOS, + args_kwargs_fn=lambda info: info.reference_inputs_fn(), + condition=lambda info: info.reference_fn is not None, ) @sample_inputs @@ -156,15 +184,7 @@ def test_dtype_and_device_consistency(self, info, args_kwargs, device): assert output.dtype == input.dtype assert output.device == input.device - @pytest.mark.parametrize( - ("info", "args_kwargs"), - [ - pytest.param(info, args_kwargs, id=f"{info.kernel_name}-{idx}") - for info in KERNEL_INFOS - for idx, args_kwargs in enumerate(info.reference_inputs_fn()) - if info.reference_fn is not None - ], - ) + @reference_inputs def test_against_reference(self, info, args_kwargs): args, kwargs = args_kwargs.load("cpu") @@ -187,15 +207,16 @@ def make_spy(fn, *, module=None, name=None): class TestDispatchers: - @pytest.mark.parametrize( - ("info", "args_kwargs"), - [ - pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}") - for info in DISPATCHER_INFOS - for idx, args_kwargs in enumerate(info.sample_inputs(features.Image)) - if features.Image in info.kernels - ], + make_dispatcher_args_kwargs_parametrization = functools.partial( + make_args_kwargs_parametrization, name_fn=lambda info: info.dispatcher.__name__ ) + image_sample_inputs = kernel_sample_inputs = make_dispatcher_args_kwargs_parametrization( + DISPATCHER_INFOS, + args_kwargs_fn=lambda info: info.sample_inputs(features.Image), + condition=lambda info: features.Image in info.kernels, + ) + + @image_sample_inputs @pytest.mark.parametrize("device", cpu_and_gpu()) def test_scripted_smoke(self, info, args_kwargs, device): dispatcher = script(info.dispatcher) @@ -223,15 +244,7 @@ def test_scripted_smoke(self, info, args_kwargs, device): def test_scriptable(self, dispatcher): script(dispatcher) - @pytest.mark.parametrize( - ("info", "args_kwargs"), - [ - pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}") - for info in DISPATCHER_INFOS - for idx, args_kwargs in enumerate(info.sample_inputs(features.Image)) - if features.Image in info.kernels - ], - ) + @image_sample_inputs def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on): (image_feature, *other_args), kwargs = args_kwargs.load() image_simple_tensor = torch.Tensor(image_feature) @@ -243,14 +256,10 @@ def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on): spy.assert_called_once() - @pytest.mark.parametrize( - ("info", "args_kwargs"), - [ - pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}") - for info in DISPATCHER_INFOS - for idx, args_kwargs in enumerate(info.sample_inputs(features.Image)) - if features.Image in info.kernels and info.pil_kernel_info is not None - ], + @make_dispatcher_args_kwargs_parametrization( + DISPATCHER_INFOS, + args_kwargs_fn=lambda info: info.sample_inputs(features.Image), + condition=lambda info: info.pil_kernel_info is not None, ) def test_dispatch_pil(self, info, args_kwargs, spy_on): (image_feature, *other_args), kwargs = args_kwargs.load() @@ -267,13 +276,9 @@ def test_dispatch_pil(self, info, args_kwargs, spy_on): spy.assert_called_once() - @pytest.mark.parametrize( - ("info", "args_kwargs"), - [ - pytest.param(info, args_kwargs, id=f"{info.dispatcher.__name__}-{idx}") - for info in DISPATCHER_INFOS - for idx, args_kwargs in enumerate(info.sample_inputs()) - ], + @make_dispatcher_args_kwargs_parametrization( + DISPATCHER_INFOS, + args_kwargs_fn=lambda info: info.sample_inputs(), ) def test_dispatch_feature(self, info, args_kwargs, spy_on): (feature, *other_args), kwargs = args_kwargs.load()