diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index bdf2e2455ad..4f3de3e7fc3 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,16 +1,17 @@ import math import numbers -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union import PIL.Image import torch +from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform from torchvision.transforms.autoaugment import AutoAugmentPolicy from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image -from ._utils import is_simple_tensor, query_chw +from ._utils import _isinstance, get_chw, is_simple_tensor K = TypeVar("K") V = TypeVar("V") @@ -35,9 +36,31 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] - def _get_params(self, sample: Any) -> Dict[str, Any]: - _, height, width = query_chw(sample) - return dict(height=height, width=width) + def _extract_image( + self, + sample: Any, + unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask), + ) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]: + sample_flat, _ = tree_flatten(sample) + images = [] + for id, inpt in enumerate(sample_flat): + if _isinstance(inpt, (features.Image, PIL.Image.Image, is_simple_tensor)): + images.append((id, inpt)) + elif isinstance(inpt, unsupported_types): + raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") + + if not images: + raise TypeError("Found no image in the sample.") + if len(images) > 1: + raise TypeError( + f"Auto augment transformations are only properly defined for a single image, but found {len(images)}." + ) + return images[0] + + def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any: + sample_flat, spec = tree_flatten(sample) + sample_flat[id] = item + return tree_unflatten(sample_flat, spec) def _apply_image_transform( self, @@ -242,22 +265,21 @@ def _get_policies( else: raise ValueError(f"The provided policy {policy} is not recognized.") - def _get_params(self, sample: Any) -> Dict[str, Any]: - params = super(AutoAugment, self)._get_params(sample) - params["policy"] = self._policies[int(torch.randint(len(self._policies), ()))] - return params + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + + id, image = self._extract_image(sample) + num_channels, height, width = get_chw(image) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)): - return inpt + policy = self._policies[int(torch.randint(len(self._policies), ()))] - for transform_id, probability, magnitude_idx in params["policy"]: + for transform_id, probability, magnitude_idx in policy: if not torch.rand(()) <= probability: continue magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] - magnitudes = magnitudes_fn(10, params["height"], params["width"]) + magnitudes = magnitudes_fn(10, height, width) if magnitudes is not None: magnitude = float(magnitudes[magnitude_idx]) if signed and torch.rand(()) <= 0.5: @@ -265,11 +287,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: else: magnitude = 0.0 - inpt = self._apply_image_transform( - inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + image = self._apply_image_transform( + image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) - return inpt + return self._put_into_sample(sample, id, image) class RandAugment(_AutoAugmentBase): @@ -315,14 +337,16 @@ def __init__( self.magnitude = magnitude self.num_magnitude_bins = num_magnitude_bins - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)): - return inpt + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + + id, image = self._extract_image(sample) + num_channels, height, width = get_chw(image) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"]) + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: @@ -330,11 +354,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: else: magnitude = 0.0 - inpt = self._apply_image_transform( - inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + image = self._apply_image_transform( + image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) - return inpt + return self._put_into_sample(sample, id, image) class TrivialAugmentWide(_AutoAugmentBase): @@ -370,13 +394,15 @@ def __init__( super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)): - return inpt + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + + id, image = self._extract_image(sample) + num_channels, height, width = get_chw(image) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"]) + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: @@ -384,9 +410,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: else: magnitude = 0.0 - return self._apply_image_transform( - inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + image = self._apply_image_transform( + image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) + return self._put_into_sample(sample, id, image) class AugMix(_AutoAugmentBase): @@ -438,13 +465,15 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: # Must be on a separate method so that we can overwrite it in tests. return torch._sample_dirichlet(params) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features.Image) or is_simple_tensor(inpt): - image = inpt - elif isinstance(inpt, PIL.Image.Image): - image = pil_to_tensor(inpt) - else: - return inpt + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + id, orig_image = self._extract_image(sample) + num_channels, height, width = get_chw(orig_image) + + if isinstance(orig_image, torch.Tensor): + image = orig_image + else: # isinstance(inpt, PIL.Image.Image): + image = pil_to_tensor(orig_image) augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE @@ -470,7 +499,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: for _ in range(depth): transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) - magnitudes = magnitudes_fn(self._PARAMETER_MAX, params["height"], params["width"]) + magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) if signed and torch.rand(()) <= 0.5: @@ -484,9 +513,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: mix.add_(combined_weights[:, i].view(batch_dims) * aug) mix = mix.view(orig_dims).to(dtype=image.dtype) - if isinstance(inpt, features.Image): - mix = features.Image.new_like(inpt, mix) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(orig_image, features.Image): + mix = features.Image.new_like(orig_image, mix) + elif isinstance(orig_image, PIL.Image.Image): mix = to_pil_image(mix) - return mix + return self._put_into_sample(sample, id, mix) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 35e6635ce06..c816562d5da 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -4,15 +4,14 @@ import numpy as np import PIL.Image import torch -import torchvision.prototype.transforms.functional as F + from torchvision.prototype import features -from torchvision.prototype.features import ColorSpace from torchvision.prototype.transforms import Transform from torchvision.transforms import functional as _F from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import is_simple_tensor +from ._utils import is_simple_tensor, query_chw class ToTensor(Transform): @@ -59,6 +58,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image: class Grayscale(Transform): + _transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor) + def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: deprecation_msg = ( f"The transform `Grayscale(num_output_channels={num_output_channels})` " @@ -81,13 +82,12 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: self.num_output_channels = num_output_channels def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB) - if self.num_output_channels == 3: - output = F.convert_color_space(inpt, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY) - return output + return _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) class RandomGrayscale(_RandomApplyTransform): + _transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor) + def __init__(self, p: float = 0.1) -> None: warnings.warn( "The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. " @@ -103,6 +103,9 @@ def __init__(self, p: float = 0.1) -> None: super().__init__(p=p) + def _get_params(self, sample: Any) -> Dict[str, Any]: + num_input_channels, _, _ = query_chw(sample) + return dict(num_input_channels=num_input_channels) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB) - return F.convert_color_space(output, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY) + return _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index d3006fe9e09..85f25d8051f 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -156,22 +156,14 @@ class FiveCrop(Transform): torch.Size([5]) """ + _transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor) + def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - # TODO: returning a list is technically BC breaking since FiveCrop returned a tuple before. We switched to a - # list here to align it with TenCrop. - if isinstance(inpt, features.Image): - output = F.five_crop_image_tensor(inpt, self.size) - return tuple(features.Image.new_like(inpt, o) for o in output) - elif is_simple_tensor(inpt): - return F.five_crop_image_tensor(inpt, self.size) - elif isinstance(inpt, PIL.Image.Image): - return F.five_crop_image_pil(inpt, self.size) - else: - return inpt + return F.five_crop(inpt, self.size) def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] @@ -185,21 +177,15 @@ class TenCrop(Transform): See :class:`~torchvision.prototype.transforms.FiveCrop` for an example. """ + _transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor) + def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features.Image): - output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip) - return [features.Image.new_like(inpt, o) for o in output] - elif is_simple_tensor(inpt): - return F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip) - elif isinstance(inpt, PIL.Image.Image): - return F.ten_crop_image_pil(inpt, self.size, vertical_flip=self.vertical_flip) - else: - return inpt + return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 37ca7857674..67a94fa6b04 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -22,7 +22,7 @@ class Lambda(Transform): def __init__(self, fn: Callable[[Any], Any], *types: Type): super().__init__() self.fn = fn - self.types = types + self.types = types or (object,) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if type(inpt) in self.types: @@ -137,7 +137,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ToDtype(Lambda): def __init__(self, dtype: torch.dtype, *types: Type) -> None: self.dtype = dtype - super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types) + super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types or (torch.Tensor,)) def extra_repr(self) -> str: return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"]) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 92707ffce58..788738f6f86 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -65,6 +65,7 @@ elastic_image_tensor, elastic_segmentation_mask, elastic_transform, + five_crop, five_crop_image_pil, five_crop_image_tensor, horizontal_flip, @@ -97,6 +98,7 @@ rotate_image_pil, rotate_image_tensor, rotate_segmentation_mask, + ten_crop, ten_crop_image_pil, ten_crop_image_tensor, vertical_flip, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 1aa79a56727..15d58fd03cb 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1078,6 +1078,17 @@ def five_crop_image_pil( return tl, tr, bl, br, center +def five_crop(inpt: DType, size: List[int]) -> Tuple[DType, DType, DType, DType, DType]: + # TODO: consider breaking BC here to return List[DType] to align this op with `ten_crop` + if isinstance(inpt, torch.Tensor): + output = five_crop_image_tensor(inpt, size) + if isinstance(inpt, features.Image): + output = tuple(features.Image.new_like(inpt, item) for item in output) # type: ignore[assignment] + return output + else: # isinstance(inpt, PIL.Image.Image): + return five_crop_image_pil(inpt, size) + + def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]: tl, tr, bl, br, center = five_crop_image_tensor(img, size) @@ -1102,3 +1113,13 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size) return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] + + +def ten_crop(inpt: DType, size: List[int], *, vertical_flip: bool = False) -> List[DType]: + if isinstance(inpt, torch.Tensor): + output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) + if isinstance(inpt, features.Image): + output = [features.Image.new_like(inpt, item) for item in output] + return output + else: # isinstance(inpt, PIL.Image.Image): + return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)