Skip to content

Commit 1078e1d

Browse files
committed
Fixes bug with fp input
1 parent 50da7be commit 1078e1d

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

torchvision/transforms/functional_tensor.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,16 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
104104
r = img[..., 0, :, :].float()
105105
g = img[..., 1, :, :].float()
106106
b = img[..., 2, :, :].float()
107-
# According to PIL docs: PIL grayscale L mode is L = R * 299/1000 + G * 587/1000 + B * 114/1000
108-
# but implementation is slightly different:
109-
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
110-
# src/libImaging/Convert.c#L47
111-
# ((rgb)[0]*19595 + (rgb)[1]*38470 + (rgb)[2]*7471 + 0x8000) >> 16
112-
l_img = torch.floor((19595 * r + 38470 * g + 7471 * b + 2 ** 15) / 2 ** 16).to(img.dtype)
107+
if not img.is_floating_point():
108+
# According to PIL docs: PIL grayscale L mode is L = R * 299/1000 + G * 587/1000 + B * 114/1000
109+
# but implementation is slightly different:
110+
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
111+
# src/libImaging/Convert.c#L47
112+
# ((rgb)[0]*19595 + (rgb)[1]*38470 + (rgb)[2]*7471 + 0x8000) >> 16
113+
# l_img = ((19595 * r + 38470 * g + 7471 * b + 2 ** 15) / 2 ** 16).to(img.dtype)
114+
l_img = torch.floor((19595 * r + 38470 * g + 7471 * b + 2 ** 15) / 2 ** 16).to(img.dtype)
115+
else:
116+
l_img = (0.299 * r + 0.587 * g + 0.114 * b).to(img.dtype)
113117

114118
if num_output_channels == 3:
115119
l_img = torch.stack([l_img, l_img, l_img], dim=-3)
@@ -407,8 +411,8 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa
407411

408412

409413
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
410-
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
411-
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
414+
bound = 1.0 if img1.is_floating_point() else 255.0
415+
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
412416

413417

414418
def _rgb2hsv(img):

0 commit comments

Comments
 (0)