Skip to content

make clamp_bounding_box a kernel / dispatcher hybrid #7227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,14 +640,14 @@ def __init__(
self.condition = condition or (lambda args_kwargs: True)


def mark_framework_limitation(test_id, reason):
def mark_framework_limitation(test_id, reason, condition=None):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time
return TestMark(test_id, pytest.mark.skip(reason=reason))
return TestMark(test_id, pytest.mark.skip(reason=reason), condition=condition)


class InfoBase:
Expand Down
24 changes: 23 additions & 1 deletion test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datasets_utils import combinations_grid
from prototype_common_utils import (
ArgsKwargs,
BoundingBoxLoader,
get_num_channels,
ImageLoader,
InfoBase,
Expand All @@ -25,6 +26,7 @@
make_video_loader,
make_video_loaders,
mark_framework_limitation,
TensorLoader,
TestMark,
)
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -2010,8 +2012,15 @@ def sample_inputs_adjust_saturation_video():

def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_box_loader)

simple_tensor_loader = TensorLoader(
fn=lambda shape, dtype, device: bounding_box_loader.fn(shape, dtype, device).as_subclass(torch.Tensor),
Comment on lines +2017 to +2018
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quick and dirty. If we have more such cases in the future, we should have something like an unwrap method or the like to get the plain tensor.

shape=bounding_box_loader.shape,
dtype=bounding_box_loader.dtype,
)
yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
simple_tensor_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
)


Expand All @@ -2020,6 +2029,19 @@ def sample_inputs_clamp_bounding_box():
F.clamp_bounding_box,
sample_inputs_fn=sample_inputs_clamp_bounding_box,
logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("format") is None
and arg_kwargs.kwargs.get("spatial_size") is None,
)
],
)
)

Expand Down
44 changes: 38 additions & 6 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,14 @@ def _unbatch(self, batch, *, data_dims):
if batched_tensor.ndim == data_dims:
return batch

return [
self._unbatch(unbatched, data_dims=data_dims)
for unbatched in (
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
)
]
unbatcheds = []
for unbatched in (
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
):
if isinstance(batch, datapoints._datapoint.Datapoint):
unbatched = type(batch).wrap_like(batch, unbatched)
Comment on lines +162 to +163
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small fix since this didn't respect datapoints types before. This was not an issue, since this is called from a kernel test and so far all kernels operated only with plain tensors. Meaning, all datapoints would have been unwrapped anyway.

unbatcheds.append(self._unbatch(unbatched, data_dims=data_dims))
return unbatcheds

@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
Expand Down Expand Up @@ -558,6 +560,36 @@ def assert_samples_from_standard_normal(t):
assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))


class TestClampBoundingBox:
@pytest.mark.parametrize(
"metadata",
[
dict(),
dict(format=datapoints.BoundingBoxFormat.XYXY),
dict(spatial_size=(1, 1)),
],
)
def test_simple_tensor_insufficient_metadata(self, metadata):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

with pytest.raises(ValueError, match="simple tensor"):
F.clamp_bounding_box(simple_tensor, **metadata)

@pytest.mark.parametrize(
"metadata",
[
dict(format=datapoints.BoundingBoxFormat.XYXY),
dict(spatial_size=(1, 1)),
dict(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(1, 1)),
],
)
def test_datapoint_explicit_metadata(self, metadata):
datapoint = next(make_bounding_boxes())

with pytest.raises(ValueError, match="bounding box datapoint"):
F.clamp_bounding_box(datapoint, **metadata)


# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`

Expand Down
7 changes: 1 addition & 6 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,4 @@ class ClampBoundingBoxes(Transform):
_transformed_types = (datapoints.BoundingBox,)

def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
# since `clamp_bounding_box` does not have a dispatcher function that would do that for us
output = F.clamp_bounding_box(
inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size
)
return datapoints.BoundingBox.wrap_like(inpt, output)
return F.clamp_bounding_box(inpt) # type: ignore[return-value]
30 changes: 25 additions & 5 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import PIL.Image
import torch
Expand Down Expand Up @@ -209,12 +209,9 @@ def convert_format_bounding_box(
return bounding_box


def clamp_bounding_box(
def _clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_box)

# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
xyxy_boxes = convert_format_bounding_box(
Expand All @@ -225,6 +222,29 @@ def clamp_bounding_box(
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True)


def clamp_bounding_box(
inpt: datapoints.InputTypeJIT,
format: Optional[BoundingBoxFormat] = None,
spatial_size: Optional[Tuple[int, int]] = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_box)

if torch.jit.is_scripting() or is_simple_tensor(inpt):
if format is None or spatial_size is None:
raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.")
return _clamp_bounding_box(inpt, format=format, spatial_size=spatial_size)
elif isinstance(inpt, datapoints.BoundingBox):
if format is not None or spatial_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.")
output = _clamp_bounding_box(inpt, format=inpt.format, spatial_size=inpt.spatial_size)
return datapoints.BoundingBox.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
)


def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
Expand Down