Skip to content

Commit 5a81554

Browse files
authored
Adding RandAugment implementation (#4348)
* Adding randaugment implementation * Refactoring. * Adding num_magnitude_bins. * Adding FIXME.
1 parent f52ddb0 commit 5a81554

File tree

4 files changed

+128
-13
lines changed

4 files changed

+128
-13
lines changed

references/classification/presets.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2
99
if hflip_prob > 0:
1010
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
1111
if auto_augment_policy is not None:
12-
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
13-
trans.append(autoaugment.AutoAugment(policy=aa_policy))
12+
if auto_augment_policy == "ra":
13+
trans.append(autoaugment.RandAugment())
14+
else:
15+
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
16+
trans.append(autoaugment.AutoAugment(policy=aa_policy))
1417
trans.extend([
1518
transforms.ToTensor(),
1619
transforms.Normalize(mean=mean, std=std),

test/test_transforms.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,18 @@ def test_autoaugment(policy, fill):
14901490
transform.__repr__()
14911491

14921492

1493+
@pytest.mark.parametrize('num_ops', [1, 2, 3])
1494+
@pytest.mark.parametrize('magnitude', [7, 9, 11])
1495+
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
1496+
def test_randaugment(num_ops, magnitude, fill):
1497+
random.seed(42)
1498+
img = Image.open(GRACE_HOPPER)
1499+
transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
1500+
for _ in range(100):
1501+
img = transform(img)
1502+
transform.__repr__()
1503+
1504+
14931505
def test_random_crop():
14941506
height = random.randint(10, 32) * 2
14951507
width = random.randint(10, 32) * 2

test/test_transforms_tensor.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,16 +525,31 @@ def test_autoaugment(device, policy, fill):
525525
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
526526
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
527527

528-
s_transform = None
529528
transform = T.AutoAugment(policy=policy, fill=fill)
530529
s_transform = torch.jit.script(transform)
531530
for _ in range(25):
532531
_test_transform_vs_scripted(transform, s_transform, tensor)
533532
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
534533

535534

536-
def test_autoaugment_save(tmpdir):
537-
transform = T.AutoAugment()
535+
@pytest.mark.parametrize('device', cpu_and_gpu())
536+
@pytest.mark.parametrize('num_ops', [1, 2, 3])
537+
@pytest.mark.parametrize('magnitude', [7, 9, 11])
538+
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
539+
def test_randaugment(device, num_ops, magnitude, fill):
540+
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
541+
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
542+
543+
transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
544+
s_transform = torch.jit.script(transform)
545+
for _ in range(25):
546+
_test_transform_vs_scripted(transform, s_transform, tensor)
547+
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
548+
549+
550+
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment])
551+
def test_autoaugment_save(augmentation, tmpdir):
552+
transform = augmentation()
538553
s_transform = torch.jit.script(transform)
539554
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
540555

torchvision/transforms/autoaugment.py

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from . import functional as F, InterpolationMode
99

10-
__all__ = ["AutoAugmentPolicy", "AutoAugment"]
10+
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"]
1111

1212

1313
def _apply_op(img: Tensor, op_name: str, magnitude: float,
@@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum):
5858
SVHN = "svhn"
5959

6060

61+
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
6162
class AutoAugment(torch.nn.Module):
6263
r"""AutoAugment data augmentation method based on
6364
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
@@ -85,9 +86,9 @@ def __init__(
8586
self.policy = policy
8687
self.interpolation = interpolation
8788
self.fill = fill
88-
self.transforms = self._get_transforms(policy)
89+
self.policies = self._get_policies(policy)
8990

90-
def _get_transforms(
91+
def _get_policies(
9192
self,
9293
policy: AutoAugmentPolicy
9394
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
@@ -178,9 +179,9 @@ def _get_transforms(
178179
else:
179180
raise ValueError("The provided policy {} is not recognized.".format(policy))
180181

181-
def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
182+
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
182183
return {
183-
# name: (magnitudes, signed)
184+
# op_name: (magnitudes, signed)
184185
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
185186
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
186187
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
@@ -224,11 +225,11 @@ def forward(self, img: Tensor) -> Tensor:
224225
elif fill is not None:
225226
fill = [float(f) for f in fill]
226227

227-
transform_id, probs, signs = self.get_params(len(self.transforms))
228+
transform_id, probs, signs = self.get_params(len(self.policies))
228229

229-
for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]):
230+
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
230231
if probs[i] <= p:
231-
op_meta = self._get_magnitudes(10, F.get_image_size(img))
232+
op_meta = self._augmentation_space(10, F.get_image_size(img))
232233
magnitudes, signed = op_meta[op_name]
233234
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
234235
if signed and signs[i] == 0:
@@ -239,3 +240,87 @@ def forward(self, img: Tensor) -> Tensor:
239240

240241
def __repr__(self) -> str:
241242
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
243+
244+
245+
class RandAugment(torch.nn.Module):
246+
r"""RandAugment data augmentation method based on
247+
`"RandAugment: Practical automated data augmentation with a reduced search space"
248+
<https://arxiv.org/abs/1909.13719>`.
249+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
250+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
251+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
252+
253+
Args:
254+
num_ops (int): Number of augmentation transformations to apply sequentially.
255+
magnitude (int): Magnitude for all the transformations.
256+
num_magnitude_bins (int): The number of different magnitude values.
257+
interpolation (InterpolationMode): Desired interpolation enum defined by
258+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
259+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
260+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
261+
image. If given a number, the value is used for all bands respectively.
262+
"""
263+
264+
def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30,
265+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
266+
fill: Optional[List[float]] = None) -> None:
267+
super().__init__()
268+
self.num_ops = num_ops
269+
self.magnitude = magnitude
270+
self.num_magnitude_bins = num_magnitude_bins
271+
self.interpolation = interpolation
272+
self.fill = fill
273+
274+
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
275+
return {
276+
# op_name: (magnitudes, signed)
277+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
278+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
279+
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
280+
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
281+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
282+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
283+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
284+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
285+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
286+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
287+
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
288+
"AutoContrast": (torch.tensor(0.0), False),
289+
"Equalize": (torch.tensor(0.0), False),
290+
"Invert": (torch.tensor(0.0), False),
291+
}
292+
293+
def forward(self, img: Tensor) -> Tensor:
294+
"""
295+
img (PIL Image or Tensor): Image to be transformed.
296+
Returns:
297+
PIL Image or Tensor: Transformed image.
298+
"""
299+
fill = self.fill
300+
if isinstance(img, Tensor):
301+
if isinstance(fill, (int, float)):
302+
fill = [float(fill)] * F.get_image_num_channels(img)
303+
elif fill is not None:
304+
fill = [float(f) for f in fill]
305+
306+
for _ in range(self.num_ops):
307+
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
308+
op_index = int(torch.randint(len(op_meta), (1,)).item())
309+
op_name = list(op_meta.keys())[op_index]
310+
magnitudes, signed = op_meta[op_name]
311+
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
312+
if signed and torch.randint(2, (1,)):
313+
magnitude *= -1.0
314+
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
315+
316+
return img
317+
318+
def __repr__(self) -> str:
319+
s = self.__class__.__name__ + '('
320+
s += 'num_ops={num_ops}'
321+
s += ', magnitude={magnitude}'
322+
s += ', num_magnitude_bins={num_magnitude_bins}'
323+
s += ', interpolation={interpolation}'
324+
s += ', fill={fill}'
325+
s += ')'
326+
return s.format(**self.__dict__)

0 commit comments

Comments
 (0)