7
7
8
8
from . import functional as F , InterpolationMode
9
9
10
- __all__ = ["AutoAugmentPolicy" , "AutoAugment" , "AugmentationSpace" , "TrivialAugment " ]
10
+ __all__ = ["AutoAugmentPolicy" , "AutoAugment" , "TrivialAugmentWide " ]
11
11
12
12
13
13
def _apply_op (img : Tensor , op_name : str , magnitude : float ,
@@ -178,8 +178,7 @@ def _get_transforms(
178
178
else :
179
179
raise ValueError ("The provided policy {} is not recognized." .format (policy ))
180
180
181
- @staticmethod
182
- def _get_magnitudes (num_bins : int , image_size : List [int ]) -> Dict [str , Tuple [Tensor , bool ]]:
181
+ def _get_magnitudes (self , num_bins : int , image_size : List [int ]) -> Dict [str , Tuple [Tensor , bool ]]:
183
182
return {
184
183
# name: (magnitudes, signed)
185
184
"ShearX" : (torch .linspace (0.0 , 0.3 , num_bins ), True ),
@@ -243,24 +242,14 @@ def __repr__(self) -> str:
243
242
return self .__class__ .__name__ + '(policy={}, fill={})' .format (self .policy , self .fill )
244
243
245
244
246
- class AugmentationSpace (Enum ):
247
- """The augmentation space to use.
248
- Available spaces are `AA` for AutoAugment and `TA_WIDE` for the TrivialAugment.
249
- """
250
- AA = "aa"
251
- TA_WIDE = "ta_wide"
252
-
253
-
254
- class TrivialAugment (torch .nn .Module ):
255
- r"""Dataset-independent data-augmentation with TrivialAugment, as described in
245
+ class TrivialAugmentWide (torch .nn .Module ):
246
+ r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
256
247
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
257
248
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
258
249
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
259
250
If img is PIL Image, it is expected to be in mode "L" or "RGB".
260
251
261
252
Args:
262
- augmentation_space (AugmentationSpace): Desired augmentation space enum defined by
263
- :class:`torchvision.transforms.autoaugment.AugmentationSpace`. Default is ``AugmentationSpace.TA_WIDE``.
264
253
num_magnitude_bins (int): The number of different magnitude values.
265
254
interpolation (InterpolationMode): Desired interpolation enum defined by
266
255
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
@@ -269,17 +258,14 @@ class TrivialAugment(torch.nn.Module):
269
258
image. If given a number, the value is used for all bands respectively.
270
259
"""
271
260
272
- def __init__ (self , augmentation_space : AugmentationSpace = AugmentationSpace .TA_WIDE , num_magnitude_bins : int = 30 ,
273
- interpolation : InterpolationMode = InterpolationMode .NEAREST ,
261
+ def __init__ (self , num_magnitude_bins : int = 30 , interpolation : InterpolationMode = InterpolationMode .NEAREST ,
274
262
fill : Optional [List [float ]] = None ) -> None :
275
263
super ().__init__ ()
276
- self .augmentation_space = augmentation_space
277
264
self .num_magnitude_bins = num_magnitude_bins
278
265
self .interpolation = interpolation
279
266
self .fill = fill
280
267
281
- @staticmethod
282
- def _get_magnitudes (num_bins : int ) -> Dict [str , Tuple [Tensor , bool ]]:
268
+ def _get_magnitudes (self , num_bins : int ) -> Dict [str , Tuple [Tensor , bool ]]:
283
269
return {
284
270
# name: (magnitudes, signed)
285
271
"ShearX" : (torch .linspace (0.0 , 0.99 , num_bins ), True ),
@@ -303,7 +289,7 @@ def forward(self, img: Tensor):
303
289
img (PIL Image or Tensor): Image to be transformed.
304
290
305
291
Returns:
306
- PIL Image or Tensor: TrivialAugmented image.
292
+ PIL Image or Tensor: Transformed image.
307
293
"""
308
294
fill = self .fill
309
295
if isinstance (img , Tensor ):
@@ -312,12 +298,7 @@ def forward(self, img: Tensor):
312
298
elif fill is not None :
313
299
fill = [float (f ) for f in fill ]
314
300
315
- if self .augmentation_space == AugmentationSpace .AA :
316
- op_meta = AutoAugment ._get_magnitudes (self .num_magnitude_bins , F .get_image_size (img ))
317
- elif self .augmentation_space == AugmentationSpace .TA_WIDE :
318
- op_meta = self ._get_magnitudes (self .num_magnitude_bins )
319
- else :
320
- raise ValueError (f"Provided augmentation_space arguments { self .augmentation_space } not available." )
301
+ op_meta = self ._get_magnitudes (self .num_magnitude_bins )
321
302
op_index = int (torch .randint (len (op_meta ), (1 ,)).item ())
322
303
op_name = list (op_meta .keys ())[op_index ]
323
304
magnitudes , signed = op_meta [op_name ]
@@ -330,8 +311,7 @@ def forward(self, img: Tensor):
330
311
331
312
def __repr__ (self ) -> str :
332
313
s = self .__class__ .__name__ + '('
333
- s += 'augmentation_space={augmentation_space}'
334
- s += ', num_magnitude_bins={num_magnitude_bins}'
314
+ s += 'num_magnitude_bins={num_magnitude_bins}'
335
315
s += ', interpolation={interpolation}'
336
316
s += ', fill={fill}'
337
317
s += ')'
0 commit comments