Skip to content

only use plain tensors in kernel tests #7230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,6 @@ class TensorLoader:
def load(self, device):
return self.fn(self.shape, self.dtype, device)

def unwrap(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just undoes the change in #7228 since we no longer need this.

return TensorLoader(
fn=lambda shape, dtype, device: self.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=self.shape,
dtype=self.dtype,
)


@dataclasses.dataclass
class ImageLoader(TensorLoader):
Expand Down
34 changes: 28 additions & 6 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,21 @@ def sample_inputs(self, *datapoint_types, filter_metadata=True):

if not filter_metadata:
yield from sample_inputs
else:
for args_kwargs in sample_inputs:
for attribute in datapoint_type.__annotations__.keys():
if attribute in args_kwargs.kwargs:
del args_kwargs.kwargs[attribute]
return

yield args_kwargs
import itertools

for args_kwargs in sample_inputs:
for name in itertools.chain(
datapoint_type.__annotations__.keys(),
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# per-dispatcher level. However, so far there is no option for that.
(f"old_{name}" for name in datapoint_type.__annotations__.keys()),
):
if name in args_kwargs.kwargs:
del args_kwargs.kwargs[name]

yield args_kwargs


def xfail_jit(reason, *, condition=None):
Expand Down Expand Up @@ -458,4 +466,18 @@ def fill_sequence_needs_broadcast(args_kwargs):
skip_dispatch_datapoint,
],
),
DispatcherInfo(
F.clamp_bounding_box,
kernels={datapoints.BoundingBox: F.clamp_bounding_box},
test_marks=[
skip_dispatch_datapoint,
],
),
DispatcherInfo(
F.convert_format_bounding_box,
kernels={datapoints.BoundingBox: F.convert_format_bounding_box},
test_marks=[
skip_dispatch_datapoint,
],
),
]
50 changes: 8 additions & 42 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from datasets_utils import combinations_grid
from prototype_common_utils import (
ArgsKwargs,
BoundingBoxLoader,
get_num_channels,
ImageLoader,
InfoBase,
Expand Down Expand Up @@ -337,7 +336,6 @@ def sample_inputs_resize_video():


def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=None):

old_height, old_width = spatial_size
new_height, new_width = F._geometry._compute_resized_output_size(spatial_size, size=size, max_size=max_size)

Expand All @@ -350,13 +348,15 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=
)

expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=bounding_box.format, affine_matrix=affine_matrix
bounding_box, format=datapoints.BoundingBoxFormat.XYXY, affine_matrix=affine_matrix
)
return expected_bboxes, (new_height, new_width)


def reference_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(extra_dims=((), (4,))):
for bounding_box_loader in make_bounding_box_loaders(
formats=[datapoints.BoundingBoxFormat.XYXY], extra_dims=((), (4,))
):
for size in _get_resize_sizes(bounding_box_loader.spatial_size):
yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size)

Expand Down Expand Up @@ -668,8 +668,7 @@ def sample_inputs_affine_video():
def sample_inputs_convert_format_bounding_box():
formats = list(datapoints.BoundingBoxFormat)
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_box_loader, new_format=new_format)
yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format)
Comment on lines -671 to -672
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer need to pass two things here, since the datapoint will be unwrapped by the test automatically.

yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we still want to test both pure-tensors and BoundingBoxes, at least for those "hybrid" kernels ? (perhaps that's still covered?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't until 99dd341, but now it is 😇 So the kernel tests actually only test plain tensors, but the dispatcher tests take over the other part.



def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
Expand All @@ -680,14 +679,8 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format):

def reference_inputs_convert_format_bounding_box():
for args_kwargs in sample_inputs_convert_format_bounding_box():
if len(args_kwargs.args[0].shape) != 2:
continue

(loader, *other_args), kwargs = args_kwargs
if isinstance(loader, BoundingBoxLoader):
kwargs["old_format"] = loader.format
loader = loader.unwrap()
yield ArgsKwargs(loader, *other_args, **kwargs)
if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs


KERNEL_INFOS.append(
Expand All @@ -697,18 +690,6 @@ def reference_inputs_convert_format_bounding_box():
reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box,
logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("old_format") is None,
)
],
),
)

Expand Down Expand Up @@ -2049,10 +2030,8 @@ def sample_inputs_adjust_saturation_video():

def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_box_loader)

yield ArgsKwargs(
bounding_box_loader.unwrap(),
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
)
Expand All @@ -2063,19 +2042,6 @@ def sample_inputs_clamp_bounding_box():
F.clamp_bounding_box,
sample_inputs_fn=sample_inputs_clamp_bounding_box,
logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("format") is None
and arg_kwargs.kwargs.get("spatial_size") is None,
)
],
)
)

Expand Down
Loading