@@ -49,24 +49,28 @@ def crop(img, top, left, height, width):
49
49
50
50
return img [..., top :top + height , left :left + width ]
51
51
52
- def to_grayscale (img , num_output_channels = 3 ):
52
+
53
+ def rgb_to_grayscale (img , num_output_channels = 3 ):
53
54
"""Convert the given RGB Image Tensor to Grayscale.
54
55
55
56
Args
56
57
img (Tensor): Image to be converted to Grayscale in the form [C, H, W].
57
58
num_output_channels (int): denotes the number of channels to return after conversion
58
59
Returns:
59
60
Tensor: Grayscale image.
61
+
62
+ For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which is
63
+ L = R * 0.2989 + G * 0.5870 + B * 0.1140
60
64
"""
61
65
if not F ._is_tensor_image (img ):
62
66
raise TypeError ('tensor is not a torch image.' )
63
-
67
+ if img .size ()[0 ] != 3 :
68
+ raise TypeError ('Input Image doesn\' t have 3 Channels' )
69
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
70
+ img = img .to (device )
71
+ weights = torch .tensor ([0.2989 , 0.5870 , 0.1140 ])
72
+ res = torch .tensordot (img , weights [:, None , None ], [[0 ], [0 ]]).squeeze ()
73
+ if num_output_channels == 1 :
74
+ return res
64
75
else :
65
- hwc_img = img .transpose (1 , 2 ).transpose (0 , 2 )
66
- weights = torch .tensor ([0.2989 , 0.5870 , 0.1140 ])
67
- res = hwc_img [:,:,:3 ] * weights [None , :]
68
- gray_img = res [:, :, 0 ] + res [:, :, 1 ] + res [:, :, 2 ]
69
- if num_output_channels == 1 :
70
- return gray_img .int ()
71
- else :
72
- return torch .cat ((gray_img .int (), gray_img .int (), gray_img .int ())).transpose (1 , 2 )
76
+ return res .repeat (3 , 1 , 1 )
0 commit comments