Skip to content

More cleanup for prototype transforms #6500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 26, 2022
4 changes: 3 additions & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip

from . import functional # usort: skip

from ._transform import Transform # usort: skip

from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
RandomAdjustSharpness,
Expand Down
6 changes: 2 additions & 4 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops import masks_to_boxes
from torchvision.prototype import features

from torchvision.prototype.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
from torchvision.prototype.transforms import functional as F, InterpolationMode

from ._transform import _RandomApplyTransform
from ._utils import has_any, is_simple_tensor, query_chw
Expand Down Expand Up @@ -278,7 +276,7 @@ def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], Lis
if isinstance(obj, features.Image) or is_simple_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(pil_to_tensor(obj))
images.append(F.to_image_tensor(obj))
elif isinstance(obj, features.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.SegmentationMask):
Expand Down
8 changes: 3 additions & 5 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform

from ._utils import _isinstance, get_chw, is_simple_tensor

Expand Down Expand Up @@ -473,7 +471,7 @@ def forward(self, *inputs: Any) -> Any:
if isinstance(orig_image, torch.Tensor):
image = orig_image
else: # isinstance(inpt, PIL.Image.Image):
image = pil_to_tensor(orig_image)
image = F.to_image_tensor(orig_image)

augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE

Expand Down Expand Up @@ -516,6 +514,6 @@ def forward(self, *inputs: Any) -> Any:
if isinstance(orig_image, features.Image):
mix = features.Image.new_like(orig_image, mix)
elif isinstance(orig_image, PIL.Image.Image):
mix = to_pil_image(mix)
mix = F.to_image_pil(mix)

return self._put_into_sample(sample, id, mix)
5 changes: 2 additions & 3 deletions torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms import functional as _F

from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor, query_chw
Expand Down Expand Up @@ -117,14 +116,14 @@ def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any:

image = inpt
if isinstance(inpt, PIL.Image.Image):
image = _F.pil_to_tensor(image)
image = F.to_image_tensor(image)

output = image[..., permutation, :, :]

if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
elif isinstance(inpt, PIL.Image.Image):
output = _F.to_pil_image(output)
output = F.to_image_pil(output)

return output

Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ToTensor(Transform):
_transformed_types = (PIL.Image.Image, np.ndarray)

def __init__(self) -> None:
# FIXME: should the replacement be a `Compose` with `ConvertImageDtype`?
warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImageTensor()`."
Expand Down
17 changes: 12 additions & 5 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@
import torch
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.functional_tensor import _parse_pad_padding
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform

from typing_extensions import Literal

from ._transform import _RandomApplyTransform
from ._utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_chw
from ._utils import (
_check_sequence_input,
_parse_pad_padding,
_setup_angle,
_setup_size,
has_all,
has_any,
is_simple_tensor,
query_bounding_box,
query_chw,
)


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down
3 changes: 1 addition & 2 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import convert_image_dtype

from ._utils import is_simple_tensor

Expand All @@ -32,7 +31,7 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None:
self.dtype = dtype

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = convert_image_dtype(inpt, dtype=self.dtype)
output = F.convert_image_dtype(inpt, dtype=self.dtype)
return output if is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)


Expand Down
3 changes: 1 addition & 2 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms._utils import query_bounding_box
from torchvision.transforms.transforms import _setup_size

from ._utils import is_simple_tensor
from ._utils import _setup_size, is_simple_tensor


class Identity(Transform):
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/transforms/_type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
# FIXME: should we move this wrapping into the functional?
output = F.decode_image_with_pil(inpt)
return features.Image(output)

Expand Down Expand Up @@ -43,6 +44,7 @@ class ToImageTensor(Transform):
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
# FIXME: should we move this wrapping into the functional?
output = F.to_image_tensor(inpt)
return features.Image(output)

Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str
from torchvision.prototype import features
from torchvision.transforms.functional_tensor import _parse_pad_padding # noqa: F401
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401

from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor

Expand Down
12 changes: 11 additions & 1 deletion torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
five_crop,
five_crop_image_pil,
five_crop_image_tensor,
hflip,
horizontal_flip,
horizontal_flip_bounding_box,
horizontal_flip_image_pil,
Expand Down Expand Up @@ -106,8 +107,17 @@
vertical_flip_image_pil,
vertical_flip_image_tensor,
vertical_flip_segmentation_mask,
vflip,
)
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, to_image_pil, to_image_tensor
from ._type_conversion import (
convert_image_dtype,
decode_image_with_pil,
decode_video_with_av,
pil_to_tensor,
to_image_pil,
to_image_tensor,
to_pil_image,
)

from ._deprecated import rgb_to_grayscale, to_grayscale # usort: skip
10 changes: 10 additions & 0 deletions torchvision/prototype/transforms/functional/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

import PIL.Image
import torch

from torchvision.prototype import features
from torchvision.transforms import functional as _F
Expand Down Expand Up @@ -41,3 +42,12 @@ def rgb_to_grayscale(inpt: Any, num_output_channels: int = 1) -> Any:
)

return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels)


def to_tensor(inpt: Any) -> torch.Tensor:
# FIXME: should we keep the "if needed" phrase or unconditionally recommend `convert_image_dtype`?
warnings.warn(
"The function `to_tensor(...)` is deprecated and will be removed in a future release. "
"Instead, please use `to_image_tensor(...)` and if needed use `convert_image_dtype(...)` afterwards."
)
return _F.to_tensor(inpt)
6 changes: 6 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def vertical_flip(inpt: DType) -> DType:
return vertical_flip_image_tensor(inpt)


# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip


def resize_image_tensor(
image: torch.Tensor,
size: List[int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,10 @@ def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) ->


to_image_pil = _F.to_pil_image

# We changed the names to align them with the new naming scheme. Still, `to_pil_image` and `pil_to_tensor` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
to_pil_image = to_image_pil
pil_to_tensor = to_image_tensor

convert_image_dtype = _F.convert_image_dtype