diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index e5e547c20d7..e318420102b 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -3,6 +3,7 @@ import torchvision.transforms as transforms import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional as F +import numpy as np import unittest import random @@ -68,6 +69,13 @@ def test_adjustments(self): max_diff = (ft_img - f_img).abs().max() self.assertLess(max_diff, 5 / 255 + 1e-5) + def test_rgb_to_grayscale(self): + img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) + grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int) + grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int) + max_diff = (grayscale_tensor - grayscale_pil_img).abs().max() + self.assertLess(max_diff, 1.0001) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 7ef83c1086b..c741ab2e7e8 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -26,7 +26,6 @@ def hflip(img_tensor): Returns: Tensor: Horizontally flipped image Tensor. """ - if not F._is_tensor_image(img_tensor): raise TypeError('tensor is not a torch image.') @@ -35,12 +34,14 @@ def hflip(img_tensor): def crop(img, top, left, height, width): """Crop the given Image Tensor. + Args: img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image. top (int): Vertical component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box. height (int): Height of the crop box. width (int): Width of the crop box. + Returns: Tensor: Cropped image. """ @@ -50,6 +51,24 @@ def crop(img, top, left, height, width): return img[..., top:top + height, left:left + width] +def rgb_to_grayscale(img): + """Convert the given RGB Image Tensor to Grayscale. + For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which + is L = R * 0.2989 + G * 0.5870 + B * 0.1140 + + Args: + img (Tensor): Image to be converted to Grayscale in the form [C, H, W]. + + Returns: + Tensor: Grayscale image. + + """ + if img.shape[0] != 3: + raise TypeError('Input Image does not contain 3 Channels') + + return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype) + + def adjust_brightness(img, brightness_factor): """Adjust brightness of an RGB image. @@ -83,7 +102,7 @@ def adjust_contrast(img, contrast_factor): if not F._is_tensor_image(img): raise TypeError('tensor is not a torch image.') - mean = torch.mean(_rgb_to_grayscale(img).to(torch.float)) + mean = torch.mean(rgb_to_grayscale(img).to(torch.float)) return _blend(img, mean, contrast_factor) @@ -103,14 +122,9 @@ def adjust_saturation(img, saturation_factor): if not F._is_tensor_image(img): raise TypeError('tensor is not a torch image.') - return _blend(img, _rgb_to_grayscale(img), saturation_factor) + return _blend(img, rgb_to_grayscale(img), saturation_factor) def _blend(img1, img2, ratio): bound = 1 if img1.dtype.is_floating_point else 255 return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) - - -def _rgb_to_grayscale(img): - # ITU-R 601-2 luma transform, as used in PIL. - return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)