Skip to content

Commit 026991b

Browse files
authored
Reduce sample inputs for prototype transform kernels (#6714)
* pad_image_tensor * pad_mask and pad_bounding_box * resize_{image_tensor, mask, bounding_box} * center_crop_{image_tensor, mask} * {five, ten}_crop_image_tensor * crop_{image_tensor, mask} * convert_color_space_image_tensor * affine_{image_tensor, mask, bounding_box} * rotate_{image_tensor, mask} * gaussian_blur_image_tensor * cleanup
1 parent e3941af commit 026991b

File tree

4 files changed

+309
-153
lines changed

4 files changed

+309
-153
lines changed

test/prototype_common_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"assert_close",
2929
"assert_equal",
3030
"ArgsKwargs",
31+
"VALID_EXTRA_DIMS",
3132
"make_image_loaders",
3233
"make_image",
3334
"make_images",
@@ -201,7 +202,10 @@ def _parse_image_size(size, *, name="size"):
201202
)
202203

203204

204-
DEFAULT_EXTRA_DIMS = ((), (0,), (4,), (2, 3), (5, 0), (0, 5))
205+
VALID_EXTRA_DIMS = ((), (4,), (2, 3))
206+
DEGENERATE_BATCH_DIMS = ((0,), (5, 0), (0, 5))
207+
208+
DEFAULT_EXTRA_DIMS = (*VALID_EXTRA_DIMS, *DEGENERATE_BATCH_DIMS)
205209

206210

207211
def from_loader(loader_fn):

test/prototype_transforms_dispatcher_infos.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,40 @@ def sample_inputs(self, *feature_types, filter_metadata=True):
6363
yield args_kwargs
6464

6565

66-
def xfail_python_scalar_arg_jit(name, *, reason=None):
66+
def xfail_jit_python_scalar_arg(name, *, reason=None):
6767
reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
6868
return TestMark(
6969
("TestDispatchers", "test_scripted_smoke"),
7070
pytest.mark.xfail(reason=reason),
71-
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs[name], (int, float)),
71+
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
7272
)
7373

7474

75-
def xfail_integer_size_jit(name="size"):
76-
return xfail_python_scalar_arg_jit(name, reason=f"Integer `{name}` is not supported when scripting.")
75+
def xfail_jit_integer_size(name="size"):
76+
return xfail_jit_python_scalar_arg(name, reason=f"Integer `{name}` is not supported when scripting.")
77+
78+
79+
def xfail_jit_tuple_instead_of_list(name, *, reason=None):
80+
reason = reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting"
81+
return TestMark(
82+
("TestDispatchers", "test_scripted_smoke"),
83+
pytest.mark.xfail(reason=reason),
84+
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple),
85+
)
86+
87+
88+
def is_list_of_ints(args_kwargs):
89+
fill = args_kwargs.kwargs.get("fill")
90+
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill)
91+
92+
93+
def xfail_jit_list_of_ints(name, *, reason=None):
94+
reason = reason or f"Passing a list of integers for `{name}` is not supported when scripting"
95+
return TestMark(
96+
("TestDispatchers", "test_scripted_smoke"),
97+
pytest.mark.xfail(reason=reason),
98+
condition=is_list_of_ints,
99+
)
77100

78101

79102
skip_dispatch_feature = TestMark(
@@ -123,7 +146,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
123146
},
124147
pil_kernel_info=PILKernelInfo(F.resize_image_pil),
125148
test_marks=[
126-
xfail_integer_size_jit(),
149+
xfail_jit_integer_size(),
127150
],
128151
),
129152
DispatcherInfo(
@@ -136,7 +159,10 @@ def fill_sequence_needs_broadcast(args_kwargs):
136159
pil_kernel_info=PILKernelInfo(F.affine_image_pil),
137160
test_marks=[
138161
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
139-
xfail_python_scalar_arg_jit("shear"),
162+
xfail_jit_python_scalar_arg("shear"),
163+
xfail_jit_tuple_instead_of_list("fill"),
164+
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
165+
xfail_jit_list_of_ints("fill"),
140166
],
141167
),
142168
DispatcherInfo(
@@ -156,6 +182,11 @@ def fill_sequence_needs_broadcast(args_kwargs):
156182
features.Mask: F.rotate_mask,
157183
},
158184
pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
185+
test_marks=[
186+
xfail_jit_tuple_instead_of_list("fill"),
187+
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
188+
xfail_jit_list_of_ints("fill"),
189+
],
159190
),
160191
DispatcherInfo(
161192
F.crop,
@@ -194,7 +225,12 @@ def fill_sequence_needs_broadcast(args_kwargs):
194225
),
195226
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
196227
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
197-
)
228+
),
229+
xfail_jit_python_scalar_arg("padding"),
230+
xfail_jit_tuple_instead_of_list("padding"),
231+
xfail_jit_tuple_instead_of_list("fill"),
232+
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
233+
xfail_jit_list_of_ints("fill"),
198234
],
199235
),
200236
DispatcherInfo(
@@ -227,7 +263,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
227263
},
228264
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
229265
test_marks=[
230-
xfail_integer_size_jit("output_size"),
266+
xfail_jit_integer_size("output_size"),
231267
],
232268
),
233269
DispatcherInfo(
@@ -237,8 +273,8 @@ def fill_sequence_needs_broadcast(args_kwargs):
237273
},
238274
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
239275
test_marks=[
240-
xfail_python_scalar_arg_jit("kernel_size"),
241-
xfail_python_scalar_arg_jit("sigma"),
276+
xfail_jit_python_scalar_arg("kernel_size"),
277+
xfail_jit_python_scalar_arg("sigma"),
242278
],
243279
),
244280
DispatcherInfo(
@@ -335,7 +371,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
335371
},
336372
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
337373
test_marks=[
338-
xfail_integer_size_jit(),
374+
xfail_jit_integer_size(),
339375
skip_dispatch_feature,
340376
],
341377
),
@@ -345,7 +381,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
345381
features.Image: F.ten_crop_image_tensor,
346382
},
347383
test_marks=[
348-
xfail_integer_size_jit(),
384+
xfail_jit_integer_size(),
349385
skip_dispatch_feature,
350386
],
351387
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),

0 commit comments

Comments
 (0)