@@ -39,7 +39,7 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
39
39
key = keys [int (torch .randint (len (keys ), ()))]
40
40
return key , dct [key ]
41
41
42
- def _check_support (self , input : Any ) -> None :
42
+ def _check_unsupported (self , input : Any ) -> None :
43
43
if isinstance (input , (features .BoundingBox , features .SegmentationMask )):
44
44
raise TypeError (f"{ type (input ).__name__ } 's are not supported by { type (self ).__name__ } ()" )
45
45
@@ -52,7 +52,7 @@ def fn(
52
52
if type (input ) in {torch .Tensor , features .Image } or isinstance (input , PIL .Image .Image ):
53
53
return id , input
54
54
55
- self ._check_support (input )
55
+ self ._check_unsupported (input )
56
56
return None
57
57
58
58
images = list (query_recursively (fn , sample ))
@@ -444,11 +444,8 @@ def forward(self, *inputs: Any) -> Any:
444
444
else :
445
445
magnitude = 0.0
446
446
447
- return _put_into_sample (
448
- sample ,
449
- id ,
450
- self ._apply_image_transform (sample , transform_id , magnitude , interpolation = self .interpolation , fill = fill ),
451
- )
447
+ image = self ._apply_image_transform (image , transform_id , magnitude , interpolation = self .interpolation , fill = fill )
448
+ return _put_into_sample (sample , id , image )
452
449
453
450
454
451
class AugMix (_AutoAugmentBase ):
@@ -543,7 +540,7 @@ def forward(self, *inputs: Any) -> Any:
543
540
magnitude = 0.0
544
541
545
542
aug = self ._apply_image_transform (
546
- image , transform_id , magnitude , interpolation = self .interpolation , fill = fill
543
+ aug , transform_id , magnitude , interpolation = self .interpolation , fill = fill
547
544
)
548
545
mix .add_ (combined_weights [:, i ].view (batch_dims ) * aug )
549
546
mix = mix .view (orig_dims ).to (dtype = image .dtype )
0 commit comments