Skip to content

Commit 18553f8

Browse files
committed
add scriptable transforms: rgb_to_grayscale
1 parent a9ab4a3 commit 18553f8

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

torchvision/transforms/functional_tensor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,25 @@ def crop(img, top, left, height, width):
4949

5050
return img[..., top:top + height, left:left + width]
5151

52-
def to_grayscale(img, num_output_channels = 3):
52+
53+
def rgb_to_grayscale(img, num_output_channels=3):
5354
"""Convert the given RGB Image Tensor to Grayscale.
5455
5556
Args
5657
img (Tensor): Image to be converted to Grayscale in the form [C, H, W].
5758
num_output_channels (int): denotes the number of channels to return after conversion
5859
Returns:
5960
Tensor: Grayscale image.
61+
62+
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which is
63+
L = R * 0.2989 + G * 0.5870 + B * 0.1140
6064
"""
6165
if not F._is_tensor_image(img):
6266
raise TypeError('tensor is not a torch image.')
6367

6468
if img.size()[0] != 3:
6569
raise TypeError('Input Image does not contain 3 Channels')
6670

71+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72+
img = img.to(device)
6773
return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)

0 commit comments

Comments
 (0)