Skip to content

Commit 75f5b57

Browse files
authored
[BC-breaking] RandomErasing is now scriptable (#2386)
* Related to #2292 - RandomErasing is not scriptable * Fixed code according to review comments - added additional checking of value vs img num_channels
1 parent e757d52 commit 75f5b57

File tree

3 files changed

+113
-57
lines changed

3 files changed

+113
-57
lines changed

test/test_transforms.py

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,38 +1618,64 @@ def test_random_grayscale(self):
16181618

16191619
def test_random_erasing(self):
16201620
"""Unit tests for random erasing transform"""
1621-
1622-
img = torch.rand([3, 60, 60])
1623-
1624-
# Test Set 1: Erasing with int value
1625-
img_re = transforms.RandomErasing(value=0.2)
1626-
i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value)
1627-
img_output = F.erase(img, i, j, h, w, v)
1628-
self.assertEqual(img_output.size(0), 3)
1629-
1630-
# Test Set 2: Check if the unerased region is preserved
1631-
orig_unerased = img.clone()
1632-
orig_unerased[:, i:i + h, j:j + w] = 0
1633-
output_unerased = img_output.clone()
1634-
output_unerased[:, i:i + h, j:j + w] = 0
1635-
self.assertTrue(torch.equal(orig_unerased, output_unerased))
1636-
1637-
# Test Set 3: Erasing with random value
1638-
img_re = transforms.RandomErasing(value='random')(img)
1639-
self.assertEqual(img_re.size(0), 3)
1640-
1641-
# Test Set 4: Erasing with tuple value
1642-
img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img)
1643-
self.assertEqual(img_re.size(0), 3)
1644-
1645-
# Test Set 5: Testing the inplace behaviour
1646-
img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
1647-
self.assertTrue(torch.equal(img_re, img))
1648-
1649-
# Test Set 6: Checking when no erased region is selected
1650-
img = torch.rand([3, 300, 1])
1651-
img_re = transforms.RandomErasing(ratio=(0.1, 0.2), value='random')(img)
1652-
self.assertTrue(torch.equal(img_re, img))
1621+
for is_scripted in [False, True]:
1622+
torch.manual_seed(12)
1623+
img = torch.rand(3, 60, 60)
1624+
1625+
# Test Set 0: invalid value
1626+
random_erasing = transforms.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
1627+
with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
1628+
img_re = random_erasing(img)
1629+
1630+
# Test Set 1: Erasing with int value
1631+
random_erasing = transforms.RandomErasing(value=0.2)
1632+
if is_scripted:
1633+
random_erasing = torch.jit.script(random_erasing)
1634+
1635+
i, j, h, w, v = transforms.RandomErasing.get_params(
1636+
img, scale=random_erasing.scale, ratio=random_erasing.ratio, value=[random_erasing.value, ]
1637+
)
1638+
img_output = F.erase(img, i, j, h, w, v)
1639+
self.assertEqual(img_output.size(0), 3)
1640+
1641+
# Test Set 2: Check if the unerased region is preserved
1642+
true_output = img.clone()
1643+
true_output[:, i:i + h, j:j + w] = random_erasing.value
1644+
self.assertTrue(torch.equal(true_output, img_output))
1645+
1646+
# Test Set 3: Erasing with random value
1647+
random_erasing = transforms.RandomErasing(value="random")
1648+
if is_scripted:
1649+
random_erasing = torch.jit.script(random_erasing)
1650+
img_re = random_erasing(img)
1651+
1652+
self.assertEqual(img_re.size(0), 3)
1653+
1654+
# Test Set 4: Erasing with tuple value
1655+
random_erasing = transforms.RandomErasing(value=(0.2, 0.2, 0.2))
1656+
if is_scripted:
1657+
random_erasing = torch.jit.script(random_erasing)
1658+
img_re = random_erasing(img)
1659+
self.assertEqual(img_re.size(0), 3)
1660+
true_output = img.clone()
1661+
true_output[:, i:i + h, j:j + w] = torch.tensor(random_erasing.value)[:, None, None]
1662+
self.assertTrue(torch.equal(true_output, img_output))
1663+
1664+
# Test Set 5: Testing the inplace behaviour
1665+
random_erasing = transforms.RandomErasing(value=(0.2,), inplace=True)
1666+
if is_scripted:
1667+
random_erasing = torch.jit.script(random_erasing)
1668+
1669+
img_re = random_erasing(img)
1670+
self.assertTrue(torch.equal(img_re, img))
1671+
1672+
# Test Set 6: Checking when no erased region is selected
1673+
img = torch.rand([3, 300, 1])
1674+
random_erasing = transforms.RandomErasing(ratio=(0.1, 0.2), value="random")
1675+
if is_scripted:
1676+
random_erasing = torch.jit.script(random_erasing)
1677+
img_re = random_erasing(img)
1678+
self.assertTrue(torch.equal(img_re, img))
16531679

16541680

16551681
if __name__ == '__main__':

torchvision/transforms/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ def to_grayscale(img, num_output_channels=1):
950950
return img
951951

952952

953-
def erase(img, i, j, h, w, v, inplace=False):
953+
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
954954
""" Erase the input Tensor Image with given value.
955955
956956
Args:

torchvision/transforms/transforms.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import random
44
import warnings
55
from collections.abc import Sequence, Iterable
6-
from typing import Tuple
6+
from typing import Tuple, List, Optional
77

88
import numpy as np
99
import torch
@@ -1343,7 +1343,7 @@ def __repr__(self):
13431343
return self.__class__.__name__ + '(p={0})'.format(self.p)
13441344

13451345

1346-
class RandomErasing(object):
1346+
class RandomErasing(torch.nn.Module):
13471347
""" Randomly selects a rectangle region in an image and erases its pixels.
13481348
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf
13491349
@@ -1370,13 +1370,21 @@ class RandomErasing(object):
13701370
"""
13711371

13721372
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
1373-
assert isinstance(value, (numbers.Number, str, tuple, list))
1373+
super().__init__()
1374+
if not isinstance(value, (numbers.Number, str, tuple, list)):
1375+
raise TypeError("Argument value should be either a number or str or a sequence")
1376+
if isinstance(value, str) and value != "random":
1377+
raise ValueError("If value is str, it should be 'random'")
1378+
if not isinstance(scale, (tuple, list)):
1379+
raise TypeError("Scale should be a sequence")
1380+
if not isinstance(ratio, (tuple, list)):
1381+
raise TypeError("Ratio should be a sequence")
13741382
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
1375-
warnings.warn("range should be of kind (min, max)")
1383+
warnings.warn("Scale and ratio should be of kind (min, max)")
13761384
if scale[0] < 0 or scale[1] > 1:
1377-
raise ValueError("range of scale should be between 0 and 1")
1385+
raise ValueError("Scale should be between 0 and 1")
13781386
if p < 0 or p > 1:
1379-
raise ValueError("range of random erasing probability should be between 0 and 1")
1387+
raise ValueError("Random erasing probability should be between 0 and 1")
13801388

13811389
self.p = p
13821390
self.scale = scale
@@ -1385,13 +1393,18 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace
13851393
self.inplace = inplace
13861394

13871395
@staticmethod
1388-
def get_params(img, scale, ratio, value=0):
1396+
def get_params(
1397+
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
1398+
) -> Tuple[int, int, int, int, Tensor]:
13891399
"""Get parameters for ``erase`` for a random erasing.
13901400
13911401
Args:
13921402
img (Tensor): Tensor image of size (C, H, W) to be erased.
1393-
scale: range of proportion of erased area against input image.
1394-
ratio: range of aspect ratio of erased area.
1403+
scale (tuple or list): range of proportion of erased area against input image.
1404+
ratio (tuple or list): range of aspect ratio of erased area.
1405+
value (list, optional): erasing value. If None, it is interpreted as "random"
1406+
(erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
1407+
i.e. ``value[0]``.
13951408
13961409
Returns:
13971410
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
@@ -1400,35 +1413,52 @@ def get_params(img, scale, ratio, value=0):
14001413
area = img_h * img_w
14011414

14021415
for _ in range(10):
1403-
erase_area = random.uniform(scale[0], scale[1]) * area
1404-
aspect_ratio = random.uniform(ratio[0], ratio[1])
1416+
erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
1417+
aspect_ratio = torch.empty(1).uniform_(ratio[0], ratio[1]).item()
14051418

14061419
h = int(round(math.sqrt(erase_area * aspect_ratio)))
14071420
w = int(round(math.sqrt(erase_area / aspect_ratio)))
1421+
if not (h < img_h and w < img_w):
1422+
continue
1423+
1424+
if value is None:
1425+
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
1426+
else:
1427+
v = torch.tensor(value)[:, None, None]
14081428

1409-
if h < img_h and w < img_w:
1410-
i = random.randint(0, img_h - h)
1411-
j = random.randint(0, img_w - w)
1412-
if isinstance(value, numbers.Number):
1413-
v = value
1414-
elif isinstance(value, torch._six.string_classes):
1415-
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
1416-
elif isinstance(value, (list, tuple)):
1417-
v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
1418-
return i, j, h, w, v
1429+
i = torch.randint(0, img_h - h, size=(1, )).item()
1430+
j = torch.randint(0, img_w - w, size=(1, )).item()
1431+
return i, j, h, w, v
14191432

14201433
# Return original image
14211434
return 0, 0, img_h, img_w, img
14221435

1423-
def __call__(self, img):
1436+
def forward(self, img):
14241437
"""
14251438
Args:
14261439
img (Tensor): Tensor image of size (C, H, W) to be erased.
14271440
14281441
Returns:
14291442
img (Tensor): Erased Tensor image.
14301443
"""
1431-
if random.uniform(0, 1) < self.p:
1432-
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
1444+
if torch.rand(1) < self.p:
1445+
1446+
# cast self.value to script acceptable type
1447+
if isinstance(self.value, (int, float)):
1448+
value = [self.value, ]
1449+
elif isinstance(self.value, str):
1450+
value = None
1451+
elif isinstance(self.value, tuple):
1452+
value = list(self.value)
1453+
else:
1454+
value = self.value
1455+
1456+
if value is not None and not (len(value) in (1, img.shape[-3])):
1457+
raise ValueError(
1458+
"If value is a sequence, it should have either a single value or "
1459+
"{} (number of input channels)".format(img.shape[-3])
1460+
)
1461+
1462+
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
14331463
return F.erase(img, x, y, h, w, v, self.inplace)
14341464
return img

0 commit comments

Comments
 (0)