diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index dc3de480d1f..ecbe0815692 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -3,7 +3,12 @@ import pytest import torch from common_utils import assert_equal -from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels +from test_prototype_transforms_functional import ( + make_images, + make_bounding_boxes, + make_one_hot_labels, + make_segmentation_masks, +) from torchvision.prototype import transforms, features from torchvision.transforms.functional import to_pil_image, pil_to_tensor @@ -153,6 +158,8 @@ def test_normalize(self, transform, input): transforms.RandomResizedCrop([16, 16]), itertools.chain( make_images(extra_dims=[(4,)]), + make_bounding_boxes(), + make_segmentation_masks(), make_vanilla_tensor_images(), make_pil_images(), ), diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 0487a71416e..06cbd18cf33 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -183,6 +183,12 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: input, **params, size=list(self.size), interpolation=self.interpolation ) return features.Image.new_like(input, output) + elif isinstance(input, features.BoundingBox): + output = F.resized_crop_bounding_box(input, **params, size=list(self.size)) + return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size))) + elif isinstance(input, features.SegmentationMask): + output = F.resized_crop_segmentation_mask(input, **params, size=list(self.size)) + return features.SegmentationMask.new_like(input, output) elif is_simple_tensor(input): return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation) elif isinstance(input, PIL.Image.Image): @@ -190,12 +196,6 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - return super().forward(sample) - class MultiCropResult(list): """Helper class for :class:`~torchvision.prototype.transforms.BatchMultiCrop`. diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 0517757a758..563cb8cdc65 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -20,7 +20,7 @@ def fn( try: return next(query_recursively(fn, sample))[1] except StopIteration: - raise TypeError("No image was found in the sample") + raise TypeError("No image was found in the sample") from None def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: