Skip to content

make convert_format_bounding_box a hybrid kernel dispatcher #7228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,13 @@ class TensorLoader:
def load(self, device):
return self.fn(self.shape, self.dtype, device)

def unwrap(self):
return TensorLoader(
fn=lambda shape, dtype, device: self.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=self.shape,
dtype=self.dtype,
)


@dataclasses.dataclass
class ImageLoader(TensorLoader):
Expand Down
35 changes: 25 additions & 10 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
make_video_loader,
make_video_loaders,
mark_framework_limitation,
TensorLoader,
TestMark,
)
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -660,7 +659,8 @@ def sample_inputs_affine_video():
def sample_inputs_convert_format_bounding_box():
formats = list(datapoints.BoundingBoxFormat)
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
yield ArgsKwargs(bounding_box_loader, new_format=new_format)
yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format)


def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
Expand All @@ -671,8 +671,14 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format):

def reference_inputs_convert_format_bounding_box():
for args_kwargs in sample_inputs_convert_format_bounding_box():
if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs
if len(args_kwargs.args[0].shape) != 2:
continue

(loader, *other_args), kwargs = args_kwargs
if isinstance(loader, BoundingBoxLoader):
kwargs["old_format"] = loader.format
loader = loader.unwrap()
yield ArgsKwargs(loader, *other_args, **kwargs)


KERNEL_INFOS.append(
Expand All @@ -682,6 +688,18 @@ def reference_inputs_convert_format_bounding_box():
reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box,
logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("old_format") is None,
)
],
),
)

Expand Down Expand Up @@ -2014,13 +2032,10 @@ def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_box_loader)

simple_tensor_loader = TensorLoader(
fn=lambda shape, dtype, device: bounding_box_loader.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=bounding_box_loader.shape,
dtype=bounding_box_loader.dtype,
)
yield ArgsKwargs(
simple_tensor_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
bounding_box_loader.unwrap(),
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
)


Expand Down
31 changes: 29 additions & 2 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ class TestClampBoundingBox:
def test_simple_tensor_insufficient_metadata(self, metadata):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

with pytest.raises(ValueError, match="simple tensor"):
with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` has to be passed")):
F.clamp_bounding_box(simple_tensor, **metadata)

@pytest.mark.parametrize(
Expand All @@ -586,10 +586,37 @@ def test_simple_tensor_insufficient_metadata(self, metadata):
def test_datapoint_explicit_metadata(self, metadata):
datapoint = next(make_bounding_boxes())

with pytest.raises(ValueError, match="bounding box datapoint"):
with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` must not be passed")):
F.clamp_bounding_box(datapoint, **metadata)


class TestConvertFormatBoundingBox:
@pytest.mark.parametrize(
("inpt", "old_format"),
[
(next(make_bounding_boxes()), None),
(next(make_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
],
)
def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_format_bounding_box(inpt, old_format)

def test_simple_tensor_insufficient_metadata(self):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_format_bounding_box(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)

def test_datapoint_explicit_metadata(self):
datapoint = next(make_bounding_boxes())

with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_format_bounding_box(
datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
)


# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`

Expand Down
7 changes: 1 addition & 6 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@ def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None:
self.format = format

def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
# since `convert_format_bounding_box` does not have a dispatcher function that would do that for us
output = F.convert_format_bounding_box(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=params["format"]
)
return datapoints.BoundingBox.wrap_like(inpt, output, format=params["format"])
return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value]


class ConvertDtype(Transform):
Expand Down
35 changes: 32 additions & 3 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
return xyxy


def convert_format_bounding_box(
def _convert_format_bounding_box(
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_format_bounding_box)

if new_format == old_format:
return bounding_box
Expand All @@ -209,6 +207,37 @@ def convert_format_bounding_box(
return bounding_box


def convert_format_bounding_box(
inpt: datapoints.InputTypeJIT,
old_format: Optional[BoundingBoxFormat] = None,
new_format: Optional[BoundingBoxFormat] = None,
inplace: bool = False,
) -> datapoints.InputTypeJIT:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
if new_format is None:
raise TypeError("convert_format_bounding_box() missing 1 required argument: 'new_format'")

if not torch.jit.is_scripting():
_log_api_usage_once(convert_format_bounding_box)

if torch.jit.is_scripting() or is_simple_tensor(inpt):
if old_format is None:
raise ValueError("For simple tensor inputs, `old_format` has to be passed.")
return _convert_format_bounding_box(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
elif isinstance(inpt, datapoints.BoundingBox):
if old_format is not None:
raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.")
output = _convert_format_bounding_box(inpt, old_format=inpt.format, new_format=new_format, inplace=inplace)
return datapoints.BoundingBox.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
)


def _clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor:
Expand Down