diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index c10cec94c31..1d5766b1fcf 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -3,6 +3,7 @@ import collections.abc import dataclasses import functools +from collections import defaultdict from typing import Callable, Optional, Sequence, Tuple, Union import PIL.Image @@ -47,6 +48,9 @@ "make_masks", "make_video", "make_videos", + "TestMark", + "mark_framework_limitation", + "InfoBase", ] @@ -588,3 +592,52 @@ def make_video_loaders( make_videos = from_loaders(make_video_loaders) + + +class TestMark: + def __init__( + self, + # Tuple of test class name and test function name that identifies the test the mark is applied to. If there is + # no test class, i.e. a standalone test function, use `None`. + test_id, + # `pytest.mark.*` to apply, e.g. `pytest.mark.skip` or `pytest.mark.xfail` + mark, + *, + # Callable, that will be passed an `ArgsKwargs` and should return a boolean to indicate if the mark will be + # applied. If omitted, defaults to always apply. + condition=None, + ): + self.test_id = test_id + self.mark = mark + self.condition = condition or (lambda args_kwargs: True) + + +def mark_framework_limitation(test_id, reason): + # The purpose of this function is to have a single entry point for skip marks that are only there, because the test + # framework cannot handle the kernel in general or a specific parameter combination. + # As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is + # still justified. + # We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus, + # we are wasting CI resources for no reason for most of the time + return TestMark(test_id, pytest.mark.skip(reason=reason)) + + +class InfoBase: + def __init__(self, *, id, test_marks=None, closeness_kwargs=None): + # Identifier if the info that shows up the parametrization. + self.id = id + # Test markers that will be (conditionally) applied to an `ArgsKwargs` parametrization. + # See the `TestMark` class for details + self.test_marks = test_marks or [] + # Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. + self.closeness_kwargs = closeness_kwargs or dict() + + test_marks_map = defaultdict(list) + for test_mark in self.test_marks: + test_marks_map[test_mark.test_id].append(test_mark) + self._test_marks_map = dict(test_marks_map) + + def get_marks(self, test_id, args_kwargs): + return [ + test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs) + ] diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index de933c7e3fa..82173907c6f 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -1,57 +1,67 @@ import collections.abc -import dataclasses - -from collections import defaultdict - -from typing import Callable, Dict, List, Optional, Sequence, Type import pytest import torchvision.prototype.transforms.functional as F -from prototype_transforms_kernel_infos import KERNEL_INFOS, TestMark +from prototype_common_utils import InfoBase, TestMark +from prototype_transforms_kernel_infos import KERNEL_INFOS from torchvision.prototype import features __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] -KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS} - - -@dataclasses.dataclass -class PILKernelInfo: - kernel: Callable - kernel_name: str = dataclasses.field(default=None) - - def __post_init__(self): - self.kernel_name = self.kernel_name or self.kernel.__name__ - -@dataclasses.dataclass -class DispatcherInfo: - dispatcher: Callable - kernels: Dict[Type, Callable] - pil_kernel_info: Optional[PILKernelInfo] = None - method_name: str = dataclasses.field(default=None) - test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list) - _test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False) - - def __post_init__(self): - self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()} - self.method_name = self.method_name or self.dispatcher.__name__ - test_marks_map = defaultdict(list) - for test_mark in self.test_marks: - test_marks_map[test_mark.test_id].append(test_mark) - self._test_marks_map = dict(test_marks_map) - - def get_marks(self, test_id, args_kwargs): - return [ - test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs) - ] +class PILKernelInfo(InfoBase): + def __init__( + self, + kernel, + *, + # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name + # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then + kernel_name=None, + ): + super().__init__(id=kernel_name or kernel.__name__) + self.kernel = kernel + + +class DispatcherInfo(InfoBase): + _KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS} + + def __init__( + self, + dispatcher, + *, + # Dictionary of types that map to the kernel the dispatcher dispatches to. + kernels, + # If omitted, no PIL dispatch test will be performed. + pil_kernel_info=None, + # See InfoBase + test_marks=None, + # See InfoBase + closeness_kwargs=None, + ): + super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs) + self.dispatcher = dispatcher + self.kernels = kernels + self.pil_kernel_info = pil_kernel_info + + kernel_infos = {} + for feature_type, kernel in self.kernels.items(): + kernel_info = self._KERNEL_INFO_MAP.get(kernel) + if not kernel_info: + raise pytest.UsageError( + f"Can't register {kernel.__name__} for type {feature_type} since there is no `KernelInfo` for it. " + f"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`." + ) + kernel_infos[feature_type] = kernel_info + self.kernel_infos = kernel_infos def sample_inputs(self, *feature_types, filter_metadata=True): - for feature_type in feature_types or self.kernels.keys(): - if feature_type not in self.kernels: - raise pytest.UsageError(f"There is no kernel registered for type {feature_type.__name__}") + for feature_type in feature_types or self.kernel_infos.keys(): + kernel_info = self.kernel_infos.get(feature_type) + if not kernel_info: + raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}") + + sample_inputs = kernel_info.sample_inputs_fn() - sample_inputs = self.kernel_infos[feature_type].sample_inputs_fn() if not filter_metadata: yield from sample_inputs else: diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 9ebfc7a00d2..34f1f875a05 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1,26 +1,24 @@ -import dataclasses import functools import itertools import math -from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple import numpy as np import pytest import torch.testing import torchvision.ops import torchvision.prototype.transforms.functional as F - -from _pytest.mark.structures import MarkDecorator from common_utils import cycle_over from datasets_utils import combinations_grid from prototype_common_utils import ( ArgsKwargs, + InfoBase, make_bounding_box_loaders, make_image_loader, make_image_loaders, make_mask_loaders, make_video_loaders, + mark_framework_limitation, + TestMark, VALID_EXTRA_DIMS, ) from torchvision.prototype import features @@ -29,51 +27,35 @@ __all__ = ["KernelInfo", "KERNEL_INFOS"] -TestID = Tuple[Optional[str], str] - - -@dataclasses.dataclass -class TestMark: - test_id: TestID - mark: MarkDecorator - condition: Callable[[ArgsKwargs], bool] = lambda args_kwargs: True - - -@dataclasses.dataclass -class KernelInfo: - kernel: Callable - # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but should - # not include extensive parameter combinations to keep to overall test count moderate. - sample_inputs_fn: Callable[[], Iterable[ArgsKwargs]] - # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name - # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then - kernel_name: str = dataclasses.field(default=None) - # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take - # tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen - # inside the function. It should return a tensor or to be more precise an object that can be compared to a - # tensor by `assert_close`. If omitted, no reference test will be performed. - reference_fn: Optional[Callable] = None - # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter - # values to be tested. If not specified, `sample_inputs_fn` will be used. - reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None - # Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. - closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list) - _test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False) - - def __post_init__(self): - self.kernel_name = self.kernel_name or self.kernel.__name__ - self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn - - test_marks_map = defaultdict(list) - for test_mark in self.test_marks: - test_marks_map[test_mark.test_id].append(test_mark) - self._test_marks_map = dict(test_marks_map) - - def get_marks(self, test_id, args_kwargs): - return [ - test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs) - ] +class KernelInfo(InfoBase): + def __init__( + self, + kernel, + *, + # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name + # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then + kernel_name=None, + # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but + # should not include extensive parameter combinations to keep to overall test count moderate. + sample_inputs_fn, + # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also + # take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should + # happen inside the function. It should return a tensor or to be more precise an object that can be compared to + # a tensor by `assert_close`. If omitted, no reference test will be performed. + reference_fn=None, + # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter + # values to be tested. If not specified, `sample_inputs_fn` will be used. + reference_inputs_fn=None, + # See InfoBase + test_marks=None, + # See InfoBase + closeness_kwargs=None, + ): + super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs) + self.kernel = kernel + self.sample_inputs_fn = sample_inputs_fn + self.reference_fn = reference_fn + self.reference_inputs_fn = reference_inputs_fn DEFAULT_IMAGE_CLOSENESS_KWARGS = dict( @@ -97,16 +79,6 @@ def wrapper(image_tensor, *other_args, **kwargs): return wrapper -def mark_framework_limitation(test_id, reason): - # The purpose of this function is to have a single entry point for skip marks that are only there, because the test - # framework cannot handle the kernel in general or a specific parameter combination. - # As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is - # still justified. - # We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus, - # we are wasting CI resources for no reason for most of the time. - return TestMark(test_id, pytest.mark.skip(reason=reason)) - - def xfail_jit_python_scalar_arg(name, *, reason=None): reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting" return TestMark( diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 5adea4d2663..8329de69782 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1,4 +1,3 @@ -import functools import math import os @@ -27,7 +26,7 @@ def script(fn): raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error -def make_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None, name_fn=lambda info: str(info)): +def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None): if condition is None: def condition(info): @@ -41,7 +40,7 @@ def decorator(test_fn): elif len(parts) == 2: test_class_name, test_function_name = parts else: - raise pytest.UsageError("Unable to parse the test class and test name from test function") + raise pytest.UsageError("Unable to parse the test class name and test function name from test function") test_id = (test_class_name, test_function_name) argnames = ("info", "args_kwargs") @@ -51,7 +50,6 @@ def decorator(test_fn): continue args_kwargs = list(args_kwargs_fn(info)) - name = name_fn(info) idx_field_len = len(str(len(args_kwargs))) for idx, args_kwargs_ in enumerate(args_kwargs): @@ -60,7 +58,7 @@ def decorator(test_fn): info, args_kwargs_, marks=info.get_marks(test_id, args_kwargs_), - id=f"{name}-{idx:0{idx_field_len}}", + id=f"{info.id}-{idx:0{idx_field_len}}", ) ) @@ -70,14 +68,11 @@ def decorator(test_fn): class TestKernels: - make_kernel_args_kwargs_parametrization = functools.partial( - make_args_kwargs_parametrization, name_fn=lambda info: info.kernel_name - ) - sample_inputs = kernel_sample_inputs = make_kernel_args_kwargs_parametrization( + sample_inputs = make_info_args_kwargs_parametrization( KERNEL_INFOS, args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(), ) - reference_inputs = make_kernel_args_kwargs_parametrization( + reference_inputs = make_info_args_kwargs_parametrization( KERNEL_INFOS, args_kwargs_fn=lambda info: info.reference_inputs_fn(), condition=lambda info: info.reference_fn is not None, @@ -208,10 +203,7 @@ def make_spy(fn, *, module=None, name=None): class TestDispatchers: - make_dispatcher_args_kwargs_parametrization = functools.partial( - make_args_kwargs_parametrization, name_fn=lambda info: info.dispatcher.__name__ - ) - image_sample_inputs = kernel_sample_inputs = make_dispatcher_args_kwargs_parametrization( + image_sample_inputs = make_info_args_kwargs_parametrization( DISPATCHER_INFOS, args_kwargs_fn=lambda info: info.sample_inputs(features.Image), condition=lambda info: features.Image in info.kernels, @@ -251,13 +243,13 @@ def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on): image_simple_tensor = torch.Tensor(image_feature) kernel_info = info.kernel_infos[features.Image] - spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.kernel_name) + spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id) info.dispatcher(image_simple_tensor, *other_args, **kwargs) spy.assert_called_once() - @make_dispatcher_args_kwargs_parametrization( + @make_info_args_kwargs_parametrization( DISPATCHER_INFOS, args_kwargs_fn=lambda info: info.sample_inputs(features.Image), condition=lambda info: info.pil_kernel_info is not None, @@ -271,22 +263,23 @@ def test_dispatch_pil(self, info, args_kwargs, spy_on): image_pil = F.to_image_pil(image_feature) pil_kernel_info = info.pil_kernel_info - spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.kernel_name) + spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id) info.dispatcher(image_pil, *other_args, **kwargs) spy.assert_called_once() - @make_dispatcher_args_kwargs_parametrization( + @make_info_args_kwargs_parametrization( DISPATCHER_INFOS, args_kwargs_fn=lambda info: info.sample_inputs(), ) def test_dispatch_feature(self, info, args_kwargs, spy_on): (feature, *other_args), kwargs = args_kwargs.load() - method = getattr(feature, info.method_name) + method_name = info.id + method = getattr(feature, method_name) feature_type = type(feature) - spy = spy_on(method, module=feature_type.__module__, name=f"{feature_type.__name__}.{info.method_name}") + spy = spy_on(method, module=feature_type.__module__, name=f"{feature_type.__name__}.{method_name}") info.dispatcher(feature, *other_args, **kwargs)