diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 067359cac2b..eaa92094ad7 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -99,7 +99,6 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config): f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." ) - @pytest.mark.xfail @parametrize_dataset_mocks(DATASET_MOCKS) def test_transformable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py new file mode 100644 index 00000000000..9aa0688e7a0 --- /dev/null +++ b/test/test_prototype_transforms.py @@ -0,0 +1,186 @@ +import itertools + +import PIL.Image +import pytest +import torch +from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels +from torchvision.prototype import transforms, features +from torchvision.transforms.functional import to_pil_image + + +def make_vanilla_tensor_images(*args, **kwargs): + for image in make_images(*args, **kwargs): + if image.ndim > 3: + continue + yield image.data + + +def make_pil_images(*args, **kwargs): + for image in make_vanilla_tensor_images(*args, **kwargs): + yield to_pil_image(image) + + +def make_vanilla_tensor_bounding_boxes(*args, **kwargs): + for bounding_box in make_bounding_boxes(*args, **kwargs): + yield bounding_box.data + + +INPUT_CREATIONS_FNS = { + features.Image: make_images, + features.BoundingBox: make_bounding_boxes, + features.OneHotLabel: make_one_hot_labels, + torch.Tensor: make_vanilla_tensor_images, + PIL.Image.Image: make_pil_images, +} + + +def parametrize(transforms_with_inputs): + return pytest.mark.parametrize( + ("transform", "input"), + [ + pytest.param( + transform, + input, + id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}", + ) + for transform, inputs in transforms_with_inputs + for idx, input in enumerate(inputs) + ], + ) + + +def parametrize_from_transforms(*transforms): + transforms_with_inputs = [] + for transform in transforms: + dispatcher = transform._DISPATCHER + if dispatcher is None: + continue + + for type_ in dispatcher._kernels: + try: + inputs = INPUT_CREATIONS_FNS[type_]() + except KeyError: + continue + + transforms_with_inputs.append((transform, inputs)) + + return parametrize(transforms_with_inputs) + + +class TestSmoke: + @parametrize_from_transforms( + transforms.RandomErasing(), + transforms.HorizontalFlip(), + transforms.Resize([16, 16]), + transforms.CenterCrop([16, 16]), + transforms.ConvertImageDtype(), + ) + def test_common(self, transform, input): + transform(input) + + @parametrize( + [ + ( + transform, + [ + dict( + image=features.Image.new_like(image, image.unsqueeze(0), dtype=torch.float), + one_hot_label=features.OneHotLabel.new_like( + one_hot_label, one_hot_label.unsqueeze(0), dtype=torch.float + ), + ) + for image, one_hot_label in itertools.product(make_images(), make_one_hot_labels()) + ], + ) + for transform in [ + transforms.RandomMixup(alpha=1.0), + transforms.RandomCutmix(alpha=1.0), + ] + ] + ) + def test_mixup_cutmix(self, transform, input): + transform(input) + + @parametrize( + [ + ( + transform, + itertools.chain.from_iterable( + fn(dtypes=[torch.uint8], extra_dims=[(4,)]) + for fn in [ + make_images, + make_vanilla_tensor_images, + make_pil_images, + ] + ), + ) + for transform in ( + transforms.RandAugment(), + transforms.TrivialAugmentWide(), + transforms.AutoAugment(), + ) + ] + ) + def test_auto_augment(self, transform, input): + transform(input) + + @parametrize( + [ + ( + transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), + itertools.chain.from_iterable( + fn(color_spaces=["rgb"], dtypes=[torch.float32]) + for fn in [ + make_images, + make_vanilla_tensor_images, + ] + ), + ), + ] + ) + def test_normalize(self, transform, input): + transform(input) + + @parametrize( + [ + ( + transforms.ConvertColorSpace("grayscale"), + itertools.chain( + make_images(), + make_vanilla_tensor_images(color_spaces=["rgb"]), + make_pil_images(color_spaces=["rgb"]), + ), + ) + ] + ) + def test_convert_bounding_color_space(self, transform, input): + transform(input) + + @parametrize( + [ + ( + transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"), + itertools.chain( + make_bounding_boxes(), + make_vanilla_tensor_bounding_boxes(formats=["xywh"]), + ), + ) + ] + ) + def test_convert_bounding_box_format(self, transform, input): + transform(input) + + @parametrize( + [ + ( + transforms.RandomResizedCrop([16, 16]), + itertools.chain( + make_images(extra_dims=[(4,)]), + make_vanilla_tensor_images(), + make_pil_images(), + ), + ) + ] + ) + def test_random_resized_crop(self, transform, input): + transform(input) diff --git a/test/test_prototype_transforms_kernels.py b/test/test_prototype_transforms_kernels.py index b83febd8915..fb436a6a830 100644 --- a/test/test_prototype_transforms_kernels.py +++ b/test/test_prototype_transforms_kernels.py @@ -5,6 +5,7 @@ import torch.testing import torchvision.prototype.transforms.kernels as K from torch import jit +from torch.nn.functional import one_hot from torchvision.prototype import features make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") @@ -39,10 +40,10 @@ def make_images( extra_dims=((4,), (2, 3)), ): for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes): - yield make_image(size, color_space=color_space) + yield make_image(size, color_space=color_space, dtype=dtype) - for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims): - yield make_image(color_space=color_space, extra_dims=extra_dims_) + for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims): + yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype) def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): @@ -106,6 +107,27 @@ def make_bounding_boxes( yield make_bounding_box(format=format, extra_dims=extra_dims_) +def make_label(size=(), *, categories=("category0", "category1")): + return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories) + + +def make_one_hot_label(*args, **kwargs): + label = make_label(*args, **kwargs) + return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories) + + +def make_one_hot_labels( + *, + num_categories=(1, 2, 10), + extra_dims=((4,), (2, 3)), +): + for num_categories_ in num_categories: + yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)]) + + for extra_dims_ in extra_dims: + yield make_one_hot_label(extra_dims_) + + class SampleInput: def __init__(self, *args, **kwargs): self.args = args diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index c9988be1930..420db8a4b4f 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,5 +1,13 @@ +from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip from . import kernels # usort: skip from . import functional # usort: skip -from .kernels import InterpolationMode # usort: skip +from ._transform import Transform # usort: skip +from ._augment import RandomErasing, RandomMixup, RandomCutmix +from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment +from ._container import Compose, RandomApply, RandomChoice, RandomOrder +from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop +from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace +from ._misc import Identity, Normalize, ToDtype, Lambda from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval +from ._type_conversion import DecodeImage, LabelToOneHot diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py new file mode 100644 index 00000000000..08e6704088b --- /dev/null +++ b/torchvision/prototype/transforms/_augment.py @@ -0,0 +1,140 @@ +import math +import numbers +import warnings +from typing import Any, Dict, Tuple + +import torch +from torchvision.prototype.transforms import Transform, functional as F + +from ._utils import query_image + + +class RandomErasing(Transform): + _DISPATCHER = F.erase + + def __init__( + self, + p: float = 0.5, + scale: Tuple[float, float] = (0.02, 0.33), + ratio: Tuple[float, float] = (0.3, 3.3), + value: float = 0, + ): + super().__init__() + if not isinstance(value, (numbers.Number, str, tuple, list)): + raise TypeError("Argument value should be either a number or str or a sequence") + if isinstance(value, str) and value != "random": + raise ValueError("If value is str, it should be 'random'") + if not isinstance(scale, (tuple, list)): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, (tuple, list)): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + if scale[0] < 0 or scale[1] > 1: + raise ValueError("Scale should be between 0 and 1") + if p < 0 or p > 1: + raise ValueError("Random erasing probability should be between 0 and 1") + # TODO: deprecate p in favor of wrapping the transform in a RandomApply + self.p = p + self.scale = scale + self.ratio = ratio + self.value = value + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + img_h, img_w = F.get_image_size(image) + img_c = F.get_image_num_channels(image) + + if isinstance(self.value, (int, float)): + value = [self.value] + elif isinstance(self.value, str): + value = None + elif isinstance(self.value, tuple): + value = list(self.value) + else: + value = self.value + + if value is not None and not (len(value) in (1, img_c)): + raise ValueError( + f"If value is a sequence, it should have either a single value or {img_c} (number of input channels)" + ) + + area = img_h * img_w + + log_ratio = torch.log(torch.tensor(self.ratio)) + for _ in range(10): + erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_( + log_ratio[0], # type: ignore[arg-type] + log_ratio[1], # type: ignore[arg-type] + ) + ).item() + + h = int(round(math.sqrt(erase_area * aspect_ratio))) + w = int(round(math.sqrt(erase_area / aspect_ratio))) + if not (h < img_h and w < img_w): + continue + + if value is None: + v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() + else: + v = torch.tensor(value)[:, None, None] + + i = torch.randint(0, img_h - h + 1, size=(1,)).item() + j = torch.randint(0, img_w - w + 1, size=(1,)).item() + break + else: + i, j, h, w, v = 0, 0, img_h, img_w, image + + return dict(zip("ijhwv", (i, j, h, w, v))) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if torch.rand(1) >= self.p: + return input + + return super()._transform(input, params) + + +class RandomMixup(Transform): + _DISPATCHER = F.mixup + + def __init__(self, *, alpha: float) -> None: + super().__init__() + self.alpha = alpha + self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) + + def _get_params(self, sample: Any) -> Dict[str, Any]: + return dict(lam=float(self._dist.sample(()))) + + +class RandomCutmix(Transform): + _DISPATCHER = F.cutmix + + def __init__(self, *, alpha: float) -> None: + super().__init__() + self.alpha = alpha + self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) + + def _get_params(self, sample: Any) -> Dict[str, Any]: + lam = float(self._dist.sample(())) + + image = query_image(sample) + H, W = F.get_image_size(image) + + r_x = torch.randint(W, ()) + r_y = torch.randint(H, ()) + + r = 0.5 * math.sqrt(1.0 - lam) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + box = (x1, y1, x2, y2) + + lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + return dict(box=box, lam_adjusted=lam_adjusted) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py new file mode 100644 index 00000000000..3fab0f92926 --- /dev/null +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -0,0 +1,356 @@ +import math +from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar + +import PIL.Image +import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F +from torchvision.prototype.utils._internal import apply_recursively + +from ._utils import query_image + +K = TypeVar("K") +V = TypeVar("V") + + +class _AutoAugmentBase(Transform): + def __init__( + self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None + ) -> None: + super().__init__() + self.interpolation = interpolation + self.fill = fill + + _DISPATCHER_MAP: Dict[str, Callable[[Any, float, InterpolationMode, Optional[List[float]]], Any]] = { + "Identity": lambda input, magnitude, interpolation, fill: input, + "ShearX": lambda input, magnitude, interpolation, fill: F.affine( + input, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(magnitude), 0.0], + interpolation=interpolation, + fill=fill, + ), + "ShearY": lambda input, magnitude, interpolation, fill: F.affine( + input, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(magnitude)], + interpolation=interpolation, + fill=fill, + ), + "TranslateX": lambda input, magnitude, interpolation, fill: F.affine( + input, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, + ), + "TranslateY": lambda input, magnitude, interpolation, fill: F.affine( + input, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, + ), + "Rotate": lambda input, magnitude, interpolation, fill: F.rotate(input, angle=magnitude), + "Brightness": lambda input, magnitude, interpolation, fill: F.adjust_brightness( + input, brightness_factor=1.0 + magnitude + ), + "Color": lambda input, magnitude, interpolation, fill: F.adjust_saturation( + input, saturation_factor=1.0 + magnitude + ), + "Contrast": lambda input, magnitude, interpolation, fill: F.adjust_contrast( + input, contrast_factor=1.0 + magnitude + ), + "Sharpness": lambda input, magnitude, interpolation, fill: F.adjust_sharpness( + input, sharpness_factor=1.0 + magnitude + ), + "Posterize": lambda input, magnitude, interpolation, fill: F.posterize(input, bits=int(magnitude)), + "Solarize": lambda input, magnitude, interpolation, fill: F.solarize(input, threshold=magnitude), + "AutoContrast": lambda input, magnitude, interpolation, fill: F.autocontrast(input), + "Equalize": lambda input, magnitude, interpolation, fill: F.equalize(input), + "Invert": lambda input, magnitude, interpolation, fill: F.invert(input), + } + + def _is_supported(self, obj: Any) -> bool: + return type(obj) in {features.Image, torch.Tensor} or isinstance(obj, PIL.Image.Image) + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + num_channels = F.get_image_num_channels(image) + + fill = self.fill + if isinstance(fill, (int, float)): + fill = [float(fill)] * num_channels + elif fill is not None: + fill = [float(f) for f in fill] + + return dict(interpolation=self.interpolation, fill=fill) + + def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: + keys = tuple(dct.keys()) + key = keys[int(torch.randint(len(keys), ()))] + return key, dct[key] + + def _apply_transform(self, sample: Any, params: Dict[str, Any], transform_id: str, magnitude: float) -> Any: + dispatcher = self._DISPATCHER_MAP[transform_id] + + def transform(input: Any) -> Any: + if not self._is_supported(input): + return input + + return dispatcher(input, magnitude, params["interpolation"], params["fill"]) + + return apply_recursively(transform, sample) + + +class AutoAugment(_AutoAugmentBase): + _AUGMENTATION_SPACE = { + "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": ( + lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + .round() + .int(), + False, + ), + "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, image_size: None, False), + "Equalize": (lambda num_bins, image_size: None, False), + "Invert": (lambda num_bins, image_size: None, False), + } + + def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.policy = policy + self._policies = self._get_policies(policy) + + def _get_policies( + self, policy: AutoAugmentPolicy + ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: + if policy == AutoAugmentPolicy.IMAGENET: + return [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + return [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + return [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + else: + raise ValueError(f"The provided policy {policy} is not recognized.") + + def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + params = params or self._get_params(sample) + + image = query_image(sample) + image_size = F.get_image_size(image) + + policy = self._policies[int(torch.randint(len(self._policies), ()))] + + for transform_id, probability, magnitude_idx in policy: + if not torch.rand(()) <= probability: + continue + + magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] + + magnitudes = magnitudes_fn(10, image_size) + if magnitudes is not None: + magnitude = float(magnitudes[magnitude_idx]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + sample = self._apply_transform(sample, params, transform_id, magnitude) + + return sample + + +class RandAugment(_AutoAugmentBase): + _AUGMENTATION_SPACE = { + "Identity": (lambda num_bins, image_size: None, False), + "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": ( + lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + .round() + .int(), + False, + ), + "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, image_size: None, False), + "Equalize": (lambda num_bins, image_size: None, False), + } + + def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.num_ops = num_ops + self.magnitude = magnitude + self.num_magnitude_bins = num_magnitude_bins + + def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + params = params or self._get_params(sample) + + image = query_image(sample) + image_size = F.get_image_size(image) + + for _ in range(self.num_ops): + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) + + magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size) + if magnitudes is not None: + magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + sample = self._apply_transform(sample, params, transform_id, magnitude) + + return sample + + +class TrivialAugmentWide(_AutoAugmentBase): + _AUGMENTATION_SPACE = { + "Identity": (lambda num_bins, image_size: None, False), + "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Posterize": ( + lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) + .round() + .int(), + False, + ), + "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, image_size: None, False), + "Equalize": (lambda num_bins, image_size: None, False), + } + + def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): + super().__init__(**kwargs) + self.num_magnitude_bins = num_magnitude_bins + + def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + params = params or self._get_params(sample) + + image = query_image(sample) + image_size = F.get_image_size(image) + + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) + + magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size) + if magnitudes is not None: + magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + return self._apply_transform(sample, params, transform_id, magnitude) diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py new file mode 100644 index 00000000000..d2a0d642626 --- /dev/null +++ b/torchvision/prototype/transforms/_container.py @@ -0,0 +1,63 @@ +from typing import Any, Optional, Dict + +import torch + +from ._transform import Transform + + +class Compose(Transform): + def __init__(self, *transforms: Transform): + super().__init__() + self.transforms = transforms + for idx, transform in enumerate(transforms): + self.add_module(str(idx), transform) + + def forward(self, *inputs: Any) -> Any: # type: ignore[override] + sample = inputs if len(inputs) > 1 else inputs[0] + for transform in self.transforms: + sample = transform(sample) + return sample + + +class RandomApply(Transform): + def __init__(self, transform: Transform, *, p: float = 0.5) -> None: + super().__init__() + self.transform = transform + self.p = p + + def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if float(torch.rand(())) < self.p: + return sample + + return self.transform(sample, params=params) + + def extra_repr(self) -> str: + return f"p={self.p}" + + +class RandomChoice(Transform): + def __init__(self, *transforms: Transform): + super().__init__() + self.transforms = transforms + for idx, transform in enumerate(transforms): + self.add_module(str(idx), transform) + + def forward(self, *inputs: Any) -> Any: # type: ignore[override] + idx = int(torch.randint(len(self.transforms), size=())) + transform = self.transforms[idx] + return transform(*inputs) + + +class RandomOrder(Transform): + def __init__(self, *transforms: Transform): + super().__init__() + self.transforms = transforms + for idx, transform in enumerate(transforms): + self.add_module(str(idx), transform) + + def forward(self, *inputs: Any) -> Any: # type: ignore[override] + for idx in torch.randperm(len(self.transforms)): + transform = self.transforms[idx] + inputs = transform(*inputs) + return inputs diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py new file mode 100644 index 00000000000..1b321286461 --- /dev/null +++ b/torchvision/prototype/transforms/_geometry.py @@ -0,0 +1,115 @@ +import math +import warnings +from typing import Any, Dict, List, Union, Sequence, Tuple, cast + +import torch +from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F +from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int + +from ._utils import query_image + + +class HorizontalFlip(Transform): + _DISPATCHER = F.horizontal_flip + + +class Resize(Transform): + _DISPATCHER = F.resize + + def __init__( + self, + size: Union[int, Sequence[int]], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ) -> None: + super().__init__() + self.size = size + self.interpolation = interpolation + + def _get_params(self, sample: Any) -> Dict[str, Any]: + return dict(size=self.size, interpolation=self.interpolation) + + +class CenterCrop(Transform): + _DISPATCHER = F.center_crop + + def __init__(self, output_size: List[int]): + super().__init__() + self.output_size = output_size + + def _get_params(self, sample: Any) -> Dict[str, Any]: + return dict(output_size=self.output_size) + + +class RandomResizedCrop(Transform): + _DISPATCHER = F.resized_crop + + def __init__( + self, + size: Union[int, Sequence[int]], + scale: Tuple[float, float] = (0.08, 1.0), + ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if not isinstance(scale, Sequence): + raise TypeError("Scale should be a sequence") + scale = cast(Tuple[float, float], scale) + if not isinstance(ratio, Sequence): + raise TypeError("Ratio should be a sequence") + ratio = cast(Tuple[float, float], ratio) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationMode instead of int. " + "Please, use InterpolationMode enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + + self.size = size + self.scale = scale + self.ratio = ratio + self.interpolation = interpolation + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + height, width = F.get_image_size(image) + area = height * width + + log_ratio = torch.log(torch.tensor(self.ratio)) + for _ in range(10): + target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_( + log_ratio[0], # type: ignore[arg-type] + log_ratio[1], # type: ignore[arg-type] + ) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + break + else: + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + + return dict(top=i, left=j, height=h, width=w, size=self.size) diff --git a/torchvision/prototype/transforms/_meta_conversion.py b/torchvision/prototype/transforms/_meta_conversion.py new file mode 100644 index 00000000000..d7a1e6a76fa --- /dev/null +++ b/torchvision/prototype/transforms/_meta_conversion.py @@ -0,0 +1,53 @@ +from typing import Union, Any, Dict, Optional + +import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import Transform, functional as F +from torchvision.transforms.functional import convert_image_dtype + + +class ConvertBoundingBoxFormat(Transform): + _DISPATCHER = F.convert_format + + def __init__( + self, + format: Union[str, features.BoundingBoxFormat], + old_format: Optional[Union[str, features.BoundingBoxFormat]] = None, + ) -> None: + super().__init__() + if isinstance(format, str): + format = features.BoundingBoxFormat[format] + self.format = format + + if isinstance(old_format, str): + old_format = features.BoundingBoxFormat[old_format] + self.old_format = old_format + + def _get_params(self, sample: Any) -> Dict[str, Any]: + return dict(format=self.format, old_format=self.old_format) + + +class ConvertImageDtype(Transform): + def __init__(self, dtype: torch.dtype = torch.float32) -> None: + super().__init__() + self.dtype = dtype + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if not isinstance(input, features.Image): + return input + + output = convert_image_dtype(input, dtype=self.dtype) + return features.Image.new_like(input, output, dtype=self.dtype) + + +class ConvertColorSpace(Transform): + _DISPATCHER = F.convert_color_space + + def __init__(self, color_space: Union[str, features.ColorSpace]) -> None: + super().__init__() + if isinstance(color_space, str): + color_space = features.ColorSpace[color_space] + self.color_space = color_space + + def _get_params(self, sample: Any) -> Dict[str, Any]: + return dict(color_space=self.color_space) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py new file mode 100644 index 00000000000..42502e74874 --- /dev/null +++ b/torchvision/prototype/transforms/_misc.py @@ -0,0 +1,52 @@ +import functools +from typing import Any, List, Type, Callable, Dict + +import torch +from torchvision.prototype.transforms import Transform, functional as F + + +class Identity(Transform): + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + return input + + +class Lambda(Transform): + def __init__(self, fn: Callable[[Any], Any], *types: Type): + super().__init__() + self.fn = fn + self.types = types + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if not isinstance(input, self.types): + return input + + return self.fn(input) + + def extra_repr(self) -> str: + extras = [] + name = getattr(self.fn, "__name__", None) + if name: + extras.append(name) + extras.append(f"types={[type.__name__ for type in self.types]}") + return ", ".join(extras) + + +class Normalize(Transform): + _DISPATCHER = F.normalize + + def __init__(self, mean: List[float], std: List[float]): + super().__init__() + self.mean = mean + self.std = std + + def _get_params(self, sample: Any) -> Dict[str, Any]: + return dict(mean=self.mean, std=self.std) + + +class ToDtype(Lambda): + def __init__(self, dtype: torch.dtype, *types: Type) -> None: + self.dtype = dtype + super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types) + + def extra_repr(self) -> str: + return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"]) diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py new file mode 100644 index 00000000000..a96bed31329 --- /dev/null +++ b/torchvision/prototype/transforms/_transform.py @@ -0,0 +1,46 @@ +import enum +import functools +from typing import Any, Dict, Optional + +from torch import nn +from torchvision.prototype.utils._internal import apply_recursively +from torchvision.utils import _log_api_usage_once + +from .functional._utils import Dispatcher + + +class Transform(nn.Module): + _DISPATCHER: Optional[Dispatcher] = None + + def __init__(self) -> None: + super().__init__() + _log_api_usage_once(self) + + def _get_params(self, sample: Any) -> Dict[str, Any]: + return dict() + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if not self._DISPATCHER: + raise NotImplementedError() + + if input not in self._DISPATCHER: + return input + + return self._DISPATCHER(input, **params) + + def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + return apply_recursively(functools.partial(self._transform, params=params or self._get_params(sample)), sample) + + def extra_repr(self) -> str: + extra = [] + for name, value in self.__dict__.items(): + if name.startswith("_") or name == "training": + continue + + if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)): + continue + + extra.append(f"{name}={value}") + + return ", ".join(extra) diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py new file mode 100644 index 00000000000..3f4afc571d9 --- /dev/null +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -0,0 +1,35 @@ +from typing import Any, Dict + +from torchvision.prototype import features +from torchvision.prototype.transforms import Transform, kernels as K + + +class DecodeImage(Transform): + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if not isinstance(input, features.EncodedImage): + return input + + return features.Image(K.decode_image_with_pil(input)) + + +class LabelToOneHot(Transform): + def __init__(self, num_categories: int = -1): + super().__init__() + self.num_categories = num_categories + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if not isinstance(input, features.Label): + return input + + num_categories = self.num_categories + if num_categories == -1 and input.categories is not None: + num_categories = len(input.categories) + return features.OneHotLabel( + K.label_to_one_hot(input, num_categories=num_categories), categories=input.categories + ) + + def extra_repr(self) -> str: + if self.num_categories == -1: + return "" + + return f"num_categories={self.num_categories}" diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py new file mode 100644 index 00000000000..7f29d817499 --- /dev/null +++ b/torchvision/prototype/transforms/_utils.py @@ -0,0 +1,19 @@ +from typing import Any, Union, Optional + +import PIL.Image +import torch +from torchvision.prototype import features +from torchvision.prototype.utils._internal import query_recursively + + +def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: + def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]: + if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): + return input + + return None + + try: + return next(query_recursively(fn, sample)) + except StopIteration: + raise TypeError("No image was found in the sample") diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 9f05f16df2d..9ebe46989d3 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -11,4 +11,5 @@ invert, ) from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate -from ._misc import normalize +from ._meta_conversion import convert_color_space, convert_format +from ._misc import normalize, get_image_size, get_image_num_channels diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index e5e93aa0b4f..4f9835bfe01 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Any +from typing import Any import torch from torchvision.prototype import features @@ -7,8 +7,6 @@ from ._utils import dispatch -T = TypeVar("T", bound=features._Feature) - @dispatch( { @@ -16,7 +14,7 @@ features.Image: K.erase_image, } ) -def erase(input: T, *args: Any, **kwargs: Any) -> T: +def erase(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -27,18 +25,18 @@ def erase(input: T, *args: Any, **kwargs: Any) -> T: features.OneHotLabel: K.mixup_one_hot_label, } ) -def mixup(input: T, *args: Any, **kwargs: Any) -> T: +def mixup(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @dispatch( { - features.Image: K.cutmix_image, - features.OneHotLabel: K.cutmix_one_hot_label, + features.Image: None, + features.OneHotLabel: None, } ) -def cutmix(input: T, *args: Any, **kwargs: Any) -> T: +def cutmix(input: Any, *args: Any, **kwargs: Any) -> Any: """Perform the CutMix operation as introduced in the paper `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" `_. @@ -54,4 +52,13 @@ def cutmix(input: T, *args: Any, **kwargs: Any) -> T: Please refer to the kernel documentations for a detailed explanation of the functionality and parameters. """ - ... + if isinstance(input, features.Image): + kwargs.pop("lam_adjusted", None) + output = K.cutmix_image(input, **kwargs) + return features.Image.new_like(input, output) + elif isinstance(input, features.OneHotLabel): + kwargs.pop("box", None) + output = K.cutmix_one_hot_label(input, **kwargs) + return features.OneHotLabel.new_like(input, output) + + raise RuntimeError diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 290ae7094ce..6c4f9d33a28 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Any +from typing import Any import PIL.Image import torch @@ -8,8 +8,6 @@ from ._utils import dispatch -T = TypeVar("T", bound=features._Feature) - @dispatch( { @@ -18,7 +16,7 @@ features.Image: K.adjust_brightness_image, } ) -def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T: +def adjust_brightness(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -30,7 +28,7 @@ def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.adjust_saturation_image, } ) -def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T: +def adjust_saturation(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -42,7 +40,7 @@ def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.adjust_contrast_image, } ) -def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T: +def adjust_contrast(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -54,7 +52,7 @@ def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.adjust_sharpness_image, } ) -def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T: +def adjust_sharpness(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -66,7 +64,7 @@ def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.posterize_image, } ) -def posterize(input: T, *args: Any, **kwargs: Any) -> T: +def posterize(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -78,7 +76,7 @@ def posterize(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.solarize_image, } ) -def solarize(input: T, *args: Any, **kwargs: Any) -> T: +def solarize(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -90,7 +88,7 @@ def solarize(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.autocontrast_image, } ) -def autocontrast(input: T, *args: Any, **kwargs: Any) -> T: +def autocontrast(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -102,7 +100,7 @@ def autocontrast(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.equalize_image, } ) -def equalize(input: T, *args: Any, **kwargs: Any) -> T: +def equalize(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -114,7 +112,7 @@ def equalize(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.invert_image, } ) -def invert(input: T, *args: Any, **kwargs: Any) -> T: +def invert(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -126,7 +124,7 @@ def invert(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.adjust_hue_image, } ) -def adjust_hue(input: T, *args: Any, **kwargs: Any) -> T: +def adjust_hue(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -138,6 +136,6 @@ def adjust_hue(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.adjust_gamma_image, } ) -def adjust_gamma(input: T, *args: Any, **kwargs: Any) -> T: +def adjust_gamma(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ae930bfc5f1..06ecd38dac0 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Any, cast +from typing import Any import PIL.Image import torch @@ -8,8 +8,6 @@ from ._utils import dispatch -T = TypeVar("T", bound=features._Feature) - @dispatch( { @@ -19,11 +17,11 @@ features.BoundingBox: None, }, ) -def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T: +def horizontal_flip(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" if isinstance(input, features.BoundingBox): output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) - return cast(T, features.BoundingBox.new_like(input, output)) + return features.BoundingBox.new_like(input, output) raise RuntimeError @@ -37,12 +35,12 @@ def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T: features.BoundingBox: None, } ) -def resize(input: T, *args: Any, **kwargs: Any) -> T: +def resize(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" if isinstance(input, features.BoundingBox): size = kwargs.pop("size") output = K.resize_bounding_box(input, size=size, image_size=input.image_size) - return cast(T, features.BoundingBox.new_like(input, output, image_size=size)) + return features.BoundingBox.new_like(input, output, image_size=size) raise RuntimeError @@ -54,7 +52,7 @@ def resize(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.center_crop_image, } ) -def center_crop(input: T, *args: Any, **kwargs: Any) -> T: +def center_crop(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -66,7 +64,7 @@ def center_crop(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.resized_crop_image, } ) -def resized_crop(input: T, *args: Any, **kwargs: Any) -> T: +def resized_crop(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -78,7 +76,7 @@ def resized_crop(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.affine_image, } ) -def affine(input: T, *args: Any, **kwargs: Any) -> T: +def affine(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -90,7 +88,7 @@ def affine(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.rotate_image, } ) -def rotate(input: T, *args: Any, **kwargs: Any) -> T: +def rotate(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -102,7 +100,7 @@ def rotate(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.pad_image, } ) -def pad(input: T, *args: Any, **kwargs: Any) -> T: +def pad(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -114,7 +112,7 @@ def pad(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.crop_image, } ) -def crop(input: T, *args: Any, **kwargs: Any) -> T: +def crop(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -126,7 +124,7 @@ def crop(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.perspective_image, } ) -def perspective(input: T, *args: Any, **kwargs: Any) -> T: +def perspective(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -138,7 +136,7 @@ def perspective(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.vertical_flip_image, } ) -def vertical_flip(input: T, *args: Any, **kwargs: Any) -> T: +def vertical_flip(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -150,7 +148,7 @@ def vertical_flip(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.five_crop_image, } ) -def five_crop(input: T, *args: Any, **kwargs: Any) -> T: +def five_crop(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -162,6 +160,6 @@ def five_crop(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.ten_crop_image, } ) -def ten_crop(input: T, *args: Any, **kwargs: Any) -> T: +def ten_crop(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... diff --git a/torchvision/prototype/transforms/functional/_meta_conversion.py b/torchvision/prototype/transforms/functional/_meta_conversion.py new file mode 100644 index 00000000000..bbda3ea939a --- /dev/null +++ b/torchvision/prototype/transforms/functional/_meta_conversion.py @@ -0,0 +1,50 @@ +from typing import Any + +import PIL.Image +import torch +from torchvision.ops import box_convert +from torchvision.prototype import features +from torchvision.prototype.transforms import kernels as K +from torchvision.transforms import functional as _F + +from ._utils import dispatch + + +@dispatch( + { + torch.Tensor: None, + features.BoundingBox: None, + } +) +def convert_format(input: Any, *args: Any, **kwargs: Any) -> Any: + format = kwargs["format"] + if type(input) is torch.Tensor: + old_format = kwargs.get("old_format") + if old_format is None: + raise TypeError("For vanilla tensors the `old_format` needs to be provided.") + return box_convert(input, in_fmt=kwargs["old_format"].name.lower(), out_fmt=format.name.lower()) + elif isinstance(input, features.BoundingBox): + output = K.convert_bounding_box_format(input, old_format=input.format, new_format=kwargs["format"]) + return features.BoundingBox.new_like(input, output, format=format) + + raise RuntimeError + + +@dispatch( + { + torch.Tensor: None, + PIL.Image.Image: None, + features.Image: None, + } +) +def convert_color_space(input: Any, *args: Any, **kwargs: Any) -> Any: + color_space = kwargs["color_space"] + if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): + if color_space != features.ColorSpace.GRAYSCALE: + raise ValueError("For vanilla tensors and PIL images only RGB to grayscale is supported") + return _F.rgb_to_grayscale(input) + elif isinstance(input, features.Image): + output = K.convert_color_space(input, old_color_space=input.color_space, new_color_space=color_space) + return features.Image.new_like(input, output, color_space=color_space) + + raise RuntimeError diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 45e1bdefd3d..212492230ea 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,15 +1,21 @@ -from typing import TypeVar, Any +from typing import Any import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import kernels as K from torchvision.transforms import functional as _F +from torchvision.transforms.functional_pil import ( + get_image_size as _get_image_size_pil, + get_image_num_channels as _get_image_num_channels_pil, +) +from torchvision.transforms.functional_tensor import ( + get_image_size as _get_image_size_tensor, + get_image_num_channels as _get_image_num_channels_tensor, +) from ._utils import dispatch -T = TypeVar("T", bound=features._Feature) - @dispatch( { @@ -17,7 +23,7 @@ features.Image: K.normalize_image, } ) -def normalize(input: T, *args: Any, **kwargs: Any) -> T: +def normalize(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... @@ -29,6 +35,35 @@ def normalize(input: T, *args: Any, **kwargs: Any) -> T: features.Image: K.gaussian_blur_image, } ) -def ten_gaussian_blur(input: T, *args: Any, **kwargs: Any) -> T: +def gaussian_blur(input: Any, *args: Any, **kwargs: Any) -> Any: """TODO: add docstring""" ... + + +@dispatch( + { + torch.Tensor: _get_image_size_tensor, + PIL.Image.Image: _get_image_size_pil, + features.Image: None, + features.BoundingBox: None, + } +) +def get_image_size(input: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(input, (features.Image, features.BoundingBox)): + return list(input.image_size) + + raise RuntimeError + + +@dispatch( + { + torch.Tensor: _get_image_num_channels_tensor, + PIL.Image.Image: _get_image_num_channels_pil, + features.Image: None, + } +) +def get_image_num_channels(input: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(input, features.Image): + return input.num_channels + + raise RuntimeError diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 591f9a83101..abdee565bc4 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -1,6 +1,5 @@ -import functools import inspect -from typing import Any, Optional, Callable, TypeVar, Dict +from typing import Any, Optional, Callable, TypeVar, Mapping, Type import torch import torch.overrides @@ -9,10 +8,10 @@ F = TypeVar("F", bound=features._Feature) -def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[..., F]], Callable[..., F]]: - """Decorates a function to automatically dispatch to registered kernels based on the call arguments. +class Dispatcher: + """Wrap a function to automatically dispatch to registered kernels based on the call arguments. - The dispatch function should have this signature + The wrapped function should have this signature .. code:: python @@ -34,7 +33,19 @@ def dispatch_fn(input, *args, **kwargs): TypeError: If the decorated function is called with an input that cannot be dispatched. """ - def check_kernel(kernel: Any) -> bool: + def __init__(self, fn: Callable, kernels: Mapping[Type, Optional[Callable]]): + self._fn = fn + + for feature_type, kernel in kernels.items(): + if not self._check_kernel(kernel): + raise TypeError( + f"Kernel for feature type {feature_type.__name__} is not callable with " + f"kernel(input, *args, **kwargs)." + ) + + self._kernels = kernels + + def _check_kernel(self, kernel: Optional[Callable]) -> bool: if kernel is None: return True @@ -47,43 +58,45 @@ def check_kernel(kernel: Any) -> bool: return params[0].kind != inspect.Parameter.KEYWORD_ONLY - for feature_type, kernel in kernels.items(): - if not check_kernel(kernel): - raise TypeError( - f"Kernel for feature type {feature_type.__name__} is not callable with kernel(input, *args, **kwargs)." - ) - - def outer_wrapper(dispatch_fn: Callable[..., F]) -> Callable[..., F]: - @functools.wraps(dispatch_fn) - def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F: - feature_type = type(input) + def _resolve(self, feature_type: Type) -> Optional[Callable]: + try: + return self._kernels[feature_type] + except KeyError: try: - kernel = kernels[feature_type] - except KeyError: - try: - feature_type, kernel = next( - (feature_type, kernel) - for feature_type, kernel in kernels.items() - if isinstance(input, feature_type) - ) - except StopIteration: - raise TypeError(f"No support for {type(input).__name__}") from None - - if kernel is None: - output = dispatch_fn(input, *args, **kwargs) - if output is None: - raise RuntimeError( - f"{dispatch_fn.__name__}() did not handle inputs of type {type(input).__name__} " - f"although it was configured to do so." - ) - else: - output = kernel(input, *args, **kwargs) - - if issubclass(feature_type, features._Feature) and type(output) is torch.Tensor: - output = feature_type.new_like(input, output) - - return output - - return inner_wrapper - - return outer_wrapper + return next( + kernel + for registered_feature_type, kernel in self._kernels.items() + if issubclass(feature_type, registered_feature_type) + ) + except StopIteration: + raise TypeError(f"No support for feature type {feature_type.__name__}") from None + + def __contains__(self, obj: Any) -> bool: + try: + self._resolve(type(obj)) + return True + except TypeError: + return False + + def __call__(self, input: Any, *args: Any, **kwargs: Any) -> Any: + kernel = self._resolve(type(input)) + + if kernel is None: + output = self._fn(input, *args, **kwargs) + if output is None: + raise RuntimeError( + f"{self._fn.__name__}() did not handle inputs of type {type(input).__name__} " + f"although it was configured to do so." + ) + else: + output = kernel(input, *args, **kwargs) + + if isinstance(input, features._Feature) and type(output) is torch.Tensor: + output = type(input).new_like(input, output) + + return output + + +def dispatch(kernels: Mapping[Type, Optional[Callable]]) -> Callable[[Callable], Dispatcher]: + """Decorates a function and turns it into a :class:`Dispatcher`.""" + return lambda fn: Dispatcher(fn, kernels) diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index fe75c19eb75..2e38471ea65 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -24,8 +24,7 @@ Tuple, TypeVar, Union, - List, - Dict, + Optional, ) import numpy as np @@ -42,6 +41,7 @@ "fromfile", "ReadOnlyTensorBuffer", "apply_recursively", + "query_recursively", ] @@ -305,22 +305,22 @@ def apply_recursively(fn: Callable, obj: Any) -> Any: # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: # "a" == "a"[0][0]... if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): - sequence: List[Any] = [] - for item in obj: - result = apply_recursively(fn, item) - if isinstance(result, collections.abc.Sequence) and hasattr(result, "__inline__"): - sequence.extend(result) - else: - sequence.append(result) - return sequence + return [apply_recursively(fn, item) for item in obj] elif isinstance(obj, collections.abc.Mapping): - mapping: Dict[Any, Any] = {} - for name, item in obj.items(): - result = apply_recursively(fn, item) - if isinstance(result, collections.abc.Mapping) and hasattr(result, "__inline__"): - mapping.update(result) - else: - mapping[name] = result - return mapping + return {key: apply_recursively(fn, item) for key, item in obj.items()} else: return fn(obj) + + +def query_recursively(fn: Callable[[Any], Optional[D]], obj: Any) -> Iterator[D]: + # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: + # "a" == "a"[0][0]... + if (isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str)) or isinstance( + obj, collections.abc.Mapping + ): + for item in obj.values() if isinstance(obj, collections.abc.Mapping) else obj: + yield from query_recursively(fn, item) + else: + result = fn(obj) + if result is not None: + yield result