From c3055815fb52a0c473357294c36a67c845ceabba Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Fri, 25 Oct 2019 14:03:19 +0100 Subject: [PATCH 1/5] Add adjustment operations for RGB Tensor Images. Right now, we have operations on PIL images, but we want to have a version of the opeartions that act directly on Tensor images. Here, we add such operations for adjust_brightness, adjust_contrast and adjust_saturation. In PIL, those functions are implemented by generating an degenerate image from the first, and then interpolating them together. - https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageEnhance.py - https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Blend.c A few caveats: * Since PIL operates on uint8, and the tensor operations might be on float, we can get slightly different values because of int truncation. * We assume here the images are RGB; in particular, to handle an alpha channel, we need to check whether it is present, in which case we copy it to the final image. --- test/test_functional_tensor.py | 23 +++++++ torchvision/transforms/functional_tensor.py | 71 +++++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 903798d19e3..39497c20a9a 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -36,6 +36,29 @@ def test_crop(self): self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)), "functional_tensor crop not working") + def test_adjustments(self): + fns = ((F.adjust_brightness, F_t.adjust_brightness), + (F.adjust_contrast, F_t.adjust_contrast), + (F.adjust_saturation, F_t.adjust_saturation)) + + for _ in range(20): + channels = 3 + dims = torch.randint(1, 50, (2,)) + img = torch.rand(channels, *dims, dtype=torch.float) + factor = 3 * torch.rand(1) + for f, ft in fns: + + ft_img = ft(img, factor) + + img_pil = transforms.ToPILImage()(img) + f_img_pil = f(img_pil, factor) + f_img = transforms.ToTensor()(f_img_pil) + + # F uses uint8 and F_t uses float, so there is a small + # difference in values caused by truncations. + l1_diff = (ft_img - f_img).norm(p=1) + self.assertLess(l1_diff, 0.01 * img.nelement()) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 5027958164c..a0be93011ce 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -48,3 +48,74 @@ def crop(img, top, left, height, width): raise TypeError('tensor is not a torch image.') return img[..., top:top + height, left:left + width] + + +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an RGB image. + + Args: + img (Tensor): Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + Tensor: Brightness adjusted image. + """ + if not F._is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + + black_img = torch.zeros(img.shape, device=img.device) + + return _blend(img, black_img, brightness_factor) + + +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an RGB image. + + Args: + img (Tensor): Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + Tensor: Contrast adjusted image. + """ + if not F._is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + + mean = torch.mean(_rgb_to_grayscale(img)) + mean_img = mean * torch.ones(img.shape, device=img.device) + + return _blend(img, mean_img, contrast_factor) + + +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an RGB image. + + Args: + img (Tensor): Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + Tensor: Saturation adjusted image. + """ + if not F._is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + + num_channels = img.shape[0] + gray_img = _rgb_to_grayscale(img).repeat(num_channels, 1, 1) + + return _blend(img, gray_img, saturation_factor) + + +def _blend(img1, img2, ratio): + return (ratio * img1 + (1 - ratio) * img2).clamp(0, 1) + + +def _rgb_to_grayscale(img): + # ITU-R 601-2 luma transform, as used in PIL. + return 0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2] From eeedaf9d5a7aad02dfd6249e354e39c02582a3b2 Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Fri, 25 Oct 2019 17:24:38 +0100 Subject: [PATCH 2/5] Keep dtype and use broadcast in adjust operations - We make our operations have input.dtype == output.dtype, at the cost of adding a few type checks and branches. - By using Tensor broadcast, we can simplify the calls to _blend. --- test/test_functional_tensor.py | 16 ++++++++++---- torchvision/transforms/functional_tensor.py | 23 ++++++++++----------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 39497c20a9a..3dcceb95e7d 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -44,20 +44,28 @@ def test_adjustments(self): for _ in range(20): channels = 3 dims = torch.randint(1, 50, (2,)) - img = torch.rand(channels, *dims, dtype=torch.float) + shape = (channels, *dims) + + if torch.randint(0, 2, (1,)) == 0: + img = torch.rand(*shape, dtype=torch.float) + else: + img = torch.randint(0, 256, shape, dtype=torch.uint8) + factor = 3 * torch.rand(1) for f, ft in fns: ft_img = ft(img, factor) + if img.dtype == torch.uint8: + ft_img = ft_img.to(torch.float) / 255 img_pil = transforms.ToPILImage()(img) f_img_pil = f(img_pil, factor) f_img = transforms.ToTensor()(f_img_pil) # F uses uint8 and F_t uses float, so there is a small - # difference in values caused by truncations. - l1_diff = (ft_img - f_img).norm(p=1) - self.assertLess(l1_diff, 0.01 * img.nelement()) + # difference in values caused by (at most 5) truncations. + max_diff = (ft_img - f_img).abs().max() + self.assertLess(max_diff, 5 / 255 + 1e-5) if __name__ == '__main__': diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index a0be93011ce..5ea80069476 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -65,9 +65,7 @@ def adjust_brightness(img, brightness_factor): if not F._is_tensor_image(img): raise TypeError('tensor is not a torch image.') - black_img = torch.zeros(img.shape, device=img.device) - - return _blend(img, black_img, brightness_factor) + return _blend(img, 0, brightness_factor) def adjust_contrast(img, contrast_factor): @@ -85,10 +83,9 @@ def adjust_contrast(img, contrast_factor): if not F._is_tensor_image(img): raise TypeError('tensor is not a torch image.') - mean = torch.mean(_rgb_to_grayscale(img)) - mean_img = mean * torch.ones(img.shape, device=img.device) + mean = torch.mean(_rgb_to_grayscale(img).to(torch.float)) - return _blend(img, mean_img, contrast_factor) + return _blend(img, mean, contrast_factor) def adjust_saturation(img, saturation_factor): @@ -106,16 +103,18 @@ def adjust_saturation(img, saturation_factor): if not F._is_tensor_image(img): raise TypeError('tensor is not a torch image.') - num_channels = img.shape[0] - gray_img = _rgb_to_grayscale(img).repeat(num_channels, 1, 1) - - return _blend(img, gray_img, saturation_factor) + return _blend(img, _rgb_to_grayscale(img), saturation_factor) def _blend(img1, img2, ratio): - return (ratio * img1 + (1 - ratio) * img2).clamp(0, 1) + bound = 255 if _is_int(img1.dtype) else 1 + return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) def _rgb_to_grayscale(img): # ITU-R 601-2 luma transform, as used in PIL. - return 0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2] + return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype) + + +def _is_int(dtype): + return dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) From eae132f6288d0481b779a64c893aa1eee9ed5d1f Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Fri, 25 Oct 2019 17:42:08 +0100 Subject: [PATCH 3/5] Use is_floating_point to check dtype. --- test/test_functional_tensor.py | 2 +- torchvision/transforms/functional_tensor.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 3dcceb95e7d..82d53d9ff68 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -55,7 +55,7 @@ def test_adjustments(self): for f, ft in fns: ft_img = ft(img, factor) - if img.dtype == torch.uint8: + if not img.dtype.is_floating_point: ft_img = ft_img.to(torch.float) / 255 img_pil = transforms.ToPILImage()(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 5ea80069476..7ef83c1086b 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -107,14 +107,10 @@ def adjust_saturation(img, saturation_factor): def _blend(img1, img2, ratio): - bound = 255 if _is_int(img1.dtype) else 1 + bound = 1 if img1.dtype.is_floating_point else 255 return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) def _rgb_to_grayscale(img): # ITU-R 601-2 luma transform, as used in PIL. return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype) - - -def _is_int(dtype): - return dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) From 879f3d83650f725cef377e3afaeffd70d1ea9053 Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Fri, 25 Oct 2019 17:57:30 +0100 Subject: [PATCH 4/5] Remove unpacking in tuple It seems Python 2 does not support this type of unpacking, so it broke Python 2 builds. This should fix it. --- test/test_functional_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 82d53d9ff68..32826b2c9d0 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -44,7 +44,7 @@ def test_adjustments(self): for _ in range(20): channels = 3 dims = torch.randint(1, 50, (2,)) - shape = (channels, *dims) + shape = (channels, dims[0], dims[1]) if torch.randint(0, 2, (1,)) == 0: img = torch.rand(*shape, dtype=torch.float) From 6eb0b60005eb569c3e05174a5b5be5fbcf3467bb Mon Sep 17 00:00:00 2001 From: Pedro Freire Date: Fri, 25 Oct 2019 20:11:23 +0100 Subject: [PATCH 5/5] Add from __future__ import division for Python 2 --- test/test_functional_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 32826b2c9d0..e5e547c20d7 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1,3 +1,4 @@ +from __future__ import division import torch import torchvision.transforms as transforms import torchvision.transforms.functional_tensor as F_t