Skip to content

Commit 50da7be

Browse files
committed
Unified inputs for grayscale op and transforms
- deprecated F.to_grayscale in favor of F.rgb_to_grayscale
1 parent 05c3a0a commit 50da7be

File tree

6 files changed

+105
-63
lines changed

6 files changed

+105
-63
lines changed

test/test_functional_tensor.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import unittest
2-
import random
32
import colorsys
43
import math
54

@@ -23,7 +22,10 @@ def _create_data(self, height=3, width=3, channels=3):
2322
return tensor, pil_img
2423

2524
def compareTensorToPIL(self, tensor, pil_image, msg=None):
26-
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
25+
np_pil_image = np.array(pil_image)
26+
if np_pil_image.ndim == 2:
27+
np_pil_image = np_pil_image[:, :, None]
28+
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
2729
if msg is None:
2830
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
2931
self.assertTrue(tensor.equal(pil_tensor), msg)
@@ -187,17 +189,21 @@ def test_adjustments(self):
187189
scripted_fn(img)
188190

189191
def test_rgb_to_grayscale(self):
190-
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
191-
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
192-
img_tensor_clone = img_tensor.clone()
193-
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
194-
grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
195-
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
196-
self.assertLess(max_diff, 1.0001)
197-
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
198-
# scriptable function test
199-
grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
200-
self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
192+
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
193+
194+
img_tensor, pil_img = self._create_data(32, 34)
195+
196+
for num_output_channels in (3, 1):
197+
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
198+
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
199+
200+
if num_output_channels == 1:
201+
print(gray_tensor.shape)
202+
203+
self.compareTensorToPIL(gray_tensor, gray_pil_image)
204+
205+
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
206+
self.assertTrue(s_gray_tensor.equal(gray_tensor))
201207

202208
def test_center_crop(self):
203209
script_center_crop = torch.jit.script(F.center_crop)

test/test_transforms_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,10 @@ def test_random_perspective(self):
324324
def test_to_grayscale(self):
325325

326326
fn_kwargs = meth_kwargs = {"num_output_channels": 1}
327-
self._test_op("to_grayscale", "Grayscale", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
327+
self._test_op("rgb_to_grayscale", "Grayscale", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
328328

329329
fn_kwargs = meth_kwargs = {"num_output_channels": 3}
330-
self._test_op("to_grayscale", "Grayscale", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
330+
self._test_op("rgb_to_grayscale", "Grayscale", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
331331

332332
meth_kwargs = {}
333333
self._test_class_op("RandomGrayscale", meth_kwargs=meth_kwargs)

torchvision/transforms/functional.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -959,12 +959,39 @@ def affine(
959959

960960

961961
def to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
962-
"""Convert image to grayscale version of image.
962+
"""DEPRECATED. Convert RGB image to grayscale version of image.
963+
The image can be a PIL Image or a Tensor, in which case it is expected
964+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
965+
966+
.. warning::
967+
968+
This method is deprecated and will be removed in future releases.
969+
Please, use ``F.rgb_to_grayscale`` instead.
970+
971+
972+
Args:
973+
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
974+
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
975+
976+
Returns:
977+
PIL Image or Tensor: Grayscale version of the image.
978+
if num_output_channels = 1 : returned image is single channel
979+
980+
if num_output_channels = 3 : returned image is 3 channel with r = g = b
981+
"""
982+
warnings.warn("The use of the F.to_grayscale transform is deprecated, " +
983+
"please use F.rgb_to_grayscale instead.")
984+
985+
return rgb_to_grayscale(img, num_output_channels)
986+
987+
988+
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
989+
"""Convert RGB image to grayscale version of image.
963990
The image can be a PIL Image or a Tensor, in which case it is expected
964991
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
965992
966993
Args:
967-
img (PIL Image or Tensor): Image to be converted to grayscale.
994+
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
968995
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
969996
970997
Returns:
@@ -974,9 +1001,9 @@ def to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
9741001
if num_output_channels = 3 : returned image is 3 channel with r = g = b
9751002
"""
9761003
if not isinstance(img, torch.Tensor):
977-
return F_pil.to_grayscale(img, num_output_channels)
1004+
return F_pil.rgb_to_grayscale(img, num_output_channels)
9781005

979-
return F_t.to_grayscale(img, num_output_channels)
1006+
return F_t.rgb_to_grayscale(img, num_output_channels)
9801007

9811008

9821009
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:

torchvision/transforms/functional_pil.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numbers
2+
import warnings
23
from typing import Any, List, Sequence
34

45
import numpy as np
@@ -491,12 +492,31 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None)
491492

492493
@torch.jit.unused
493494
def to_grayscale(img, num_output_channels):
494-
"""Convert image to grayscale version of image.
495+
"""DEPRECATED. Convert RGB image to grayscale version of image.
495496
496497
Args:
497498
img (PIL Image): Image to be converted to grayscale.
498499
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
499500
501+
Returns:
502+
PIL Image: Grayscale version of the image.
503+
if num_output_channels = 1 : returned image is single channel
504+
505+
if num_output_channels = 3 : returned image is 3 channel with r = g = b
506+
"""
507+
warnings.warn("The use of the F_pil.to_grayscale transform is deprecated, " +
508+
"please use F.rgb_to_grayscale instead.")
509+
return rgb_to_grayscale(img, num_output_channels)
510+
511+
512+
@torch.jit.unused
513+
def rgb_to_grayscale(img, num_output_channels):
514+
"""Convert RGB image to grayscale version of image.
515+
516+
Args:
517+
img (PIL Image): RGB Image to be converted to grayscale.
518+
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
519+
500520
Returns:
501521
PIL Image: Grayscale version of the image.
502522
if num_output_channels = 1 : returned image is single channel

torchvision/transforms/functional_tensor.py

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,47 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
7676
return img[..., top:top + height, left:left + width]
7777

7878

79-
def rgb_to_grayscale(img: Tensor) -> Tensor:
79+
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
8080
"""Convert the given RGB Image Tensor to Grayscale.
8181
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
8282
is L = R * 0.2989 + G * 0.5870 + B * 0.1140
8383
8484
Args:
8585
img (Tensor): Image to be converted to Grayscale in the form [C, H, W].
86+
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
8687
8788
Returns:
88-
Tensor: Grayscale image.
89+
Tensor: Grayscale version of the image.
90+
if num_output_channels = 1 : returned image is single channel
91+
92+
if num_output_channels = 3 : returned image is 3 channel with r = g = b
8993
9094
"""
91-
if img.shape[0] != 3:
92-
raise TypeError('Input Image does not contain 3 Channels')
95+
if img.ndim < 3:
96+
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
97+
c = img.shape[-3]
98+
if c != 3:
99+
raise TypeError("Input image tensor should 3 channels, but found {}".format(c))
93100

94-
return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)
101+
if num_output_channels not in (1, 3):
102+
raise ValueError('num_output_channels should be either 1 or 3')
103+
104+
r = img[..., 0, :, :].float()
105+
g = img[..., 1, :, :].float()
106+
b = img[..., 2, :, :].float()
107+
# According to PIL docs: PIL grayscale L mode is L = R * 299/1000 + G * 587/1000 + B * 114/1000
108+
# but implementation is slightly different:
109+
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
110+
# src/libImaging/Convert.c#L47
111+
# ((rgb)[0]*19595 + (rgb)[1]*38470 + (rgb)[2]*7471 + 0x8000) >> 16
112+
l_img = torch.floor((19595 * r + 38470 * g + 7471 * b + 2 ** 15) / 2 ** 16).to(img.dtype)
113+
114+
if num_output_channels == 3:
115+
l_img = torch.stack([l_img, l_img, l_img], dim=-3)
116+
else:
117+
l_img = l_img.unsqueeze(dim=-3)
118+
119+
return l_img
95120

96121

97122
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
@@ -893,39 +918,3 @@ def perspective(
893918
mode = _interpolation_modes[interpolation]
894919

895920
return _apply_grid_transform(img, grid, mode)
896-
897-
898-
def to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
899-
"""Convert image to grayscale version of image.
900-
901-
Args:
902-
img (Tensor): Image to be converted to grayscale. We assume (..., 3, H, W) layout.
903-
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
904-
905-
Returns:
906-
Tensor: Grayscale version of the image.
907-
if num_output_channels = 1 : returned image is single channel
908-
909-
if num_output_channels = 3 : returned image is 3 channel with r = g = b
910-
"""
911-
if img.ndim < 3:
912-
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
913-
c = img.shape[-3]
914-
if c != 3:
915-
raise TypeError("Input image tensor should 3 channels, but found {}".format(c))
916-
917-
if num_output_channels not in (1, 3):
918-
raise ValueError('num_output_channels should be either 1 or 3')
919-
920-
# PIL grayscale L mode is L = R * 299/1000 + G * 587/1000 + B * 114/1000
921-
r = img[..., 0, :, :]
922-
g = img[..., 1, :, :]
923-
b = img[..., 2, :, :]
924-
l_img = (0.299 * r + 0.587 * g + 0.114 * b + 0.5).to(img.dtype)
925-
926-
if num_output_channels == 3:
927-
l_img = torch.stack([l_img, l_img, l_img], dim=-3)
928-
else:
929-
l_img = l_img.unsqueeze(dim=-3)
930-
931-
return l_img

torchvision/transforms/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,7 @@ def forward(self, img: Tensor) -> Tensor:
13821382
Returns:
13831383
PIL Image or Tensor: Grayscaled image.
13841384
"""
1385-
return F.to_grayscale(img, num_output_channels=self.num_output_channels)
1385+
return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
13861386

13871387
def __repr__(self):
13881388
return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
@@ -1419,7 +1419,7 @@ def forward(self, img: Tensor) -> Tensor:
14191419
"""
14201420
num_output_channels = F._get_image_num_channels(img)
14211421
if torch.rand(1) < self.p:
1422-
return F.to_grayscale(img, num_output_channels=num_output_channels)
1422+
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
14231423
return img
14241424

14251425
def __repr__(self):

0 commit comments

Comments
 (0)