Skip to content

Commit 18914b4

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] port ScaleJitter from detection reference to prototype transforms (#6411)
Summary: * port ScaleJitter from detection reference to prototype transforms * add test * use MagicMock as sentinel Reviewed By: datumbox Differential Revision: D38824248 fbshipit-source-id: b9ee3cd23ebadf8e77f336a62d3184ece714ee1a
1 parent 0adf189 commit 18914b4

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

test/test_prototype_transforms.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,3 +1125,42 @@ def test_ctor(self, trfms):
11251125
inpt = torch.rand(1, 3, 32, 32)
11261126
output = c(inpt)
11271127
assert isinstance(output, torch.Tensor)
1128+
1129+
1130+
class TestScaleJitter:
1131+
def test__get_params(self, mocker):
1132+
image_size = (24, 32)
1133+
target_size = (16, 12)
1134+
scale_range = (0.5, 1.5)
1135+
1136+
transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
1137+
1138+
sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size)
1139+
params = transform._get_params(sample)
1140+
1141+
assert "size" in params
1142+
size = params["size"]
1143+
1144+
assert isinstance(size, tuple) and len(size) == 2
1145+
height, width = size
1146+
1147+
assert int(target_size[0] * scale_range[0]) <= height <= int(target_size[0] * scale_range[1])
1148+
assert int(target_size[1] * scale_range[0]) <= width <= int(target_size[1] * scale_range[1])
1149+
1150+
def test__transform(self, mocker):
1151+
interpolation_sentinel = mocker.MagicMock()
1152+
1153+
transform = transforms.ScaleJitter(target_size=(16, 12), interpolation=interpolation_sentinel)
1154+
transform._transformed_types = (mocker.MagicMock,)
1155+
1156+
size_sentinel = mocker.MagicMock()
1157+
mocker.patch(
1158+
"torchvision.prototype.transforms._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel)
1159+
)
1160+
1161+
inpt_sentinel = mocker.MagicMock()
1162+
1163+
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
1164+
transform(inpt_sentinel)
1165+
1166+
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
RandomVerticalFlip,
3131
RandomZoomOut,
3232
Resize,
33+
ScaleJitter,
3334
TenCrop,
3435
)
3536
from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype

torchvision/prototype/transforms/_geometry.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,3 +631,29 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
631631
fill=self.fill,
632632
interpolation=self.interpolation,
633633
)
634+
635+
636+
class ScaleJitter(Transform):
637+
def __init__(
638+
self,
639+
target_size: Tuple[int, int],
640+
scale_range: Tuple[float, float] = (0.1, 2.0),
641+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
642+
):
643+
super().__init__()
644+
self.target_size = target_size
645+
self.scale_range = scale_range
646+
self.interpolation = interpolation
647+
648+
def _get_params(self, sample: Any) -> Dict[str, Any]:
649+
image = query_image(sample)
650+
_, orig_height, orig_width = get_image_dimensions(image)
651+
652+
r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
653+
new_width = int(self.target_size[1] * r)
654+
new_height = int(self.target_size[0] * r)
655+
656+
return dict(size=(new_height, new_width))
657+
658+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
659+
return F.resize(inpt, size=params["size"], interpolation=self.interpolation)

0 commit comments

Comments
 (0)