Skip to content

Commit 3a455d8

Browse files
committed
address review comments
1 parent 9782f24 commit 3a455d8

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
3939
key = keys[int(torch.randint(len(keys), ()))]
4040
return key, dct[key]
4141

42-
def _check_support(self, input: Any) -> None:
42+
def _check_unsupported(self, input: Any) -> None:
4343
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
4444
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
4545

@@ -52,7 +52,7 @@ def fn(
5252
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
5353
return id, input
5454

55-
self._check_support(input)
55+
self._check_unsupported(input)
5656
return None
5757

5858
images = list(query_recursively(fn, sample))
@@ -444,11 +444,8 @@ def forward(self, *inputs: Any) -> Any:
444444
else:
445445
magnitude = 0.0
446446

447-
return _put_into_sample(
448-
sample,
449-
id,
450-
self._apply_image_transform(sample, transform_id, magnitude, interpolation=self.interpolation, fill=fill),
451-
)
447+
image = self._apply_image_transform(image, transform_id, magnitude, interpolation=self.interpolation, fill=fill)
448+
return _put_into_sample(sample, id, image)
452449

453450

454451
class AugMix(_AutoAugmentBase):
@@ -543,7 +540,7 @@ def forward(self, *inputs: Any) -> Any:
543540
magnitude = 0.0
544541

545542
aug = self._apply_image_transform(
546-
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill
543+
aug, transform_id, magnitude, interpolation=self.interpolation, fill=fill
547544
)
548545
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
549546
mix = mix.view(orig_dims).to(dtype=image.dtype)

0 commit comments

Comments
 (0)