Skip to content

Commit 1505603

Browse files
Yosua Michael Maranathafacebook-github-bot
Yosua Michael Maranatha
authored andcommitted
[fbsync] Adding support of Video to remaining Transforms and Kernels (#6724)
Summary: * Adding support of Video to missed Transforms and Kernels * Fixing Grayscale Transform. * Fixing FiveCrop and TenCrop Transforms. * Fix Linter * Fix more kernels. * Add `five_crop_video` and `ten_crop_video` kernels * Added a TODO. * Missed Video isinstance * nits * Fix bug on AugMix * Nits and TODOs. * Reapply Philip's recommendation * Fix mypy and JIT * Fixing test Reviewed By: NicolasHug Differential Revision: D40427468 fbshipit-source-id: e7f699aee86b80ea3f614dc4e09ae1aaf22fc37d
1 parent 8543a62 commit 1505603

File tree

14 files changed

+88
-41
lines changed

14 files changed

+88
-41
lines changed

torchvision/prototype/features/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,14 @@
1313
)
1414
from ._label import Label, OneHotLabel
1515
from ._mask import Mask
16-
from ._video import ImageOrVideoType, ImageOrVideoTypeJIT, TensorImageOrVideoType, TensorImageOrVideoTypeJIT, Video
16+
from ._video import (
17+
ImageOrVideoType,
18+
ImageOrVideoTypeJIT,
19+
LegacyVideoType,
20+
LegacyVideoTypeJIT,
21+
TensorImageOrVideoType,
22+
TensorImageOrVideoTypeJIT,
23+
Video,
24+
VideoType,
25+
VideoTypeJIT,
26+
)

torchvision/prototype/features/_video.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N
238238
TensorVideoType = Union[torch.Tensor, Video]
239239
TensorVideoTypeJIT = torch.Tensor
240240

241+
# TODO: decide if we should do definitions for both Images and Videos or use unions in the methods
241242
ImageOrVideoType = Union[ImageType, VideoType]
242243
ImageOrVideoTypeJIT = Union[ImageTypeJIT, VideoTypeJIT]
243244
TensorImageOrVideoType = Union[TensorImageType, TensorVideoType]

torchvision/prototype/transforms/_augment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) ->
9999
return inpt
100100

101101

102+
# TODO: Add support for Video: https://github.com/pytorch/vision/issues/6731
102103
class _BaseMixupCutmix(_RandomApplyTransform):
103104
def __init__(self, alpha: float, p: float = 0.5) -> None:
104105
super().__init__(p=p)

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,8 @@ def forward(self, *inputs: Any) -> Any:
483483
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
484484

485485
orig_dims = list(image_or_video.shape)
486-
batch = image_or_video.view([1] * max(4 - image_or_video.ndim, 0) + orig_dims)
486+
expected_dim = 5 if isinstance(orig_image_or_video, features.Video) else 4
487+
batch = image_or_video.view([1] * max(expected_dim - image_or_video.ndim, 0) + orig_dims)
487488
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
488489

489490
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
@@ -520,7 +521,7 @@ def forward(self, *inputs: Any) -> Any:
520521
mix = mix.view(orig_dims).to(dtype=image_or_video.dtype)
521522

522523
if isinstance(orig_image_or_video, (features.Image, features.Video)):
523-
mix = type(orig_image_or_video).wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
524+
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
524525
elif isinstance(orig_image_or_video, PIL.Image.Image):
525526
mix = F.to_image_pil(mix)
526527

torchvision/prototype/transforms/_color.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _permute_channels(
119119
output = inpt[..., permutation, :, :]
120120

121121
if isinstance(inpt, (features.Image, features.Video)):
122-
output = type(inpt).wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type]
122+
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type]
123123

124124
elif isinstance(inpt, PIL.Image.Image):
125125
output = F.to_image_pil(output)

torchvision/prototype/transforms/_deprecated.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str,
2929

3030

3131
class Grayscale(Transform):
32-
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
32+
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
3333

3434
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
3535
deprecation_msg = (
@@ -52,15 +52,15 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
5252
super().__init__()
5353
self.num_output_channels = num_output_channels
5454

55-
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
55+
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
5656
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
57-
if isinstance(inpt, features.Image):
58-
output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY)
57+
if isinstance(inpt, (features.Image, features.Video)):
58+
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
5959
return output
6060

6161

6262
class RandomGrayscale(_RandomApplyTransform):
63-
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
63+
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
6464

6565
def __init__(self, p: float = 0.1) -> None:
6666
warnings.warn(
@@ -81,8 +81,8 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
8181
num_input_channels, _, _ = query_chw(sample)
8282
return dict(num_input_channels=num_input_channels)
8383

84-
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
84+
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
8585
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
86-
if isinstance(inpt, features.Image):
87-
output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY)
86+
if isinstance(inpt, (features.Image, features.Video)):
87+
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
8888
return output

torchvision/prototype/transforms/_geometry.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,13 @@ class FiveCrop(Transform):
155155
"""
156156
Example:
157157
>>> class BatchMultiCrop(transforms.Transform):
158-
... def forward(self, sample: Tuple[Tuple[features.Image, ...], features.Label]):
159-
... images, labels = sample
160-
... batch_size = len(images)
161-
... images = features.Image.wrap_like(images[0], torch.stack(images))
158+
... def forward(self, sample: Tuple[Tuple[Union[features.Image, features.Video], ...], features.Label]):
159+
... images_or_videos, labels = sample
160+
... batch_size = len(images_or_videos)
161+
... image_or_video = images_or_videos[0]
162+
... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos))
162163
... labels = features.Label.wrap_like(labels, labels.repeat(batch_size))
163-
... return images, labels
164+
... return images_or_videos, labels
164165
...
165166
>>> image = features.Image(torch.rand(3, 256, 256))
166167
>>> label = features.Label(0)
@@ -172,15 +173,21 @@ class FiveCrop(Transform):
172173
torch.Size([5])
173174
"""
174175

175-
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
176+
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
176177

177178
def __init__(self, size: Union[int, Sequence[int]]) -> None:
178179
super().__init__()
179180
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
180181

181182
def _transform(
182-
self, inpt: features.ImageType, params: Dict[str, Any]
183-
) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]:
183+
self, inpt: features.ImageOrVideoType, params: Dict[str, Any]
184+
) -> Tuple[
185+
features.ImageOrVideoType,
186+
features.ImageOrVideoType,
187+
features.ImageOrVideoType,
188+
features.ImageOrVideoType,
189+
features.ImageOrVideoType,
190+
]:
184191
return F.five_crop(inpt, self.size)
185192

186193
def forward(self, *inputs: Any) -> Any:
@@ -194,14 +201,14 @@ class TenCrop(Transform):
194201
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
195202
"""
196203

197-
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
204+
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
198205

199206
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
200207
super().__init__()
201208
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
202209
self.vertical_flip = vertical_flip
203210

204-
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> List[features.ImageType]:
211+
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> List[features.ImageOrVideoType]:
205212
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
206213

207214
def forward(self, *inputs: Any) -> Any:

torchvision/prototype/transforms/_meta.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> feat
2222

2323

2424
class ConvertImageDtype(Transform):
25-
_transformed_types = (features.is_simple_tensor, features.Image)
25+
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
2626

2727
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
2828
super().__init__()
2929
self.dtype = dtype
3030

31-
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType:
31+
def _transform(
32+
self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]
33+
) -> features.TensorImageOrVideoType:
3234
output = F.convert_image_dtype(inpt, dtype=self.dtype)
3335
return (
34-
output
35-
if features.is_simple_tensor(inpt)
36-
else features.Image.wrap_like(inpt, output) # type: ignore[arg-type]
36+
output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined]
3737
)
3838

3939

torchvision/prototype/transforms/_misc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
140140
return F.gaussian_blur(inpt, self.kernel_size, **params)
141141

142142

143+
# TODO: Enhance as described at https://github.com/pytorch/vision/issues/6697
143144
class ToDtype(Lambda):
144145
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
145146
self.dtype = dtype

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
five_crop,
9797
five_crop_image_pil,
9898
five_crop_image_tensor,
99+
five_crop_video,
99100
hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file
100101
horizontal_flip,
101102
horizontal_flip_bounding_box,
@@ -136,6 +137,7 @@
136137
ten_crop,
137138
ten_crop_image_pil,
138139
ten_crop_image_tensor,
140+
ten_crop_video,
139141
vertical_flip,
140142
vertical_flip_bounding_box,
141143
vertical_flip_image_pil,

torchvision/prototype/transforms/functional/_augment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def erase(
3535
if isinstance(inpt, torch.Tensor):
3636
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
3737
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
38-
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
38+
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
3939
return output
4040
else: # isinstance(inpt, PIL.Image.Image):
4141
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)

torchvision/prototype/transforms/functional/_deprecated.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Any, List
2+
from typing import Any, List, Union
33

44
import PIL.Image
55
import torch
@@ -22,10 +22,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
2222
return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
2323

2424

25-
def rgb_to_grayscale(inpt: features.LegacyImageTypeJIT, num_output_channels: int = 1) -> features.LegacyImageTypeJIT:
25+
def rgb_to_grayscale(
26+
inpt: Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT], num_output_channels: int = 1
27+
) -> Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT]:
2628
old_color_space = (
2729
features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
28-
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image))
30+
if isinstance(inpt, torch.Tensor)
31+
and (torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)))
2932
else None
3033
)
3134

@@ -56,7 +59,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
5659
return _F.to_tensor(inpt)
5760

5861

59-
def get_image_size(inpt: features.ImageTypeJIT) -> List[int]:
62+
def get_image_size(inpt: features.ImageOrVideoTypeJIT) -> List[int]:
6063
warnings.warn(
6164
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
6265
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,16 +1376,27 @@ def five_crop_image_pil(
13761376
return tl, tr, bl, br, center
13771377

13781378

1379+
def five_crop_video(
1380+
video: torch.Tensor, size: List[int]
1381+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1382+
return five_crop_image_tensor(video, size)
1383+
1384+
13791385
def five_crop(
1380-
inpt: features.ImageTypeJIT, size: List[int]
1386+
inpt: features.ImageOrVideoTypeJIT, size: List[int]
13811387
) -> Tuple[
1382-
features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT
1388+
features.ImageOrVideoTypeJIT,
1389+
features.ImageOrVideoTypeJIT,
1390+
features.ImageOrVideoTypeJIT,
1391+
features.ImageOrVideoTypeJIT,
1392+
features.ImageOrVideoTypeJIT,
13831393
]:
1384-
# TODO: consider breaking BC here to return List[features.ImageTypeJIT] to align this op with `ten_crop`
1394+
# TODO: consider breaking BC here to return List[features.ImageOrVideoTypeJIT] to align this op with `ten_crop`
13851395
if isinstance(inpt, torch.Tensor):
13861396
output = five_crop_image_tensor(inpt, size)
1387-
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
1388-
output = tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[assignment]
1397+
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
1398+
tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type]
1399+
output = tmp # type: ignore[assignment]
13891400
return output
13901401
else: # isinstance(inpt, PIL.Image.Image):
13911402
return five_crop_image_pil(inpt, size)
@@ -1418,11 +1429,17 @@ def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: b
14181429
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
14191430

14201431

1421-
def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool = False) -> List[features.ImageTypeJIT]:
1432+
def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
1433+
return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)
1434+
1435+
1436+
def ten_crop(
1437+
inpt: features.ImageOrVideoTypeJIT, size: List[int], vertical_flip: bool = False
1438+
) -> List[features.ImageOrVideoTypeJIT]:
14221439
if isinstance(inpt, torch.Tensor):
14231440
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
1424-
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
1425-
output = [features.Image.wrap_like(inpt, item) for item in output]
1441+
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
1442+
output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type]
14261443
return output
14271444
else: # isinstance(inpt, PIL.Image.Image):
14281445
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
5555
return [height, width]
5656

5757

58+
# TODO: Should we have get_spatial_size_video here? How about masks/bbox etc? What is the criterion for deciding when
59+
# a kernel will be created?
60+
61+
5862
def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]:
5963
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
6064
return get_spatial_size_image_tensor(inpt)
@@ -246,7 +250,7 @@ def convert_color_space(
246250
):
247251
if old_color_space is None:
248252
raise RuntimeError(
249-
"In order to convert the color space of simple tensor images, "
253+
"In order to convert the color space of simple tensors, "
250254
"the `old_color_space=...` parameter needs to be passed."
251255
)
252256
return convert_color_space_image_tensor(

0 commit comments

Comments
 (0)