From 1674954b2f6401e1295e93c75e1d1fc8a26dc918 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 6 Sep 2021 10:24:46 +0100 Subject: [PATCH 1/2] Fix RA bugs. --- torchvision/transforms/autoaugment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 44c7990482b..a4e43846567 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -263,7 +263,7 @@ class RandAugment(torch.nn.Module): image. If given a number, the value is used for all bands respectively. """ - def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30, + def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None) -> None: super().__init__() @@ -276,6 +276,7 @@ def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: return { # op_name: (magnitudes, signed) + "Identity": (torch.tensor(0.0), False), "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), @@ -289,7 +290,6 @@ def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), "AutoContrast": (torch.tensor(0.0), False), "Equalize": (torch.tensor(0.0), False), - "Invert": (torch.tensor(0.0), False), } def forward(self, img: Tensor) -> Tensor: From 1e3e7c677e38309455d0a7eca2139b77598dea47 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 6 Sep 2021 10:47:13 +0100 Subject: [PATCH 2/2] Fix bins for TA. --- torchvision/transforms/autoaugment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index a4e43846567..bffc4a24f67 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -345,7 +345,7 @@ class TrivialAugmentWide(torch.nn.Module): image. If given a number, the value is used for all bands respectively. """ - def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST, + def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None) -> None: super().__init__() self.num_magnitude_bins = num_magnitude_bins