Skip to content

Commit bb88c45

Browse files
adjust_hue now supports inputs of type Tensor (#2566)
* adjust_hue now supports inputs of type Tensor * Added comparison between original adjust_hue and its Tensor and torch.jit.script versions. * Added a few type checkings related to adjust_hue in functional_tensor.py in hopes to make F_t.adjust_hue scriptable...but to no avail. * Changed implementation of _rgb2hsv and removed useless type declaration according to PR's review. * Handled the range of hue_factor in the assertions and temporarily increased the assertLess bound to make sure that no other test fails. * Fixed some lint issues with CircleCI and added type hints in functional_pil.py as well. * Corrected type hint mistakes. * Followed PR review recommendations and added test for class interface with hue. * Refactored test_functional_tensor.py to match vfdev-5's d016cab branch by simple copy/paste and added the test_adjust_hue and ColorJitter class interface test in the same style (class interface test was removed in vfdev-5's branch for some reason). * Removed test_adjustments from test_transforms_tensor.py and moved the ColorJitter class interface test in test_transforms_tensor.py. * Added cuda test cases for test_adjustments and tried to fix conflict. * Updated tests - adjust hue - color jitter * Fixes incompatible devices * Increased tol for cuda tests * Fixes potential issue with inplace op - fixes irreproducible failing test on Travis CI * Reverted fmod -> % Co-authored-by: vfdev-5 <[email protected]>
1 parent ac3ba94 commit bb88c45

File tree

4 files changed

+41
-18
lines changed

4 files changed

+41
-18
lines changed

test/test_functional_tensor.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,9 @@ def test_pad(self):
217217
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
218218
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
219219

220-
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs):
220+
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"):
221221
script_fn = torch.jit.script(fn)
222-
223222
torch.manual_seed(15)
224-
225223
tensor, pil_img = self._create_data(26, 34, device=self.device)
226224

227225
for dt in [None, torch.float32, torch.float64]:
@@ -230,7 +228,6 @@ def _test_adjust_fn(self, fn, fn_pil, fn_t, configs):
230228
tensor = F.convert_image_dtype(tensor, dt)
231229

232230
for config in configs:
233-
234231
adjusted_tensor = fn_t(tensor, **config)
235232
adjusted_pil = fn_pil(pil_img, **config)
236233
scripted_result = script_fn(tensor, **config)
@@ -245,9 +242,12 @@ def _test_adjust_fn(self, fn, fn_pil, fn_t, configs):
245242

246243
# Check that max difference does not exceed 2 in [0, 255] range
247244
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
248-
tol = 2.0 + 1e-10
249-
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol, msg=msg, agg_method="max")
250-
self.assertTrue(adjusted_tensor.allclose(scripted_result), msg=msg)
245+
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol=tol, msg=msg, agg_method=agg_method)
246+
247+
atol = 1e-6
248+
if adjusted_tensor.dtype == torch.uint8 and "cuda" in torch.device(self.device).type:
249+
atol = 1.0
250+
self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)
251251

252252
def test_adjust_brightness(self):
253253
self._test_adjust_fn(
@@ -273,6 +273,16 @@ def test_adjust_saturation(self):
273273
[{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
274274
)
275275

276+
def test_adjust_hue(self):
277+
self._test_adjust_fn(
278+
F.adjust_hue,
279+
F_pil.adjust_hue,
280+
F_t.adjust_hue,
281+
[{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]],
282+
tol=0.1,
283+
agg_method="mean"
284+
)
285+
276286
def test_adjust_gamma(self):
277287
self._test_adjust_fn(
278288
F.adjust_gamma,

test/test_transforms_tensor.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,36 @@ def test_random_vertical_flip(self):
6060
def test_color_jitter(self):
6161

6262
tol = 1.0 + 1e-10
63-
for f in [0.1, 0.5, 1.0, 1.34]:
63+
for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
6464
meth_kwargs = {"brightness": f}
6565
self._test_class_op(
6666
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
6767
)
6868

69-
for f in [0.2, 0.5, 1.0, 1.5]:
69+
for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
7070
meth_kwargs = {"contrast": f}
7171
self._test_class_op(
7272
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
7373
)
7474

75-
for f in [0.5, 0.75, 1.0, 1.25]:
75+
for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
7676
meth_kwargs = {"saturation": f}
7777
self._test_class_op(
7878
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
7979
)
8080

81+
for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
82+
meth_kwargs = {"hue": f}
83+
self._test_class_op(
84+
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
85+
)
86+
87+
# All 4 parameters together
88+
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
89+
self._test_class_op(
90+
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
91+
)
92+
8193
def test_pad(self):
8294

8395
# Test functional.pad (PIL and Tensor) with padding as single int

torchvision/transforms/functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -736,20 +736,20 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
736736
.. _Hue: https://en.wikipedia.org/wiki/Hue
737737
738738
Args:
739-
img (PIL Image): PIL Image to be adjusted.
739+
img (PIL Image or Tensor): Image to be adjusted.
740740
hue_factor (float): How much to shift the hue channel. Should be in
741741
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
742742
HSV space in positive and negative direction respectively.
743743
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
744744
with complementary colors while 0 gives the original image.
745745
746746
Returns:
747-
PIL Image: Hue adjusted image.
747+
PIL Image or Tensor: Hue adjusted image.
748748
"""
749749
if not isinstance(img, torch.Tensor):
750750
return F_pil.adjust_hue(img, hue_factor)
751751

752-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
752+
return F_t.adjust_hue(img, hue_factor)
753753

754754

755755
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:

torchvision/transforms/functional_tensor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
157157
return _blend(img, mean, contrast_factor)
158158

159159

160-
def adjust_hue(img, hue_factor):
160+
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
161161
"""Adjust hue of an image.
162162
163163
The image hue is adjusted by converting the image to HSV and
@@ -185,17 +185,16 @@ def adjust_hue(img, hue_factor):
185185
if not (-0.5 <= hue_factor <= 0.5):
186186
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
187187

188-
if not _is_tensor_a_torch_image(img):
189-
raise TypeError('tensor is not a torch image.')
188+
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
189+
raise TypeError('img should be Tensor image. Got {}'.format(type(img)))
190190

191191
orig_dtype = img.dtype
192192
if img.dtype == torch.uint8:
193193
img = img.to(dtype=torch.float32) / 255.0
194194

195195
img = _rgb2hsv(img)
196196
h, s, v = img.unbind(0)
197-
h += hue_factor
198-
h = h % 1.0
197+
h = (h + hue_factor) % 1.0
199198
img = torch.stack((h, s, v))
200199
img_hue_adj = _hsv2rgb(img)
201200

@@ -408,6 +407,8 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
408407
def _rgb2hsv(img):
409408
r, g, b = img.unbind(0)
410409

410+
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
411+
# src/libImaging/Convert.c#L330
411412
maxc = torch.max(img, dim=0).values
412413
minc = torch.min(img, dim=0).values
413414

0 commit comments

Comments
 (0)