From bd8b506b1829ac66e41365bf9a51048190912e08 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 15 Aug 2022 05:32:06 +0200 Subject: [PATCH 1/3] port `RandomShortestSize` from detection references to prototype transforms --- torchvision/prototype/transforms/__init__.py | 1 + torchvision/prototype/transforms/_geometry.py | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5617c010e5f..c449321b7a3 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -27,6 +27,7 @@ RandomPerspective, RandomResizedCrop, RandomRotation, + RandomShortestSize, RandomVerticalFlip, RandomZoomOut, Resize, diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index aa1ca109cc4..b439aa091fb 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -611,3 +611,31 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill=self.fill, interpolation=self.interpolation, ) + + +class RandomShortestSize(Transform): + def __init__( + self, + min_size: Union[List[int], Tuple[int], int], + max_size: int, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ): + super().__init__() + self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) + self.max_size = max_size + self.interpolation = interpolation + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + _, orig_height, orig_width = get_image_dimensions(image) + + min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()] + r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) + + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + return dict(size=(new_height, new_width)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize(inpt, size=params["size"], interpolation=self.interpolation) From bc6a5ea3c935713569c82501eb9f54b3f14e8ac4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 15 Aug 2022 05:58:56 +0200 Subject: [PATCH 2/3] mypy --- torchvision/prototype/transforms/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index b439aa091fb..a3ba191504f 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -629,7 +629,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) _, orig_height, orig_width = get_image_dimensions(image) - min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()] + min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) new_width = int(orig_width * r) From b320e99460f9c819556ad98cefe32fe2a9cb7b97 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 16 Aug 2022 11:10:53 +0200 Subject: [PATCH 3/3] add test --- test/test_prototype_transforms.py | 45 +++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 28b21ebbaf6..1b951196dc9 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1164,3 +1164,48 @@ def test__transform(self, mocker): transform(inpt_sentinel) mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) + + +class TestRandomShortestSize: + def test__get_params(self, mocker): + image_size = (3, 10) + min_size = [5, 9] + max_size = 20 + + transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size) + + sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size) + params = transform._get_params(sample) + + assert "size" in params + size = params["size"] + + assert isinstance(size, tuple) and len(size) == 2 + + longer = max(size) + assert longer <= max_size + + shorter = min(size) + if longer == max_size: + assert shorter <= max_size + else: + assert shorter in min_size + + def test__transform(self, mocker): + interpolation_sentinel = mocker.MagicMock() + + transform = transforms.RandomShortestSize(min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel) + transform._transformed_types = (mocker.MagicMock,) + + size_sentinel = mocker.MagicMock() + mocker.patch( + "torchvision.prototype.transforms._geometry.RandomShortestSize._get_params", + return_value=dict(size=size_sentinel), + ) + + inpt_sentinel = mocker.MagicMock() + + mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") + transform(inpt_sentinel) + + mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)