|
| 1 | +import collections.abc |
| 2 | +import functools |
| 3 | +from typing import Any, Dict, Union, Tuple, Optional, Sequence, Callable, TypeVar |
| 4 | + |
| 5 | +import PIL.Image |
| 6 | +import torch |
| 7 | +from torchvision.prototype import features |
| 8 | +from torchvision.prototype.transforms import Transform, functional as F |
| 9 | + |
| 10 | +from ._utils import is_simple_tensor |
| 11 | + |
| 12 | +T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image) |
| 13 | + |
| 14 | + |
| 15 | +class ColorJitter(Transform): |
| 16 | + def __init__( |
| 17 | + self, |
| 18 | + brightness: Optional[Union[float, Sequence[float]]] = None, |
| 19 | + contrast: Optional[Union[float, Sequence[float]]] = None, |
| 20 | + saturation: Optional[Union[float, Sequence[float]]] = None, |
| 21 | + hue: Optional[Union[float, Sequence[float]]] = None, |
| 22 | + ) -> None: |
| 23 | + super().__init__() |
| 24 | + self.brightness = self._check_input(brightness, "brightness") |
| 25 | + self.contrast = self._check_input(contrast, "contrast") |
| 26 | + self.saturation = self._check_input(saturation, "saturation") |
| 27 | + self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) |
| 28 | + |
| 29 | + def _check_input( |
| 30 | + self, |
| 31 | + value: Optional[Union[float, Sequence[float]]], |
| 32 | + name: str, |
| 33 | + center: float = 1.0, |
| 34 | + bound: Tuple[float, float] = (0, float("inf")), |
| 35 | + clip_first_on_zero: bool = True, |
| 36 | + ) -> Optional[Tuple[float, float]]: |
| 37 | + if value is None: |
| 38 | + return None |
| 39 | + |
| 40 | + if isinstance(value, float): |
| 41 | + if value < 0: |
| 42 | + raise ValueError(f"If {name} is a single number, it must be non negative.") |
| 43 | + value = [center - value, center + value] |
| 44 | + if clip_first_on_zero: |
| 45 | + value[0] = max(value[0], 0.0) |
| 46 | + elif isinstance(value, collections.abc.Sequence) and len(value) == 2: |
| 47 | + if not bound[0] <= value[0] <= value[1] <= bound[1]: |
| 48 | + raise ValueError(f"{name} values should be between {bound}") |
| 49 | + else: |
| 50 | + raise TypeError(f"{name} should be a single number or a sequence with length 2.") |
| 51 | + |
| 52 | + return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) |
| 53 | + |
| 54 | + def _image_transform( |
| 55 | + self, |
| 56 | + input: T, |
| 57 | + *, |
| 58 | + kernel_tensor: Callable[..., torch.Tensor], |
| 59 | + kernel_pil: Callable[..., PIL.Image.Image], |
| 60 | + **kwargs: Any, |
| 61 | + ) -> T: |
| 62 | + if isinstance(input, features.Image): |
| 63 | + output = kernel_tensor(input, **kwargs) |
| 64 | + return features.Image.new_like(input, output) |
| 65 | + elif is_simple_tensor(input): |
| 66 | + return kernel_tensor(input, **kwargs) |
| 67 | + elif isinstance(input, PIL.Image.Image): |
| 68 | + return kernel_pil(input, **kwargs) # type: ignore[no-any-return] |
| 69 | + else: |
| 70 | + raise RuntimeError |
| 71 | + |
| 72 | + def _get_params(self, sample: Any) -> Dict[str, Any]: |
| 73 | + image_transforms = [] |
| 74 | + if self.brightness is not None: |
| 75 | + image_transforms.append( |
| 76 | + functools.partial( |
| 77 | + self._image_transform, |
| 78 | + kernel_tensor=F.adjust_brightness_image_tensor, |
| 79 | + kernel_pil=F.adjust_brightness_image_pil, |
| 80 | + brightness_factor=float( |
| 81 | + torch.distributions.Uniform(self.brightness[0], self.brightness[1]).sample() |
| 82 | + ), |
| 83 | + ) |
| 84 | + ) |
| 85 | + if self.contrast is not None: |
| 86 | + image_transforms.append( |
| 87 | + functools.partial( |
| 88 | + self._image_transform, |
| 89 | + kernel_tensor=F.adjust_contrast_image_tensor, |
| 90 | + kernel_pil=F.adjust_contrast_image_pil, |
| 91 | + contrast_factor=float(torch.distributions.Uniform(self.contrast[0], self.contrast[1]).sample()), |
| 92 | + ) |
| 93 | + ) |
| 94 | + if self.saturation is not None: |
| 95 | + image_transforms.append( |
| 96 | + functools.partial( |
| 97 | + self._image_transform, |
| 98 | + kernel_tensor=F.adjust_saturation_image_tensor, |
| 99 | + kernel_pil=F.adjust_saturation_image_pil, |
| 100 | + saturation_factor=float( |
| 101 | + torch.distributions.Uniform(self.saturation[0], self.saturation[1]).sample() |
| 102 | + ), |
| 103 | + ) |
| 104 | + ) |
| 105 | + if self.hue is not None: |
| 106 | + image_transforms.append( |
| 107 | + functools.partial( |
| 108 | + self._image_transform, |
| 109 | + kernel_tensor=F.adjust_hue_image_tensor, |
| 110 | + kernel_pil=F.adjust_hue_image_pil, |
| 111 | + hue_factor=float(torch.distributions.Uniform(self.hue[0], self.hue[1]).sample()), |
| 112 | + ) |
| 113 | + ) |
| 114 | + |
| 115 | + return dict(image_transforms=[image_transforms[idx] for idx in torch.randperm(len(image_transforms))]) |
| 116 | + |
| 117 | + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: |
| 118 | + if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)): |
| 119 | + return input |
| 120 | + |
| 121 | + for transform in params["image_transforms"]: |
| 122 | + input = transform(input) |
| 123 | + |
| 124 | + return input |
0 commit comments