5
5
from collections .abc import Sequence
6
6
from typing import Tuple , List , Optional
7
7
8
- import numpy as np
9
8
import torch
10
9
from PIL import Image
11
10
from torch import Tensor
@@ -721,9 +720,9 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat
721
720
raise ValueError ("Please provide only two dimensions (h, w) for size." )
722
721
self .size = size
723
722
724
- if not isinstance (scale , ( tuple , list ) ):
723
+ if not isinstance (scale , Sequence ):
725
724
raise TypeError ("Scale should be a sequence" )
726
- if not isinstance (ratio , ( tuple , list ) ):
725
+ if not isinstance (ratio , Sequence ):
727
726
raise TypeError ("Ratio should be a sequence" )
728
727
if (scale [0 ] > scale [1 ]) or (ratio [0 ] > ratio [1 ]):
729
728
warnings .warn ("Scale and ratio should be of kind (min, max)" )
@@ -734,14 +733,14 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat
734
733
735
734
@staticmethod
736
735
def get_params (
737
- img : Tensor , scale : Tuple [float , float ], ratio : Tuple [ float , float ]
736
+ img : Tensor , scale : List [float ], ratio : List [ float ]
738
737
) -> Tuple [int , int , int , int ]:
739
738
"""Get parameters for ``crop`` for a random sized crop.
740
739
741
740
Args:
742
741
img (PIL Image or Tensor): Input image.
743
- scale (tuple ): range of scale of the origin size cropped
744
- ratio (tuple ): range of aspect ratio of the origin aspect ratio cropped
742
+ scale (list ): range of scale of the origin size cropped
743
+ ratio (list ): range of aspect ratio of the origin aspect ratio cropped
745
744
746
745
Returns:
747
746
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
@@ -751,7 +750,7 @@ def get_params(
751
750
area = height * width
752
751
753
752
for _ in range (10 ):
754
- target_area = area * torch .empty (1 ).uniform_ (* scale ).item ()
753
+ target_area = area * torch .empty (1 ).uniform_ (scale [ 0 ], scale [ 1 ] ).item ()
755
754
log_ratio = torch .log (torch .tensor (ratio ))
756
755
aspect_ratio = torch .exp (
757
756
torch .empty (1 ).uniform_ (log_ratio [0 ], log_ratio [1 ])
@@ -1173,8 +1172,10 @@ def __repr__(self):
1173
1172
return format_string
1174
1173
1175
1174
1176
- class RandomAffine (object ):
1177
- """Random affine transformation of the image keeping center invariant
1175
+ class RandomAffine (torch .nn .Module ):
1176
+ """Random affine transformation of the image keeping center invariant.
1177
+ The image can be a PIL Image or a Tensor, in which case it is expected
1178
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1178
1179
1179
1180
Args:
1180
1181
degrees (sequence or float or int): Range of degrees to select from.
@@ -1188,41 +1189,51 @@ class RandomAffine(object):
1188
1189
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
1189
1190
shear (sequence or float or int, optional): Range of degrees to select from.
1190
1191
If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
1191
- will be apllied . Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
1192
+ will be applied . Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
1192
1193
range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
1193
1194
a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
1194
- Will not apply shear by default
1195
- resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
1196
- An optional resampling filter. See `filters`_ for more information.
1197
- If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
1198
- fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area
1199
- outside the transform in the output image.(Pillow>=5.0.0)
1195
+ Will not apply shear by default.
1196
+ resample (int, optional): An optional resampling filter. See `filters`_ for more information.
1197
+ If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
1198
+ If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
1199
+ fillcolor (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area
1200
+ outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor
1201
+ input. Fill value for the area outside the transform in the output image is always 0.
1200
1202
1201
1203
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
1202
1204
1203
1205
"""
1204
1206
1205
- def __init__ (self , degrees , translate = None , scale = None , shear = None , resample = False , fillcolor = 0 ):
1207
+ def __init__ (self , degrees , translate = None , scale = None , shear = None , resample = 0 , fillcolor = 0 ):
1208
+ super ().__init__ ()
1206
1209
if isinstance (degrees , numbers .Number ):
1207
1210
if degrees < 0 :
1208
1211
raise ValueError ("If degrees is a single number, it must be positive." )
1209
- self . degrees = ( - degrees , degrees )
1212
+ degrees = [ - degrees , degrees ]
1210
1213
else :
1211
- assert isinstance (degrees , (tuple , list )) and len (degrees ) == 2 , \
1212
- "degrees should be a list or tuple and it must be of length 2."
1213
- self .degrees = degrees
1214
+ if not isinstance (degrees , Sequence ):
1215
+ raise TypeError ("degrees should be a sequence of length 2." )
1216
+ if len (degrees ) != 2 :
1217
+ raise ValueError ("degrees should be sequence of length 2." )
1218
+
1219
+ self .degrees = [float (d ) for d in degrees ]
1214
1220
1215
1221
if translate is not None :
1216
- assert isinstance (translate , (tuple , list )) and len (translate ) == 2 , \
1217
- "translate should be a list or tuple and it must be of length 2."
1222
+ if not isinstance (translate , Sequence ):
1223
+ raise TypeError ("translate should be a sequence of length 2." )
1224
+ if len (translate ) != 2 :
1225
+ raise ValueError ("translate should be sequence of length 2." )
1218
1226
for t in translate :
1219
1227
if not (0.0 <= t <= 1.0 ):
1220
1228
raise ValueError ("translation values should be between 0 and 1" )
1221
1229
self .translate = translate
1222
1230
1223
1231
if scale is not None :
1224
- assert isinstance (scale , (tuple , list )) and len (scale ) == 2 , \
1225
- "scale should be a list or tuple and it must be of length 2."
1232
+ if not isinstance (scale , Sequence ):
1233
+ raise TypeError ("scale should be a sequence of length 2." )
1234
+ if len (scale ) != 2 :
1235
+ raise ValueError ("scale should be sequence of length 2." )
1236
+
1226
1237
for s in scale :
1227
1238
if s <= 0 :
1228
1239
raise ValueError ("scale values should be positive" )
@@ -1232,62 +1243,69 @@ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Fal
1232
1243
if isinstance (shear , numbers .Number ):
1233
1244
if shear < 0 :
1234
1245
raise ValueError ("If shear is a single number, it must be positive." )
1235
- self . shear = ( - shear , shear )
1246
+ shear = [ - shear , shear ]
1236
1247
else :
1237
- assert isinstance (shear , (tuple , list )) and \
1238
- (len (shear ) == 2 or len (shear ) == 4 ), \
1239
- "shear should be a list or tuple and it must be of length 2 or 4."
1240
- # X-Axis shear with [min, max]
1241
- if len (shear ) == 2 :
1242
- self .shear = [shear [0 ], shear [1 ], 0. , 0. ]
1243
- elif len (shear ) == 4 :
1244
- self .shear = [s for s in shear ]
1248
+ if not isinstance (shear , Sequence ):
1249
+ raise TypeError ("shear should be a sequence of length 2 or 4." )
1250
+ if len (shear ) not in (2 , 4 ):
1251
+ raise ValueError ("shear should be sequence of length 2 or 4." )
1252
+
1253
+ self .shear = [float (s ) for s in shear ]
1245
1254
else :
1246
1255
self .shear = shear
1247
1256
1248
1257
self .resample = resample
1249
1258
self .fillcolor = fillcolor
1250
1259
1251
1260
@staticmethod
1252
- def get_params (degrees , translate , scale_ranges , shears , img_size ):
1261
+ def get_params (
1262
+ degrees : List [float ],
1263
+ translate : Optional [List [float ]],
1264
+ scale_ranges : Optional [List [float ]],
1265
+ shears : Optional [List [float ]],
1266
+ img_size : List [int ]
1267
+ ) -> Tuple [float , Tuple [int , int ], float , Tuple [float , float ]]:
1253
1268
"""Get parameters for affine transformation
1254
1269
1255
1270
Returns:
1256
- sequence: params to be passed to the affine transformation
1271
+ params to be passed to the affine transformation
1257
1272
"""
1258
- angle = random . uniform ( degrees [0 ], degrees [1 ])
1273
+ angle = float ( torch . empty ( 1 ). uniform_ ( float ( degrees [0 ]), float ( degrees [1 ])). item () )
1259
1274
if translate is not None :
1260
- max_dx = translate [0 ] * img_size [0 ]
1261
- max_dy = translate [1 ] * img_size [1 ]
1262
- translations = (np .round (random .uniform (- max_dx , max_dx )),
1263
- np .round (random .uniform (- max_dy , max_dy )))
1275
+ max_dx = float (translate [0 ] * img_size [0 ])
1276
+ max_dy = float (translate [1 ] * img_size [1 ])
1277
+ tx = int (round (torch .empty (1 ).uniform_ (- max_dx , max_dx ).item ()))
1278
+ ty = int (round (torch .empty (1 ).uniform_ (- max_dy , max_dy ).item ()))
1279
+ translations = (tx , ty )
1264
1280
else :
1265
1281
translations = (0 , 0 )
1266
1282
1267
1283
if scale_ranges is not None :
1268
- scale = random . uniform ( scale_ranges [0 ], scale_ranges [1 ])
1284
+ scale = float ( torch . empty ( 1 ). uniform_ ( scale_ranges [0 ], scale_ranges [1 ]). item () )
1269
1285
else :
1270
1286
scale = 1.0
1271
1287
1288
+ shear_x = shear_y = 0.0
1272
1289
if shears is not None :
1273
- if len (shears ) == 2 :
1274
- shear = [random .uniform (shears [0 ], shears [1 ]), 0. ]
1275
- elif len (shears ) == 4 :
1276
- shear = [random .uniform (shears [0 ], shears [1 ]),
1277
- random .uniform (shears [2 ], shears [3 ])]
1278
- else :
1279
- shear = 0.0
1290
+ shear_x = float (torch .empty (1 ).uniform_ (shears [0 ], shears [1 ]).item ())
1291
+ if len (shears ) == 4 :
1292
+ shear_y = float (torch .empty (1 ).uniform_ (shears [2 ], shears [3 ]).item ())
1293
+
1294
+ shear = (shear_x , shear_y )
1280
1295
1281
1296
return angle , translations , scale , shear
1282
1297
1283
- def __call__ (self , img ):
1298
+ def forward (self , img ):
1284
1299
"""
1285
- img (PIL Image): Image to be transformed.
1300
+ img (PIL Image or Tensor ): Image to be transformed.
1286
1301
1287
1302
Returns:
1288
- PIL Image: Affine transformed image.
1303
+ PIL Image or Tensor : Affine transformed image.
1289
1304
"""
1290
- ret = self .get_params (self .degrees , self .translate , self .scale , self .shear , img .size )
1305
+
1306
+ img_size = F ._get_image_size (img )
1307
+
1308
+ ret = self .get_params (self .degrees , self .translate , self .scale , self .shear , img_size )
1291
1309
return F .affine (img , * ret , resample = self .resample , fillcolor = self .fillcolor )
1292
1310
1293
1311
def __repr__ (self ):
0 commit comments