Skip to content

More cleanup for prototype transforms #6500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 26, 2022
5 changes: 2 additions & 3 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,11 +1071,10 @@ class TestToPILImage:
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_pil_image")
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil")

inpt = mocker.MagicMock(spec=inpt_type)
with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToPILImage()
transform = transforms.ToPILImage()
transform(inpt)
if inpt_type in (PIL.Image.Image, features.BoundingBox, str, int):
assert fn.call_count == 0
Expand Down
2 changes: 2 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,8 @@ def erase_image_tensor():
and name
not in {
"to_image_tensor",
"get_image_num_channels",
"get_image_size",
}
],
)
Expand Down
8 changes: 5 additions & 3 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip

from . import functional # usort: skip

from ._transform import Transform # usort: skip

from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
RandomAdjustSharpness,
Expand Down Expand Up @@ -37,6 +39,6 @@
)
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype
from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor
from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage

from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip
from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip
6 changes: 2 additions & 4 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops import masks_to_boxes
from torchvision.prototype import features

from torchvision.prototype.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
from torchvision.prototype.transforms import functional as F, InterpolationMode

from ._transform import _RandomApplyTransform
from ._utils import has_any, query_chw
Expand Down Expand Up @@ -279,7 +277,7 @@ def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], Lis
if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(pil_to_tensor(obj))
images.append(F.to_image_tensor(obj))
elif isinstance(obj, features.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.SegmentationMask):
Expand Down
11 changes: 5 additions & 6 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@

from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_chw

from ._utils import _isinstance, get_chw
from ._utils import _isinstance

K = TypeVar("K")
V = TypeVar("V")
Expand Down Expand Up @@ -473,7 +472,7 @@ def forward(self, *inputs: Any) -> Any:
if isinstance(orig_image, torch.Tensor):
image = orig_image
else: # isinstance(inpt, PIL.Image.Image):
image = pil_to_tensor(orig_image)
image = F.to_image_tensor(orig_image)

augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE

Expand Down Expand Up @@ -516,6 +515,6 @@ def forward(self, *inputs: Any) -> Any:
if isinstance(orig_image, features.Image):
mix = features.Image.new_like(orig_image, mix)
elif isinstance(orig_image, PIL.Image.Image):
mix = to_pil_image(mix)
mix = F.to_image_pil(mix)

return self._put_into_sample(sample, id, mix)
13 changes: 5 additions & 8 deletions torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms import functional as _F

from ._transform import _RandomApplyTransform
from ._utils import query_chw
Expand Down Expand Up @@ -85,6 +84,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomPhotometricDistort(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)

def __init__(
self,
contrast: Tuple[float, float] = (0.5, 1.5),
Expand Down Expand Up @@ -112,19 +113,15 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
)

def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(inpt)):
return inpt

image = inpt
if isinstance(inpt, PIL.Image.Image):
image = _F.pil_to_tensor(image)
inpt = F.to_image_tensor(inpt)

output = image[..., permutation, :, :]
output = inpt[..., permutation, :, :]

if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
elif isinstance(inpt, PIL.Image.Image):
output = _F.to_pil_image(output)
output = F.to_image_pil(output)

return output

Expand Down
33 changes: 2 additions & 31 deletions torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Dict, Optional
from typing import Any, Dict

import numpy as np
import PIL.Image
Expand All @@ -20,43 +20,14 @@ class ToTensor(Transform):
def __init__(self) -> None:
warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImageTensor()`."
"Instead, please use `transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])`."
)
super().__init__()

def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return _F.to_tensor(inpt)


class PILToTensor(Transform):
_transformed_types = (PIL.Image.Image,)

def __init__(self) -> None:
warnings.warn(
"The transform `PILToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImageTensor()`."
)
super().__init__()

def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return _F.pil_to_tensor(inpt)


class ToPILImage(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, np.ndarray)

def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn(
"The transform `ToPILImage()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImagePIL()`."
)
super().__init__()
self.mode = mode

def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
return _F.to_pil_image(inpt, mode=self.mode)


class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)

Expand Down
16 changes: 11 additions & 5 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@
import torch
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.functional_tensor import _parse_pad_padding
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform

from typing_extensions import Literal

from ._transform import _RandomApplyTransform
from ._utils import has_all, has_any, query_bounding_box, query_chw
from ._utils import (
_check_sequence_input,
_parse_pad_padding,
_setup_angle,
_setup_size,
has_all,
has_any,
query_bounding_box,
query_chw,
)


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down
3 changes: 1 addition & 2 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import convert_image_dtype


class ConvertBoundingBoxFormat(Transform):
Expand All @@ -30,7 +29,7 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None:
self.dtype = dtype

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = convert_image_dtype(inpt, dtype=self.dtype)
output = F.convert_image_dtype(inpt, dtype=self.dtype)
return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)


Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from torchvision.ops import remove_small_boxes
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms._utils import has_any, query_bounding_box
from torchvision.transforms.transforms import _setup_size

from ._utils import _setup_size, has_any, query_bounding_box


class Identity(Transform):
Expand Down
6 changes: 6 additions & 0 deletions torchvision/prototype/transforms/_type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@ def __init__(self, *, mode: Optional[str] = None) -> None:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode)


# We changed the names to align them with the new naming scheme. Still, `PILToTensor` and `ToPILImage` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
PILToTensor = ToImageTensor
ToPILImage = ToImagePIL
18 changes: 3 additions & 15 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any, Callable, Tuple, Type, Union

import PIL.Image
import torch
from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str
from torchvision.prototype import features

from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor
from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.transforms.functional_tensor import _parse_pad_padding # noqa: F401
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401


def query_bounding_box(sample: Any) -> features.BoundingBox:
Expand All @@ -19,19 +20,6 @@ def query_bounding_box(sample: Any) -> features.BoundingBox:
return bounding_boxes.pop()


def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
if isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
elif features.is_simple_tensor(image):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image):
channels, height, width = get_dimensions_image_pil(image)
else:
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
return channels, height, width


def query_chw(sample: Any) -> Tuple[int, int, int]:
flat_sample, _ = tree_flatten(sample)
chws = {
Expand Down
15 changes: 14 additions & 1 deletion torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
convert_color_space_image_tensor,
convert_color_space_image_pil,
convert_color_space,
get_dimensions,
get_image_num_channels,
get_image_size,
) # usort: skip

from ._augment import erase, erase_image_pil, erase_image_tensor
Expand Down Expand Up @@ -68,6 +71,7 @@
five_crop,
five_crop_image_pil,
five_crop_image_tensor,
hflip,
horizontal_flip,
horizontal_flip_bounding_box,
horizontal_flip_image_pil,
Expand Down Expand Up @@ -106,8 +110,17 @@
vertical_flip_image_pil,
vertical_flip_image_tensor,
vertical_flip_segmentation_mask,
vflip,
)
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, to_image_pil, to_image_tensor
from ._type_conversion import (
convert_image_dtype,
decode_image_with_pil,
decode_video_with_av,
pil_to_tensor,
to_image_pil,
to_image_tensor,
to_pil_image,
)

from ._deprecated import rgb_to_grayscale, to_grayscale # usort: skip
9 changes: 9 additions & 0 deletions torchvision/prototype/transforms/functional/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

import PIL.Image
import torch

from torchvision.prototype import features
from torchvision.transforms import functional as _F
Expand Down Expand Up @@ -39,3 +40,11 @@ def rgb_to_grayscale(inpt: Any, num_output_channels: int = 1) -> Any:
)

return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels)


def to_tensor(inpt: Any) -> torch.Tensor:
warnings.warn(
"The function `to_tensor(...)` is deprecated and will be removed in a future release. "
"Instead, please use `to_image_tensor(...)` followed by `convert_image_dtype(...)`."
)
return _F.to_tensor(inpt)
6 changes: 6 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def vertical_flip(inpt: DType) -> DType:
return vertical_flip_image_tensor(inpt)


# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip


def resize_image_tensor(
image: torch.Tensor,
size: List[int],
Expand Down
Loading