Skip to content

Commit c9e876b

Browse files
committed
Merge remote-tracking branch 'origin/prototype/video_corrections' into prototype/video_corrections
2 parents 1f08c38 + 584ccee commit c9e876b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchvision/prototype/transforms/_geometry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ class FiveCrop(Transform):
158158
... def forward(self, sample: Tuple[Tuple[Union[features.Image, features.Video], ...], features.Label]):
159159
... images_or_videos, labels = sample
160160
... batch_size = len(images_or_videos)
161-
... images_or_videos = features.Image.wrap_like(images_or_videos[0], torch.stack(images_or_videos))
161+
... image_or_video = images_or_videos[0]
162+
... images_or_videos = type(image_or_video).wrap_like(image_or_video, torch.stack(images_or_videos))
162163
... labels = features.Label.wrap_like(labels, labels.repeat(batch_size))
163164
... return images_or_videos, labels
164165
...

0 commit comments

Comments
 (0)