Skip to content

Commit c3573c8

Browse files
authored
port RandomShortestSize from detection references to prototype transforms (#6418)
* port `RandomShortestSize` from detection references to prototype transforms * mypy * add test
1 parent c0ba3ec commit c3573c8

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

test/test_prototype_transforms.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,3 +1164,48 @@ def test__transform(self, mocker):
11641164
transform(inpt_sentinel)
11651165

11661166
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)
1167+
1168+
1169+
class TestRandomShortestSize:
1170+
def test__get_params(self, mocker):
1171+
image_size = (3, 10)
1172+
min_size = [5, 9]
1173+
max_size = 20
1174+
1175+
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size)
1176+
1177+
sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size)
1178+
params = transform._get_params(sample)
1179+
1180+
assert "size" in params
1181+
size = params["size"]
1182+
1183+
assert isinstance(size, tuple) and len(size) == 2
1184+
1185+
longer = max(size)
1186+
assert longer <= max_size
1187+
1188+
shorter = min(size)
1189+
if longer == max_size:
1190+
assert shorter <= max_size
1191+
else:
1192+
assert shorter in min_size
1193+
1194+
def test__transform(self, mocker):
1195+
interpolation_sentinel = mocker.MagicMock()
1196+
1197+
transform = transforms.RandomShortestSize(min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel)
1198+
transform._transformed_types = (mocker.MagicMock,)
1199+
1200+
size_sentinel = mocker.MagicMock()
1201+
mocker.patch(
1202+
"torchvision.prototype.transforms._geometry.RandomShortestSize._get_params",
1203+
return_value=dict(size=size_sentinel),
1204+
)
1205+
1206+
inpt_sentinel = mocker.MagicMock()
1207+
1208+
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
1209+
transform(inpt_sentinel)
1210+
1211+
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
RandomPerspective,
2828
RandomResizedCrop,
2929
RandomRotation,
30+
RandomShortestSize,
3031
RandomVerticalFlip,
3132
RandomZoomOut,
3233
Resize,

torchvision/prototype/transforms/_geometry.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,3 +644,31 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
644644

645645
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
646646
return F.resize(inpt, size=params["size"], interpolation=self.interpolation)
647+
648+
649+
class RandomShortestSize(Transform):
650+
def __init__(
651+
self,
652+
min_size: Union[List[int], Tuple[int], int],
653+
max_size: int,
654+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
655+
):
656+
super().__init__()
657+
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
658+
self.max_size = max_size
659+
self.interpolation = interpolation
660+
661+
def _get_params(self, sample: Any) -> Dict[str, Any]:
662+
image = query_image(sample)
663+
_, orig_height, orig_width = get_image_dimensions(image)
664+
665+
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
666+
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
667+
668+
new_width = int(orig_width * r)
669+
new_height = int(orig_height * r)
670+
671+
return dict(size=(new_height, new_width))
672+
673+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
674+
return F.resize(inpt, size=params["size"], interpolation=self.interpolation)

0 commit comments

Comments
 (0)