7
7
8
8
from . import functional as F , InterpolationMode
9
9
10
- __all__ = ["AutoAugmentPolicy" , "AutoAugment" , "TrivialAugmentWide" ]
10
+ __all__ = ["AutoAugmentPolicy" , "AutoAugment" , "RandAugment" , " TrivialAugmentWide" ]
11
11
12
12
13
13
def _apply_op (img : Tensor , op_name : str , magnitude : float ,
@@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum):
58
58
SVHN = "svhn"
59
59
60
60
61
+ # FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
61
62
class AutoAugment (torch .nn .Module ):
62
63
r"""AutoAugment data augmentation method based on
63
64
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
@@ -85,9 +86,9 @@ def __init__(
85
86
self .policy = policy
86
87
self .interpolation = interpolation
87
88
self .fill = fill
88
- self .transforms = self ._get_transforms (policy )
89
+ self .policies = self ._get_policies (policy )
89
90
90
- def _get_transforms (
91
+ def _get_policies (
91
92
self ,
92
93
policy : AutoAugmentPolicy
93
94
) -> List [Tuple [Tuple [str , float , Optional [int ]], Tuple [str , float , Optional [int ]]]]:
@@ -178,9 +179,9 @@ def _get_transforms(
178
179
else :
179
180
raise ValueError ("The provided policy {} is not recognized." .format (policy ))
180
181
181
- def _get_magnitudes (self , num_bins : int , image_size : List [int ]) -> Dict [str , Tuple [Tensor , bool ]]:
182
+ def _augmentation_space (self , num_bins : int , image_size : List [int ]) -> Dict [str , Tuple [Tensor , bool ]]:
182
183
return {
183
- # name : (magnitudes, signed)
184
+ # op_name : (magnitudes, signed)
184
185
"ShearX" : (torch .linspace (0.0 , 0.3 , num_bins ), True ),
185
186
"ShearY" : (torch .linspace (0.0 , 0.3 , num_bins ), True ),
186
187
"TranslateX" : (torch .linspace (0.0 , 150.0 / 331.0 * image_size [0 ], num_bins ), True ),
@@ -224,11 +225,11 @@ def forward(self, img: Tensor) -> Tensor:
224
225
elif fill is not None :
225
226
fill = [float (f ) for f in fill ]
226
227
227
- transform_id , probs , signs = self .get_params (len (self .transforms ))
228
+ transform_id , probs , signs = self .get_params (len (self .policies ))
228
229
229
- for i , (op_name , p , magnitude_id ) in enumerate (self .transforms [transform_id ]):
230
+ for i , (op_name , p , magnitude_id ) in enumerate (self .policies [transform_id ]):
230
231
if probs [i ] <= p :
231
- op_meta = self ._get_magnitudes (10 , F .get_image_size (img ))
232
+ op_meta = self ._augmentation_space (10 , F .get_image_size (img ))
232
233
magnitudes , signed = op_meta [op_name ]
233
234
magnitude = float (magnitudes [magnitude_id ].item ()) if magnitude_id is not None else 0.0
234
235
if signed and signs [i ] == 0 :
@@ -241,6 +242,91 @@ def __repr__(self) -> str:
241
242
return self .__class__ .__name__ + '(policy={}, fill={})' .format (self .policy , self .fill )
242
243
243
244
245
+ class RandAugment (torch .nn .Module ):
246
+ r"""RandAugment data augmentation method based on
247
+ `"RandAugment: Practical automated data augmentation with a reduced search space"
248
+ <https://arxiv.org/abs/1909.13719>`_.
249
+ If the image is torch Tensor, it should be of type torch.uint8, and it is expected
250
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
251
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
252
+
253
+ Args:
254
+ num_ops (int): Number of augmentation transformations to apply sequentially.
255
+ magnitude (int): Magnitude for all the transformations.
256
+ num_magnitude_bins (int): The number of different magnitude values.
257
+ interpolation (InterpolationMode): Desired interpolation enum defined by
258
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
259
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
260
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
261
+ image. If given a number, the value is used for all bands respectively.
262
+ """
263
+
264
+ def __init__ (self , num_ops : int = 2 , magnitude : int = 9 , num_magnitude_bins : int = 30 ,
265
+ interpolation : InterpolationMode = InterpolationMode .NEAREST ,
266
+ fill : Optional [List [float ]] = None ) -> None :
267
+ super ().__init__ ()
268
+ self .num_ops = num_ops
269
+ self .magnitude = magnitude
270
+ self .num_magnitude_bins = num_magnitude_bins
271
+ self .interpolation = interpolation
272
+ self .fill = fill
273
+
274
+ def _augmentation_space (self , num_bins : int , image_size : List [int ]) -> Dict [str , Tuple [Tensor , bool ]]:
275
+ return {
276
+ # op_name: (magnitudes, signed)
277
+ "ShearX" : (torch .linspace (0.0 , 0.3 , num_bins ), True ),
278
+ "ShearY" : (torch .linspace (0.0 , 0.3 , num_bins ), True ),
279
+ "TranslateX" : (torch .linspace (0.0 , 150.0 / 331.0 * image_size [0 ], num_bins ), True ),
280
+ "TranslateY" : (torch .linspace (0.0 , 150.0 / 331.0 * image_size [1 ], num_bins ), True ),
281
+ "Rotate" : (torch .linspace (0.0 , 30.0 , num_bins ), True ),
282
+ "Brightness" : (torch .linspace (0.0 , 0.9 , num_bins ), True ),
283
+ "Color" : (torch .linspace (0.0 , 0.9 , num_bins ), True ),
284
+ "Contrast" : (torch .linspace (0.0 , 0.9 , num_bins ), True ),
285
+ "Sharpness" : (torch .linspace (0.0 , 0.9 , num_bins ), True ),
286
+ "Posterize" : (8 - (torch .arange (num_bins ) / ((num_bins - 1 ) / 4 )).round ().int (), False ),
287
+ "Solarize" : (torch .linspace (256.0 , 0.0 , num_bins ), False ),
288
+ "AutoContrast" : (torch .tensor (0.0 ), False ),
289
+ "Equalize" : (torch .tensor (0.0 ), False ),
290
+ "Invert" : (torch .tensor (0.0 ), False ),
291
+ }
292
+
293
+ def forward (self , img : Tensor ) -> Tensor :
294
+ """
295
+ img (PIL Image or Tensor): Image to be transformed.
296
+
297
+ Returns:
298
+ PIL Image or Tensor: Transformed image.
299
+ """
300
+ fill = self .fill
301
+ if isinstance (img , Tensor ):
302
+ if isinstance (fill , (int , float )):
303
+ fill = [float (fill )] * F .get_image_num_channels (img )
304
+ elif fill is not None :
305
+ fill = [float (f ) for f in fill ]
306
+
307
+ for _ in range (self .num_ops ):
308
+ op_meta = self ._augmentation_space (self .num_magnitude_bins , F .get_image_size (img ))
309
+ op_index = int (torch .randint (len (op_meta ), (1 ,)).item ())
310
+ op_name = list (op_meta .keys ())[op_index ]
311
+ magnitudes , signed = op_meta [op_name ]
312
+ magnitude = float (magnitudes [self .magnitude ].item ()) if magnitudes .ndim > 0 else 0.0
313
+ if signed and torch .randint (2 , (1 ,)):
314
+ magnitude *= - 1.0
315
+ img = _apply_op (img , op_name , magnitude , interpolation = self .interpolation , fill = fill )
316
+
317
+ return img
318
+
319
+ def __repr__ (self ) -> str :
320
+ s = self .__class__ .__name__ + '('
321
+ s += 'num_ops={num_ops}'
322
+ s += ', magnitude={magnitude}'
323
+ s += ', num_magnitude_bins={num_magnitude_bins}'
324
+ s += ', interpolation={interpolation}'
325
+ s += ', fill={fill}'
326
+ s += ')'
327
+ return s .format (** self .__dict__ )
328
+
329
+
244
330
class TrivialAugmentWide (torch .nn .Module ):
245
331
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
246
332
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
@@ -264,9 +350,9 @@ def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMod
264
350
self .interpolation = interpolation
265
351
self .fill = fill
266
352
267
- def _get_magnitudes (self , num_bins : int ) -> Dict [str , Tuple [Tensor , bool ]]:
353
+ def _augmentation_space (self , num_bins : int ) -> Dict [str , Tuple [Tensor , bool ]]:
268
354
return {
269
- # name : (magnitudes, signed)
355
+ # op_name : (magnitudes, signed)
270
356
"ShearX" : (torch .linspace (0.0 , 0.99 , num_bins ), True ),
271
357
"ShearY" : (torch .linspace (0.0 , 0.99 , num_bins ), True ),
272
358
"TranslateX" : (torch .linspace (0.0 , 32.0 , num_bins ), True ),
@@ -283,7 +369,7 @@ def _get_magnitudes(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
283
369
"Invert" : (torch .tensor (0.0 ), False ),
284
370
}
285
371
286
- def forward (self , img : Tensor ):
372
+ def forward (self , img : Tensor ) -> Tensor :
287
373
"""
288
374
img (PIL Image or Tensor): Image to be transformed.
289
375
@@ -297,7 +383,7 @@ def forward(self, img: Tensor):
297
383
elif fill is not None :
298
384
fill = [float (f ) for f in fill ]
299
385
300
- op_meta = self ._get_magnitudes (self .num_magnitude_bins )
386
+ op_meta = self ._augmentation_space (self .num_magnitude_bins )
301
387
op_index = int (torch .randint (len (op_meta ), (1 ,)).item ())
302
388
op_name = list (op_meta .keys ())[op_index ]
303
389
magnitudes , signed = op_meta [op_name ]
0 commit comments