Skip to content

Commit b6375e6

Browse files
committed
refactor auto augment subclasses to only trnasform a single image
1 parent 67c3056 commit b6375e6

File tree

3 files changed

+126
-110
lines changed

3 files changed

+126
-110
lines changed

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 107 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
1-
import functools
21
import math
32
from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union
43

54
import PIL.Image
65
import torch
76
from torchvision.prototype import features
87
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F
9-
from torchvision.prototype.utils._internal import apply_recursively
8+
from torchvision.prototype.utils._internal import query_recursively
109
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
1110

12-
from ._utils import query_images, get_image_dimensions
11+
from ._utils import get_image_dimensions
1312

1413
K = TypeVar("K")
1514
V = TypeVar("V")
1615

1716

17+
def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any:
18+
if not id:
19+
return item
20+
21+
parent = sample
22+
for key in id[:-1]:
23+
parent = parent[key]
24+
25+
parent[id[-1]] = item
26+
return sample
27+
28+
1829
class _AutoAugmentBase(Transform):
1930
def __init__(
2031
self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None
@@ -28,68 +39,77 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
2839
key = keys[int(torch.randint(len(keys), ()))]
2940
return key, dct[key]
3041

31-
def _query_image(self, sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
32-
images = list(query_images(sample))
42+
def _check_support(self, input: Any) -> None:
43+
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
44+
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
45+
46+
def _extract_image(
47+
self, sample: Any
48+
) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]:
49+
def fn(
50+
id: Tuple[Any, ...], input: Any
51+
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
52+
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
53+
return id, input
54+
55+
self._check_support(input)
56+
return None
57+
58+
images = list(query_recursively(fn, sample))
59+
if not images:
60+
raise TypeError("Found no image in the sample.")
3361
if len(images) > 1:
3462
raise TypeError(
3563
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
3664
)
3765
return images[0]
3866

39-
def _parse_fill(self, sample: Any) -> Optional[List[float]]:
67+
def _parse_fill(
68+
self, image: Union[PIL.Image.Image, torch.Tensor, features.Image], num_channels: int
69+
) -> Optional[List[float]]:
4070
fill = self.fill
4171

42-
if fill is None:
43-
return fill
44-
45-
image = self._query_image(sample)
46-
47-
if not isinstance(image, torch.Tensor):
72+
if isinstance(image, PIL.Image.Image) or fill is None:
4873
return fill
4974

5075
if isinstance(fill, (int, float)):
51-
num_channels, *_ = get_image_dimensions(image)
5276
fill = [float(fill)] * num_channels
5377
else:
5478
fill = [float(f) for f in fill]
5579

5680
return fill
5781

58-
def _dispatch(
82+
def _dispatch_image_kernels(
5983
self,
6084
image_tensor_kernel: Callable,
6185
image_pil_kernel: Callable,
6286
input: Any,
6387
*args: Any,
6488
**kwargs: Any,
6589
) -> Any:
66-
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
67-
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
68-
elif isinstance(input, features.Image):
90+
if isinstance(input, features.Image):
6991
output = image_tensor_kernel(input, *args, **kwargs)
7092
return features.Image.new_like(input, output)
7193
elif isinstance(input, torch.Tensor):
7294
return image_tensor_kernel(input, *args, **kwargs)
73-
elif isinstance(input, PIL.Image.Image):
95+
else: # isinstance(input, PIL.Image.Image):
7496
return image_pil_kernel(input, *args, **kwargs)
75-
else:
76-
return input
7797

78-
def _apply_transform_to_item(
98+
def _apply_image_transform(
7999
self,
80-
item: Any,
100+
image: Any,
81101
transform_id: str,
82102
magnitude: float,
83103
interpolation: InterpolationMode,
84104
fill: Optional[List[float]],
85105
) -> Any:
86106
if transform_id == "Identity":
87-
return item
107+
return image
88108
elif transform_id == "ShearX":
89-
return self._dispatch(
109+
return self._dispatch_image_kernels(
90110
F.affine_image_tensor,
91111
F.affine_image_pil,
92-
item,
112+
image,
93113
angle=0.0,
94114
translate=[0, 0],
95115
scale=1.0,
@@ -98,10 +118,10 @@ def _apply_transform_to_item(
98118
fill=fill,
99119
)
100120
elif transform_id == "ShearY":
101-
return self._dispatch(
121+
return self._dispatch_image_kernels(
102122
F.affine_image_tensor,
103123
F.affine_image_pil,
104-
item,
124+
image,
105125
angle=0.0,
106126
translate=[0, 0],
107127
scale=1.0,
@@ -110,10 +130,10 @@ def _apply_transform_to_item(
110130
fill=fill,
111131
)
112132
elif transform_id == "TranslateX":
113-
return self._dispatch(
133+
return self._dispatch_image_kernels(
114134
F.affine_image_tensor,
115135
F.affine_image_pil,
116-
item,
136+
image,
117137
angle=0.0,
118138
translate=[int(magnitude), 0],
119139
scale=1.0,
@@ -122,10 +142,10 @@ def _apply_transform_to_item(
122142
fill=fill,
123143
)
124144
elif transform_id == "TranslateY":
125-
return self._dispatch(
145+
return self._dispatch_image_kernels(
126146
F.affine_image_tensor,
127147
F.affine_image_pil,
128-
item,
148+
image,
129149
angle=0.0,
130150
translate=[0, int(magnitude)],
131151
scale=1.0,
@@ -134,57 +154,49 @@ def _apply_transform_to_item(
134154
fill=fill,
135155
)
136156
elif transform_id == "Rotate":
137-
return self._dispatch(F.rotate_image_tensor, F.rotate_image_pil, item, angle=magnitude)
157+
return self._dispatch_image_kernels(F.rotate_image_tensor, F.rotate_image_pil, image, angle=magnitude)
138158
elif transform_id == "Brightness":
139-
return self._dispatch(
159+
return self._dispatch_image_kernels(
140160
F.adjust_brightness_image_tensor,
141161
F.adjust_brightness_image_pil,
142-
item,
162+
image,
143163
brightness_factor=1.0 + magnitude,
144164
)
145165
elif transform_id == "Color":
146-
return self._dispatch(
166+
return self._dispatch_image_kernels(
147167
F.adjust_saturation_image_tensor,
148168
F.adjust_saturation_image_pil,
149-
item,
169+
image,
150170
saturation_factor=1.0 + magnitude,
151171
)
152172
elif transform_id == "Contrast":
153-
return self._dispatch(
154-
F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, item, contrast_factor=1.0 + magnitude
173+
return self._dispatch_image_kernels(
174+
F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, image, contrast_factor=1.0 + magnitude
155175
)
156176
elif transform_id == "Sharpness":
157-
return self._dispatch(
177+
return self._dispatch_image_kernels(
158178
F.adjust_sharpness_image_tensor,
159179
F.adjust_sharpness_image_pil,
160-
item,
180+
image,
161181
sharpness_factor=1.0 + magnitude,
162182
)
163183
elif transform_id == "Posterize":
164-
return self._dispatch(F.posterize_image_tensor, F.posterize_image_pil, item, bits=int(magnitude))
184+
return self._dispatch_image_kernels(
185+
F.posterize_image_tensor, F.posterize_image_pil, image, bits=int(magnitude)
186+
)
165187
elif transform_id == "Solarize":
166-
return self._dispatch(F.solarize_image_tensor, F.solarize_image_pil, item, threshold=magnitude)
188+
return self._dispatch_image_kernels(
189+
F.solarize_image_tensor, F.solarize_image_pil, image, threshold=magnitude
190+
)
167191
elif transform_id == "AutoContrast":
168-
return self._dispatch(F.autocontrast_image_tensor, F.autocontrast_image_pil, item)
192+
return self._dispatch_image_kernels(F.autocontrast_image_tensor, F.autocontrast_image_pil, image)
169193
elif transform_id == "Equalize":
170-
return self._dispatch(F.equalize_image_tensor, F.equalize_image_pil, item)
194+
return self._dispatch_image_kernels(F.equalize_image_tensor, F.equalize_image_pil, image)
171195
elif transform_id == "Invert":
172-
return self._dispatch(F.invert_image_tensor, F.invert_image_pil, item)
196+
return self._dispatch_image_kernels(F.invert_image_tensor, F.invert_image_pil, image)
173197
else:
174198
raise ValueError(f"No transform available for {transform_id}")
175199

176-
def _apply_transform_to_sample(self, sample: Any, transform_id: str, magnitude: float) -> Any:
177-
return apply_recursively(
178-
functools.partial(
179-
self._apply_transform_to_item,
180-
transform_id=transform_id,
181-
magnitude=magnitude,
182-
interpolation=self.interpolation,
183-
fill=self._parse_fill(sample),
184-
),
185-
sample,
186-
)
187-
188200

189201
class AutoAugment(_AutoAugmentBase):
190202
_AUGMENTATION_SPACE = {
@@ -307,8 +319,9 @@ def _get_policies(
307319
def forward(self, *inputs: Any) -> Any:
308320
sample = inputs if len(inputs) > 1 else inputs[0]
309321

310-
image = self._query_image(sample)
311-
_, height, width = get_image_dimensions(image)
322+
id, image = self._extract_image(sample)
323+
num_channels, height, width = get_image_dimensions(image)
324+
fill = self._parse_fill(image, num_channels)
312325

313326
policy = self._policies[int(torch.randint(len(self._policies), ()))]
314327

@@ -326,9 +339,11 @@ def forward(self, *inputs: Any) -> Any:
326339
else:
327340
magnitude = 0.0
328341

329-
sample = self._apply_transform_to_sample(sample, transform_id, magnitude)
342+
image = self._apply_image_transform(
343+
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill
344+
)
330345

331-
return sample
346+
return _put_into_sample(sample, id, image)
332347

333348

334349
class RandAugment(_AutoAugmentBase):
@@ -363,8 +378,9 @@ def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins:
363378
def forward(self, *inputs: Any) -> Any:
364379
sample = inputs if len(inputs) > 1 else inputs[0]
365380

366-
image = self._query_image(sample)
367-
_, height, width = get_image_dimensions(image)
381+
id, image = self._extract_image(sample)
382+
num_channels, height, width = get_image_dimensions(image)
383+
fill = self._parse_fill(image, num_channels)
368384

369385
for _ in range(self.num_ops):
370386
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
@@ -377,9 +393,11 @@ def forward(self, *inputs: Any) -> Any:
377393
else:
378394
magnitude = 0.0
379395

380-
sample = self._apply_transform_to_sample(sample, transform_id, magnitude)
396+
image = self._apply_image_transform(
397+
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill
398+
)
381399

382-
return sample
400+
return _put_into_sample(sample, id, image)
383401

384402

385403
class TrivialAugmentWide(_AutoAugmentBase):
@@ -412,8 +430,9 @@ def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any):
412430
def forward(self, *inputs: Any) -> Any:
413431
sample = inputs if len(inputs) > 1 else inputs[0]
414432

415-
image = self._query_image(sample)
416-
_, height, width = get_image_dimensions(image)
433+
id, image = self._extract_image(sample)
434+
num_channels, height, width = get_image_dimensions(image)
435+
fill = self._parse_fill(image, num_channels)
417436

418437
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
419438

@@ -425,7 +444,11 @@ def forward(self, *inputs: Any) -> Any:
425444
else:
426445
magnitude = 0.0
427446

428-
return self._apply_transform_to_sample(sample, transform_id, magnitude)
447+
return _put_into_sample(
448+
sample,
449+
id,
450+
self._apply_image_transform(sample, transform_id, magnitude, interpolation=self.interpolation, fill=fill),
451+
)
429452

430453

431454
class AugMix(_AutoAugmentBase):
@@ -476,20 +499,18 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
476499
# Must be on a separate method so that we can overwrite it in tests.
477500
return torch._sample_dirichlet(params)
478501

479-
def _apply_augmix(self, input: Any) -> Any:
480-
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
481-
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
482-
elif isinstance(input, torch.Tensor):
483-
image = input
484-
elif isinstance(input, PIL.Image.Image):
485-
image = pil_to_tensor(input)
486-
else:
487-
return input
502+
def forward(self, *inputs: Any) -> Any:
503+
sample = inputs if len(inputs) > 1 else inputs[0]
504+
id, orig_image = self._extract_image(sample)
505+
num_channels, height, width = get_image_dimensions(orig_image)
506+
fill = self._parse_fill(orig_image, num_channels)
488507

489-
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
508+
if isinstance(orig_image, torch.Tensor):
509+
image = orig_image
510+
else: # isinstance(input, PIL.Image.Image):
511+
image = pil_to_tensor(orig_image)
490512

491-
_, height, width = get_image_dimensions(image)
492-
fill = self._parse_fill(image)
513+
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
493514

494515
orig_dims = list(image.shape)
495516
batch = image.view([1] * max(4 - image.ndim, 0) + orig_dims)
@@ -521,20 +542,15 @@ def _apply_augmix(self, input: Any) -> Any:
521542
else:
522543
magnitude = 0.0
523544

524-
aug = self._apply_transform_to_item(
545+
aug = self._apply_image_transform(
525546
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill
526547
)
527548
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
528549
mix = mix.view(orig_dims).to(dtype=image.dtype)
529550

530-
if isinstance(input, features.Image):
531-
return features.Image.new_like(input, mix)
532-
elif isinstance(input, torch.Tensor):
533-
return mix
534-
else: # isinstance(input, PIL.Image.Image):
535-
return to_pil_image(mix)
551+
if isinstance(orig_image, features.Image):
552+
mix = features.Image.new_like(orig_image, mix)
553+
elif isinstance(orig_image, PIL.Image.Image):
554+
mix = to_pil_image(mix)
536555

537-
def forward(self, *inputs: Any) -> Any:
538-
sample = inputs if len(inputs) > 1 else inputs[0]
539-
self._query_image(sample)
540-
return apply_recursively(self._apply_augmix, sample)
556+
return _put_into_sample(sample, id, mix)

0 commit comments

Comments
 (0)