Skip to content

Commit 29b0831

Browse files
authored
diversify parameter types for a couple of prototype kernels (#6635)
* add more size types for prototype resize sample inputs * add skip for dispatcher * add more sizes to resize kernel info * add more skips * add more diversity to gaussian_blur parameters * diversify affine parameters and fix bounding box kernel * fix center_crop dispatcher info * revert kernel fixes * add skips for scalar shears in affine_bounding_box
1 parent 2d92728 commit 29b0831

File tree

2 files changed

+151
-36
lines changed

2 files changed

+151
-36
lines changed

test/prototype_transforms_dispatcher_infos.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
2-
from typing import Callable, Dict, Sequence, Type
2+
from collections import defaultdict
3+
from typing import Callable, Dict, List, Sequence, Type
34

45
import pytest
56
import torchvision.prototype.transforms.functional as F
@@ -11,15 +12,30 @@
1112
KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS}
1213

1314

15+
def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"):
16+
return Skip(
17+
"test_scripted_smoke",
18+
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)),
19+
reason=reason,
20+
)
21+
22+
23+
def skip_integer_size_jit(name="size"):
24+
return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.")
25+
26+
1427
@dataclasses.dataclass
1528
class DispatcherInfo:
1629
dispatcher: Callable
1730
kernels: Dict[Type, Callable]
1831
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
19-
_skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False)
32+
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False)
2033

2134
def __post_init__(self):
22-
self._skips_map = {skip.test_name: skip for skip in self.skips}
35+
skips_map = defaultdict(list)
36+
for skip in self.skips:
37+
skips_map[skip.test_name].append(skip)
38+
self._skips_map = dict(skips_map)
2339

2440
def sample_inputs(self, *types):
2541
for type in types or self.kernels.keys():
@@ -29,9 +45,13 @@ def sample_inputs(self, *types):
2945
yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]()
3046

3147
def maybe_skip(self, *, test_name, args_kwargs, device):
32-
skip = self._skips_map.get(test_name)
33-
if skip and skip.condition(args_kwargs, device):
34-
pytest.skip(skip.reason)
48+
skips = self._skips_map.get(test_name)
49+
if not skips:
50+
return
51+
52+
for skip in skips:
53+
if skip.condition(args_kwargs, device):
54+
pytest.skip(skip.reason)
3555

3656

3757
DISPATCHER_INFOS = [
@@ -50,6 +70,9 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
5070
features.BoundingBox: F.resize_bounding_box,
5171
features.Mask: F.resize_mask,
5272
},
73+
skips=[
74+
skip_integer_size_jit(),
75+
],
5376
),
5477
DispatcherInfo(
5578
F.affine,
@@ -58,6 +81,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
5881
features.BoundingBox: F.affine_bounding_box,
5982
features.Mask: F.affine_mask,
6083
},
84+
skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")],
6185
),
6286
DispatcherInfo(
6387
F.vertical_flip,
@@ -122,12 +146,19 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
122146
features.BoundingBox: F.center_crop_bounding_box,
123147
features.Mask: F.center_crop_mask,
124148
},
149+
skips=[
150+
skip_integer_size_jit("output_size"),
151+
],
125152
),
126153
DispatcherInfo(
127154
F.gaussian_blur,
128155
kernels={
129156
features.Image: F.gaussian_blur_image_tensor,
130157
},
158+
skips=[
159+
skip_python_scalar_arg_jit("kernel_size"),
160+
skip_python_scalar_arg_jit("sigma"),
161+
],
131162
),
132163
DispatcherInfo(
133164
F.equalize,
@@ -207,11 +238,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
207238
features.Image: F.five_crop_image_tensor,
208239
},
209240
skips=[
210-
Skip(
211-
"test_scripted_smoke",
212-
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
213-
reason="Integer size is not supported when scripting five_crop_image_tensor.",
214-
),
241+
skip_integer_size_jit(),
215242
],
216243
),
217244
DispatcherInfo(
@@ -220,11 +247,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
220247
features.Image: F.ten_crop_image_tensor,
221248
},
222249
skips=[
223-
Skip(
224-
"test_scripted_smoke",
225-
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
226-
reason="Integer size is not supported when scripting ten_crop_image_tensor.",
227-
),
250+
skip_integer_size_jit(),
228251
],
229252
),
230253
DispatcherInfo(

0 commit comments

Comments
 (0)