Skip to content

Commit 9bd25d0

Browse files
authored
Unified inputs for T.RandomAffine transformation (2292) (#2478)
* [WIP] Unified input for T.RandomAffine * Unified inputs for T.RandomAffine transformation * Update transforms.py * Updated docs of F.affine fillcolor * Update transforms.py
1 parent 1aef87d commit 9bd25d0

File tree

3 files changed

+108
-69
lines changed

3 files changed

+108
-69
lines changed

test/test_transforms_tensor.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -248,21 +248,40 @@ def test_resize(self):
248248
def test_resized_crop(self):
249249
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8)
250250

251-
scale = (0.7, 1.2)
252-
ratio = (0.75, 1.333)
253-
254-
for size in [(32, ), [32, ], [32, 32], (32, 32)]:
255-
for interpolation in [NEAREST, BILINEAR, BICUBIC]:
256-
transform = T.RandomResizedCrop(
257-
size=size, scale=scale, ratio=ratio, interpolation=interpolation
258-
)
259-
s_transform = torch.jit.script(transform)
260-
261-
torch.manual_seed(12)
262-
out1 = transform(tensor)
263-
torch.manual_seed(12)
264-
out2 = s_transform(tensor)
265-
self.assertTrue(out1.equal(out2))
251+
for scale in [(0.7, 1.2), [0.7, 1.2]]:
252+
for ratio in [(0.75, 1.333), [0.75, 1.333]]:
253+
for size in [(32, ), [32, ], [32, 32], (32, 32)]:
254+
for interpolation in [NEAREST, BILINEAR, BICUBIC]:
255+
transform = T.RandomResizedCrop(
256+
size=size, scale=scale, ratio=ratio, interpolation=interpolation
257+
)
258+
s_transform = torch.jit.script(transform)
259+
260+
torch.manual_seed(12)
261+
out1 = transform(tensor)
262+
torch.manual_seed(12)
263+
out2 = s_transform(tensor)
264+
self.assertTrue(out1.equal(out2))
265+
266+
def test_random_affine(self):
267+
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8)
268+
269+
for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
270+
for scale in [(0.7, 1.2), [0.7, 1.2]]:
271+
for translate in [(0.1, 0.2), [0.2, 0.1]]:
272+
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
273+
for interpolation in [NEAREST, BILINEAR]:
274+
transform = T.RandomAffine(
275+
degrees=degrees, translate=translate,
276+
scale=scale, shear=shear, resample=interpolation
277+
)
278+
s_transform = torch.jit.script(transform)
279+
280+
torch.manual_seed(12)
281+
out1 = transform(tensor)
282+
torch.manual_seed(12)
283+
out2 = s_transform(tensor)
284+
self.assertTrue(out1.equal(out2))
266285

267286

268287
if __name__ == '__main__':

torchvision/transforms/functional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,9 @@ def affine(
858858
An optional resampling filter. See `filters`_ for more information.
859859
If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
860860
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
861-
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
861+
fillcolor (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0).
862+
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
863+
image is always 0.
862864
863865
Returns:
864866
PIL Image or Tensor: Transformed image.

torchvision/transforms/transforms.py

Lines changed: 71 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from collections.abc import Sequence
66
from typing import Tuple, List, Optional
77

8-
import numpy as np
98
import torch
109
from PIL import Image
1110
from torch import Tensor
@@ -721,9 +720,9 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat
721720
raise ValueError("Please provide only two dimensions (h, w) for size.")
722721
self.size = size
723722

724-
if not isinstance(scale, (tuple, list)):
723+
if not isinstance(scale, Sequence):
725724
raise TypeError("Scale should be a sequence")
726-
if not isinstance(ratio, (tuple, list)):
725+
if not isinstance(ratio, Sequence):
727726
raise TypeError("Ratio should be a sequence")
728727
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
729728
warnings.warn("Scale and ratio should be of kind (min, max)")
@@ -734,14 +733,14 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat
734733

735734
@staticmethod
736735
def get_params(
737-
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float]
736+
img: Tensor, scale: List[float], ratio: List[float]
738737
) -> Tuple[int, int, int, int]:
739738
"""Get parameters for ``crop`` for a random sized crop.
740739
741740
Args:
742741
img (PIL Image or Tensor): Input image.
743-
scale (tuple): range of scale of the origin size cropped
744-
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
742+
scale (list): range of scale of the origin size cropped
743+
ratio (list): range of aspect ratio of the origin aspect ratio cropped
745744
746745
Returns:
747746
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
@@ -751,7 +750,7 @@ def get_params(
751750
area = height * width
752751

753752
for _ in range(10):
754-
target_area = area * torch.empty(1).uniform_(*scale).item()
753+
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
755754
log_ratio = torch.log(torch.tensor(ratio))
756755
aspect_ratio = torch.exp(
757756
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
@@ -1173,8 +1172,10 @@ def __repr__(self):
11731172
return format_string
11741173

11751174

1176-
class RandomAffine(object):
1177-
"""Random affine transformation of the image keeping center invariant
1175+
class RandomAffine(torch.nn.Module):
1176+
"""Random affine transformation of the image keeping center invariant.
1177+
The image can be a PIL Image or a Tensor, in which case it is expected
1178+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
11781179
11791180
Args:
11801181
degrees (sequence or float or int): Range of degrees to select from.
@@ -1188,41 +1189,51 @@ class RandomAffine(object):
11881189
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
11891190
shear (sequence or float or int, optional): Range of degrees to select from.
11901191
If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
1191-
will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
1192+
will be applied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
11921193
range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
11931194
a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
1194-
Will not apply shear by default
1195-
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
1196-
An optional resampling filter. See `filters`_ for more information.
1197-
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
1198-
fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area
1199-
outside the transform in the output image.(Pillow>=5.0.0)
1195+
Will not apply shear by default.
1196+
resample (int, optional): An optional resampling filter. See `filters`_ for more information.
1197+
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
1198+
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
1199+
fillcolor (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area
1200+
outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor
1201+
input. Fill value for the area outside the transform in the output image is always 0.
12001202
12011203
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
12021204
12031205
"""
12041206

1205-
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
1207+
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=0, fillcolor=0):
1208+
super().__init__()
12061209
if isinstance(degrees, numbers.Number):
12071210
if degrees < 0:
12081211
raise ValueError("If degrees is a single number, it must be positive.")
1209-
self.degrees = (-degrees, degrees)
1212+
degrees = [-degrees, degrees]
12101213
else:
1211-
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
1212-
"degrees should be a list or tuple and it must be of length 2."
1213-
self.degrees = degrees
1214+
if not isinstance(degrees, Sequence):
1215+
raise TypeError("degrees should be a sequence of length 2.")
1216+
if len(degrees) != 2:
1217+
raise ValueError("degrees should be sequence of length 2.")
1218+
1219+
self.degrees = [float(d) for d in degrees]
12141220

12151221
if translate is not None:
1216-
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
1217-
"translate should be a list or tuple and it must be of length 2."
1222+
if not isinstance(translate, Sequence):
1223+
raise TypeError("translate should be a sequence of length 2.")
1224+
if len(translate) != 2:
1225+
raise ValueError("translate should be sequence of length 2.")
12181226
for t in translate:
12191227
if not (0.0 <= t <= 1.0):
12201228
raise ValueError("translation values should be between 0 and 1")
12211229
self.translate = translate
12221230

12231231
if scale is not None:
1224-
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
1225-
"scale should be a list or tuple and it must be of length 2."
1232+
if not isinstance(scale, Sequence):
1233+
raise TypeError("scale should be a sequence of length 2.")
1234+
if len(scale) != 2:
1235+
raise ValueError("scale should be sequence of length 2.")
1236+
12261237
for s in scale:
12271238
if s <= 0:
12281239
raise ValueError("scale values should be positive")
@@ -1232,62 +1243,69 @@ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Fal
12321243
if isinstance(shear, numbers.Number):
12331244
if shear < 0:
12341245
raise ValueError("If shear is a single number, it must be positive.")
1235-
self.shear = (-shear, shear)
1246+
shear = [-shear, shear]
12361247
else:
1237-
assert isinstance(shear, (tuple, list)) and \
1238-
(len(shear) == 2 or len(shear) == 4), \
1239-
"shear should be a list or tuple and it must be of length 2 or 4."
1240-
# X-Axis shear with [min, max]
1241-
if len(shear) == 2:
1242-
self.shear = [shear[0], shear[1], 0., 0.]
1243-
elif len(shear) == 4:
1244-
self.shear = [s for s in shear]
1248+
if not isinstance(shear, Sequence):
1249+
raise TypeError("shear should be a sequence of length 2 or 4.")
1250+
if len(shear) not in (2, 4):
1251+
raise ValueError("shear should be sequence of length 2 or 4.")
1252+
1253+
self.shear = [float(s) for s in shear]
12451254
else:
12461255
self.shear = shear
12471256

12481257
self.resample = resample
12491258
self.fillcolor = fillcolor
12501259

12511260
@staticmethod
1252-
def get_params(degrees, translate, scale_ranges, shears, img_size):
1261+
def get_params(
1262+
degrees: List[float],
1263+
translate: Optional[List[float]],
1264+
scale_ranges: Optional[List[float]],
1265+
shears: Optional[List[float]],
1266+
img_size: List[int]
1267+
) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
12531268
"""Get parameters for affine transformation
12541269
12551270
Returns:
1256-
sequence: params to be passed to the affine transformation
1271+
params to be passed to the affine transformation
12571272
"""
1258-
angle = random.uniform(degrees[0], degrees[1])
1273+
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
12591274
if translate is not None:
1260-
max_dx = translate[0] * img_size[0]
1261-
max_dy = translate[1] * img_size[1]
1262-
translations = (np.round(random.uniform(-max_dx, max_dx)),
1263-
np.round(random.uniform(-max_dy, max_dy)))
1275+
max_dx = float(translate[0] * img_size[0])
1276+
max_dy = float(translate[1] * img_size[1])
1277+
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
1278+
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
1279+
translations = (tx, ty)
12641280
else:
12651281
translations = (0, 0)
12661282

12671283
if scale_ranges is not None:
1268-
scale = random.uniform(scale_ranges[0], scale_ranges[1])
1284+
scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
12691285
else:
12701286
scale = 1.0
12711287

1288+
shear_x = shear_y = 0.0
12721289
if shears is not None:
1273-
if len(shears) == 2:
1274-
shear = [random.uniform(shears[0], shears[1]), 0.]
1275-
elif len(shears) == 4:
1276-
shear = [random.uniform(shears[0], shears[1]),
1277-
random.uniform(shears[2], shears[3])]
1278-
else:
1279-
shear = 0.0
1290+
shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
1291+
if len(shears) == 4:
1292+
shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
1293+
1294+
shear = (shear_x, shear_y)
12801295

12811296
return angle, translations, scale, shear
12821297

1283-
def __call__(self, img):
1298+
def forward(self, img):
12841299
"""
1285-
img (PIL Image): Image to be transformed.
1300+
img (PIL Image or Tensor): Image to be transformed.
12861301
12871302
Returns:
1288-
PIL Image: Affine transformed image.
1303+
PIL Image or Tensor: Affine transformed image.
12891304
"""
1290-
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
1305+
1306+
img_size = F._get_image_size(img)
1307+
1308+
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
12911309
return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)
12921310

12931311
def __repr__(self):

0 commit comments

Comments
 (0)