Skip to content

Commit 446b2ca

Browse files
Integration of TrivialAugment with the current AutoAugment Code (#4221)
* Initial Proposal * Tensor Save Test + Test Name Fix * Formatting + removing unused argument * fix old argument * fix isnan check error + indexing error with jit * Fix Flake8 error. * Fix MyPy error. * Fix Flake8 error. * Fix PyTorch JIT error in UnitTests due to type annotation. * Fixing tests. * Removing type ignore. * Adding support of ta_wide in references. * Move methods in classes. * Moving new classes to the bottom. * Specialize to TA to TAwide * Add missing type * Fixing lint * Fix doc * Fix search space of TrivialAugment. Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 80d5f50 commit 446b2ca

File tree

6 files changed

+119
-2
lines changed

6 files changed

+119
-2
lines changed

docs/source/transforms.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ The new transform can be used standalone or mixed-and-matched with existing tran
234234
.. autoclass:: RandAugment
235235
:members:
236236

237+
`TrivialAugmentWide <https://arxiv.org/abs/2103.10158>`_ is a dataset-independent data-augmentation technique which improves the accuracy of Image Classification models.
238+
239+
.. autoclass:: TrivialAugmentWide
240+
:members:
241+
237242
.. _functional_transforms:
238243

239244
Functional Transforms

gallery/plot_transforms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
253253
imgs = [augmenter(orig_img) for _ in range(4)]
254254
plot(imgs)
255255

256+
####################################
257+
# TrivialAugmentWide
258+
# ~~~~~~~~~~~~~~~~~~
259+
# The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data.
260+
augmenter = T.TrivialAugmentWide()
261+
imgs = [augmenter(orig_img) for _ in range(4)]
262+
plot(imgs)
263+
256264
####################################
257265
# Randomly-applied transforms
258266
# ---------------------------

references/classification/presets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2
1111
if auto_augment_policy is not None:
1212
if auto_augment_policy == "ra":
1313
trans.append(autoaugment.RandAugment())
14+
elif auto_augment_policy == "ta_wide":
15+
trans.append(autoaugment.TrivialAugmentWide())
1416
else:
1517
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
1618
trans.append(autoaugment.AutoAugment(policy=aa_policy))

test/test_transforms.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,6 +1502,17 @@ def test_randaugment(num_ops, magnitude, fill):
15021502
transform.__repr__()
15031503

15041504

1505+
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
1506+
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
1507+
def test_trivialaugmentwide(fill, num_magnitude_bins):
1508+
random.seed(42)
1509+
img = Image.open(GRACE_HOPPER)
1510+
transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins)
1511+
for _ in range(100):
1512+
img = transform(img)
1513+
transform.__repr__()
1514+
1515+
15051516
def test_random_crop():
15061517
height = random.randint(10, 32) * 2
15071518
width = random.randint(10, 32) * 2

test/test_transforms_tensor.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,20 @@ def test_randaugment(device, num_ops, magnitude, fill):
547547
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
548548

549549

550-
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment])
550+
@pytest.mark.parametrize('device', cpu_and_gpu())
551+
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
552+
def test_trivialaugmentwide(device, fill):
553+
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
554+
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
555+
556+
transform = T.TrivialAugmentWide(fill=fill)
557+
s_transform = torch.jit.script(transform)
558+
for _ in range(25):
559+
_test_transform_vs_scripted(transform, s_transform, tensor)
560+
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
561+
562+
563+
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
551564
def test_autoaugment_save(augmentation, tmpdir):
552565
transform = augmentation()
553566
s_transform = torch.jit.script(transform)

torchvision/transforms/autoaugment.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from . import functional as F, InterpolationMode
99

10-
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"]
10+
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
1111

1212

1313
def _apply_op(img: Tensor, op_name: str, magnitude: float,
@@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float,
4444
img = F.equalize(img)
4545
elif op_name == "Invert":
4646
img = F.invert(img)
47+
elif op_name == "Identity":
48+
pass
4749
else:
4850
raise ValueError("The provided operator {} is not recognized.".format(op_name))
4951
return img
@@ -325,3 +327,79 @@ def __repr__(self) -> str:
325327
s += ', fill={fill}'
326328
s += ')'
327329
return s.format(**self.__dict__)
330+
331+
332+
class TrivialAugmentWide(torch.nn.Module):
333+
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
334+
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
335+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
336+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
337+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
338+
339+
Args:
340+
num_magnitude_bins (int): The number of different magnitude values.
341+
interpolation (InterpolationMode): Desired interpolation enum defined by
342+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
343+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
344+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
345+
image. If given a number, the value is used for all bands respectively.
346+
"""
347+
348+
def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST,
349+
fill: Optional[List[float]] = None) -> None:
350+
super().__init__()
351+
self.num_magnitude_bins = num_magnitude_bins
352+
self.interpolation = interpolation
353+
self.fill = fill
354+
355+
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
356+
return {
357+
# op_name: (magnitudes, signed)
358+
"Identity": (torch.tensor(0.0), False),
359+
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
360+
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
361+
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
362+
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
363+
"Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
364+
"Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
365+
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
366+
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
367+
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
368+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
369+
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
370+
"AutoContrast": (torch.tensor(0.0), False),
371+
"Equalize": (torch.tensor(0.0), False),
372+
}
373+
374+
def forward(self, img: Tensor) -> Tensor:
375+
"""
376+
img (PIL Image or Tensor): Image to be transformed.
377+
378+
Returns:
379+
PIL Image or Tensor: Transformed image.
380+
"""
381+
fill = self.fill
382+
if isinstance(img, Tensor):
383+
if isinstance(fill, (int, float)):
384+
fill = [float(fill)] * F.get_image_num_channels(img)
385+
elif fill is not None:
386+
fill = [float(f) for f in fill]
387+
388+
op_meta = self._augmentation_space(self.num_magnitude_bins)
389+
op_index = int(torch.randint(len(op_meta), (1,)).item())
390+
op_name = list(op_meta.keys())[op_index]
391+
magnitudes, signed = op_meta[op_name]
392+
magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \
393+
if magnitudes.ndim > 0 else 0.0
394+
if signed and torch.randint(2, (1,)):
395+
magnitude *= -1.0
396+
397+
return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
398+
399+
def __repr__(self) -> str:
400+
s = self.__class__.__name__ + '('
401+
s += 'num_magnitude_bins={num_magnitude_bins}'
402+
s += ', interpolation={interpolation}'
403+
s += ', fill={fill}'
404+
s += ')'
405+
return s.format(**self.__dict__)

0 commit comments

Comments
 (0)