From 87c44420185cda667aa497bae9882f6a31ef3802 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 13:33:06 +0100 Subject: [PATCH 1/2] make convert_format_bounding_box a hybrid kernel dispatcher --- test/prototype_common_utils.py | 7 ++++ test/prototype_transforms_kernel_infos.py | 35 +++++++++++++------ test/test_prototype_transforms_functional.py | 27 ++++++++++++++ torchvision/prototype/transforms/_meta.py | 7 +--- .../prototype/transforms/functional/_meta.py | 35 +++++++++++++++++-- 5 files changed, 92 insertions(+), 19 deletions(-) diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 3d34383319f..89358ee7dcf 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -237,6 +237,13 @@ class TensorLoader: def load(self, device): return self.fn(self.shape, self.dtype, device) + def unwrap(self): + return TensorLoader( + fn=lambda shape, dtype, device: self.fn(shape, dtype, device).as_subclass(torch.Tensor), + shape=self.shape, + dtype=self.dtype, + ) + @dataclasses.dataclass class ImageLoader(TensorLoader): diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index ce80658ce8b..2ddf085ea19 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -26,7 +26,6 @@ make_video_loader, make_video_loaders, mark_framework_limitation, - TensorLoader, TestMark, ) from torch.utils._pytree import tree_map @@ -660,7 +659,8 @@ def sample_inputs_affine_video(): def sample_inputs_convert_format_bounding_box(): formats = list(datapoints.BoundingBoxFormat) for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): - yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format) + yield ArgsKwargs(bounding_box_loader, new_format=new_format) + yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format) def reference_convert_format_bounding_box(bounding_box, old_format, new_format): @@ -671,8 +671,14 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format): def reference_inputs_convert_format_bounding_box(): for args_kwargs in sample_inputs_convert_format_bounding_box(): - if len(args_kwargs.args[0].shape) == 2: - yield args_kwargs + if len(args_kwargs.args[0].shape) != 2: + continue + + (loader, *other_args), kwargs = args_kwargs + if isinstance(loader, BoundingBoxLoader): + kwargs["old_format"] = loader.format + loader = loader.unwrap() + yield ArgsKwargs(loader, *other_args, **kwargs) KERNEL_INFOS.append( @@ -682,6 +688,18 @@ def reference_inputs_convert_format_bounding_box(): reference_fn=reference_convert_format_bounding_box, reference_inputs_fn=reference_inputs_convert_format_bounding_box, logs_usage=True, + test_marks=[ + mark_framework_limitation( + ("TestKernels", "test_scripted_vs_eager"), + reason=( + "The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a " + "`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor " + "`spatial_size` was passed" + ), + condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader) + and arg_kwargs.kwargs.get("old_format") is None, + ) + ], ), ) @@ -2014,13 +2032,10 @@ def sample_inputs_clamp_bounding_box(): for bounding_box_loader in make_bounding_box_loaders(): yield ArgsKwargs(bounding_box_loader) - simple_tensor_loader = TensorLoader( - fn=lambda shape, dtype, device: bounding_box_loader.fn(shape, dtype, device).as_subclass(torch.Tensor), - shape=bounding_box_loader.shape, - dtype=bounding_box_loader.dtype, - ) yield ArgsKwargs( - simple_tensor_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size + bounding_box_loader.unwrap(), + format=bounding_box_loader.format, + spatial_size=bounding_box_loader.spatial_size, ) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 948143771ab..55ea89f38a3 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -590,6 +590,33 @@ def test_datapoint_explicit_metadata(self, metadata): F.clamp_bounding_box(datapoint, **metadata) +class TestConvertFormatBoundingBox: + @pytest.mark.parametrize( + ("inpt", "old_format"), + [ + (next(make_bounding_boxes()), None), + (next(make_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY), + ], + ) + def test_missing_new_format(self, inpt, old_format): + with pytest.raises(TypeError, match="new_format"): + F.convert_format_bounding_box(inpt, old_format) + + def test_simple_tensor_insufficient_metadata(self): + simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) + + with pytest.raises(ValueError, match="simple tensor"): + F.convert_format_bounding_box(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH) + + def test_datapoint_explicit_metadata(self): + datapoint = next(make_bounding_boxes()) + + with pytest.raises(ValueError, match="bounding box datapoint"): + F.convert_format_bounding_box( + datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH + ) + + # TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in # `prototype_transforms_kernel_infos.py` diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 946c00b0ee6..75085fff6d5 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -19,12 +19,7 @@ def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None: self.format = format def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: - # We need to unwrap here to avoid unnecessary `__torch_function__` calls, - # since `convert_format_bounding_box` does not have a dispatcher function that would do that for us - output = F.convert_format_bounding_box( - inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=params["format"] - ) - return datapoints.BoundingBox.wrap_like(inpt, output, format=params["format"]) + return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value] class ConvertDtype(Transform): diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 2c5180a8644..29350909014 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -186,11 +186,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: return xyxy -def convert_format_bounding_box( +def _convert_format_bounding_box( bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False ) -> torch.Tensor: - if not torch.jit.is_scripting(): - _log_api_usage_once(convert_format_bounding_box) if new_format == old_format: return bounding_box @@ -209,6 +207,37 @@ def convert_format_bounding_box( return bounding_box +def convert_format_bounding_box( + inpt: datapoints.InputTypeJIT, + old_format: Optional[BoundingBoxFormat] = None, + new_format: Optional[BoundingBoxFormat] = None, + inplace: bool = False, +) -> datapoints.InputTypeJIT: + # This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor + # inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on + # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the + # default error that would be thrown if `new_format` had no default value. + if new_format is None: + raise TypeError("convert_format_bounding_box() missing 1 required argument: 'new_format'") + + if not torch.jit.is_scripting(): + _log_api_usage_once(convert_format_bounding_box) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + if old_format is None: + raise ValueError("For simple tensor inputs, `old_format` has to be passed.") + return _convert_format_bounding_box(inpt, old_format=old_format, new_format=new_format, inplace=inplace) + elif isinstance(inpt, datapoints.BoundingBox): + if old_format is not None: + raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.") + output = _convert_format_bounding_box(inpt, old_format=inpt.format, new_format=new_format, inplace=inplace) + return datapoints.BoundingBox.wrap_like(inpt, output) + else: + raise TypeError( + f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." + ) + + def _clamp_bounding_box( bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] ) -> torch.Tensor: From 0e9de7c688cffef5be9a490ff98783727cc7ca31 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 13:59:12 +0100 Subject: [PATCH 2/2] fix error message test capture --- test/test_prototype_transforms_functional.py | 10 +++++----- torchvision/prototype/transforms/functional/_meta.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 55ea89f38a3..5469e56df96 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -572,7 +572,7 @@ class TestClampBoundingBox: def test_simple_tensor_insufficient_metadata(self, metadata): simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) - with pytest.raises(ValueError, match="simple tensor"): + with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` has to be passed")): F.clamp_bounding_box(simple_tensor, **metadata) @pytest.mark.parametrize( @@ -586,7 +586,7 @@ def test_simple_tensor_insufficient_metadata(self, metadata): def test_datapoint_explicit_metadata(self, metadata): datapoint = next(make_bounding_boxes()) - with pytest.raises(ValueError, match="bounding box datapoint"): + with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` must not be passed")): F.clamp_bounding_box(datapoint, **metadata) @@ -599,19 +599,19 @@ class TestConvertFormatBoundingBox: ], ) def test_missing_new_format(self, inpt, old_format): - with pytest.raises(TypeError, match="new_format"): + with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")): F.convert_format_bounding_box(inpt, old_format) def test_simple_tensor_insufficient_metadata(self): simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) - with pytest.raises(ValueError, match="simple tensor"): + with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")): F.convert_format_bounding_box(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH) def test_datapoint_explicit_metadata(self): datapoint = next(make_bounding_boxes()) - with pytest.raises(ValueError, match="bounding box datapoint"): + with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")): F.convert_format_bounding_box( datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH ) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 29350909014..a9917a80e7a 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -229,7 +229,7 @@ def convert_format_bounding_box( return _convert_format_bounding_box(inpt, old_format=old_format, new_format=new_format, inplace=inplace) elif isinstance(inpt, datapoints.BoundingBox): if old_format is not None: - raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.") + raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.") output = _convert_format_bounding_box(inpt, old_format=inpt.format, new_format=new_format, inplace=inplace) return datapoints.BoundingBox.wrap_like(inpt, output) else: