Skip to content

Commit 627ce70

Browse files
sourabhdalykhantejani
authored andcommitted
transforms: randomly grayscaling an image (#325)
* add to_grayscale + randomGrayscale
1 parent b5f0b6e commit 627ce70

File tree

3 files changed

+225
-2
lines changed

3 files changed

+225
-2
lines changed

test/test_transforms.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,143 @@ def test_random_rotation(self):
714714
angle = t.get_params(t.degrees)
715715
assert angle > -10 and angle < 10
716716

717+
def test_to_grayscale(self):
718+
"""Unit tests for grayscale transform"""
719+
720+
x_shape = [2, 2, 3]
721+
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
722+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
723+
x_pil = Image.fromarray(x_np, mode='RGB')
724+
x_pil_2 = x_pil.convert('L')
725+
gray_np = np.array(x_pil_2)
726+
727+
# Test Set: Grayscale an image with desired number of output channels
728+
# Case 1: RGB -> 1 channel grayscale
729+
trans1 = transforms.Grayscale(num_output_channels=1)
730+
gray_pil_1 = trans1(x_pil)
731+
gray_np_1 = np.array(gray_pil_1)
732+
assert gray_pil_1.mode == 'L', 'mode should be L'
733+
assert gray_np_1.shape == tuple(x_shape[0:2]), 'should be 1 channel'
734+
np.testing.assert_equal(gray_np, gray_np_1)
735+
736+
# Case 2: RGB -> 3 channel grayscale
737+
trans2 = transforms.Grayscale(num_output_channels=3)
738+
gray_pil_2 = trans2(x_pil)
739+
gray_np_2 = np.array(gray_pil_2)
740+
assert gray_pil_2.mode == 'RGB', 'mode should be RGB'
741+
assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel'
742+
np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
743+
np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
744+
np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])
745+
746+
# Case 3: 1 channel grayscale -> 1 channel grayscale
747+
trans3 = transforms.Grayscale(num_output_channels=1)
748+
gray_pil_3 = trans3(x_pil_2)
749+
gray_np_3 = np.array(gray_pil_3)
750+
assert gray_pil_3.mode == 'L', 'mode should be L'
751+
assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel'
752+
np.testing.assert_equal(gray_np, gray_np_3)
753+
754+
# Case 4: 1 channel grayscale -> 3 channel grayscale
755+
trans4 = transforms.Grayscale(num_output_channels=3)
756+
gray_pil_4 = trans4(x_pil_2)
757+
gray_np_4 = np.array(gray_pil_4)
758+
assert gray_pil_4.mode == 'RGB', 'mode should be RGB'
759+
assert gray_np_4.shape == tuple(x_shape), 'should be 3 channel'
760+
np.testing.assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1])
761+
np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
762+
np.testing.assert_equal(gray_np, gray_np_4[:, :, 0])
763+
764+
@unittest.skipIf(stats is None, 'scipy.stats not available')
765+
def test_random_grayscale(self):
766+
"""Unit tests for random grayscale transform"""
767+
768+
# Test Set 1: RGB -> 3 channel grayscale
769+
random_state = random.getstate()
770+
random.seed(42)
771+
x_shape = [2, 2, 3]
772+
x_np = np.random.randint(0, 256, x_shape, np.uint8)
773+
x_pil = Image.fromarray(x_np, mode='RGB')
774+
x_pil_2 = x_pil.convert('L')
775+
gray_np = np.array(x_pil_2)
776+
777+
num_samples = 250
778+
num_gray = 0
779+
for _ in range(num_samples):
780+
gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil)
781+
gray_np_2 = np.array(gray_pil_2)
782+
if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \
783+
np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \
784+
np.array_equal(gray_np, gray_np_2[:, :, 0]):
785+
num_gray = num_gray + 1
786+
787+
p_value = stats.binom_test(num_gray, num_samples, p=0.5)
788+
random.setstate(random_state)
789+
assert p_value > 0.0001
790+
791+
# Test Set 2: grayscale -> 1 channel grayscale
792+
random_state = random.getstate()
793+
random.seed(42)
794+
x_shape = [2, 2, 3]
795+
x_np = np.random.randint(0, 256, x_shape, np.uint8)
796+
x_pil = Image.fromarray(x_np, mode='RGB')
797+
x_pil_2 = x_pil.convert('L')
798+
gray_np = np.array(x_pil_2)
799+
800+
num_samples = 250
801+
num_gray = 0
802+
for _ in range(num_samples):
803+
gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2)
804+
gray_np_3 = np.array(gray_pil_3)
805+
if np.array_equal(gray_np, gray_np_3):
806+
num_gray = num_gray + 1
807+
808+
p_value = stats.binom_test(num_gray, num_samples, p=1.0) # Note: grayscale is always unchanged
809+
random.setstate(random_state)
810+
assert p_value > 0.0001
811+
812+
# Test set 3: Explicit tests
813+
x_shape = [2, 2, 3]
814+
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
815+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
816+
x_pil = Image.fromarray(x_np, mode='RGB')
817+
x_pil_2 = x_pil.convert('L')
818+
gray_np = np.array(x_pil_2)
819+
820+
# Case 3a: RGB -> 3 channel grayscale (grayscaled)
821+
trans2 = transforms.RandomGrayscale(p=1.0)
822+
gray_pil_2 = trans2(x_pil)
823+
gray_np_2 = np.array(gray_pil_2)
824+
assert gray_pil_2.mode == 'RGB', 'mode should be RGB'
825+
assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel'
826+
np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
827+
np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
828+
np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])
829+
830+
# Case 3b: RGB -> 3 channel grayscale (unchanged)
831+
trans2 = transforms.RandomGrayscale(p=0.0)
832+
gray_pil_2 = trans2(x_pil)
833+
gray_np_2 = np.array(gray_pil_2)
834+
assert gray_pil_2.mode == 'RGB', 'mode should be RGB'
835+
assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel'
836+
np.testing.assert_equal(x_np, gray_np_2)
837+
838+
# Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
839+
trans3 = transforms.RandomGrayscale(p=1.0)
840+
gray_pil_3 = trans3(x_pil_2)
841+
gray_np_3 = np.array(gray_pil_3)
842+
assert gray_pil_3.mode == 'L', 'mode should be L'
843+
assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel'
844+
np.testing.assert_equal(gray_np, gray_np_3)
845+
846+
# Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
847+
trans3 = transforms.RandomGrayscale(p=0.0)
848+
gray_pil_3 = trans3(x_pil_2)
849+
gray_np_3 = np.array(gray_pil_3)
850+
assert gray_pil_3.mode == 'L', 'mode should be L'
851+
assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel'
852+
np.testing.assert_equal(gray_np, gray_np_3)
853+
717854

718855
if __name__ == '__main__':
719856
unittest.main()

torchvision/transforms/functional.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,35 @@ def rotate(img, angle, resample=False, expand=False, center=None):
545545
Origin is the upper left corner.
546546
Default is the center of the image.
547547
"""
548+
548549
if not _is_pil_image(img):
549550
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
550551

551552
return img.rotate(angle, resample, expand, center)
553+
554+
555+
def to_grayscale(img, num_output_channels=1):
556+
"""Convert image to grayscale version of image.
557+
558+
Args:
559+
img (PIL Image): Image to be converted to grayscale.
560+
561+
Returns:
562+
PIL Image: Grayscale version of the image.
563+
if num_output_channels == 1 : returned image is single channel
564+
if num_output_channels == 3 : returned image is 3 channel with r == g == b
565+
"""
566+
if not _is_pil_image(img):
567+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
568+
569+
if num_output_channels == 1:
570+
img = img.convert('L')
571+
elif num_output_channels == 3:
572+
img = img.convert('L')
573+
np_img = np.array(img, dtype=np.uint8)
574+
np_img = np.dstack([np_img, np_img, np_img])
575+
img = Image.fromarray(np_img, 'RGB')
576+
else:
577+
raise ValueError('num_output_channels should be either 1 or 3')
578+
579+
return img

torchvision/transforms/transforms.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
1919
"Lambda", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
20-
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation"]
20+
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation",
21+
"Grayscale", "RandomGrayscale"]
2122

2223

2324
class Compose(object):
@@ -619,7 +620,6 @@ def get_params(degrees):
619620

620621
def __call__(self, img):
621622
"""
622-
Args:
623623
img (PIL Image): Image to be rotated.
624624
625625
Returns:
@@ -629,3 +629,61 @@ def __call__(self, img):
629629
angle = self.get_params(self.degrees)
630630

631631
return F.rotate(img, angle, self.resample, self.expand, self.center)
632+
633+
634+
class Grayscale(object):
635+
"""Convert image to grayscale.
636+
Args:
637+
num_output_channels (int): (1 or 3) number of channels desired for output image
638+
639+
Returns:
640+
PIL Image: grayscale version of the input
641+
if num_output_channels == 1 : returned image is single channel
642+
if num_output_channels == 3 : returned image is 3 channel with r == g == b
643+
644+
"""
645+
646+
def __init__(self, num_output_channels=1):
647+
self.num_output_channels = num_output_channels
648+
649+
def __call__(self, img):
650+
"""
651+
Args:
652+
img (PIL Image): Image to be converted to grayscale.
653+
654+
Returns:
655+
PIL Image: Randomly grayscaled image.
656+
"""
657+
return F.to_grayscale(img, num_output_channels=self.num_output_channels)
658+
659+
660+
class RandomGrayscale(object):
661+
"""Randomly convert image to grayscale with a probability of p (default 0.1).
662+
Args:
663+
p (float): probability that image should be converted to grayscale.
664+
665+
Returns:
666+
PIL Image: grayscale version of the input image with probability p
667+
and unchanged with probability (1-p)
668+
- if input image is 1 channel:
669+
grayscale version is 1 channel
670+
- if input image is 3 channel:
671+
grayscale version is 3 channel with r == g == b
672+
673+
"""
674+
675+
def __init__(self, p=0.1):
676+
self.p = p
677+
678+
def __call__(self, img):
679+
"""
680+
Args:
681+
img (PIL Image): Image to be converted to grayscale.
682+
683+
Returns:
684+
PIL Image: Randomly grayscaled image.
685+
"""
686+
num_output_channels = 1 if img.mode == 'L' else 3
687+
if random.random() < self.p:
688+
return F.to_grayscale(img, num_output_channels=num_output_channels)
689+
return img

0 commit comments

Comments
 (0)