Skip to content

Commit 11bc8c2

Browse files
committed
Nits and TODOs.
1 parent 9cbec19 commit 11bc8c2

File tree

7 files changed

+9
-9
lines changed

7 files changed

+9
-9
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) ->
9999
return inpt
100100

101101

102+
# TODO: Add support for Video: https://github.com/pytorch/vision/issues/6731
102103
class _BaseMixupCutmix(_RandomApplyTransform):
103104
def __init__(self, alpha: float, p: float = 0.5) -> None:
104105
super().__init__(p=p)

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def forward(self, *inputs: Any) -> Any:
521521
mix = mix.view(orig_dims).to(dtype=image_or_video.dtype)
522522

523523
if isinstance(orig_image_or_video, (features.Image, features.Video)):
524-
mix = type(orig_image_or_video).wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
524+
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
525525
elif isinstance(orig_image_or_video, PIL.Image.Image):
526526
mix = F.to_image_pil(mix)
527527

torchvision/prototype/transforms/_color.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _permute_channels(
119119
output = inpt[..., permutation, :, :]
120120

121121
if isinstance(inpt, (features.Image, features.Video)):
122-
output = type(inpt).wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type]
122+
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type]
123123

124124
elif isinstance(inpt, PIL.Image.Image):
125125
output = F.to_image_pil(output)

torchvision/prototype/transforms/_deprecated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
5555
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
5656
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
5757
if isinstance(inpt, (features.Image, features.Video)):
58-
output = type(inpt).wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
58+
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
5959
return output
6060

6161

@@ -84,5 +84,5 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
8484
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
8585
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
8686
if isinstance(inpt, (features.Image, features.Video)):
87-
output = type(inpt).wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
87+
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
8888
return output

torchvision/prototype/transforms/_misc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
140140
return F.gaussian_blur(inpt, self.kernel_size, **params)
141141

142142

143+
# TODO: Enhance as described at https://github.com/pytorch/vision/issues/6697
143144
class ToDtype(Lambda):
144145
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
145146
self.dtype = dtype

torchvision/prototype/transforms/functional/_augment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def erase(
3535
if isinstance(inpt, torch.Tensor):
3636
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
3737
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
38-
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
38+
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
3939
return output
4040
else: # isinstance(inpt, PIL.Image.Image):
4141
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,8 +1424,7 @@ def five_crop(
14241424
if isinstance(inpt, torch.Tensor):
14251425
output = five_crop_image_tensor(inpt, size)
14261426
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
1427-
cls = type(inpt)
1428-
output = tuple(cls.wrap_like(inpt, item) for item in output) # type: ignore[assignment,arg-type]
1427+
output = (inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type]
14291428
return output
14301429
else: # isinstance(inpt, PIL.Image.Image):
14311430
return five_crop_image_pil(inpt, size)
@@ -1468,8 +1467,7 @@ def ten_crop(
14681467
if isinstance(inpt, torch.Tensor):
14691468
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
14701469
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
1471-
cls = type(inpt)
1472-
output = [cls.wrap_like(inpt, item) for item in output] # type: ignore[arg-type]
1470+
output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type]
14731471
return output
14741472
else: # isinstance(inpt, PIL.Image.Image):
14751473
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)

0 commit comments

Comments
 (0)