Skip to content

Commit d60bd48

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Fix RandAugment and TrivialAugment bugs (#4370)
Summary: * Fix RA bugs. * Fix bins for TA. Reviewed By: fmassa Differential Revision: D30793321 fbshipit-source-id: ebed731f85daae298c07963ed3115b9dda8403ea
1 parent d6a64d9 commit d60bd48

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchvision/transforms/autoaugment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ class RandAugment(torch.nn.Module):
263263
image. If given a number, the value is used for all bands respectively.
264264
"""
265265

266-
def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30,
266+
def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31,
267267
interpolation: InterpolationMode = InterpolationMode.NEAREST,
268268
fill: Optional[List[float]] = None) -> None:
269269
super().__init__()
@@ -276,6 +276,7 @@ def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int
276276
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
277277
return {
278278
# op_name: (magnitudes, signed)
279+
"Identity": (torch.tensor(0.0), False),
279280
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
280281
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
281282
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
@@ -289,7 +290,6 @@ def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str,
289290
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
290291
"AutoContrast": (torch.tensor(0.0), False),
291292
"Equalize": (torch.tensor(0.0), False),
292-
"Invert": (torch.tensor(0.0), False),
293293
}
294294

295295
def forward(self, img: Tensor) -> Tensor:
@@ -345,7 +345,7 @@ class TrivialAugmentWide(torch.nn.Module):
345345
image. If given a number, the value is used for all bands respectively.
346346
"""
347347

348-
def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST,
348+
def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST,
349349
fill: Optional[List[float]] = None) -> None:
350350
super().__init__()
351351
self.num_magnitude_bins = num_magnitude_bins

0 commit comments

Comments
 (0)