Skip to content

Commit 1f6364b

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Refactor AutoAugment to support more augmentations. (#4338)
Reviewed By: fmassa Differential Revision: D30793336 fbshipit-source-id: b9da3cf133f06e44bef6cdd7174bae21a306bec6
1 parent b507e72 commit 1f6364b

File tree

1 file changed

+153
-155
lines changed

1 file changed

+153
-155
lines changed

torchvision/transforms/autoaugment.py

Lines changed: 153 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,45 @@
1010
__all__ = ["AutoAugmentPolicy", "AutoAugment"]
1111

1212

13+
def _apply_op(img: Tensor, op_name: str, magnitude: float,
14+
interpolation: InterpolationMode, fill: Optional[List[float]]):
15+
if op_name == "ShearX":
16+
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
17+
interpolation=interpolation, fill=fill)
18+
elif op_name == "ShearY":
19+
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
20+
interpolation=interpolation, fill=fill)
21+
elif op_name == "TranslateX":
22+
img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0,
23+
interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
24+
elif op_name == "TranslateY":
25+
img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0,
26+
interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
27+
elif op_name == "Rotate":
28+
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
29+
elif op_name == "Brightness":
30+
img = F.adjust_brightness(img, 1.0 + magnitude)
31+
elif op_name == "Color":
32+
img = F.adjust_saturation(img, 1.0 + magnitude)
33+
elif op_name == "Contrast":
34+
img = F.adjust_contrast(img, 1.0 + magnitude)
35+
elif op_name == "Sharpness":
36+
img = F.adjust_sharpness(img, 1.0 + magnitude)
37+
elif op_name == "Posterize":
38+
img = F.posterize(img, int(magnitude))
39+
elif op_name == "Solarize":
40+
img = F.solarize(img, magnitude)
41+
elif op_name == "AutoContrast":
42+
img = F.autocontrast(img)
43+
elif op_name == "Equalize":
44+
img = F.equalize(img)
45+
elif op_name == "Invert":
46+
img = F.invert(img)
47+
else:
48+
raise ValueError("The provided operator {} is not recognized.".format(op_name))
49+
return img
50+
51+
1352
class AutoAugmentPolicy(Enum):
1453
"""AutoAugment policies learned on different datasets.
1554
Available policies are IMAGENET, CIFAR10 and SVHN.
@@ -19,116 +58,6 @@ class AutoAugmentPolicy(Enum):
1958
SVHN = "svhn"
2059

2160

22-
def _get_transforms( # type: ignore[return]
23-
policy: AutoAugmentPolicy
24-
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
25-
if policy == AutoAugmentPolicy.IMAGENET:
26-
return [
27-
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
28-
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
29-
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
30-
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
31-
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
32-
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
33-
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
34-
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
35-
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
36-
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
37-
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
38-
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
39-
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
40-
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
41-
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
42-
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
43-
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
44-
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
45-
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
46-
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
47-
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
48-
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
49-
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
50-
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
51-
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
52-
]
53-
elif policy == AutoAugmentPolicy.CIFAR10:
54-
return [
55-
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
56-
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
57-
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
58-
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
59-
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
60-
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
61-
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
62-
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
63-
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
64-
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
65-
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
66-
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
67-
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
68-
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
69-
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
70-
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
71-
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
72-
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
73-
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
74-
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
75-
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
76-
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
77-
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
78-
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
79-
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
80-
]
81-
elif policy == AutoAugmentPolicy.SVHN:
82-
return [
83-
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
84-
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
85-
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
86-
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
87-
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
88-
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
89-
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
90-
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
91-
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
92-
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
93-
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
94-
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
95-
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
96-
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
97-
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
98-
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
99-
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
100-
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
101-
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
102-
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
103-
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
104-
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
105-
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
106-
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
107-
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
108-
]
109-
110-
111-
def _get_magnitudes() -> Dict[str, Tuple[Optional[Tensor], Optional[bool]]]:
112-
_BINS = 10
113-
return {
114-
# name: (magnitudes, signed)
115-
"ShearX": (torch.linspace(0.0, 0.3, _BINS), True),
116-
"ShearY": (torch.linspace(0.0, 0.3, _BINS), True),
117-
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
118-
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
119-
"Rotate": (torch.linspace(0.0, 30.0, _BINS), True),
120-
"Brightness": (torch.linspace(0.0, 0.9, _BINS), True),
121-
"Color": (torch.linspace(0.0, 0.9, _BINS), True),
122-
"Contrast": (torch.linspace(0.0, 0.9, _BINS), True),
123-
"Sharpness": (torch.linspace(0.0, 0.9, _BINS), True),
124-
"Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False),
125-
"Solarize": (torch.linspace(256.0, 0.0, _BINS), False),
126-
"AutoContrast": (None, None),
127-
"Equalize": (None, None),
128-
"Invert": (None, None),
129-
}
130-
131-
13261
class AutoAugment(torch.nn.Module):
13362
r"""AutoAugment data augmentation method based on
13463
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
@@ -156,11 +85,117 @@ def __init__(
15685
self.policy = policy
15786
self.interpolation = interpolation
15887
self.fill = fill
88+
self.transforms = self._get_transforms(policy)
15989

160-
self.transforms = _get_transforms(policy)
161-
if self.transforms is None:
90+
def _get_transforms(
91+
self,
92+
policy: AutoAugmentPolicy
93+
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
94+
if policy == AutoAugmentPolicy.IMAGENET:
95+
return [
96+
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
97+
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
98+
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
99+
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
100+
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
101+
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
102+
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
103+
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
104+
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
105+
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
106+
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
107+
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
108+
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
109+
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
110+
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
111+
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
112+
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
113+
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
114+
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
115+
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
116+
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
117+
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
118+
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
119+
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
120+
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
121+
]
122+
elif policy == AutoAugmentPolicy.CIFAR10:
123+
return [
124+
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
125+
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
126+
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
127+
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
128+
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
129+
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
130+
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
131+
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
132+
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
133+
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
134+
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
135+
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
136+
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
137+
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
138+
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
139+
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
140+
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
141+
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
142+
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
143+
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
144+
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
145+
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
146+
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
147+
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
148+
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
149+
]
150+
elif policy == AutoAugmentPolicy.SVHN:
151+
return [
152+
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
153+
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
154+
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
155+
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
156+
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
157+
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
158+
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
159+
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
160+
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
161+
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
162+
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
163+
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
164+
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
165+
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
166+
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
167+
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
168+
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
169+
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
170+
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
171+
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
172+
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
173+
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
174+
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
175+
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
176+
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
177+
]
178+
else:
162179
raise ValueError("The provided policy {} is not recognized.".format(policy))
163-
self._op_meta = _get_magnitudes()
180+
181+
def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
182+
return {
183+
# name: (magnitudes, signed)
184+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
185+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
186+
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
187+
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
188+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
189+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
190+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
191+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
192+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
193+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
194+
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
195+
"AutoContrast": (torch.tensor(0.0), False),
196+
"Equalize": (torch.tensor(0.0), False),
197+
"Invert": (torch.tensor(0.0), False),
198+
}
164199

165200
@staticmethod
166201
def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
@@ -175,9 +210,6 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
175210

176211
return policy_id, probs, signs
177212

178-
def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]:
179-
return self._op_meta[name]
180-
181213
def forward(self, img: Tensor) -> Tensor:
182214
"""
183215
img (PIL Image or Tensor): Image to be transformed.
@@ -196,46 +228,12 @@ def forward(self, img: Tensor) -> Tensor:
196228

197229
for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]):
198230
if probs[i] <= p:
199-
magnitudes, signed = self._get_op_meta(op_name)
200-
magnitude = float(magnitudes[magnitude_id].item()) \
201-
if magnitudes is not None and magnitude_id is not None else 0.0
202-
if signed is not None and signed and signs[i] == 0:
231+
op_meta = self._get_magnitudes(10, F.get_image_size(img))
232+
magnitudes, signed = op_meta[op_name]
233+
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
234+
if signed and signs[i] == 0:
203235
magnitude *= -1.0
204-
205-
if op_name == "ShearX":
206-
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
207-
interpolation=self.interpolation, fill=fill)
208-
elif op_name == "ShearY":
209-
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
210-
interpolation=self.interpolation, fill=fill)
211-
elif op_name == "TranslateX":
212-
img = F.affine(img, angle=0.0, translate=[int(F.get_image_size(img)[0] * magnitude), 0], scale=1.0,
213-
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
214-
elif op_name == "TranslateY":
215-
img = F.affine(img, angle=0.0, translate=[0, int(F.get_image_size(img)[1] * magnitude)], scale=1.0,
216-
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
217-
elif op_name == "Rotate":
218-
img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill)
219-
elif op_name == "Brightness":
220-
img = F.adjust_brightness(img, 1.0 + magnitude)
221-
elif op_name == "Color":
222-
img = F.adjust_saturation(img, 1.0 + magnitude)
223-
elif op_name == "Contrast":
224-
img = F.adjust_contrast(img, 1.0 + magnitude)
225-
elif op_name == "Sharpness":
226-
img = F.adjust_sharpness(img, 1.0 + magnitude)
227-
elif op_name == "Posterize":
228-
img = F.posterize(img, int(magnitude))
229-
elif op_name == "Solarize":
230-
img = F.solarize(img, magnitude)
231-
elif op_name == "AutoContrast":
232-
img = F.autocontrast(img)
233-
elif op_name == "Equalize":
234-
img = F.equalize(img)
235-
elif op_name == "Invert":
236-
img = F.invert(img)
237-
else:
238-
raise ValueError("The provided operator {} is not recognized.".format(op_name))
236+
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
239237

240238
return img
241239

0 commit comments

Comments
 (0)