1
- import functools
2
1
import math
3
2
from typing import Any , Dict , Tuple , Optional , Callable , List , cast , TypeVar , Union
4
3
5
4
import PIL .Image
6
5
import torch
7
6
from torchvision .prototype import features
8
7
from torchvision .prototype .transforms import Transform , InterpolationMode , AutoAugmentPolicy , functional as F
9
- from torchvision .prototype .utils ._internal import apply_recursively
8
+ from torchvision .prototype .utils ._internal import query_recursively
10
9
from torchvision .transforms .functional import pil_to_tensor , to_pil_image
11
10
12
- from ._utils import query_images , get_image_dimensions
11
+ from ._utils import get_image_dimensions
13
12
14
13
K = TypeVar ("K" )
15
14
V = TypeVar ("V" )
16
15
17
16
17
+ def _put_into_sample (sample : Any , id : Tuple [Any , ...], item : Any ) -> Any :
18
+ if not id :
19
+ return item
20
+
21
+ parent = sample
22
+ for key in id [:- 1 ]:
23
+ parent = parent [key ]
24
+
25
+ parent [id [- 1 ]] = item
26
+ return sample
27
+
28
+
18
29
class _AutoAugmentBase (Transform ):
19
30
def __init__ (
20
31
self , * , interpolation : InterpolationMode = InterpolationMode .NEAREST , fill : Optional [List [float ]] = None
@@ -28,68 +39,77 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
28
39
key = keys [int (torch .randint (len (keys ), ()))]
29
40
return key , dct [key ]
30
41
31
- def _query_image (self , sample : Any ) -> Union [PIL .Image .Image , torch .Tensor , features .Image ]:
32
- images = list (query_images (sample ))
42
+ def _check_support (self , input : Any ) -> None :
43
+ if isinstance (input , (features .BoundingBox , features .SegmentationMask )):
44
+ raise TypeError (f"{ type (input ).__name__ } 's are not supported by { type (self ).__name__ } ()" )
45
+
46
+ def _extract_image (
47
+ self , sample : Any
48
+ ) -> Tuple [Tuple [Any , ...], Union [PIL .Image .Image , torch .Tensor , features .Image ]]:
49
+ def fn (
50
+ id : Tuple [Any , ...], input : Any
51
+ ) -> Optional [Tuple [Tuple [Any , ...], Union [PIL .Image .Image , torch .Tensor , features .Image ]]]:
52
+ if type (input ) in {torch .Tensor , features .Image } or isinstance (input , PIL .Image .Image ):
53
+ return id , input
54
+
55
+ self ._check_support (input )
56
+ return None
57
+
58
+ images = list (query_recursively (fn , sample ))
59
+ if not images :
60
+ raise TypeError ("Found no image in the sample." )
33
61
if len (images ) > 1 :
34
62
raise TypeError (
35
63
f"Auto augment transformations are only properly defined for a single image, but found { len (images )} ."
36
64
)
37
65
return images [0 ]
38
66
39
- def _parse_fill (self , sample : Any ) -> Optional [List [float ]]:
67
+ def _parse_fill (
68
+ self , image : Union [PIL .Image .Image , torch .Tensor , features .Image ], num_channels : int
69
+ ) -> Optional [List [float ]]:
40
70
fill = self .fill
41
71
42
- if fill is None :
43
- return fill
44
-
45
- image = self ._query_image (sample )
46
-
47
- if not isinstance (image , torch .Tensor ):
72
+ if isinstance (image , PIL .Image .Image ) or fill is None :
48
73
return fill
49
74
50
75
if isinstance (fill , (int , float )):
51
- num_channels , * _ = get_image_dimensions (image )
52
76
fill = [float (fill )] * num_channels
53
77
else :
54
78
fill = [float (f ) for f in fill ]
55
79
56
80
return fill
57
81
58
- def _dispatch (
82
+ def _dispatch_image_kernels (
59
83
self ,
60
84
image_tensor_kernel : Callable ,
61
85
image_pil_kernel : Callable ,
62
86
input : Any ,
63
87
* args : Any ,
64
88
** kwargs : Any ,
65
89
) -> Any :
66
- if isinstance (input , (features .BoundingBox , features .SegmentationMask )):
67
- raise TypeError (f"{ type (input ).__name__ } 's are not supported by { type (self ).__name__ } ()" )
68
- elif isinstance (input , features .Image ):
90
+ if isinstance (input , features .Image ):
69
91
output = image_tensor_kernel (input , * args , ** kwargs )
70
92
return features .Image .new_like (input , output )
71
93
elif isinstance (input , torch .Tensor ):
72
94
return image_tensor_kernel (input , * args , ** kwargs )
73
- elif isinstance (input , PIL .Image .Image ):
95
+ else : # isinstance(input, PIL.Image.Image):
74
96
return image_pil_kernel (input , * args , ** kwargs )
75
- else :
76
- return input
77
97
78
- def _apply_transform_to_item (
98
+ def _apply_image_transform (
79
99
self ,
80
- item : Any ,
100
+ image : Any ,
81
101
transform_id : str ,
82
102
magnitude : float ,
83
103
interpolation : InterpolationMode ,
84
104
fill : Optional [List [float ]],
85
105
) -> Any :
86
106
if transform_id == "Identity" :
87
- return item
107
+ return image
88
108
elif transform_id == "ShearX" :
89
- return self ._dispatch (
109
+ return self ._dispatch_image_kernels (
90
110
F .affine_image_tensor ,
91
111
F .affine_image_pil ,
92
- item ,
112
+ image ,
93
113
angle = 0.0 ,
94
114
translate = [0 , 0 ],
95
115
scale = 1.0 ,
@@ -98,10 +118,10 @@ def _apply_transform_to_item(
98
118
fill = fill ,
99
119
)
100
120
elif transform_id == "ShearY" :
101
- return self ._dispatch (
121
+ return self ._dispatch_image_kernels (
102
122
F .affine_image_tensor ,
103
123
F .affine_image_pil ,
104
- item ,
124
+ image ,
105
125
angle = 0.0 ,
106
126
translate = [0 , 0 ],
107
127
scale = 1.0 ,
@@ -110,10 +130,10 @@ def _apply_transform_to_item(
110
130
fill = fill ,
111
131
)
112
132
elif transform_id == "TranslateX" :
113
- return self ._dispatch (
133
+ return self ._dispatch_image_kernels (
114
134
F .affine_image_tensor ,
115
135
F .affine_image_pil ,
116
- item ,
136
+ image ,
117
137
angle = 0.0 ,
118
138
translate = [int (magnitude ), 0 ],
119
139
scale = 1.0 ,
@@ -122,10 +142,10 @@ def _apply_transform_to_item(
122
142
fill = fill ,
123
143
)
124
144
elif transform_id == "TranslateY" :
125
- return self ._dispatch (
145
+ return self ._dispatch_image_kernels (
126
146
F .affine_image_tensor ,
127
147
F .affine_image_pil ,
128
- item ,
148
+ image ,
129
149
angle = 0.0 ,
130
150
translate = [0 , int (magnitude )],
131
151
scale = 1.0 ,
@@ -134,57 +154,49 @@ def _apply_transform_to_item(
134
154
fill = fill ,
135
155
)
136
156
elif transform_id == "Rotate" :
137
- return self ._dispatch (F .rotate_image_tensor , F .rotate_image_pil , item , angle = magnitude )
157
+ return self ._dispatch_image_kernels (F .rotate_image_tensor , F .rotate_image_pil , image , angle = magnitude )
138
158
elif transform_id == "Brightness" :
139
- return self ._dispatch (
159
+ return self ._dispatch_image_kernels (
140
160
F .adjust_brightness_image_tensor ,
141
161
F .adjust_brightness_image_pil ,
142
- item ,
162
+ image ,
143
163
brightness_factor = 1.0 + magnitude ,
144
164
)
145
165
elif transform_id == "Color" :
146
- return self ._dispatch (
166
+ return self ._dispatch_image_kernels (
147
167
F .adjust_saturation_image_tensor ,
148
168
F .adjust_saturation_image_pil ,
149
- item ,
169
+ image ,
150
170
saturation_factor = 1.0 + magnitude ,
151
171
)
152
172
elif transform_id == "Contrast" :
153
- return self ._dispatch (
154
- F .adjust_contrast_image_tensor , F .adjust_contrast_image_pil , item , contrast_factor = 1.0 + magnitude
173
+ return self ._dispatch_image_kernels (
174
+ F .adjust_contrast_image_tensor , F .adjust_contrast_image_pil , image , contrast_factor = 1.0 + magnitude
155
175
)
156
176
elif transform_id == "Sharpness" :
157
- return self ._dispatch (
177
+ return self ._dispatch_image_kernels (
158
178
F .adjust_sharpness_image_tensor ,
159
179
F .adjust_sharpness_image_pil ,
160
- item ,
180
+ image ,
161
181
sharpness_factor = 1.0 + magnitude ,
162
182
)
163
183
elif transform_id == "Posterize" :
164
- return self ._dispatch (F .posterize_image_tensor , F .posterize_image_pil , item , bits = int (magnitude ))
184
+ return self ._dispatch_image_kernels (
185
+ F .posterize_image_tensor , F .posterize_image_pil , image , bits = int (magnitude )
186
+ )
165
187
elif transform_id == "Solarize" :
166
- return self ._dispatch (F .solarize_image_tensor , F .solarize_image_pil , item , threshold = magnitude )
188
+ return self ._dispatch_image_kernels (
189
+ F .solarize_image_tensor , F .solarize_image_pil , image , threshold = magnitude
190
+ )
167
191
elif transform_id == "AutoContrast" :
168
- return self ._dispatch (F .autocontrast_image_tensor , F .autocontrast_image_pil , item )
192
+ return self ._dispatch_image_kernels (F .autocontrast_image_tensor , F .autocontrast_image_pil , image )
169
193
elif transform_id == "Equalize" :
170
- return self ._dispatch (F .equalize_image_tensor , F .equalize_image_pil , item )
194
+ return self ._dispatch_image_kernels (F .equalize_image_tensor , F .equalize_image_pil , image )
171
195
elif transform_id == "Invert" :
172
- return self ._dispatch (F .invert_image_tensor , F .invert_image_pil , item )
196
+ return self ._dispatch_image_kernels (F .invert_image_tensor , F .invert_image_pil , image )
173
197
else :
174
198
raise ValueError (f"No transform available for { transform_id } " )
175
199
176
- def _apply_transform_to_sample (self , sample : Any , transform_id : str , magnitude : float ) -> Any :
177
- return apply_recursively (
178
- functools .partial (
179
- self ._apply_transform_to_item ,
180
- transform_id = transform_id ,
181
- magnitude = magnitude ,
182
- interpolation = self .interpolation ,
183
- fill = self ._parse_fill (sample ),
184
- ),
185
- sample ,
186
- )
187
-
188
200
189
201
class AutoAugment (_AutoAugmentBase ):
190
202
_AUGMENTATION_SPACE = {
@@ -307,8 +319,9 @@ def _get_policies(
307
319
def forward (self , * inputs : Any ) -> Any :
308
320
sample = inputs if len (inputs ) > 1 else inputs [0 ]
309
321
310
- image = self ._query_image (sample )
311
- _ , height , width = get_image_dimensions (image )
322
+ id , image = self ._extract_image (sample )
323
+ num_channels , height , width = get_image_dimensions (image )
324
+ fill = self ._parse_fill (image , num_channels )
312
325
313
326
policy = self ._policies [int (torch .randint (len (self ._policies ), ()))]
314
327
@@ -326,9 +339,11 @@ def forward(self, *inputs: Any) -> Any:
326
339
else :
327
340
magnitude = 0.0
328
341
329
- sample = self ._apply_transform_to_sample (sample , transform_id , magnitude )
342
+ image = self ._apply_image_transform (
343
+ image , transform_id , magnitude , interpolation = self .interpolation , fill = fill
344
+ )
330
345
331
- return sample
346
+ return _put_into_sample ( sample , id , image )
332
347
333
348
334
349
class RandAugment (_AutoAugmentBase ):
@@ -363,8 +378,9 @@ def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins:
363
378
def forward (self , * inputs : Any ) -> Any :
364
379
sample = inputs if len (inputs ) > 1 else inputs [0 ]
365
380
366
- image = self ._query_image (sample )
367
- _ , height , width = get_image_dimensions (image )
381
+ id , image = self ._extract_image (sample )
382
+ num_channels , height , width = get_image_dimensions (image )
383
+ fill = self ._parse_fill (image , num_channels )
368
384
369
385
for _ in range (self .num_ops ):
370
386
transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
@@ -377,9 +393,11 @@ def forward(self, *inputs: Any) -> Any:
377
393
else :
378
394
magnitude = 0.0
379
395
380
- sample = self ._apply_transform_to_sample (sample , transform_id , magnitude )
396
+ image = self ._apply_image_transform (
397
+ image , transform_id , magnitude , interpolation = self .interpolation , fill = fill
398
+ )
381
399
382
- return sample
400
+ return _put_into_sample ( sample , id , image )
383
401
384
402
385
403
class TrivialAugmentWide (_AutoAugmentBase ):
@@ -412,8 +430,9 @@ def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any):
412
430
def forward (self , * inputs : Any ) -> Any :
413
431
sample = inputs if len (inputs ) > 1 else inputs [0 ]
414
432
415
- image = self ._query_image (sample )
416
- _ , height , width = get_image_dimensions (image )
433
+ id , image = self ._extract_image (sample )
434
+ num_channels , height , width = get_image_dimensions (image )
435
+ fill = self ._parse_fill (image , num_channels )
417
436
418
437
transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
419
438
@@ -425,7 +444,11 @@ def forward(self, *inputs: Any) -> Any:
425
444
else :
426
445
magnitude = 0.0
427
446
428
- return self ._apply_transform_to_sample (sample , transform_id , magnitude )
447
+ return _put_into_sample (
448
+ sample ,
449
+ id ,
450
+ self ._apply_image_transform (sample , transform_id , magnitude , interpolation = self .interpolation , fill = fill ),
451
+ )
429
452
430
453
431
454
class AugMix (_AutoAugmentBase ):
@@ -476,20 +499,18 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
476
499
# Must be on a separate method so that we can overwrite it in tests.
477
500
return torch ._sample_dirichlet (params )
478
501
479
- def _apply_augmix (self , input : Any ) -> Any :
480
- if isinstance (input , (features .BoundingBox , features .SegmentationMask )):
481
- raise TypeError (f"{ type (input ).__name__ } 's are not supported by { type (self ).__name__ } ()" )
482
- elif isinstance (input , torch .Tensor ):
483
- image = input
484
- elif isinstance (input , PIL .Image .Image ):
485
- image = pil_to_tensor (input )
486
- else :
487
- return input
502
+ def forward (self , * inputs : Any ) -> Any :
503
+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
504
+ id , orig_image = self ._extract_image (sample )
505
+ num_channels , height , width = get_image_dimensions (orig_image )
506
+ fill = self ._parse_fill (orig_image , num_channels )
488
507
489
- augmentation_space = self ._AUGMENTATION_SPACE if self .all_ops else self ._PARTIAL_AUGMENTATION_SPACE
508
+ if isinstance (orig_image , torch .Tensor ):
509
+ image = orig_image
510
+ else : # isinstance(input, PIL.Image.Image):
511
+ image = pil_to_tensor (orig_image )
490
512
491
- _ , height , width = get_image_dimensions (image )
492
- fill = self ._parse_fill (image )
513
+ augmentation_space = self ._AUGMENTATION_SPACE if self .all_ops else self ._PARTIAL_AUGMENTATION_SPACE
493
514
494
515
orig_dims = list (image .shape )
495
516
batch = image .view ([1 ] * max (4 - image .ndim , 0 ) + orig_dims )
@@ -521,20 +542,15 @@ def _apply_augmix(self, input: Any) -> Any:
521
542
else :
522
543
magnitude = 0.0
523
544
524
- aug = self ._apply_transform_to_item (
545
+ aug = self ._apply_image_transform (
525
546
image , transform_id , magnitude , interpolation = self .interpolation , fill = fill
526
547
)
527
548
mix .add_ (combined_weights [:, i ].view (batch_dims ) * aug )
528
549
mix = mix .view (orig_dims ).to (dtype = image .dtype )
529
550
530
- if isinstance (input , features .Image ):
531
- return features .Image .new_like (input , mix )
532
- elif isinstance (input , torch .Tensor ):
533
- return mix
534
- else : # isinstance(input, PIL.Image.Image):
535
- return to_pil_image (mix )
551
+ if isinstance (orig_image , features .Image ):
552
+ mix = features .Image .new_like (orig_image , mix )
553
+ elif isinstance (orig_image , PIL .Image .Image ):
554
+ mix = to_pil_image (mix )
536
555
537
- def forward (self , * inputs : Any ) -> Any :
538
- sample = inputs if len (inputs ) > 1 else inputs [0 ]
539
- self ._query_image (sample )
540
- return apply_recursively (self ._apply_augmix , sample )
556
+ return _put_into_sample (sample , id , mix )
0 commit comments