Skip to content

Commit 425c52d

Browse files
committed
Specialize to TA to TAwide
1 parent 226998c commit 425c52d

File tree

4 files changed

+17
-41
lines changed

4 files changed

+17
-41
lines changed

references/classification/presets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2
99
if hflip_prob > 0:
1010
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
1111
if auto_augment_policy is not None:
12-
if auto_augment_policy == autoaugment.AugmentationSpace.TA_WIDE.value:
13-
trans.append(autoaugment.TrivialAugment())
12+
if auto_augment_policy == "ta_wide":
13+
trans.append(autoaugment.TrivialAugmentWide())
1414
else:
1515
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
1616
trans.append(autoaugment.AutoAugment(policy=aa_policy))

test/test_transforms.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,14 +1490,12 @@ def test_autoaugment(policy, fill):
14901490
transform.__repr__()
14911491

14921492

1493-
@pytest.mark.parametrize('augmentation_space', [space for space in transforms.AugmentationSpace])
14941493
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
14951494
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
1496-
def test_trivialaugment(augmentation_space, fill, num_magnitude_bins):
1495+
def test_trivialaugmentwide(fill, num_magnitude_bins):
14971496
random.seed(42)
14981497
img = Image.open(GRACE_HOPPER)
1499-
transform = transforms.TrivialAugment(augmentation_space=augmentation_space,
1500-
fill=fill, num_magnitude_bins=num_magnitude_bins)
1498+
transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins)
15011499
for _ in range(100):
15021500
img = transform(img)
15031501
transform.__repr__()

test/test_transforms_tensor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -540,22 +540,20 @@ def test_autoaugment_save(tmpdir):
540540

541541

542542
@pytest.mark.parametrize('device', cpu_and_gpu())
543-
@pytest.mark.parametrize('augmentation_space', [space for space in T.AugmentationSpace])
544543
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
545-
def test_trivialaugment(device, augmentation_space, fill):
544+
def test_trivialaugmentwide(device, fill):
546545
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
547546
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
548547

549-
s_transform = None
550-
transform = T.TrivialAugment(augmentation_space=augmentation_space, fill=fill)
548+
transform = T.TrivialAugmentWide(fill=fill)
551549
s_transform = torch.jit.script(transform)
552550
for _ in range(25):
553551
_test_transform_vs_scripted(transform, s_transform, tensor)
554552
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
555553

556554

557-
def test_trivialaugment_save(tmpdir):
558-
transform = T.TrivialAugment()
555+
def test_trivialaugmentwide_save(tmpdir):
556+
transform = T.TrivialAugmentWide()
559557
s_transform = torch.jit.script(transform)
560558
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
561559

torchvision/transforms/autoaugment.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from . import functional as F, InterpolationMode
99

10-
__all__ = ["AutoAugmentPolicy", "AutoAugment", "AugmentationSpace", "TrivialAugment"]
10+
__all__ = ["AutoAugmentPolicy", "AutoAugment", "TrivialAugmentWide"]
1111

1212

1313
def _apply_op(img: Tensor, op_name: str, magnitude: float,
@@ -178,8 +178,7 @@ def _get_transforms(
178178
else:
179179
raise ValueError("The provided policy {} is not recognized.".format(policy))
180180

181-
@staticmethod
182-
def _get_magnitudes(num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
181+
def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
183182
return {
184183
# name: (magnitudes, signed)
185184
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
@@ -243,24 +242,14 @@ def __repr__(self) -> str:
243242
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
244243

245244

246-
class AugmentationSpace(Enum):
247-
"""The augmentation space to use.
248-
Available spaces are `AA` for AutoAugment and `TA_WIDE` for the TrivialAugment.
249-
"""
250-
AA = "aa"
251-
TA_WIDE = "ta_wide"
252-
253-
254-
class TrivialAugment(torch.nn.Module):
255-
r"""Dataset-independent data-augmentation with TrivialAugment, as described in
245+
class TrivialAugmentWide(torch.nn.Module):
246+
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
256247
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
257248
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
258249
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
259250
If img is PIL Image, it is expected to be in mode "L" or "RGB".
260251
261252
Args:
262-
augmentation_space (AugmentationSpace): Desired augmentation space enum defined by
263-
:class:`torchvision.transforms.autoaugment.AugmentationSpace`. Default is ``AugmentationSpace.TA_WIDE``.
264253
num_magnitude_bins (int): The number of different magnitude values.
265254
interpolation (InterpolationMode): Desired interpolation enum defined by
266255
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
@@ -269,17 +258,14 @@ class TrivialAugment(torch.nn.Module):
269258
image. If given a number, the value is used for all bands respectively.
270259
"""
271260

272-
def __init__(self, augmentation_space: AugmentationSpace = AugmentationSpace.TA_WIDE, num_magnitude_bins: int = 30,
273-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
261+
def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST,
274262
fill: Optional[List[float]] = None) -> None:
275263
super().__init__()
276-
self.augmentation_space = augmentation_space
277264
self.num_magnitude_bins = num_magnitude_bins
278265
self.interpolation = interpolation
279266
self.fill = fill
280267

281-
@staticmethod
282-
def _get_magnitudes(num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
268+
def _get_magnitudes(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
283269
return {
284270
# name: (magnitudes, signed)
285271
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
@@ -303,7 +289,7 @@ def forward(self, img: Tensor):
303289
img (PIL Image or Tensor): Image to be transformed.
304290
305291
Returns:
306-
PIL Image or Tensor: TrivialAugmented image.
292+
PIL Image or Tensor: Transformed image.
307293
"""
308294
fill = self.fill
309295
if isinstance(img, Tensor):
@@ -312,12 +298,7 @@ def forward(self, img: Tensor):
312298
elif fill is not None:
313299
fill = [float(f) for f in fill]
314300

315-
if self.augmentation_space == AugmentationSpace.AA:
316-
op_meta = AutoAugment._get_magnitudes(self.num_magnitude_bins, F.get_image_size(img))
317-
elif self.augmentation_space == AugmentationSpace.TA_WIDE:
318-
op_meta = self._get_magnitudes(self.num_magnitude_bins)
319-
else:
320-
raise ValueError(f"Provided augmentation_space arguments {self.augmentation_space} not available.")
301+
op_meta = self._get_magnitudes(self.num_magnitude_bins)
321302
op_index = int(torch.randint(len(op_meta), (1,)).item())
322303
op_name = list(op_meta.keys())[op_index]
323304
magnitudes, signed = op_meta[op_name]
@@ -330,8 +311,7 @@ def forward(self, img: Tensor):
330311

331312
def __repr__(self) -> str:
332313
s = self.__class__.__name__ + '('
333-
s += 'augmentation_space={augmentation_space}'
334-
s += ', num_magnitude_bins={num_magnitude_bins}'
314+
s += 'num_magnitude_bins={num_magnitude_bins}'
335315
s += ', interpolation={interpolation}'
336316
s += ', fill={fill}'
337317
s += ')'

0 commit comments

Comments
 (0)