diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index d61004c61a1..30c68e118e1 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1164,3 +1164,48 @@ def test__transform(self, mocker): transform(inpt_sentinel) mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) + + +class TestRandomShortestSize: + def test__get_params(self, mocker): + image_size = (3, 10) + min_size = [5, 9] + max_size = 20 + + transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size) + + sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size) + params = transform._get_params(sample) + + assert "size" in params + size = params["size"] + + assert isinstance(size, tuple) and len(size) == 2 + + longer = max(size) + assert longer <= max_size + + shorter = min(size) + if longer == max_size: + assert shorter <= max_size + else: + assert shorter in min_size + + def test__transform(self, mocker): + interpolation_sentinel = mocker.MagicMock() + + transform = transforms.RandomShortestSize(min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel) + transform._transformed_types = (mocker.MagicMock,) + + size_sentinel = mocker.MagicMock() + mocker.patch( + "torchvision.prototype.transforms._geometry.RandomShortestSize._get_params", + return_value=dict(size=size_sentinel), + ) + + inpt_sentinel = mocker.MagicMock() + + mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") + transform(inpt_sentinel) + + mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index fb7aa7015fd..dc6476ab4b5 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -27,6 +27,7 @@ RandomPerspective, RandomResizedCrop, RandomRotation, + RandomShortestSize, RandomVerticalFlip, RandomZoomOut, Resize, diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 23fb311a73b..303f4502b04 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -644,3 +644,31 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.resize(inpt, size=params["size"], interpolation=self.interpolation) + + +class RandomShortestSize(Transform): + def __init__( + self, + min_size: Union[List[int], Tuple[int], int], + max_size: int, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ): + super().__init__() + self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) + self.max_size = max_size + self.interpolation = interpolation + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + _, orig_height, orig_width = get_image_dimensions(image) + + min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] + r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) + + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + return dict(size=(new_height, new_width)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize(inpt, size=params["size"], interpolation=self.interpolation)