From d3e8b08a0e5f22013ee38721e1dc0b88c56ae812 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 9 Feb 2022 18:08:50 +0000 Subject: [PATCH 01/10] Adding basic augmix implementation. --- torchvision/transforms/autoaugment.py | 102 ++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index d58077c9b14..fe5ba227d47 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -458,3 +458,105 @@ def __repr__(self) -> str: f")" ) return s + + +class AugMix(torch.nn.Module): + # TODO: Documentation + _PARAMETER_MAX: int = 10 + + def __init__( + self, + severity: int = 1, + mixture_width: int = 3, + chain_depth: int = -1, + alpha: float = 1.0, + all_ops: bool = False, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, + ) -> None: + super().__init__() + if not (1 <= severity <= self._PARAMETER_MAX): + raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") + self.severity = severity + self.mixture_width = mixture_width + self.chain_depth = chain_depth + self.alpha = alpha + self.all_ops = all_ops + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + s = { + # op_name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), + "TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + } + if self.all_ops: + s.update({ + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + }) + return s + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + device = None + fill = self.fill + if isinstance(img, Tensor): + device = img.device + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + mixing_weights = torch._sample_dirichlet(torch.tensor([self.alpha] * self.mixture_width, device=device)) + m = torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha], device=device))[0] + op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img)) + + mix = torch.zeros() # TODO: add size, device, handle PIL/Tensors/Batches etc + for i in range(self.mixture_width): + img_aug = img.clone() # TODO: handle PIL + depth = self.chain_depth if self.chain_depth > 0 else torch.randint(low=1, high=4, size=(1,)).item() + for _ in range(depth): + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = ( + float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item()) + if magnitudes.ndim > 0 + else 0.0 + ) + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + img_aug = _apply_op(img_aug, op_name, magnitude, interpolation=self.interpolation, fill=fill) + mix += mixing_weights[i] * img_aug # TODO: handle PIL + + return (1.0 - m) * img + m * mix # TODO: handle PIL + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"severity={self.severity}" + f", mixture_width={self.mixture_width}" + f", chain_depth={self.chain_depth}" + f", alpha={self.alpha}" + f", all_ops={self.all_ops}" + f", interpolation={self.interpolation}" + f", fill={self.fill}" + f")" + ) + return s From 6c1a3887b1713b8b70fb0c4e773bc77959ab07a3 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Feb 2022 14:26:15 +0000 Subject: [PATCH 02/10] Finish the implementation. --- torchvision/transforms/autoaugment.py | 79 +++++++++++++++++++-------- 1 file changed, 57 insertions(+), 22 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index fe5ba227d47..b35dd756594 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -7,7 +7,7 @@ from . import functional as F, InterpolationMode -__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] +__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"] def _apply_op( @@ -461,7 +461,25 @@ def __repr__(self) -> str: class AugMix(torch.nn.Module): - # TODO: Documentation + r"""AugMix data augmentation method based on + `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" `_. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + severity (int): The severity of base augmentation operators. Default is ``1``. + mixture_width (int): The number of augmentation chains. Default is ``3``. + chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth in [1, 3]. + Default is ``-1``. + alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``. + all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ _PARAMETER_MAX: int = 10 def __init__( @@ -470,7 +488,7 @@ def __init__( mixture_width: int = 3, chain_depth: int = -1, alpha: float = 1.0, - all_ops: bool = False, + all_ops: bool = True, interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[List[float]] = None, ) -> None: @@ -499,38 +517,50 @@ def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, "Equalize": (torch.tensor(0.0), False), } if self.all_ops: - s.update({ - "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), - "Color": (torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), - }) + s.update( + { + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + } + ) return s - def forward(self, img: Tensor) -> Tensor: + @torch.jit.unused + def _pil_to_tensor(self, img) -> Tensor: + return F.pil_to_tensor(img) + + @torch.jit.unused + def _tensor_to_pil(self, img: Tensor): + return F.to_pil_image(img) + + def forward(self, orig_img: Tensor) -> Tensor: """ img (PIL Image or Tensor): Image to be transformed. Returns: PIL Image or Tensor: Transformed image. """ - device = None fill = self.fill - if isinstance(img, Tensor): - device = img.device + if isinstance(orig_img, Tensor): + img = orig_img if isinstance(fill, (int, float)): fill = [float(fill)] * F.get_image_num_channels(img) elif fill is not None: fill = [float(f) for f in fill] + else: + img = self._pil_to_tensor(orig_img) - mixing_weights = torch._sample_dirichlet(torch.tensor([self.alpha] * self.mixture_width, device=device)) - m = torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha], device=device))[0] + mixing_weights = torch._sample_dirichlet(torch.tensor([self.alpha] * self.mixture_width, device=img.device)) + m = torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha], device=img.device))[0] op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img)) - mix = torch.zeros() # TODO: add size, device, handle PIL/Tensors/Batches etc + batch = img[(None,) * max(4 - img.ndim, 0)] + mix = torch.zeros_like(batch, dtype=torch.float) for i in range(self.mixture_width): - img_aug = img.clone() # TODO: handle PIL - depth = self.chain_depth if self.chain_depth > 0 else torch.randint(low=1, high=4, size=(1,)).item() + aug = batch.clone() + depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) for _ in range(depth): op_index = int(torch.randint(len(op_meta), (1,)).item()) op_name = list(op_meta.keys())[op_index] @@ -542,10 +572,15 @@ def forward(self, img: Tensor) -> Tensor: ) if signed and torch.randint(2, (1,)): magnitude *= -1.0 - img_aug = _apply_op(img_aug, op_name, magnitude, interpolation=self.interpolation, fill=fill) - mix += mixing_weights[i] * img_aug # TODO: handle PIL - - return (1.0 - m) * img + m * mix # TODO: handle PIL + aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill) + mix.add_(aug.mul_(mixing_weights[i])) + mix.mul_(m).add_((1 - m) * batch) + mix = mix.to(dtype=img.dtype) + mix = mix[(0,) * max(4 - img.ndim, 0)] + + if not isinstance(orig_img, Tensor): + return self._tensor_to_pil(mix) + return mix def __repr__(self) -> str: s = ( From bc5667c0a0c56002820379d3f378b761f1d11f27 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Feb 2022 14:26:30 +0000 Subject: [PATCH 03/10] Add tests and documentation. --- docs/source/transforms.rst | 1 + gallery/plot_transforms.py | 8 ++++++++ references/classification/presets.py | 2 ++ test/test_transforms.py | 19 +++++++++++++++++++ test/test_transforms_tensor.py | 28 +++++++++++++++++++++++++++- 5 files changed, 57 insertions(+), 1 deletion(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 7f835267200..cae53728b96 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -198,6 +198,7 @@ The new transform can be used standalone or mixed-and-matched with existing tran AutoAugment RandAugment TrivialAugmentWide + AugMix .. _functional_transforms: diff --git a/gallery/plot_transforms.py b/gallery/plot_transforms.py index ab0cb892b16..d781f8f35ed 100644 --- a/gallery/plot_transforms.py +++ b/gallery/plot_transforms.py @@ -263,6 +263,14 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): imgs = [augmenter(orig_img) for _ in range(4)] plot(imgs) +#################################### +# AugMix +# ~~~~~~ +# The :class:`~torchvision.transforms.AugMix` transform automatically augments the data. +augmenter = T.AugMix() +imgs = [augmenter(orig_img) for _ in range(4)] +plot(imgs) + #################################### # Randomly-applied transforms # --------------------------- diff --git a/references/classification/presets.py b/references/classification/presets.py index 6e1000174ab..418ef3e2e07 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -22,6 +22,8 @@ def __init__( trans.append(autoaugment.RandAugment(interpolation=interpolation)) elif auto_augment_policy == "ta_wide": trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) + elif auto_augment_policy == "augmix": + trans.append(autoaugment.AugMix(interpolation=interpolation)) else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) diff --git a/test/test_transforms.py b/test/test_transforms.py index 160e4407d8b..b0559927caf 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1588,6 +1588,25 @@ def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale): transform.__repr__() +@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)]) +@pytest.mark.parametrize("severity", [1, 10]) +@pytest.mark.parametrize("mixture_width", [1, 2]) +@pytest.mark.parametrize("chain_depth", [-1, 2]) +@pytest.mark.parametrize("all_ops", [True, False]) +@pytest.mark.parametrize("grayscale", [True, False]) +def test_augmix(fill, severity, mixture_width, chain_depth, all_ops, grayscale): + random.seed(42) + img = Image.open(GRACE_HOPPER) + if grayscale: + img, fill = _get_grayscale_test_image(img, fill) + transform = transforms.AugMix( + fill=fill, severity=severity, mixture_width=mixture_width, chain_depth=chain_depth, all_ops=all_ops + ) + for _ in range(100): + img = transform(img) + transform.__repr__() + + def test_random_crop(): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 9bc499467b7..c9f30c7e560 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -720,7 +720,33 @@ def test_trivialaugmentwide(device, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "fill", + [ + None, + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) +def test_augmix(device, fill): + tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) + batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) + + transform = T.AugMix(fill=fill) + s_transform = torch.jit.script(transform) + for _ in range(25): + _test_transform_vs_scripted(transform, s_transform, tensor) + _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + + +@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide, T.AugMix]) def test_autoaugment_save(augmentation, tmpdir): transform = augmentation() s_transform = torch.jit.script(transform) From bf1a17e8d2d9c77db6b8d93c92c19783452005a9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Feb 2022 15:08:37 +0000 Subject: [PATCH 04/10] Fix tests. --- torchvision/transforms/autoaugment.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index b35dd756594..dc270ba7f1b 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -480,7 +480,6 @@ class AugMix(torch.nn.Module): fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """ - _PARAMETER_MAX: int = 10 def __init__( self, @@ -493,6 +492,7 @@ def __init__( fill: Optional[List[float]] = None, ) -> None: super().__init__() + self._PARAMETER_MAX = 10 if not (1 <= severity <= self._PARAMETER_MAX): raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") self.severity = severity @@ -556,7 +556,10 @@ def forward(self, orig_img: Tensor) -> Tensor: m = torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha], device=img.device))[0] op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img)) - batch = img[(None,) * max(4 - img.ndim, 0)] + expand_dims = max(4 - img.ndim, 0) + batch = img + for _ in range(expand_dims): + batch = batch.unsqueeze(0) mix = torch.zeros_like(batch, dtype=torch.float) for i in range(self.mixture_width): aug = batch.clone() @@ -573,10 +576,11 @@ def forward(self, orig_img: Tensor) -> Tensor: if signed and torch.randint(2, (1,)): magnitude *= -1.0 aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill) - mix.add_(aug.mul_(mixing_weights[i])) + mix.add_(aug * mixing_weights[i]) mix.mul_(m).add_((1 - m) * batch) mix = mix.to(dtype=img.dtype) - mix = mix[(0,) * max(4 - img.ndim, 0)] + for _ in range(expand_dims): + mix = mix.squeeze(0) if not isinstance(orig_img, Tensor): return self._tensor_to_pil(mix) From 649305e368e696f1af37e51c0ae195b7f8952bd7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Feb 2022 15:19:21 +0000 Subject: [PATCH 05/10] Simplify code. --- torchvision/transforms/autoaugment.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index dc270ba7f1b..12effd0d007 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -556,10 +556,8 @@ def forward(self, orig_img: Tensor) -> Tensor: m = torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha], device=img.device))[0] op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img)) - expand_dims = max(4 - img.ndim, 0) - batch = img - for _ in range(expand_dims): - batch = batch.unsqueeze(0) + orig_dims = list(img.shape) + batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims) mix = torch.zeros_like(batch, dtype=torch.float) for i in range(self.mixture_width): aug = batch.clone() @@ -578,9 +576,7 @@ def forward(self, orig_img: Tensor) -> Tensor: aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill) mix.add_(aug * mixing_weights[i]) mix.mul_(m).add_((1 - m) * batch) - mix = mix.to(dtype=img.dtype) - for _ in range(expand_dims): - mix = mix.squeeze(0) + mix = mix.view(orig_dims).to(dtype=img.dtype) if not isinstance(orig_img, Tensor): return self._tensor_to_pil(mix) From 547bcf93aa9b18042447a61bde057dadca07355a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 12 Feb 2022 09:49:02 +0000 Subject: [PATCH 06/10] Speed optimizations. --- torchvision/transforms/autoaugment.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index f4de7e544e7..8f515ab31ed 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -558,9 +558,9 @@ def forward(self, orig_img: Tensor) -> Tensor: orig_dims = list(img.shape) batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims) - mix = torch.zeros_like(batch, dtype=torch.float) + mix = (1.0 - m) * batch for i in range(self.mixture_width): - aug = batch.clone() + aug = batch depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) for _ in range(depth): op_index = int(torch.randint(len(op_meta), (1,)).item()) @@ -574,8 +574,7 @@ def forward(self, orig_img: Tensor) -> Tensor: if signed and torch.randint(2, (1,)): magnitude *= -1.0 aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill) - mix.add_(aug * mixing_weights[i]) - mix.mul_(m).add_((1 - m) * batch) + mix.add_((mixing_weights[i] * m) * aug) mix = mix.view(orig_dims).to(dtype=img.dtype) if not isinstance(orig_img, Tensor): From 9c284cc64802687d85d71beee015849e731eab4d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 18 Feb 2022 13:17:23 +0000 Subject: [PATCH 07/10] Per image weights instead of per batch. --- torchvision/transforms/autoaugment.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 8f515ab31ed..d4418a92d4f 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -552,13 +552,24 @@ def forward(self, orig_img: Tensor) -> Tensor: else: img = self._pil_to_tensor(orig_img) - mixing_weights = torch._sample_dirichlet(torch.tensor([self.alpha] * self.mixture_width, device=img.device)) - m = torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha], device=img.device))[0] op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img)) orig_dims = list(img.shape) batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims) - mix = (1.0 - m) * batch + batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) + + # Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet + # with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image. + m = torch._sample_dirichlet( + torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) + ) + + # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images. + combined_weights = torch._sample_dirichlet( + torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) + ) * m[:, 1].view([batch_dims[0], -1]) + + mix = m[:, 0].view(batch_dims) * batch for i in range(self.mixture_width): aug = batch depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) @@ -574,7 +585,7 @@ def forward(self, orig_img: Tensor) -> Tensor: if signed and torch.randint(2, (1,)): magnitude *= -1.0 aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill) - mix.add_((mixing_weights[i] * m) * aug) + mix.add_(combined_weights[:, i].view(batch_dims) * aug) mix = mix.view(orig_dims).to(dtype=img.dtype) if not isinstance(orig_img, Tensor): From e4b62be7cf5225d0b5bc1471af54c54281ddbe7d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 18 Feb 2022 14:21:53 +0000 Subject: [PATCH 08/10] Fix tests. --- test/test_transforms_tensor.py | 7 ++++++- torchvision/transforms/autoaugment.py | 8 ++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index c9f30c7e560..fb6bec5bb9b 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -739,7 +739,12 @@ def test_augmix(device, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) - transform = T.AugMix(fill=fill) + class DeterministicAugMix(T.AugMix): + def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: + # patch the method to ensure that the order of rand calls doesn't affect the outcome + return params.softmax(dim=-1) + + transform = DeterministicAugMix(fill=fill) s_transform = torch.jit.script(transform) for _ in range(25): _test_transform_vs_scripted(transform, s_transform, tensor) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index d4418a92d4f..cd4bb75ea50 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -535,6 +535,10 @@ def _pil_to_tensor(self, img) -> Tensor: def _tensor_to_pil(self, img: Tensor): return F.to_pil_image(img) + def _sample_dirichlet(self, params: Tensor) -> Tensor: + # Must be on a separate method so that we can overwrite it in tests. + return torch._sample_dirichlet(params) + def forward(self, orig_img: Tensor) -> Tensor: """ img (PIL Image or Tensor): Image to be transformed. @@ -560,12 +564,12 @@ def forward(self, orig_img: Tensor) -> Tensor: # Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet # with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image. - m = torch._sample_dirichlet( + m = self._sample_dirichlet( torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) ) # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images. - combined_weights = torch._sample_dirichlet( + combined_weights = self._sample_dirichlet( torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) ) * m[:, 1].view([batch_dims[0], -1]) From c722c02d5778f5a9a974c22a9ff7d2bd813c0753 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 18 Feb 2022 14:22:28 +0000 Subject: [PATCH 09/10] Update torchvision/transforms/autoaugment.py Co-authored-by: vfdev --- torchvision/transforms/autoaugment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index cd4bb75ea50..69e3c8dfcd4 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -470,7 +470,7 @@ class AugMix(torch.nn.Module): Args: severity (int): The severity of base augmentation operators. Default is ``1``. mixture_width (int): The number of augmentation chains. Default is ``3``. - chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth in [1, 3]. + chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3]. Default is ``-1``. alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``. all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``. From ecc598e25e4867a6d09551fce6215c160e6c9fdd Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 18 Feb 2022 16:15:00 +0000 Subject: [PATCH 10/10] Changing the default severity value to get by default the same strength as RandAugment. --- torchvision/transforms/autoaugment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 69e3c8dfcd4..d820e5126a1 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -468,7 +468,7 @@ class AugMix(torch.nn.Module): If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: - severity (int): The severity of base augmentation operators. Default is ``1``. + severity (int): The severity of base augmentation operators. Default is ``3``. mixture_width (int): The number of augmentation chains. Default is ``3``. chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3]. Default is ``-1``. @@ -483,7 +483,7 @@ class AugMix(torch.nn.Module): def __init__( self, - severity: int = 1, + severity: int = 3, mixture_width: int = 3, chain_depth: int = -1, alpha: float = 1.0,