33import random
44import warnings
55from collections .abc import Sequence , Iterable
6- from typing import Tuple
6+ from typing import Tuple , List , Optional
77
88import numpy as np
99import torch
@@ -1343,7 +1343,7 @@ def __repr__(self):
13431343 return self .__class__ .__name__ + '(p={0})' .format (self .p )
13441344
13451345
1346- class RandomErasing (object ):
1346+ class RandomErasing (torch . nn . Module ):
13471347 """ Randomly selects a rectangle region in an image and erases its pixels.
13481348 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf
13491349
@@ -1370,13 +1370,21 @@ class RandomErasing(object):
13701370 """
13711371
13721372 def __init__ (self , p = 0.5 , scale = (0.02 , 0.33 ), ratio = (0.3 , 3.3 ), value = 0 , inplace = False ):
1373- assert isinstance (value , (numbers .Number , str , tuple , list ))
1373+ super ().__init__ ()
1374+ if not isinstance (value , (numbers .Number , str , tuple , list )):
1375+ raise TypeError ("Argument value should be either a number or str or a sequence" )
1376+ if isinstance (value , str ) and value != "random" :
1377+ raise ValueError ("If value is str, it should be 'random'" )
1378+ if not isinstance (scale , (tuple , list )):
1379+ raise TypeError ("Scale should be a sequence" )
1380+ if not isinstance (ratio , (tuple , list )):
1381+ raise TypeError ("Ratio should be a sequence" )
13741382 if (scale [0 ] > scale [1 ]) or (ratio [0 ] > ratio [1 ]):
1375- warnings .warn ("range should be of kind (min, max)" )
1383+ warnings .warn ("Scale and ratio should be of kind (min, max)" )
13761384 if scale [0 ] < 0 or scale [1 ] > 1 :
1377- raise ValueError ("range of scale should be between 0 and 1" )
1385+ raise ValueError ("Scale should be between 0 and 1" )
13781386 if p < 0 or p > 1 :
1379- raise ValueError ("range of random erasing probability should be between 0 and 1" )
1387+ raise ValueError ("Random erasing probability should be between 0 and 1" )
13801388
13811389 self .p = p
13821390 self .scale = scale
@@ -1385,13 +1393,18 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace
13851393 self .inplace = inplace
13861394
13871395 @staticmethod
1388- def get_params (img , scale , ratio , value = 0 ):
1396+ def get_params (
1397+ img : Tensor , scale : Tuple [float , float ], ratio : Tuple [float , float ], value : Optional [List [float ]] = None
1398+ ) -> Tuple [int , int , int , int , Tensor ]:
13891399 """Get parameters for ``erase`` for a random erasing.
13901400
13911401 Args:
13921402 img (Tensor): Tensor image of size (C, H, W) to be erased.
1393- scale: range of proportion of erased area against input image.
1394- ratio: range of aspect ratio of erased area.
1403+ scale (tuple or list): range of proportion of erased area against input image.
1404+ ratio (tuple or list): range of aspect ratio of erased area.
1405+ value (list, optional): erasing value. If None, it is interpreted as "random"
1406+ (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
1407+ i.e. ``value[0]``.
13951408
13961409 Returns:
13971410 tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
@@ -1400,35 +1413,52 @@ def get_params(img, scale, ratio, value=0):
14001413 area = img_h * img_w
14011414
14021415 for _ in range (10 ):
1403- erase_area = random . uniform ( scale [0 ], scale [1 ]) * area
1404- aspect_ratio = random . uniform ( ratio [0 ], ratio [1 ])
1416+ erase_area = area * torch . empty ( 1 ). uniform_ ( scale [0 ], scale [1 ]). item ()
1417+ aspect_ratio = torch . empty ( 1 ). uniform_ ( ratio [0 ], ratio [1 ]). item ( )
14051418
14061419 h = int (round (math .sqrt (erase_area * aspect_ratio )))
14071420 w = int (round (math .sqrt (erase_area / aspect_ratio )))
1421+ if not (h < img_h and w < img_w ):
1422+ continue
1423+
1424+ if value is None :
1425+ v = torch .empty ([img_c , h , w ], dtype = torch .float32 ).normal_ ()
1426+ else :
1427+ v = torch .tensor (value )[:, None , None ]
14081428
1409- if h < img_h and w < img_w :
1410- i = random .randint (0 , img_h - h )
1411- j = random .randint (0 , img_w - w )
1412- if isinstance (value , numbers .Number ):
1413- v = value
1414- elif isinstance (value , torch ._six .string_classes ):
1415- v = torch .empty ([img_c , h , w ], dtype = torch .float32 ).normal_ ()
1416- elif isinstance (value , (list , tuple )):
1417- v = torch .tensor (value , dtype = torch .float32 ).view (- 1 , 1 , 1 ).expand (- 1 , h , w )
1418- return i , j , h , w , v
1429+ i = torch .randint (0 , img_h - h , size = (1 , )).item ()
1430+ j = torch .randint (0 , img_w - w , size = (1 , )).item ()
1431+ return i , j , h , w , v
14191432
14201433 # Return original image
14211434 return 0 , 0 , img_h , img_w , img
14221435
1423- def __call__ (self , img ):
1436+ def forward (self , img ):
14241437 """
14251438 Args:
14261439 img (Tensor): Tensor image of size (C, H, W) to be erased.
14271440
14281441 Returns:
14291442 img (Tensor): Erased Tensor image.
14301443 """
1431- if random .uniform (0 , 1 ) < self .p :
1432- x , y , h , w , v = self .get_params (img , scale = self .scale , ratio = self .ratio , value = self .value )
1444+ if torch .rand (1 ) < self .p :
1445+
1446+ # cast self.value to script acceptable type
1447+ if isinstance (self .value , (int , float )):
1448+ value = [self .value , ]
1449+ elif isinstance (self .value , str ):
1450+ value = None
1451+ elif isinstance (self .value , tuple ):
1452+ value = list (self .value )
1453+ else :
1454+ value = self .value
1455+
1456+ if value is not None and not (len (value ) in (1 , img .shape [- 3 ])):
1457+ raise ValueError (
1458+ "If value is a sequence, it should have either a single value or "
1459+ "{} (number of input channels)" .format (img .shape [- 3 ])
1460+ )
1461+
1462+ x , y , h , w , v = self .get_params (img , scale = self .scale , ratio = self .ratio , value = value )
14331463 return F .erase (img , x , y , h , w , v , self .inplace )
14341464 return img
0 commit comments