Skip to content

Commit f96f083

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] More cleanup for prototype transforms (#6500)
Summary: * add aliases for hflip and vflip * reduce imports from torchvision.transforms in torchvision.prototype.transforms * add aliases for to_pil_image abd pil_to_tensor * deprecate to_tensor * add some FIXME cleanup comments * address reviews * add dimension getters * undeprecate PILToTensor and ToPILImage * address review * fix test Reviewed By: NicolasHug Differential Revision: D39131018 fbshipit-source-id: a1fc5d9dfd1cd587f273674716105f59a01e6cf0
1 parent dc8fa0e commit f96f083

17 files changed

+118
-81
lines changed

test/test_prototype_transforms.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,11 +1071,10 @@ class TestToPILImage:
10711071
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
10721072
)
10731073
def test__transform(self, inpt_type, mocker):
1074-
fn = mocker.patch("torchvision.transforms.functional.to_pil_image")
1074+
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil")
10751075

10761076
inpt = mocker.MagicMock(spec=inpt_type)
1077-
with pytest.warns(UserWarning, match="deprecated and will be removed"):
1078-
transform = transforms.ToPILImage()
1077+
transform = transforms.ToPILImage()
10791078
transform(inpt)
10801079
if inpt_type in (PIL.Image.Image, features.BoundingBox, str, int):
10811080
assert fn.call_count == 0

test/test_prototype_transforms_functional.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,8 @@ def erase_image_tensor():
674674
and name
675675
not in {
676676
"to_image_tensor",
677+
"get_image_num_channels",
678+
"get_image_size",
677679
}
678680
],
679681
)

torchvision/prototype/transforms/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
2+
13
from . import functional # usort: skip
24

35
from ._transform import Transform # usort: skip
46

57
from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
6-
from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide
8+
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
79
from ._color import (
810
ColorJitter,
911
RandomAdjustSharpness,
@@ -37,6 +39,6 @@
3739
)
3840
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
3941
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype
40-
from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor
42+
from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
4143

42-
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip
44+
from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip

torchvision/prototype/transforms/_augment.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
from torch.utils._pytree import tree_flatten, tree_unflatten
99
from torchvision.ops import masks_to_boxes
1010
from torchvision.prototype import features
11-
12-
from torchvision.prototype.transforms import functional as F
13-
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
11+
from torchvision.prototype.transforms import functional as F, InterpolationMode
1412

1513
from ._transform import _RandomApplyTransform
1614
from ._utils import has_any, query_chw
@@ -279,7 +277,7 @@ def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], Lis
279277
if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
280278
images.append(obj)
281279
elif isinstance(obj, PIL.Image.Image):
282-
images.append(pil_to_tensor(obj))
280+
images.append(F.to_image_tensor(obj))
283281
elif isinstance(obj, features.BoundingBox):
284282
bboxes.append(obj)
285283
elif isinstance(obj, features.SegmentationMask):

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77

88
from torch.utils._pytree import tree_flatten, tree_unflatten
99
from torchvision.prototype import features
10-
from torchvision.prototype.transforms import functional as F, Transform
11-
from torchvision.transforms.autoaugment import AutoAugmentPolicy
12-
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
10+
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
11+
from torchvision.prototype.transforms.functional._meta import get_chw
1312

14-
from ._utils import _isinstance, get_chw
13+
from ._utils import _isinstance
1514

1615
K = TypeVar("K")
1716
V = TypeVar("V")
@@ -473,7 +472,7 @@ def forward(self, *inputs: Any) -> Any:
473472
if isinstance(orig_image, torch.Tensor):
474473
image = orig_image
475474
else: # isinstance(inpt, PIL.Image.Image):
476-
image = pil_to_tensor(orig_image)
475+
image = F.to_image_tensor(orig_image)
477476

478477
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
479478

@@ -516,6 +515,6 @@ def forward(self, *inputs: Any) -> Any:
516515
if isinstance(orig_image, features.Image):
517516
mix = features.Image.new_like(orig_image, mix)
518517
elif isinstance(orig_image, PIL.Image.Image):
519-
mix = to_pil_image(mix)
518+
mix = F.to_image_pil(mix)
520519

521520
return self._put_into_sample(sample, id, mix)

torchvision/prototype/transforms/_color.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
from torchvision.prototype import features
77
from torchvision.prototype.transforms import functional as F, Transform
8-
from torchvision.transforms import functional as _F
98

109
from ._transform import _RandomApplyTransform
1110
from ._utils import query_chw
@@ -85,6 +84,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
8584

8685

8786
class RandomPhotometricDistort(Transform):
87+
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
88+
8889
def __init__(
8990
self,
9091
contrast: Tuple[float, float] = (0.5, 1.5),
@@ -112,19 +113,15 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
112113
)
113114

114115
def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any:
115-
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(inpt)):
116-
return inpt
117-
118-
image = inpt
119116
if isinstance(inpt, PIL.Image.Image):
120-
image = _F.pil_to_tensor(image)
117+
inpt = F.to_image_tensor(inpt)
121118

122-
output = image[..., permutation, :, :]
119+
output = inpt[..., permutation, :, :]
123120

124121
if isinstance(inpt, features.Image):
125122
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
126123
elif isinstance(inpt, PIL.Image.Image):
127-
output = _F.to_pil_image(output)
124+
output = F.to_image_pil(output)
128125

129126
return output
130127

torchvision/prototype/transforms/_deprecated.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict
33

44
import numpy as np
55
import PIL.Image
@@ -20,43 +20,14 @@ class ToTensor(Transform):
2020
def __init__(self) -> None:
2121
warnings.warn(
2222
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
23-
"Instead, please use `transforms.ToImageTensor()`."
23+
"Instead, please use `transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])`."
2424
)
2525
super().__init__()
2626

2727
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
2828
return _F.to_tensor(inpt)
2929

3030

31-
class PILToTensor(Transform):
32-
_transformed_types = (PIL.Image.Image,)
33-
34-
def __init__(self) -> None:
35-
warnings.warn(
36-
"The transform `PILToTensor()` is deprecated and will be removed in a future release. "
37-
"Instead, please use `transforms.ToImageTensor()`."
38-
)
39-
super().__init__()
40-
41-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
42-
return _F.pil_to_tensor(inpt)
43-
44-
45-
class ToPILImage(Transform):
46-
_transformed_types = (features.is_simple_tensor, features.Image, np.ndarray)
47-
48-
def __init__(self, mode: Optional[str] = None) -> None:
49-
warnings.warn(
50-
"The transform `ToPILImage()` is deprecated and will be removed in a future release. "
51-
"Instead, please use `transforms.ToImagePIL()`."
52-
)
53-
super().__init__()
54-
self.mode = mode
55-
56-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
57-
return _F.to_pil_image(inpt, mode=self.mode)
58-
59-
6031
class Grayscale(Transform):
6132
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
6233

torchvision/prototype/transforms/_geometry.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,21 @@
77
import torch
88
from torchvision.ops.boxes import box_iou
99
from torchvision.prototype import features
10-
from torchvision.prototype.transforms import functional as F, Transform
11-
from torchvision.transforms.functional import InterpolationMode
12-
from torchvision.transforms.functional_tensor import _parse_pad_padding
13-
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size
10+
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
1411

1512
from typing_extensions import Literal
1613

1714
from ._transform import _RandomApplyTransform
18-
from ._utils import has_all, has_any, query_bounding_box, query_chw
15+
from ._utils import (
16+
_check_sequence_input,
17+
_parse_pad_padding,
18+
_setup_angle,
19+
_setup_size,
20+
has_all,
21+
has_any,
22+
query_bounding_box,
23+
query_chw,
24+
)
1925

2026

2127
class RandomHorizontalFlip(_RandomApplyTransform):

torchvision/prototype/transforms/_meta.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
from torchvision.prototype import features
77
from torchvision.prototype.transforms import functional as F, Transform
8-
from torchvision.transforms.functional import convert_image_dtype
98

109

1110
class ConvertBoundingBoxFormat(Transform):
@@ -30,7 +29,7 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None:
3029
self.dtype = dtype
3130

3231
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
33-
output = convert_image_dtype(inpt, dtype=self.dtype)
32+
output = F.convert_image_dtype(inpt, dtype=self.dtype)
3433
return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)
3534

3635

torchvision/prototype/transforms/_misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from torchvision.ops import remove_small_boxes
88
from torchvision.prototype import features
99
from torchvision.prototype.transforms import functional as F, Transform
10-
from torchvision.prototype.transforms._utils import has_any, query_bounding_box
11-
from torchvision.transforms.transforms import _setup_size
10+
11+
from ._utils import _setup_size, has_any, query_bounding_box
1212

1313

1414
class Identity(Transform):

torchvision/prototype/transforms/_type_conversion.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,9 @@ def __init__(self, *, mode: Optional[str] = None) -> None:
5252

5353
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
5454
return F.to_image_pil(inpt, mode=self.mode)
55+
56+
57+
# We changed the names to align them with the new naming scheme. Still, `PILToTensor` and `ToPILImage` are
58+
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
59+
PILToTensor = ToImageTensor
60+
ToPILImage = ToImagePIL

torchvision/prototype/transforms/_utils.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import Any, Callable, Tuple, Type, Union
22

33
import PIL.Image
4-
import torch
54
from torch.utils._pytree import tree_flatten
65
from torchvision._utils import sequence_to_str
76
from torchvision.prototype import features
87

9-
from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor
8+
from torchvision.prototype.transforms.functional._meta import get_chw
9+
from torchvision.transforms.functional_tensor import _parse_pad_padding # noqa: F401
10+
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
1011

1112

1213
def query_bounding_box(sample: Any) -> features.BoundingBox:
@@ -19,19 +20,6 @@ def query_bounding_box(sample: Any) -> features.BoundingBox:
1920
return bounding_boxes.pop()
2021

2122

22-
def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
23-
if isinstance(image, features.Image):
24-
channels = image.num_channels
25-
height, width = image.image_size
26-
elif features.is_simple_tensor(image):
27-
channels, height, width = get_dimensions_image_tensor(image)
28-
elif isinstance(image, PIL.Image.Image):
29-
channels, height, width = get_dimensions_image_pil(image)
30-
else:
31-
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
32-
return channels, height, width
33-
34-
3523
def query_chw(sample: Any) -> Tuple[int, int, int]:
3624
flat_sample, _ = tree_flatten(sample)
3725
chws = {

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
convert_color_space_image_tensor,
66
convert_color_space_image_pil,
77
convert_color_space,
8+
get_dimensions,
9+
get_image_num_channels,
10+
get_image_size,
811
) # usort: skip
912

1013
from ._augment import erase, erase_image_pil, erase_image_tensor
@@ -68,6 +71,7 @@
6871
five_crop,
6972
five_crop_image_pil,
7073
five_crop_image_tensor,
74+
hflip,
7175
horizontal_flip,
7276
horizontal_flip_bounding_box,
7377
horizontal_flip_image_pil,
@@ -106,8 +110,17 @@
106110
vertical_flip_image_pil,
107111
vertical_flip_image_tensor,
108112
vertical_flip_segmentation_mask,
113+
vflip,
109114
)
110115
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor
111-
from ._type_conversion import decode_image_with_pil, decode_video_with_av, to_image_pil, to_image_tensor
116+
from ._type_conversion import (
117+
convert_image_dtype,
118+
decode_image_with_pil,
119+
decode_video_with_av,
120+
pil_to_tensor,
121+
to_image_pil,
122+
to_image_tensor,
123+
to_pil_image,
124+
)
112125

113126
from ._deprecated import rgb_to_grayscale, to_grayscale # usort: skip

torchvision/prototype/transforms/functional/_deprecated.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any
33

44
import PIL.Image
5+
import torch
56

67
from torchvision.prototype import features
78
from torchvision.transforms import functional as _F
@@ -39,3 +40,11 @@ def rgb_to_grayscale(inpt: Any, num_output_channels: int = 1) -> Any:
3940
)
4041

4142
return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels)
43+
44+
45+
def to_tensor(inpt: Any) -> torch.Tensor:
46+
warnings.warn(
47+
"The function `to_tensor(...)` is deprecated and will be removed in a future release. "
48+
"Instead, please use `to_image_tensor(...)` followed by `convert_image_dtype(...)`."
49+
)
50+
return _F.to_tensor(inpt)

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ def vertical_flip(inpt: DType) -> DType:
8989
return vertical_flip_image_tensor(inpt)
9090

9191

92+
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
93+
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
94+
hflip = horizontal_flip
95+
vflip = vertical_flip
96+
97+
9298
def resize_image_tensor(
9399
image: torch.Tensor,
94100
size: List[int],

0 commit comments

Comments
 (0)