Skip to content

Commit a93e53a

Browse files
committed
Added RandomGrayscale transform
1 parent 005adfd commit a93e53a

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

torchvision/transforms.py

Lines changed: 43 additions & 0 deletions
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -494,6 +494,26 @@ def adjust_gamma(img, gamma, gain=1):
494
img = Image.fromarray(np_img, 'RGB').convert(input_mode)
494
img = Image.fromarray(np_img, 'RGB').convert(input_mode)
495
return img
495
return img
496

496

497+
def to_grayscale(img):
498+
"""Convert image to grayscale, repeated over three channels.
499+
500+
Args:
501+
img (PIL Image): Image to be converted to grayscale.
502+
503+
Returns:
504+
PIL Image: Grayscale version of the image, repeated over three channels.
505+
"""
506+
if not _is_pil_image(img):
507+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
508+
509+
input_mode = img.mode
510+
img = img.convert('L')
511+
512+
np_img = np.array(img, dtype=np.uint8)
513+
np_img = np.dstack([np_img] * 3)
514+
515+
img = Image.fromarray(np_img, 'RGB').convert(input_mode)
516+
return img
497

517

498
class Compose(object):
518
class Compose(object):
499
"""Composes several transforms together.
519
"""Composes several transforms together.
@@ -1026,3 +1046,26 @@ def __call__(self, img):
1026
transform = self.get_params(self.brightness, self.contrast,
1046
transform = self.get_params(self.brightness, self.contrast,
1027
self.saturation, self.hue)
1047
self.saturation, self.hue)
1028
return transform(img)
1048
return transform(img)
1049+
1050+
1051+
class RandomGrayscale(object):
1052+
"""Randomly convert image to grayscale with a probability of p (default 0.1).
1053+
Args:
1054+
p (float): probability that image should be converted to grayscale.
1055+
1056+
"""
1057+
1058+
def __init__(self, p=0.1):
1059+
self.p = p
1060+
1061+
def __call__(self, img):
1062+
"""
1063+
Args:
1064+
img (PIL Image): Image to be converted to grayscale.
1065+
1066+
Returns:
1067+
PIL Image: Randomly grayscaled image.
1068+
"""
1069+
if random.random() < self.p:
1070+
return to_grayscale(img)
1071+
return img

0 commit comments

Comments
 (0)