Skip to content

Reduce sample inputs for prototype transform kernels #6714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 6, 2022
6 changes: 5 additions & 1 deletion test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"assert_close",
"assert_equal",
"ArgsKwargs",
"VALID_EXTRA_DIMS",
"make_image_loaders",
"make_image",
"make_images",
Expand Down Expand Up @@ -201,7 +202,10 @@ def _parse_image_size(size, *, name="size"):
)


DEFAULT_EXTRA_DIMS = ((), (0,), (4,), (2, 3), (5, 0), (0, 5))
VALID_EXTRA_DIMS = ((), (4,), (2, 3))
DEGENERATE_BATCH_DIMS = ((0,), (5, 0), (0, 5))

DEFAULT_EXTRA_DIMS = (*VALID_EXTRA_DIMS, *DEGENERATE_BATCH_DIMS)


def from_loader(loader_fn):
Expand Down
60 changes: 48 additions & 12 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,40 @@ def sample_inputs(self, *feature_types, filter_metadata=True):
yield args_kwargs


def xfail_python_scalar_arg_jit(name, *, reason=None):
def xfail_jit_python_scalar_arg(name, *, reason=None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a rename that moves the jit term to the front of the name to make it clear this is only a JIT issue.

reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
return TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason=reason),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs[name], (int, float)),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
)


def xfail_integer_size_jit(name="size"):
return xfail_python_scalar_arg_jit(name, reason=f"Integer `{name}` is not supported when scripting.")
def xfail_jit_integer_size(name="size"):
return xfail_jit_python_scalar_arg(name, reason=f"Integer `{name}` is not supported when scripting.")


def xfail_jit_tuple_instead_of_list(name, *, reason=None):
reason = reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting"
return TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason=reason),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple),
)


def is_list_of_ints(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill)


def xfail_jit_list_of_ints(name, *, reason=None):
reason = reason or f"Passing a list of integers for `{name}` is not supported when scripting"
return TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason=reason),
condition=is_list_of_ints,
)


skip_dispatch_feature = TestMark(
Expand Down Expand Up @@ -123,7 +146,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.resize_image_pil),
test_marks=[
xfail_integer_size_jit(),
xfail_jit_integer_size(),
],
),
DispatcherInfo(
Expand All @@ -136,7 +159,10 @@ def fill_sequence_needs_broadcast(args_kwargs):
pil_kernel_info=PILKernelInfo(F.affine_image_pil),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
xfail_python_scalar_arg_jit("shear"),
xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
Comment on lines +164 to +165
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a few of these. I'm not sure if this is a regression of #6636. Will check and send a follow-up PR since this is one is only test changes.

],
),
DispatcherInfo(
Expand All @@ -156,6 +182,11 @@ def fill_sequence_needs_broadcast(args_kwargs):
features.Mask: F.rotate_mask,
},
pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
],
),
DispatcherInfo(
F.crop,
Expand Down Expand Up @@ -194,7 +225,12 @@ def fill_sequence_needs_broadcast(args_kwargs):
),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
)
),
xfail_jit_python_scalar_arg("padding"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why all this xfail appeared in this PR and not before ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we didn't test scalar padding before

padding=[[1], [1, 1], [1, 1, 2, 2]],

Thus, while reducing the number of sample inputs now, the tests are actually more comprehensive.

xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
],
),
DispatcherInfo(
Expand Down Expand Up @@ -227,7 +263,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
test_marks=[
xfail_integer_size_jit("output_size"),
xfail_jit_integer_size("output_size"),
],
),
DispatcherInfo(
Expand All @@ -237,8 +273,8 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
test_marks=[
xfail_python_scalar_arg_jit("kernel_size"),
xfail_python_scalar_arg_jit("sigma"),
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
],
),
DispatcherInfo(
Expand Down Expand Up @@ -335,7 +371,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
test_marks=[
xfail_integer_size_jit(),
xfail_jit_integer_size(),
skip_dispatch_feature,
],
),
Expand All @@ -345,7 +381,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
features.Image: F.ten_crop_image_tensor,
},
test_marks=[
xfail_integer_size_jit(),
xfail_jit_integer_size(),
skip_dispatch_feature,
],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
Expand Down
Loading