Skip to content

Commit 9cbec19

Browse files
committed
Fix bug on AugMix
1 parent 7e7701d commit 9cbec19

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,8 @@ def forward(self, *inputs: Any) -> Any:
483483
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
484484

485485
orig_dims = list(image_or_video.shape)
486-
batch = image_or_video.view([1] * max(4 - image_or_video.ndim, 0) + orig_dims)
486+
expected_dim = 5 if isinstance(orig_image_or_video, features.Video) else 4
487+
batch = image_or_video.view([1] * max(expected_dim - image_or_video.ndim, 0) + orig_dims)
487488
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
488489

489490
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a

0 commit comments

Comments
 (0)