Skip to content

only use plain tensors in kernel tests #7230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 13, 2023
Merged

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Feb 13, 2023

With #7227 and #7228 we have a new paradigm of kernel dispatcher hybrids. So far we just simply passed the datapoints in the kernel tests directly, since it made no difference. With these recent changes, we can no longer do that without some limitations of the test suite.

Since some of our kernel test depend on the datapoint type as metadata, e.g.

datapoint_type = (
datapoints.Image
if torchvision.prototype.transforms.utils.is_simple_tensor(batched_input)
else type(batched_input)
)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
data_dims = {
datapoints.Image: 3,
datapoints.BoundingBox: 1,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground.
datapoints.Mask: 2,
datapoints.Video: 4,
}.get(datapoint_type)

I've opted to unwrap inside the tests.

The largest part of the diff in this PR comes from our correctness tests for bounding boxes that we haven't ported to the proper framework yet. They happily mixed plain tensors and datapoints.

cc @vfdev-5 @bjuncek

@@ -237,13 +237,6 @@ class TensorLoader:
def load(self, device):
return self.fn(self.shape, self.dtype, device)

def unwrap(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just undoes the change in #7228 since we no longer need this.

Comment on lines -671 to -672
yield ArgsKwargs(bounding_box_loader, new_format=new_format)
yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer need to pass two things here, since the datapoint will be unwrapped by the test automatically.

@@ -155,14 +156,12 @@ def _unbatch(self, batch, *, data_dims):
if batched_tensor.ndim == data_dims:
return batch

unbatcheds = []
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, just undoing what #7227 did, since we no longer care for datapoints at this stage.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, just gave a light review - LGTM to unblock anyway

@@ -668,8 +669,7 @@ def sample_inputs_affine_video():
def sample_inputs_convert_format_bounding_box():
formats = list(datapoints.BoundingBoxFormat)
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_box_loader, new_format=new_format)
yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format)
yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we still want to test both pure-tensors and BoundingBoxes, at least for those "hybrid" kernels ? (perhaps that's still covered?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't until 99dd341, but now it is 😇 So the kernel tests actually only test plain tensors, but the dispatcher tests take over the other part.

@pmeier pmeier merged commit a63046c into pytorch:main Feb 13, 2023
@pmeier pmeier deleted the kernel-tests branch February 13, 2023 14:58
facebook-github-bot pushed a commit that referenced this pull request Mar 28, 2023
Reviewed By: vmoens

Differential Revision: D44416278

fbshipit-source-id: 7ac022e18aec970d7b4c5091bd9840ea2d7e0ed6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants