Skip to content

fix RandomGrayscale for grayscale inputs #5585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2237,5 +2237,17 @@ def test_random_affine():
assert t.interpolation == transforms.InterpolationMode.BILINEAR


def test_random_grayscale_with_grayscale_input():
transform = transforms.RandomGrayscale(p=1.0)

image_tensor = torch.randint(0, 256, (1, 16, 16), dtype=torch.uint8)
output_tensor = transform(image_tensor)
torch.testing.assert_close(output_tensor, image_tensor)

image_pil = F.to_pil_image(image_tensor)
output_pil = transform(image_pil)
torch.testing.assert_close(F.pil_to_tensor(output_pil), image_tensor)


if __name__ == "__main__":
pytest.main([__file__])
21 changes: 14 additions & 7 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,16 +1584,16 @@ def __repr__(self) -> str:
class RandomGrayscale(torch.nn.Module):
"""Randomly convert image to grayscale with a probability of p (default 0.1).
If the image is torch Tensor, it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions

Args:
p (float): probability that image should be converted to grayscale.

Returns:
PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
with probability (1-p).
- If input image is 1 channel: grayscale version is 1 channel
- If input image is 3 channel: grayscale version is 3 channel with r == g == b
- If input image is grayscale (1 channel): copy input is returned
- If input image is RGB (3 channels): grayscale version of input is returned with 3 channels and r == g == b

"""

Expand All @@ -1610,10 +1610,17 @@ def forward(self, img):
Returns:
PIL Image or Tensor: Randomly grayscaled image.
"""
num_output_channels, _, _ = F.get_dimensions(img)
if torch.rand(1) < self.p:
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
return img
if torch.rand(1) >= self.p:
return img

num_input_channels, _, _ = F.get_dimensions(img)
if num_input_channels == 1:
if isinstance(img, torch.Tensor):
return img.clone()
else: # isinstance(img, PIL.Image.Image)
return img.copy()

return F.rgb_to_grayscale(img, num_output_channels=3)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
Expand Down