Skip to content

[WIP] Add Scriptable Transform: Grayscale #1505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 5, 2019
8 changes: 8 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional as F
import numpy as np
import unittest
import random

Expand Down Expand Up @@ -68,6 +69,13 @@ def test_adjustments(self):
max_diff = (ft_img - f_img).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)

def test_rgb_to_grayscale(self):
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001)


if __name__ == '__main__':
unittest.main()
30 changes: 22 additions & 8 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def hflip(img_tensor):
Returns:
Tensor: Horizontally flipped image Tensor.
"""

if not F._is_tensor_image(img_tensor):
raise TypeError('tensor is not a torch image.')

Expand All @@ -35,12 +34,14 @@ def hflip(img_tensor):

def crop(img, top, left, height, width):
"""Crop the given Image Tensor.

Args:
img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.

Returns:
Tensor: Cropped image.
"""
Expand All @@ -50,6 +51,24 @@ def crop(img, top, left, height, width):
return img[..., top:top + height, left:left + width]


def rgb_to_grayscale(img):
"""Convert the given RGB Image Tensor to Grayscale.
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
is L = R * 0.2989 + G * 0.5870 + B * 0.1140

Args:
img (Tensor): Image to be converted to Grayscale in the form [C, H, W].

Returns:
Tensor: Grayscale image.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you format the output as the other examples?

Args:
    img (Tensor): ...

Returns:
    Tensor: Grayscale image


"""
if img.shape[0] != 3:
raise TypeError('Input Image does not contain 3 Channels')

return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)


def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an RGB image.

Expand Down Expand Up @@ -83,7 +102,7 @@ def adjust_contrast(img, contrast_factor):
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')

mean = torch.mean(_rgb_to_grayscale(img).to(torch.float))
mean = torch.mean(rgb_to_grayscale(img).to(torch.float))

return _blend(img, mean, contrast_factor)

Expand All @@ -103,14 +122,9 @@ def adjust_saturation(img, saturation_factor):
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')

return _blend(img, _rgb_to_grayscale(img), saturation_factor)
return _blend(img, rgb_to_grayscale(img), saturation_factor)


def _blend(img1, img2, ratio):
bound = 1 if img1.dtype.is_floating_point else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)


def _rgb_to_grayscale(img):
# ITU-R 601-2 luma transform, as used in PIL.
return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)