Skip to content

Commit 3f1d9f6

Browse files
authored
Refactor KernelInfo and DispatcherInfo (#6710)
* make args and kwargs in ArgsKwargs more accessible * refactor KernelInfo and DispatcherInfo * remove ArgsKwargs __getitem__ shortcut again
1 parent 17969eb commit 3f1d9f6

File tree

4 files changed

+150
-122
lines changed

4 files changed

+150
-122
lines changed

test/prototype_common_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections.abc
44
import dataclasses
55
import functools
6+
from collections import defaultdict
67
from typing import Callable, Optional, Sequence, Tuple, Union
78

89
import PIL.Image
@@ -47,6 +48,9 @@
4748
"make_masks",
4849
"make_video",
4950
"make_videos",
51+
"TestMark",
52+
"mark_framework_limitation",
53+
"InfoBase",
5054
]
5155

5256

@@ -588,3 +592,52 @@ def make_video_loaders(
588592

589593

590594
make_videos = from_loaders(make_video_loaders)
595+
596+
597+
class TestMark:
598+
def __init__(
599+
self,
600+
# Tuple of test class name and test function name that identifies the test the mark is applied to. If there is
601+
# no test class, i.e. a standalone test function, use `None`.
602+
test_id,
603+
# `pytest.mark.*` to apply, e.g. `pytest.mark.skip` or `pytest.mark.xfail`
604+
mark,
605+
*,
606+
# Callable, that will be passed an `ArgsKwargs` and should return a boolean to indicate if the mark will be
607+
# applied. If omitted, defaults to always apply.
608+
condition=None,
609+
):
610+
self.test_id = test_id
611+
self.mark = mark
612+
self.condition = condition or (lambda args_kwargs: True)
613+
614+
615+
def mark_framework_limitation(test_id, reason):
616+
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
617+
# framework cannot handle the kernel in general or a specific parameter combination.
618+
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
619+
# still justified.
620+
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
621+
# we are wasting CI resources for no reason for most of the time
622+
return TestMark(test_id, pytest.mark.skip(reason=reason))
623+
624+
625+
class InfoBase:
626+
def __init__(self, *, id, test_marks=None, closeness_kwargs=None):
627+
# Identifier if the info that shows up the parametrization.
628+
self.id = id
629+
# Test markers that will be (conditionally) applied to an `ArgsKwargs` parametrization.
630+
# See the `TestMark` class for details
631+
self.test_marks = test_marks or []
632+
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
633+
self.closeness_kwargs = closeness_kwargs or dict()
634+
635+
test_marks_map = defaultdict(list)
636+
for test_mark in self.test_marks:
637+
test_marks_map[test_mark.test_id].append(test_mark)
638+
self._test_marks_map = dict(test_marks_map)
639+
640+
def get_marks(self, test_id, args_kwargs):
641+
return [
642+
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
643+
]

test/prototype_transforms_dispatcher_infos.py

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,67 @@
11
import collections.abc
2-
import dataclasses
3-
4-
from collections import defaultdict
5-
6-
from typing import Callable, Dict, List, Optional, Sequence, Type
72

83
import pytest
94
import torchvision.prototype.transforms.functional as F
10-
from prototype_transforms_kernel_infos import KERNEL_INFOS, TestMark
5+
from prototype_common_utils import InfoBase, TestMark
6+
from prototype_transforms_kernel_infos import KERNEL_INFOS
117
from torchvision.prototype import features
128

139
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
1410

15-
KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}
16-
17-
18-
@dataclasses.dataclass
19-
class PILKernelInfo:
20-
kernel: Callable
21-
kernel_name: str = dataclasses.field(default=None)
22-
23-
def __post_init__(self):
24-
self.kernel_name = self.kernel_name or self.kernel.__name__
2511

26-
27-
@dataclasses.dataclass
28-
class DispatcherInfo:
29-
dispatcher: Callable
30-
kernels: Dict[Type, Callable]
31-
pil_kernel_info: Optional[PILKernelInfo] = None
32-
method_name: str = dataclasses.field(default=None)
33-
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
34-
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)
35-
36-
def __post_init__(self):
37-
self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()}
38-
self.method_name = self.method_name or self.dispatcher.__name__
39-
test_marks_map = defaultdict(list)
40-
for test_mark in self.test_marks:
41-
test_marks_map[test_mark.test_id].append(test_mark)
42-
self._test_marks_map = dict(test_marks_map)
43-
44-
def get_marks(self, test_id, args_kwargs):
45-
return [
46-
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
47-
]
12+
class PILKernelInfo(InfoBase):
13+
def __init__(
14+
self,
15+
kernel,
16+
*,
17+
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
18+
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
19+
kernel_name=None,
20+
):
21+
super().__init__(id=kernel_name or kernel.__name__)
22+
self.kernel = kernel
23+
24+
25+
class DispatcherInfo(InfoBase):
26+
_KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}
27+
28+
def __init__(
29+
self,
30+
dispatcher,
31+
*,
32+
# Dictionary of types that map to the kernel the dispatcher dispatches to.
33+
kernels,
34+
# If omitted, no PIL dispatch test will be performed.
35+
pil_kernel_info=None,
36+
# See InfoBase
37+
test_marks=None,
38+
# See InfoBase
39+
closeness_kwargs=None,
40+
):
41+
super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
42+
self.dispatcher = dispatcher
43+
self.kernels = kernels
44+
self.pil_kernel_info = pil_kernel_info
45+
46+
kernel_infos = {}
47+
for feature_type, kernel in self.kernels.items():
48+
kernel_info = self._KERNEL_INFO_MAP.get(kernel)
49+
if not kernel_info:
50+
raise pytest.UsageError(
51+
f"Can't register {kernel.__name__} for type {feature_type} since there is no `KernelInfo` for it. "
52+
f"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`."
53+
)
54+
kernel_infos[feature_type] = kernel_info
55+
self.kernel_infos = kernel_infos
4856

4957
def sample_inputs(self, *feature_types, filter_metadata=True):
50-
for feature_type in feature_types or self.kernels.keys():
51-
if feature_type not in self.kernels:
52-
raise pytest.UsageError(f"There is no kernel registered for type {feature_type.__name__}")
58+
for feature_type in feature_types or self.kernel_infos.keys():
59+
kernel_info = self.kernel_infos.get(feature_type)
60+
if not kernel_info:
61+
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
62+
63+
sample_inputs = kernel_info.sample_inputs_fn()
5364

54-
sample_inputs = self.kernel_infos[feature_type].sample_inputs_fn()
5565
if not filter_metadata:
5666
yield from sample_inputs
5767
else:

test/prototype_transforms_kernel_infos.py

Lines changed: 32 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
1-
import dataclasses
21
import functools
32
import itertools
43
import math
5-
from collections import defaultdict
6-
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
74

85
import numpy as np
96
import pytest
107
import torch.testing
118
import torchvision.ops
129
import torchvision.prototype.transforms.functional as F
13-
14-
from _pytest.mark.structures import MarkDecorator
1510
from common_utils import cycle_over
1611
from datasets_utils import combinations_grid
1712
from prototype_common_utils import (
1813
ArgsKwargs,
14+
InfoBase,
1915
make_bounding_box_loaders,
2016
make_image_loader,
2117
make_image_loaders,
2218
make_mask_loaders,
2319
make_video_loaders,
20+
mark_framework_limitation,
21+
TestMark,
2422
VALID_EXTRA_DIMS,
2523
)
2624
from torchvision.prototype import features
@@ -29,51 +27,35 @@
2927
__all__ = ["KernelInfo", "KERNEL_INFOS"]
3028

3129

32-
TestID = Tuple[Optional[str], str]
33-
34-
35-
@dataclasses.dataclass
36-
class TestMark:
37-
test_id: TestID
38-
mark: MarkDecorator
39-
condition: Callable[[ArgsKwargs], bool] = lambda args_kwargs: True
40-
41-
42-
@dataclasses.dataclass
43-
class KernelInfo:
44-
kernel: Callable
45-
# Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but should
46-
# not include extensive parameter combinations to keep to overall test count moderate.
47-
sample_inputs_fn: Callable[[], Iterable[ArgsKwargs]]
48-
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
49-
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
50-
kernel_name: str = dataclasses.field(default=None)
51-
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take
52-
# tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen
53-
# inside the function. It should return a tensor or to be more precise an object that can be compared to a
54-
# tensor by `assert_close`. If omitted, no reference test will be performed.
55-
reference_fn: Optional[Callable] = None
56-
# These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
57-
# values to be tested. If not specified, `sample_inputs_fn` will be used.
58-
reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None
59-
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
60-
closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
61-
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
62-
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)
63-
64-
def __post_init__(self):
65-
self.kernel_name = self.kernel_name or self.kernel.__name__
66-
self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn
67-
68-
test_marks_map = defaultdict(list)
69-
for test_mark in self.test_marks:
70-
test_marks_map[test_mark.test_id].append(test_mark)
71-
self._test_marks_map = dict(test_marks_map)
72-
73-
def get_marks(self, test_id, args_kwargs):
74-
return [
75-
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
76-
]
30+
class KernelInfo(InfoBase):
31+
def __init__(
32+
self,
33+
kernel,
34+
*,
35+
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
36+
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
37+
kernel_name=None,
38+
# Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
39+
# should not include extensive parameter combinations to keep to overall test count moderate.
40+
sample_inputs_fn,
41+
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
42+
# take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
43+
# happen inside the function. It should return a tensor or to be more precise an object that can be compared to
44+
# a tensor by `assert_close`. If omitted, no reference test will be performed.
45+
reference_fn=None,
46+
# These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
47+
# values to be tested. If not specified, `sample_inputs_fn` will be used.
48+
reference_inputs_fn=None,
49+
# See InfoBase
50+
test_marks=None,
51+
# See InfoBase
52+
closeness_kwargs=None,
53+
):
54+
super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
55+
self.kernel = kernel
56+
self.sample_inputs_fn = sample_inputs_fn
57+
self.reference_fn = reference_fn
58+
self.reference_inputs_fn = reference_inputs_fn
7759

7860

7961
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
@@ -97,16 +79,6 @@ def wrapper(image_tensor, *other_args, **kwargs):
9779
return wrapper
9880

9981

100-
def mark_framework_limitation(test_id, reason):
101-
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
102-
# framework cannot handle the kernel in general or a specific parameter combination.
103-
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
104-
# still justified.
105-
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
106-
# we are wasting CI resources for no reason for most of the time.
107-
return TestMark(test_id, pytest.mark.skip(reason=reason))
108-
109-
11082
def xfail_jit_python_scalar_arg(name, *, reason=None):
11183
reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
11284
return TestMark(

0 commit comments

Comments
 (0)