diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 6c5dac72d53..fbf4522be93 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Tuple, Union, Optional import torch @@ -15,7 +17,6 @@ class BoundingBoxFormat(StrEnum): class BoundingBox(Feature): - formats = BoundingBoxFormat format: BoundingBoxFormat image_size: Tuple[int, int] @@ -27,7 +28,7 @@ def __new__( device: Optional[torch.device] = None, format: Union[BoundingBoxFormat, str], image_size: Tuple[int, int], - ): + ) -> BoundingBox: bounding_box = super().__new__(cls, data, dtype=dtype, device=device) if isinstance(format, str): @@ -37,7 +38,7 @@ def __new__( return bounding_box - def to_format(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox": + def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: # import at runtime to avoid cyclic imports from torchvision.prototype.transforms.functional import convert_bounding_box_format diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index a07da277314..4a445548f64 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from typing import Any, Optional, Union, Tuple, cast @@ -20,7 +22,6 @@ class ColorSpace(StrEnum): class Image(Feature): - color_spaces = ColorSpace color_space: ColorSpace def __new__( @@ -79,5 +80,5 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: def show(self) -> None: to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show() - def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> "Image": + def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image: return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index 3ce1da647e7..b386d17ea69 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -26,7 +26,6 @@ def __new__( @classmethod def from_category(cls, category: str, *, categories: Sequence[str]): - categories = list(categories) return cls(categories.index(category), categories=categories) def to_categories(self): @@ -45,7 +44,7 @@ def __new__( *, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - like: Optional["Label"] = None, + like: Optional[Label] = None, categories: Optional[Sequence[str]] = None, ): one_hot_label = super().__new__(cls, data, dtype=dtype, device=device) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 963bdebc7ed..1fe3d010b28 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,8 +1,4 @@ from . import functional from .functional import InterpolationMode # usort: skip -from ._transform import Transform -from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip -from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop -from ._misc import Identity, Normalize from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py deleted file mode 100644 index 86d7804dd17..00000000000 --- a/torchvision/prototype/transforms/_container.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Any, List - -import torch -from torch import nn -from torchvision.prototype.transforms import Transform - - -class ContainerTransform(nn.Module): - def supports(self, obj: Any) -> bool: - raise NotImplementedError() - - def forward(self, *inputs: Any) -> Any: - raise NotImplementedError() - - def _make_repr(self, lines: List[str]) -> str: - extra_repr = self.extra_repr() - if extra_repr: - lines = [self.extra_repr(), *lines] - head = f"{type(self).__name__}(" - tail = ")" - body = [f" {line.rstrip()}" for line in lines] - return "\n".join([head, *body, tail]) - - -class WrapperTransform(ContainerTransform): - def __init__(self, transform: Transform): - super().__init__() - self._transform = transform - - def supports(self, obj: Any) -> bool: - return self._transform.supports(obj) - - def __repr__(self) -> str: - return self._make_repr(repr(self._transform).splitlines()) - - -class MultiTransform(ContainerTransform): - def __init__(self, *transforms: Transform) -> None: - super().__init__() - self._transforms = transforms - - def supports(self, obj: Any) -> bool: - return all(transform.supports(obj) for transform in self._transforms) - - def __repr__(self) -> str: - lines = [] - for idx, transform in enumerate(self._transforms): - partial_lines = repr(transform).splitlines() - lines.append(f"({idx:d}): {partial_lines[0]}") - lines.extend(partial_lines[1:]) - return self._make_repr(lines) - - -class Compose(MultiTransform): - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - for transform in self._transforms: - sample = transform(sample) - return sample - - -class RandomApply(WrapperTransform): - def __init__(self, transform: Transform, *, p: float = 0.5) -> None: - super().__init__(transform) - self._p = p - - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if float(torch.rand(())) < self._p: - return sample - - return self._transform(sample) - - def extra_repr(self) -> str: - return f"p={self._p}" - - -class RandomChoice(MultiTransform): - def forward(self, *inputs: Any) -> Any: - idx = int(torch.randint(len(self._transforms), size=())) - transform = self._transforms[idx] - return transform(*inputs) - - -class RandomOrder(MultiTransform): - def forward(self, *inputs: Any) -> Any: - for idx in torch.randperm(len(self._transforms)): - transform = self._transforms[idx] - inputs = transform(*inputs) - return inputs diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py deleted file mode 100644 index f34e5daa063..00000000000 --- a/torchvision/prototype/transforms/_geometry.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Any, Dict, Tuple, Union - -import torch -from torch.nn.functional import interpolate -from torchvision.prototype.datasets.utils import SampleQuery -from torchvision.prototype.features import BoundingBox, Image, Label -from torchvision.prototype.transforms import Transform - - -class HorizontalFlip(Transform): - NO_OP_FEATURE_TYPES = {Label} - - @staticmethod - def image(input: Image) -> Image: - return Image(input.flip((-1,)), like=input) - - @staticmethod - def bounding_box(input: BoundingBox) -> BoundingBox: - x, y, w, h = input.convert("xywh").to_parts() - x = input.image_size[1] - (x + w) - return BoundingBox.from_parts(x, y, w, h, like=input, format="xywh").convert(input.format) - - -class Resize(Transform): - NO_OP_FEATURE_TYPES = {Label} - - def __init__( - self, - size: Union[int, Tuple[int, int]], - *, - interpolation_mode: str = "nearest", - ) -> None: - super().__init__() - self.size = (size, size) if isinstance(size, int) else size - self.interpolation_mode = interpolation_mode - - def get_params(self, sample: Any) -> Dict[str, Any]: - return dict(size=self.size, interpolation_mode=self.interpolation_mode) - - @staticmethod - def image(input: Image, *, size: Tuple[int, int], interpolation_mode: str = "nearest") -> Image: - return Image(interpolate(input.unsqueeze(0), size, mode=interpolation_mode).squeeze(0), like=input) - - @staticmethod - def bounding_box(input: BoundingBox, *, size: Tuple[int, int], **_: Any) -> BoundingBox: - old_height, old_width = input.image_size - new_height, new_width = size - - height_scale = new_height / old_height - width_scale = new_width / old_width - - old_x1, old_y1, old_x2, old_y2 = input.convert("xyxy").to_parts() - - new_x1 = old_x1 * width_scale - new_y1 = old_y1 * height_scale - - new_x2 = old_x2 * width_scale - new_y2 = old_y2 * height_scale - - return BoundingBox.from_parts( - new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=size - ).convert(input.format) - - def extra_repr(self) -> str: - extra_repr = f"size={self.size}" - if self.interpolation_mode != "bilinear": - extra_repr += f", interpolation_mode={self.interpolation_mode}" - return extra_repr - - -class RandomResize(Transform, wraps=Resize): - def __init__(self, min_size: Union[int, Tuple[int, int]], max_size: Union[int, Tuple[int, int]]) -> None: - super().__init__() - self.min_size = (min_size, min_size) if isinstance(min_size, int) else min_size - self.max_size = (max_size, max_size) if isinstance(max_size, int) else max_size - - def get_params(self, sample: Any) -> Dict[str, Any]: - min_height, min_width = self.min_size - max_height, max_width = self.max_size - height = int(torch.randint(min_height, max_height + 1, size=())) - width = int(torch.randint(min_width, max_width + 1, size=())) - return dict(size=(height, width)) - - def extra_repr(self) -> str: - return f"min_size={self.min_size}, max_size={self.max_size}" - - -class Crop(Transform): - NO_OP_FEATURE_TYPES = {BoundingBox, Label} - - def __init__(self, crop_box: BoundingBox) -> None: - super().__init__() - self.crop_box = crop_box.convert("xyxy") - - def get_params(self, sample: Any) -> Dict[str, Any]: - return dict(crop_box=self.crop_box) - - @staticmethod - def image(input: Image, *, crop_box: BoundingBox) -> Image: - # FIXME: pad input in case it is smaller than crop_box - x1, y1, x2, y2 = crop_box.convert("xyxy").to_parts() - return Image(input[..., y1 : y2 + 1, x1 : x2 + 1], like=input) # type: ignore[misc] - - -class CenterCrop(Transform, wraps=Crop): - def __init__(self, crop_size: Union[int, Tuple[int, int]]) -> None: - super().__init__() - self.crop_size = (crop_size, crop_size) if isinstance(crop_size, int) else crop_size - - def get_params(self, sample: Any) -> Dict[str, Any]: - image_size = SampleQuery(sample).image_size() - image_height, image_width = image_size - cx = image_width // 2 - cy = image_height // 2 - h, w = self.crop_size - crop_box = BoundingBox.from_parts(cx, cy, w, h, image_size=image_size, format="cxcywh") - return dict(crop_box=crop_box) - - def extra_repr(self) -> str: - return f"crop_size={self.crop_size}" - - -class RandomCrop(Transform, wraps=Crop): - def __init__(self, crop_size: Union[int, Tuple[int, int]]) -> None: - super().__init__() - self.crop_size = (crop_size, crop_size) if isinstance(crop_size, int) else crop_size - - def get_params(self, sample: Any) -> Dict[str, Any]: - image_size = SampleQuery(sample).image_size() - image_height, image_width = image_size - crop_height, crop_width = self.crop_size - x = torch.randint(0, image_width - crop_width + 1, size=()) if crop_width < image_width else 0 - y = torch.randint(0, image_height - crop_height + 1, size=()) if crop_height < image_height else 0 - crop_box = BoundingBox.from_parts(x, y, crop_width, crop_height, image_size=image_size, format="xywh") - return dict(crop_box=crop_box) - - def extra_repr(self) -> str: - return f"crop_size={self.crop_size}" diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py deleted file mode 100644 index 47062aeaf03..00000000000 --- a/torchvision/prototype/transforms/_misc.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Any, Dict, Sequence - -import torch -from torchvision.prototype.features import Image, BoundingBox, Label -from torchvision.prototype.transforms import Transform - - -class Identity(Transform): - """Identity transform that supports all built-in :class:`~torchvision.prototype.features.Feature`'s.""" - - def __init__(self): - super().__init__() - for feature_type in self._BUILTIN_FEATURE_TYPES: - self.register_feature_transform(feature_type, lambda input, **params: input) - - -class Normalize(Transform): - NO_OP_FEATURE_TYPES = {BoundingBox, Label} - - def __init__(self, mean: Sequence[float], std: Sequence[float]): - super().__init__() - self.mean = mean - self.std = std - - def get_params(self, sample: Any) -> Dict[str, Any]: - return dict(mean=self.mean, std=self.std) - - @staticmethod - def _channel_stats_to_tensor(stats: Sequence[float], *, like: torch.Tensor) -> torch.Tensor: - return torch.as_tensor(stats, device=like.device, dtype=like.dtype).view(-1, 1, 1) - - @staticmethod - def image(input: Image, *, mean: Sequence[float], std: Sequence[float]) -> Image: - mean_t = Normalize._channel_stats_to_tensor(mean, like=input) - std_t = Normalize._channel_stats_to_tensor(std, like=input) - return Image((input - mean_t) / std_t, like=input) - - def extra_repr(self) -> str: - return f"mean={tuple(self.mean)}, std={tuple(self.std)}" diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py deleted file mode 100644 index 8062ff0fad0..00000000000 --- a/torchvision/prototype/transforms/_transform.py +++ /dev/null @@ -1,406 +0,0 @@ -import collections.abc -import inspect -import re -from typing import Any, Callable, Dict, Optional, Type, Union, cast, Set, Collection - -import torch -from torch import nn -from torchvision.prototype import features -from torchvision.prototype.utils._internal import add_suggestion - - -class Transform(nn.Module): - """Base class for transforms. - - A transform operates on a full sample at once, which might be a nested container of elements to transform. The - non-container elements of the sample will be dispatched to feature transforms based on their type in case it is - supported by the transform. Each transform needs to define at least one feature transform, which is canonical done - as static method: - - .. code-block:: - - class ImageIdentity(Transform): - @staticmethod - def image(input): - return input - - To achieve correct results for a complete sample, each transform should implement feature transforms for every - :class:`Feature` it can handle: - - .. code-block:: - - class Identity(Transform): - @staticmethod - def image(input): - return input - - @staticmethod - def bounding_box(input): - return input - - ... - - If the name of a static method in camel-case matches the name of a :class:`Feature`, the feature transform is - auto-registered. Supported pairs are: - - +----------------+----------------+ - | method name | `Feature` | - +================+================+ - | `image` | `Image` | - +----------------+----------------+ - | `bounding_box` | `BoundingBox` | - +----------------+----------------+ - | `label` | `Label` | - +----------------+----------------+ - - If you don't want to stick to this scheme, you can disable the auto-registration and perform it manually: - - .. code-block:: - - def my_image_transform(input): - ... - - class MyTransform(Transform, auto_register=False): - def __init__(self): - super().__init__() - self.register_feature_transform(Image, my_image_transform) - self.register_feature_transform(BoundingBox, self.my_bounding_box_transform) - - @staticmethod - def my_bounding_box_transform(input): - ... - - In any case, the registration will assert that the feature transform can be invoked with - ``feature_transform(input, **params)``. - - .. warning:: - - Feature transforms are **registered on the class and not on the instance**. This means you cannot have two - instances of the same :class:`Transform` with different feature transforms. - - If the feature transforms needs additional parameters, you need to - overwrite the :meth:`~Transform.get_params` method. It needs to return the parameter dictionary that will be - unpacked and its contents passed to each feature transform: - - .. code-block:: - - class Rotate(Transform): - def __init__(self, degrees): - super().__init__() - self.degrees = degrees - - def get_params(self, sample): - return dict(degrees=self.degrees) - - def image(input, *, degrees): - ... - - The :meth:`~Transform.get_params` method will be invoked once per sample. Thus, in case of randomly sampled - parameters they will be the same for all features of the whole sample. - - .. code-block:: - - class RandomRotate(Transform) - def __init__(self, range): - super().__init__() - self._dist = torch.distributions.Uniform(range) - - def get_params(self, sample): - return dict(degrees=self._dist.sample().item()) - - @staticmethod - def image(input, *, degrees): - ... - - In case the sampling depends on one or more features at runtime, the complete ``sample`` gets passed to the - :meth:`Transform.get_params` method. Derivative transforms that only changes the parameter sampling, but the - feature transformations are identical, can simply wrap the transform they dispatch to: - - .. code-block:: - - class RandomRotate(Transform, wraps=Rotate): - def get_params(self, sample): - return dict(degrees=float(torch.rand(())) * 30.0) - - To transform a sample, you simply call an instance of the transform with it: - - .. code-block:: - - transform = MyTransform() - sample = dict(input=Image(torch.tensor(...)), target=BoundingBox(torch.tensor(...)), ...) - transformed_sample = transform(sample) - - .. note:: - - To use a :class:`Transform` with a dataset, simply use it as map: - - .. code-block:: - - torchvision.datasets.load(...).map(MyTransform()) - """ - - _BUILTIN_FEATURE_TYPES = ( - features.BoundingBox, - features.Image, - features.Label, - ) - _FEATURE_NAME_MAP = { - "_".join([part.lower() for part in re.findall("[A-Z][^A-Z]*", feature_type.__name__)]): feature_type - for feature_type in _BUILTIN_FEATURE_TYPES - } - _feature_transforms: Dict[Type[features.Feature], Callable] - - NO_OP_FEATURE_TYPES: Collection[Type[features.Feature]] = () - - def __init_subclass__( - cls, *, wraps: Optional[Type["Transform"]] = None, auto_register: bool = True, verbose: bool = False - ): - cls._feature_transforms = {} if wraps is None else wraps._feature_transforms.copy() - if wraps: - cls.NO_OP_FEATURE_TYPES = wraps.NO_OP_FEATURE_TYPES - if auto_register: - cls._auto_register(verbose=verbose) - - @staticmethod - def _has_allowed_signature(feature_transform: Callable) -> bool: - """Checks if ``feature_transform`` can be invoked with ``feature_transform(input, **params)``""" - - parameters = tuple(inspect.signature(feature_transform).parameters.values()) - if not parameters: - return False - elif len(parameters) == 1: - return parameters[0].kind != inspect.Parameter.KEYWORD_ONLY - else: - return parameters[1].kind != inspect.Parameter.POSITIONAL_ONLY - - @classmethod - def register_feature_transform(cls, feature_type: Type[features.Feature], transform: Callable) -> None: - """Registers a transform for given feature on the class. - - If a transform object is called or :meth:`Transform.apply` is invoked, inputs are dispatched to the registered - transforms based on their type. - - Args: - feature_type: Feature type the transformation is registered for. - transform: Feature transformation. - - Raises: - TypeError: If ``transform`` cannot be invoked with ``transform(input, **params)``. - """ - if not cls._has_allowed_signature(transform): - raise TypeError("Feature transform cannot be invoked with transform(input, **params)") - cls._feature_transforms[feature_type] = transform - - @classmethod - def _auto_register(cls, *, verbose: bool = False) -> None: - """Auto-registers methods on the class as feature transforms if they meet the following criteria: - - 1. They are static. - 2. They can be invoked with `cls.feature_transform(input, **params)`. - 3. They are public. - 4. Their name in camel case matches the name of a builtin feature, e.g. 'bounding_box' and 'BoundingBox'. - - The name from 4. determines for which feature the method is registered. - - .. note:: - - The ``auto_register`` and ``verbose`` flags need to be passed as keyword arguments to the class: - - .. code-block:: - - class MyTransform(Transform, auto_register=True, verbose=True): - ... - - Args: - verbose: If ``True``, prints to STDOUT which methods were registered or why a method was not registered - """ - for name, value in inspect.getmembers(cls): - # check if attribute is a static method and was defined in the subclass - # TODO: this needs to be revisited to allow subclassing of custom transforms - if not (name in cls.__dict__ and inspect.isfunction(value)): - continue - - not_registered_prefix = f"{cls.__name__}.{name}() was not registered as feature transform, because" - - if not cls._has_allowed_signature(value): - if verbose: - print(f"{not_registered_prefix} it cannot be invoked with {name}(input, **params).") - continue - - if name.startswith("_"): - if verbose: - print(f"{not_registered_prefix} it is private.") - continue - - try: - feature_type = cls._FEATURE_NAME_MAP[name] - except KeyError: - if verbose: - print( - add_suggestion( - f"{not_registered_prefix} its name doesn't match any known feature type.", - word=name, - possibilities=cls._FEATURE_NAME_MAP.keys(), - close_match_hint=lambda close_match: ( - f"Did you mean to name it '{close_match}' " - f"to be registered for type '{cls._FEATURE_NAME_MAP[close_match]}'?" - ), - ) - ) - continue - - cls.register_feature_transform(feature_type, value) - if verbose: - print( - f"{cls.__name__}.{name}() was registered as feature transform for type '{feature_type.__name__}'." - ) - - @classmethod - def from_callable( - cls, - feature_transform: Union[Callable, Dict[Type[features.Feature], Callable]], - *, - name: str = "FromCallable", - get_params: Optional[Union[Dict[str, Any], Callable[[Any], Dict[str, Any]]]] = None, - ) -> "Transform": - """Creates a new transform from a callable. - - Args: - feature_transform: Feature transform that will be registered to handle :class:`Image`'s. Can be passed as - dictionary in which case each key-value-pair is needs to consists of a ``Feature`` type and the - corresponding transform. - name: Name of the transform. - get_params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``. - Can be passed as callable in which case it will be called with the transform instance (``self``) and - the input of the transform. - - Raises: - TypeError: If ``feature_transform`` cannot be invoked with ``feature_transform(input, **params)``. - """ - if get_params is None: - get_params = dict() - attributes = dict( - get_params=get_params if callable(get_params) else lambda self, sample: get_params, # type: ignore[misc] - ) - transform_cls = cast(Type[Transform], type(name, (cls,), attributes)) - - if callable(feature_transform): - feature_transform = {features.Image: feature_transform} - for feature_type, transform in feature_transform.items(): - transform_cls.register_feature_transform(feature_type, transform) - - return transform_cls() - - @classmethod - def supported_feature_types(cls) -> Set[Type[features.Feature]]: - return set(cls._feature_transforms.keys()) - - @classmethod - def supports(cls, obj: Any) -> bool: - """Checks if object or type is supported. - - Args: - obj: Object or type. - """ - # TODO: should this handle containers? - feature_type = obj if isinstance(obj, type) else type(obj) - return feature_type is torch.Tensor or feature_type in cls.supported_feature_types() - - @classmethod - def transform(cls, input: Union[torch.Tensor, features.Feature], **params: Any) -> torch.Tensor: - """Applies the registered feature transform to the input based on its type. - - This can be uses as feature type generic functional interface: - - .. code-block:: - - transform = Rotate.transform - transformed_image = transform(Image(torch.tensor(...)), degrees=30.0) - transformed_bbox = transform(BoundingBox(torch.tensor(...)), degrees=-10.0) - - Args: - input: ``input`` in ``feature_transform(input, **params)`` - **params: Parameter dictionary ``params`` in ``feature_transform(input, **params)``. - - Returns: - Transformed input. - """ - feature_type = type(input) - if not cls.supports(feature_type): - raise TypeError(f"{cls.__name__}() is not able to handle inputs of type {feature_type}.") - - if feature_type is torch.Tensor: - # To keep BC, we treat all regular torch.Tensor's as images - feature_type = features.Image - input = feature_type(input) - feature_type = cast(Type[features.Feature], feature_type) - - feature_transform = cls._feature_transforms[feature_type] - output = feature_transform(input, **params) - - if type(output) is torch.Tensor: - output = feature_type(output, like=input) - return output - - def _transform_recursively(self, sample: Any, *, params: Dict[str, Any]) -> Any: - """Recurses through a sample and invokes :meth:`Transform.transform` on non-container elements. - - If an element is not supported by the transform, it is returned untransformed. - - Args: - sample: Sample. - params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``. - """ - # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: - # "a" == "a"[0][0]... - if isinstance(sample, collections.abc.Sequence) and not isinstance(sample, str): - return [self._transform_recursively(item, params=params) for item in sample] - elif isinstance(sample, collections.abc.Mapping): - return {name: self._transform_recursively(item, params=params) for name, item in sample.items()} - else: - feature_type = type(sample) - if not self.supports(feature_type): - if ( - not issubclass(feature_type, features.Feature) - # issubclass is not a strict check, but also allows the type checked against. Thus, we need to - # check it separately - or feature_type is features.Feature - or feature_type in self.NO_OP_FEATURE_TYPES - ): - return sample - - raise TypeError( - f"{type(self).__name__}() is not able to handle inputs of type {feature_type}. " - f"If you want it to be a no-op, add the feature type to {type(self).__name__}.NO_OP_FEATURE_TYPES." - ) - - return self.transform(cast(Union[torch.Tensor, features.Feature], sample), **params) - - def get_params(self, sample: Any) -> Dict[str, Any]: - """Returns the parameter dictionary used to transform the current sample. - - .. note:: - - Since ``sample`` might be a nested container, it is recommended to use the - :class:`torchvision.datasets.utils.Query` class if you need to extract information from it. - - Args: - sample: Current sample. - - Returns: - Parameter dictionary ``params`` in ``feature_transform(input, **params)``. - """ - return dict() - - def forward( - self, - *inputs: Any, - params: Optional[Dict[str, Any]] = None, - ) -> Any: - if not self._feature_transforms: - raise RuntimeError(f"{type(self).__name__}() has no registered feature transform.") - - sample = inputs if len(inputs) > 1 else inputs[0] - if params is None: - params = self.get_params(sample) - return self._transform_recursively(sample, params=params) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 814c34e5b00..842ff0cd5d6 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -16,25 +16,37 @@ def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> to def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: + if image_batch.ndim < 4: + raise ValueError("Need a batch of images") + return _mixup(image_batch, -4, lam, inplace) def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: + if one_hot_label_batch.ndim < 2: + raise ValueError("Need a batch of one hot labels") + return _mixup(one_hot_label_batch, -2, lam, inplace) -def cutmix_image(image: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: +def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: + if image_batch.ndim < 4: + raise ValueError("Need a batch of images") + if not inplace: - image = image.clone() + image_batch = image_batch.clone() x1, y1, x2, y2 = box - image_rolled = image.roll(1, -4) + image_rolled = image_batch.roll(1, -4) - image[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] - return image + image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] + return image_batch def cutmix_one_hot_label( one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False ) -> torch.Tensor: - return mixup_one_hot_label(one_hot_label_batch, lam=lam_adjusted, inplace=inplace) + if one_hot_label_batch.ndim < 2: + raise ValueError("Need a batch of one hot labels") + + return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index c8142742fa8..26334527241 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,31 +1,18 @@ from typing import Tuple, List, Optional import torch -from torchvision.prototype.features import BoundingBoxFormat from torchvision.transforms import ( # noqa: F401 functional as _F, InterpolationMode, ) -from ._meta_conversion import convert_bounding_box_format - - -def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor: - return image.flip((-1,)) +horizontal_flip_image = _F.hflip def horizontal_flip_bounding_box(bounding_box: torch.Tensor, *, image_size: Tuple[int, int]) -> torch.Tensor: - x, y, w, h = convert_bounding_box_format( - bounding_box, - old_format=BoundingBoxFormat.XYXY, - new_format=BoundingBoxFormat.XYWH, - ).unbind(-1) - x = image_size[1] - (x + w) - return convert_bounding_box_format( - torch.stack((x, y, w, h), dim=-1), - old_format=BoundingBoxFormat.XYWH, - new_format=BoundingBoxFormat.XYXY, - ) + bounding_box = bounding_box.clone() + bounding_box[..., (0, 2)] = image_size[1] - bounding_box[..., (2, 0)] + return bounding_box _resize_image = _F.resize @@ -71,11 +58,8 @@ def resize_bounding_box( ) -> torch.Tensor: old_height, old_width = old_image_size new_height, new_width = new_image_size - return ( - bounding_box.view(-1, 2, 2) - .mul(torch.tensor([new_width / old_width, new_height / old_height])) - .view(bounding_box.shape) - ) + ratios = torch.tensor((new_width / old_width, new_height / old_height)) + return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) center_crop_image = _F.center_crop