@@ -494,6 +494,26 @@ def adjust_gamma(img, gamma, gain=1):
494
img = Image .fromarray (np_img , 'RGB' ).convert (input_mode )
494
img = Image .fromarray (np_img , 'RGB' ).convert (input_mode )
495
return img
495
return img
496
496
497
+ def to_grayscale (img ):
498
+ """Convert image to grayscale, repeated over three channels.
499
+
500
+ Args:
501
+ img (PIL Image): Image to be converted to grayscale.
502
+
503
+ Returns:
504
+ PIL Image: Grayscale version of the image, repeated over three channels.
505
+ """
506
+ if not _is_pil_image (img ):
507
+ raise TypeError ('img should be PIL Image. Got {}' .format (type (img )))
508
+
509
+ input_mode = img .mode
510
+ img = img .convert ('L' )
511
+
512
+ np_img = np .array (img , dtype = np .uint8 )
513
+ np_img = np .dstack ([np_img ] * 3 )
514
+
515
+ img = Image .fromarray (np_img , 'RGB' ).convert (input_mode )
516
+ return img
497
517
498
class Compose (object ):
518
class Compose (object ):
499
"""Composes several transforms together.
519
"""Composes several transforms together.
@@ -1026,3 +1046,26 @@ def __call__(self, img):
1026
transform = self .get_params (self .brightness , self .contrast ,
1046
transform = self .get_params (self .brightness , self .contrast ,
1027
self .saturation , self .hue )
1047
self .saturation , self .hue )
1028
return transform (img )
1048
return transform (img )
1049
+
1050
+
1051
+ class RandomGrayscale (object ):
1052
+ """Randomly convert image to grayscale with a probability of p (default 0.1).
1053
+ Args:
1054
+ p (float): probability that image should be converted to grayscale.
1055
+
1056
+ """
1057
+
1058
+ def __init__ (self , p = 0.1 ):
1059
+ self .p = p
1060
+
1061
+ def __call__ (self , img ):
1062
+ """
1063
+ Args:
1064
+ img (PIL Image): Image to be converted to grayscale.
1065
+
1066
+ Returns:
1067
+ PIL Image: Randomly grayscaled image.
1068
+ """
1069
+ if random .random () < self .p :
1070
+ return to_grayscale (img )
1071
+ return img
0 commit comments