diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index abdf4f131f1..29650b5bc5f 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -217,11 +217,9 @@ def test_pad(self): with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): F_t.pad(tensor, (-2, -3), padding_mode="symmetric") - def _test_adjust_fn(self, fn, fn_pil, fn_t, configs): + def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"): script_fn = torch.jit.script(fn) - torch.manual_seed(15) - tensor, pil_img = self._create_data(26, 34, device=self.device) for dt in [None, torch.float32, torch.float64]: @@ -230,7 +228,6 @@ def _test_adjust_fn(self, fn, fn_pil, fn_t, configs): tensor = F.convert_image_dtype(tensor, dt) for config in configs: - adjusted_tensor = fn_t(tensor, **config) adjusted_pil = fn_pil(pil_img, **config) scripted_result = script_fn(tensor, **config) @@ -245,9 +242,12 @@ def _test_adjust_fn(self, fn, fn_pil, fn_t, configs): # Check that max difference does not exceed 2 in [0, 255] range # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results - tol = 2.0 + 1e-10 - self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol, msg=msg, agg_method="max") - self.assertTrue(adjusted_tensor.allclose(scripted_result), msg=msg) + self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol=tol, msg=msg, agg_method=agg_method) + + atol = 1e-6 + if adjusted_tensor.dtype == torch.uint8 and "cuda" in torch.device(self.device).type: + atol = 1.0 + self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg) def test_adjust_brightness(self): self._test_adjust_fn( @@ -273,6 +273,16 @@ def test_adjust_saturation(self): [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]] ) + def test_adjust_hue(self): + self._test_adjust_fn( + F.adjust_hue, + F_pil.adjust_hue, + F_t.adjust_hue, + [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]], + tol=0.1, + agg_method="mean" + ) + def test_adjust_gamma(self): self._test_adjust_fn( F.adjust_gamma, diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index f92ebfdcd9d..d82c4f6b309 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -60,24 +60,36 @@ def test_random_vertical_flip(self): def test_color_jitter(self): tol = 1.0 + 1e-10 - for f in [0.1, 0.5, 1.0, 1.34]: + for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]: meth_kwargs = {"brightness": f} self._test_class_op( "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" ) - for f in [0.2, 0.5, 1.0, 1.5]: + for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]: meth_kwargs = {"contrast": f} self._test_class_op( "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" ) - for f in [0.5, 0.75, 1.0, 1.25]: + for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]: meth_kwargs = {"saturation": f} self._test_class_op( "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" ) + for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]: + meth_kwargs = {"hue": f} + self._test_class_op( + "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean" + ) + + # All 4 parameters together + meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2} + self._test_class_op( + "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean" + ) + def test_pad(self): # Test functional.pad (PIL and Tensor) with padding as single int diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 4a36e0b05e6..0f884c9edf0 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -736,7 +736,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: .. _Hue: https://en.wikipedia.org/wiki/Hue Args: - img (PIL Image): PIL Image to be adjusted. + img (PIL Image or Tensor): Image to be adjusted. hue_factor (float): How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative direction respectively. @@ -744,12 +744,12 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: with complementary colors while 0 gives the original image. Returns: - PIL Image: Hue adjusted image. + PIL Image or Tensor: Hue adjusted image. """ if not isinstance(img, torch.Tensor): return F_pil.adjust_hue(img, hue_factor) - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return F_t.adjust_hue(img, hue_factor) def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 02da5fec206..73aa020b637 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -157,7 +157,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: return _blend(img, mean, contrast_factor) -def adjust_hue(img, hue_factor): +def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: """Adjust hue of an image. The image hue is adjusted by converting the image to HSV and @@ -185,8 +185,8 @@ def adjust_hue(img, hue_factor): if not (-0.5 <= hue_factor <= 0.5): raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) - if not _is_tensor_a_torch_image(img): - raise TypeError('tensor is not a torch image.') + if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): + raise TypeError('img should be Tensor image. Got {}'.format(type(img))) orig_dtype = img.dtype if img.dtype == torch.uint8: @@ -194,8 +194,7 @@ def adjust_hue(img, hue_factor): img = _rgb2hsv(img) h, s, v = img.unbind(0) - h += hue_factor - h = h % 1.0 + h = (h + hue_factor) % 1.0 img = torch.stack((h, s, v)) img_hue_adj = _hsv2rgb(img) @@ -408,6 +407,8 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: def _rgb2hsv(img): r, g, b = img.unbind(0) + # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ + # src/libImaging/Convert.c#L330 maxc = torch.max(img, dim=0).values minc = torch.min(img, dim=0).values