1
1
import dataclasses
2
- from typing import Callable , Dict , Sequence , Type
2
+ from collections import defaultdict
3
+ from typing import Callable , Dict , List , Sequence , Type
3
4
4
5
import pytest
5
6
import torchvision .prototype .transforms .functional as F
11
12
KERNEL_SAMPLE_INPUTS_FN_MAP = {info .kernel : info .sample_inputs_fn for info in KERNEL_INFOS }
12
13
13
14
15
+ def skip_python_scalar_arg_jit (name , * , reason = "Python scalar int or float is not supported when scripting" ):
16
+ return Skip (
17
+ "test_scripted_smoke" ,
18
+ condition = lambda args_kwargs , device : isinstance (args_kwargs .kwargs [name ], (int , float )),
19
+ reason = reason ,
20
+ )
21
+
22
+
23
+ def skip_integer_size_jit (name = "size" ):
24
+ return skip_python_scalar_arg_jit (name , reason = "Integer size is not supported when scripting." )
25
+
26
+
14
27
@dataclasses .dataclass
15
28
class DispatcherInfo :
16
29
dispatcher : Callable
17
30
kernels : Dict [Type , Callable ]
18
31
skips : Sequence [Skip ] = dataclasses .field (default_factory = list )
19
- _skips_map : Dict [str , Skip ] = dataclasses .field (default = None , init = False )
32
+ _skips_map : Dict [str , List [ Skip ] ] = dataclasses .field (default = None , init = False )
20
33
21
34
def __post_init__ (self ):
22
- self ._skips_map = {skip .test_name : skip for skip in self .skips }
35
+ skips_map = defaultdict (list )
36
+ for skip in self .skips :
37
+ skips_map [skip .test_name ].append (skip )
38
+ self ._skips_map = dict (skips_map )
23
39
24
40
def sample_inputs (self , * types ):
25
41
for type in types or self .kernels .keys ():
@@ -29,9 +45,13 @@ def sample_inputs(self, *types):
29
45
yield from KERNEL_SAMPLE_INPUTS_FN_MAP [self .kernels [type ]]()
30
46
31
47
def maybe_skip (self , * , test_name , args_kwargs , device ):
32
- skip = self ._skips_map .get (test_name )
33
- if skip and skip .condition (args_kwargs , device ):
34
- pytest .skip (skip .reason )
48
+ skips = self ._skips_map .get (test_name )
49
+ if not skips :
50
+ return
51
+
52
+ for skip in skips :
53
+ if skip .condition (args_kwargs , device ):
54
+ pytest .skip (skip .reason )
35
55
36
56
37
57
DISPATCHER_INFOS = [
@@ -50,6 +70,9 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
50
70
features .BoundingBox : F .resize_bounding_box ,
51
71
features .Mask : F .resize_mask ,
52
72
},
73
+ skips = [
74
+ skip_integer_size_jit (),
75
+ ],
53
76
),
54
77
DispatcherInfo (
55
78
F .affine ,
@@ -58,6 +81,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
58
81
features .BoundingBox : F .affine_bounding_box ,
59
82
features .Mask : F .affine_mask ,
60
83
},
84
+ skips = [skip_python_scalar_arg_jit ("shear" , reason = "Scalar shear is not supported by JIT" )],
61
85
),
62
86
DispatcherInfo (
63
87
F .vertical_flip ,
@@ -122,12 +146,19 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
122
146
features .BoundingBox : F .center_crop_bounding_box ,
123
147
features .Mask : F .center_crop_mask ,
124
148
},
149
+ skips = [
150
+ skip_integer_size_jit ("output_size" ),
151
+ ],
125
152
),
126
153
DispatcherInfo (
127
154
F .gaussian_blur ,
128
155
kernels = {
129
156
features .Image : F .gaussian_blur_image_tensor ,
130
157
},
158
+ skips = [
159
+ skip_python_scalar_arg_jit ("kernel_size" ),
160
+ skip_python_scalar_arg_jit ("sigma" ),
161
+ ],
131
162
),
132
163
DispatcherInfo (
133
164
F .equalize ,
@@ -207,11 +238,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
207
238
features .Image : F .five_crop_image_tensor ,
208
239
},
209
240
skips = [
210
- Skip (
211
- "test_scripted_smoke" ,
212
- condition = lambda args_kwargs , device : isinstance (args_kwargs .kwargs ["size" ], int ),
213
- reason = "Integer size is not supported when scripting five_crop_image_tensor." ,
214
- ),
241
+ skip_integer_size_jit (),
215
242
],
216
243
),
217
244
DispatcherInfo (
@@ -220,11 +247,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
220
247
features .Image : F .ten_crop_image_tensor ,
221
248
},
222
249
skips = [
223
- Skip (
224
- "test_scripted_smoke" ,
225
- condition = lambda args_kwargs , device : isinstance (args_kwargs .kwargs ["size" ], int ),
226
- reason = "Integer size is not supported when scripting ten_crop_image_tensor." ,
227
- ),
250
+ skip_integer_size_jit (),
228
251
],
229
252
),
230
253
DispatcherInfo (
0 commit comments