3
3
import random
4
4
import warnings
5
5
from collections .abc import Sequence , Iterable
6
- from typing import Tuple
6
+ from typing import Tuple , List , Optional
7
7
8
8
import numpy as np
9
9
import torch
@@ -1343,7 +1343,7 @@ def __repr__(self):
1343
1343
return self .__class__ .__name__ + '(p={0})' .format (self .p )
1344
1344
1345
1345
1346
- class RandomErasing (object ):
1346
+ class RandomErasing (torch . nn . Module ):
1347
1347
""" Randomly selects a rectangle region in an image and erases its pixels.
1348
1348
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf
1349
1349
@@ -1370,13 +1370,21 @@ class RandomErasing(object):
1370
1370
"""
1371
1371
1372
1372
def __init__ (self , p = 0.5 , scale = (0.02 , 0.33 ), ratio = (0.3 , 3.3 ), value = 0 , inplace = False ):
1373
- assert isinstance (value , (numbers .Number , str , tuple , list ))
1373
+ super ().__init__ ()
1374
+ if not isinstance (value , (numbers .Number , str , tuple , list )):
1375
+ raise TypeError ("Argument value should be either a number or str or a sequence" )
1376
+ if isinstance (value , str ) and value != "random" :
1377
+ raise ValueError ("If value is str, it should be 'random'" )
1378
+ if not isinstance (scale , (tuple , list )):
1379
+ raise TypeError ("Scale should be a sequence" )
1380
+ if not isinstance (ratio , (tuple , list )):
1381
+ raise TypeError ("Ratio should be a sequence" )
1374
1382
if (scale [0 ] > scale [1 ]) or (ratio [0 ] > ratio [1 ]):
1375
- warnings .warn ("range should be of kind (min, max)" )
1383
+ warnings .warn ("Scale and ratio should be of kind (min, max)" )
1376
1384
if scale [0 ] < 0 or scale [1 ] > 1 :
1377
- raise ValueError ("range of scale should be between 0 and 1" )
1385
+ raise ValueError ("Scale should be between 0 and 1" )
1378
1386
if p < 0 or p > 1 :
1379
- raise ValueError ("range of random erasing probability should be between 0 and 1" )
1387
+ raise ValueError ("Random erasing probability should be between 0 and 1" )
1380
1388
1381
1389
self .p = p
1382
1390
self .scale = scale
@@ -1385,13 +1393,18 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace
1385
1393
self .inplace = inplace
1386
1394
1387
1395
@staticmethod
1388
- def get_params (img , scale , ratio , value = 0 ):
1396
+ def get_params (
1397
+ img : Tensor , scale : Tuple [float , float ], ratio : Tuple [float , float ], value : Optional [List [float ]] = None
1398
+ ) -> Tuple [int , int , int , int , Tensor ]:
1389
1399
"""Get parameters for ``erase`` for a random erasing.
1390
1400
1391
1401
Args:
1392
1402
img (Tensor): Tensor image of size (C, H, W) to be erased.
1393
- scale: range of proportion of erased area against input image.
1394
- ratio: range of aspect ratio of erased area.
1403
+ scale (tuple or list): range of proportion of erased area against input image.
1404
+ ratio (tuple or list): range of aspect ratio of erased area.
1405
+ value (list, optional): erasing value. If None, it is interpreted as "random"
1406
+ (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
1407
+ i.e. ``value[0]``.
1395
1408
1396
1409
Returns:
1397
1410
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
@@ -1400,35 +1413,52 @@ def get_params(img, scale, ratio, value=0):
1400
1413
area = img_h * img_w
1401
1414
1402
1415
for _ in range (10 ):
1403
- erase_area = random . uniform ( scale [0 ], scale [1 ]) * area
1404
- aspect_ratio = random . uniform ( ratio [0 ], ratio [1 ])
1416
+ erase_area = area * torch . empty ( 1 ). uniform_ ( scale [0 ], scale [1 ]). item ()
1417
+ aspect_ratio = torch . empty ( 1 ). uniform_ ( ratio [0 ], ratio [1 ]). item ( )
1405
1418
1406
1419
h = int (round (math .sqrt (erase_area * aspect_ratio )))
1407
1420
w = int (round (math .sqrt (erase_area / aspect_ratio )))
1421
+ if not (h < img_h and w < img_w ):
1422
+ continue
1423
+
1424
+ if value is None :
1425
+ v = torch .empty ([img_c , h , w ], dtype = torch .float32 ).normal_ ()
1426
+ else :
1427
+ v = torch .tensor (value )[:, None , None ]
1408
1428
1409
- if h < img_h and w < img_w :
1410
- i = random .randint (0 , img_h - h )
1411
- j = random .randint (0 , img_w - w )
1412
- if isinstance (value , numbers .Number ):
1413
- v = value
1414
- elif isinstance (value , torch ._six .string_classes ):
1415
- v = torch .empty ([img_c , h , w ], dtype = torch .float32 ).normal_ ()
1416
- elif isinstance (value , (list , tuple )):
1417
- v = torch .tensor (value , dtype = torch .float32 ).view (- 1 , 1 , 1 ).expand (- 1 , h , w )
1418
- return i , j , h , w , v
1429
+ i = torch .randint (0 , img_h - h , size = (1 , )).item ()
1430
+ j = torch .randint (0 , img_w - w , size = (1 , )).item ()
1431
+ return i , j , h , w , v
1419
1432
1420
1433
# Return original image
1421
1434
return 0 , 0 , img_h , img_w , img
1422
1435
1423
- def __call__ (self , img ):
1436
+ def forward (self , img ):
1424
1437
"""
1425
1438
Args:
1426
1439
img (Tensor): Tensor image of size (C, H, W) to be erased.
1427
1440
1428
1441
Returns:
1429
1442
img (Tensor): Erased Tensor image.
1430
1443
"""
1431
- if random .uniform (0 , 1 ) < self .p :
1432
- x , y , h , w , v = self .get_params (img , scale = self .scale , ratio = self .ratio , value = self .value )
1444
+ if torch .rand (1 ) < self .p :
1445
+
1446
+ # cast self.value to script acceptable type
1447
+ if isinstance (self .value , (int , float )):
1448
+ value = [self .value , ]
1449
+ elif isinstance (self .value , str ):
1450
+ value = None
1451
+ elif isinstance (self .value , tuple ):
1452
+ value = list (self .value )
1453
+ else :
1454
+ value = self .value
1455
+
1456
+ if value is not None and not (len (value ) in (1 , img .shape [- 3 ])):
1457
+ raise ValueError (
1458
+ "If value is a sequence, it should have either a single value or "
1459
+ "{} (number of input channels)" .format (img .shape [- 3 ])
1460
+ )
1461
+
1462
+ x , y , h , w , v = self .get_params (img , scale = self .scale , ratio = self .ratio , value = value )
1433
1463
return F .erase (img , x , y , h , w , v , self .inplace )
1434
1464
return img
0 commit comments