Skip to content

Commit 08af5cb

Browse files
authored
Unified inputs for T.RandomRotation (#2496)
* Added code for F_t.rotate with test - updated F.affine tests * Rotate test tolerance to 2% * Fixes failing test * [WIP] RandomRotation * Unified RandomRotation with tests
1 parent 025b71d commit 08af5cb

File tree

3 files changed

+49
-18
lines changed

3 files changed

+49
-18
lines changed

test/test_transforms_tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,24 @@ def test_random_affine(self):
283283
out2 = s_transform(tensor)
284284
self.assertTrue(out1.equal(out2))
285285

286+
def test_random_rotate(self):
287+
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8)
288+
289+
for center in [(0, 0), [10, 10], None, (56, 44)]:
290+
for expand in [True, False]:
291+
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
292+
for interpolation in [NEAREST, BILINEAR]:
293+
transform = T.RandomRotation(
294+
degrees=degrees, resample=interpolation, expand=expand, center=center
295+
)
296+
s_transform = torch.jit.script(transform)
297+
298+
torch.manual_seed(12)
299+
out1 = transform(tensor)
300+
torch.manual_seed(12)
301+
out2 = s_transform(tensor)
302+
self.assertTrue(out1.equal(out2))
303+
286304

287305
if __name__ == '__main__':
288306
unittest.main()

torchvision/transforms/functional.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,8 @@ def rotate(
829829
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
830830
image. If int or float, the value is used for all bands respectively.
831831
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
832+
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
833+
image is always 0.
832834
833835
Returns:
834836
PIL Image or Tensor: Rotated image.

torchvision/transforms/transforms.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,68 +1102,79 @@ def __repr__(self):
11021102
return format_string
11031103

11041104

1105-
class RandomRotation(object):
1105+
class RandomRotation(torch.nn.Module):
11061106
"""Rotate the image by angle.
1107+
The image can be a PIL Image or a Tensor, in which case it is expected
1108+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
11071109
11081110
Args:
11091111
degrees (sequence or float or int): Range of degrees to select from.
11101112
If degrees is a number instead of sequence like (min, max), the range of degrees
11111113
will be (-degrees, +degrees).
1112-
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
1113-
An optional resampling filter. See `filters`_ for more information.
1114+
resample (int, optional): An optional resampling filter. See `filters`_ for more information.
11141115
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
1116+
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
11151117
expand (bool, optional): Optional expansion flag.
11161118
If true, expands the output to make it large enough to hold the entire rotated image.
11171119
If false or omitted, make the output image the same size as the input image.
11181120
Note that the expand flag assumes rotation around the center and no translation.
1119-
center (2-tuple, optional): Optional center of rotation.
1120-
Origin is the upper left corner.
1121+
center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
11211122
Default is the center of the image.
11221123
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
11231124
image. If int or float, the value is used for all bands respectively.
1124-
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
1125+
Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0.
1126+
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
1127+
image is always 0.
11251128
11261129
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
11271130
11281131
"""
11291132

11301133
def __init__(self, degrees, resample=False, expand=False, center=None, fill=None):
1134+
super().__init__()
11311135
if isinstance(degrees, numbers.Number):
11321136
if degrees < 0:
11331137
raise ValueError("If degrees is a single number, it must be positive.")
1134-
self.degrees = (-degrees, degrees)
1138+
degrees = [-degrees, degrees]
11351139
else:
1140+
if not isinstance(degrees, Sequence):
1141+
raise TypeError("degrees should be a sequence of length 2.")
11361142
if len(degrees) != 2:
11371143
raise ValueError("If degrees is a sequence, it must be of len 2.")
1138-
self.degrees = degrees
1144+
1145+
self.degrees = [float(d) for d in degrees]
1146+
1147+
if center is not None:
1148+
if not isinstance(center, Sequence):
1149+
raise TypeError("center should be a sequence of length 2.")
1150+
if len(center) != 2:
1151+
raise ValueError("center should be a sequence of length 2.")
1152+
1153+
self.center = center
11391154

11401155
self.resample = resample
11411156
self.expand = expand
1142-
self.center = center
11431157
self.fill = fill
11441158

11451159
@staticmethod
1146-
def get_params(degrees):
1160+
def get_params(degrees: List[float]) -> float:
11471161
"""Get parameters for ``rotate`` for a random rotation.
11481162
11491163
Returns:
1150-
sequence: params to be passed to ``rotate`` for random rotation.
1164+
float: angle parameter to be passed to ``rotate`` for random rotation.
11511165
"""
1152-
angle = random.uniform(degrees[0], degrees[1])
1153-
1166+
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
11541167
return angle
11551168

1156-
def __call__(self, img):
1169+
def forward(self, img):
11571170
"""
11581171
Args:
1159-
img (PIL Image): Image to be rotated.
1172+
img (PIL Image or Tensor): Image to be rotated.
11601173
11611174
Returns:
1162-
PIL Image: Rotated image.
1175+
PIL Image or Tensor: Rotated image.
11631176
"""
1164-
11651177
angle = self.get_params(self.degrees)
1166-
11671178
return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
11681179

11691180
def __repr__(self):

0 commit comments

Comments
 (0)