Skip to content

Commit b03d237

Browse files
committed
add scriptable transforms: rgb_to_grayscale
1 parent b19c4c0 commit b03d237

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

torchvision/transforms/functional_tensor.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,28 @@ def crop(img, top, left, height, width):
4949

5050
return img[..., top:top + height, left:left + width]
5151

52-
def to_grayscale(img, num_output_channels = 3):
52+
53+
def rgb_to_grayscale(img, num_output_channels=3):
5354
"""Convert the given RGB Image Tensor to Grayscale.
5455
5556
Args
5657
img (Tensor): Image to be converted to Grayscale in the form [C, H, W].
5758
num_output_channels (int): denotes the number of channels to return after conversion
5859
Returns:
5960
Tensor: Grayscale image.
61+
62+
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which is
63+
L = R * 0.2989 + G * 0.5870 + B * 0.1140
6064
"""
6165
if not F._is_tensor_image(img):
6266
raise TypeError('tensor is not a torch image.')
63-
67+
if img.size()[0] != 3:
68+
raise TypeError('Input Image doesn\'t have 3 Channels')
69+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70+
img = img.to(device)
71+
weights = torch.tensor([0.2989, 0.5870, 0.1140])
72+
res = torch.tensordot(img, weights[:, None, None], [[0], [0]]).squeeze()
73+
if num_output_channels == 1:
74+
return res
6475
else:
65-
hwc_img = img.transpose(1, 2).transpose(0, 2)
66-
weights = torch.tensor([0.2989, 0.5870, 0.1140])
67-
res = hwc_img[:,:,:3] * weights[None, :]
68-
gray_img = res[:, :, 0] + res[:, :, 1] + res[:, :, 2]
69-
if num_output_channels == 1:
70-
return gray_img.int()
71-
else:
72-
return torch.cat((gray_img.int(), gray_img.int(), gray_img.int())).transpose(1, 2)
76+
return res.repeat(3, 1, 1)

0 commit comments

Comments
 (0)