diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 8b1665a3d31..04093309774 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -269,6 +269,21 @@ def test_common(self, transform, adapter, container_type, image_or_video, device else: assert output_item is input_item + if isinstance(input_item, datapoints.BoundingBox) and not isinstance( + transform, transforms.ConvertBoundingBoxFormat + ): + assert output_item.format == input_item.format + + # Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future + # transform that does this), back into a valid one. + # TODO: we should test that against all degenerate boxes above + for format in list(datapoints.BoundingBoxFormat): + sample = dict( + boxes=datapoints.BoundingBox([[0, 0, 0, 0]], format=format, spatial_size=(224, 244)), + labels=torch.tensor([3]), + ) + assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4) + @parametrize( [ ( diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 7dff7a509ad..ffee57eea6f 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -508,6 +508,22 @@ def test_unkown_type(self, info): with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))): info.dispatcher(unkown_input, *other_args, **kwargs) + @make_info_args_kwargs_parametrization( + [ + info + for info in DISPATCHER_INFOS + if datapoints.BoundingBox in info.kernels and info.dispatcher is not F.convert_format_bounding_box + ], + args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBox), + ) + def test_bounding_box_format_consistency(self, info, args_kwargs): + (bounding_box, *other_args), kwargs = args_kwargs.load() + format = bounding_box.format + + output = info.dispatcher(bounding_box, *other_args, **kwargs) + + assert output.format == format + @pytest.mark.parametrize( ("alias", "target"),