Skip to content

Commit b29c553

Browse files
authored
port ColorJitter to prototype transforms (#5656)
* port ColorJitter to prototype transforms * make color module private * address review comments
1 parent 11bd2ea commit b29c553

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from ._augment import RandomErasing, RandomMixup, RandomCutmix
66
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
7+
from ._color import ColorJitter
78
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
89
from ._geometry import (
910
Resize,
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import collections.abc
2+
import functools
3+
from typing import Any, Dict, Union, Tuple, Optional, Sequence, Callable, TypeVar
4+
5+
import PIL.Image
6+
import torch
7+
from torchvision.prototype import features
8+
from torchvision.prototype.transforms import Transform, functional as F
9+
10+
from ._utils import is_simple_tensor
11+
12+
T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
13+
14+
15+
class ColorJitter(Transform):
16+
def __init__(
17+
self,
18+
brightness: Optional[Union[float, Sequence[float]]] = None,
19+
contrast: Optional[Union[float, Sequence[float]]] = None,
20+
saturation: Optional[Union[float, Sequence[float]]] = None,
21+
hue: Optional[Union[float, Sequence[float]]] = None,
22+
) -> None:
23+
super().__init__()
24+
self.brightness = self._check_input(brightness, "brightness")
25+
self.contrast = self._check_input(contrast, "contrast")
26+
self.saturation = self._check_input(saturation, "saturation")
27+
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
28+
29+
def _check_input(
30+
self,
31+
value: Optional[Union[float, Sequence[float]]],
32+
name: str,
33+
center: float = 1.0,
34+
bound: Tuple[float, float] = (0, float("inf")),
35+
clip_first_on_zero: bool = True,
36+
) -> Optional[Tuple[float, float]]:
37+
if value is None:
38+
return None
39+
40+
if isinstance(value, float):
41+
if value < 0:
42+
raise ValueError(f"If {name} is a single number, it must be non negative.")
43+
value = [center - value, center + value]
44+
if clip_first_on_zero:
45+
value[0] = max(value[0], 0.0)
46+
elif isinstance(value, collections.abc.Sequence) and len(value) == 2:
47+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
48+
raise ValueError(f"{name} values should be between {bound}")
49+
else:
50+
raise TypeError(f"{name} should be a single number or a sequence with length 2.")
51+
52+
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
53+
54+
def _image_transform(
55+
self,
56+
input: T,
57+
*,
58+
kernel_tensor: Callable[..., torch.Tensor],
59+
kernel_pil: Callable[..., PIL.Image.Image],
60+
**kwargs: Any,
61+
) -> T:
62+
if isinstance(input, features.Image):
63+
output = kernel_tensor(input, **kwargs)
64+
return features.Image.new_like(input, output)
65+
elif is_simple_tensor(input):
66+
return kernel_tensor(input, **kwargs)
67+
elif isinstance(input, PIL.Image.Image):
68+
return kernel_pil(input, **kwargs) # type: ignore[no-any-return]
69+
else:
70+
raise RuntimeError
71+
72+
def _get_params(self, sample: Any) -> Dict[str, Any]:
73+
image_transforms = []
74+
if self.brightness is not None:
75+
image_transforms.append(
76+
functools.partial(
77+
self._image_transform,
78+
kernel_tensor=F.adjust_brightness_image_tensor,
79+
kernel_pil=F.adjust_brightness_image_pil,
80+
brightness_factor=float(
81+
torch.distributions.Uniform(self.brightness[0], self.brightness[1]).sample()
82+
),
83+
)
84+
)
85+
if self.contrast is not None:
86+
image_transforms.append(
87+
functools.partial(
88+
self._image_transform,
89+
kernel_tensor=F.adjust_contrast_image_tensor,
90+
kernel_pil=F.adjust_contrast_image_pil,
91+
contrast_factor=float(torch.distributions.Uniform(self.contrast[0], self.contrast[1]).sample()),
92+
)
93+
)
94+
if self.saturation is not None:
95+
image_transforms.append(
96+
functools.partial(
97+
self._image_transform,
98+
kernel_tensor=F.adjust_saturation_image_tensor,
99+
kernel_pil=F.adjust_saturation_image_pil,
100+
saturation_factor=float(
101+
torch.distributions.Uniform(self.saturation[0], self.saturation[1]).sample()
102+
),
103+
)
104+
)
105+
if self.hue is not None:
106+
image_transforms.append(
107+
functools.partial(
108+
self._image_transform,
109+
kernel_tensor=F.adjust_hue_image_tensor,
110+
kernel_pil=F.adjust_hue_image_pil,
111+
hue_factor=float(torch.distributions.Uniform(self.hue[0], self.hue[1]).sample()),
112+
)
113+
)
114+
115+
return dict(image_transforms=[image_transforms[idx] for idx in torch.randperm(len(image_transforms))])
116+
117+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
118+
if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)):
119+
return input
120+
121+
for transform in params["image_transforms"]:
122+
input = transform(input)
123+
124+
return input

0 commit comments

Comments
 (0)