@@ -714,6 +714,143 @@ def test_random_rotation(self):
714
714
angle = t .get_params (t .degrees )
715
715
assert angle > - 10 and angle < 10
716
716
717
+ def test_to_grayscale (self ):
718
+ """Unit tests for grayscale transform"""
719
+
720
+ x_shape = [2 , 2 , 3 ]
721
+ x_data = [0 , 5 , 13 , 54 , 135 , 226 , 37 , 8 , 234 , 90 , 255 , 1 ]
722
+ x_np = np .array (x_data , dtype = np .uint8 ).reshape (x_shape )
723
+ x_pil = Image .fromarray (x_np , mode = 'RGB' )
724
+ x_pil_2 = x_pil .convert ('L' )
725
+ gray_np = np .array (x_pil_2 )
726
+
727
+ # Test Set: Grayscale an image with desired number of output channels
728
+ # Case 1: RGB -> 1 channel grayscale
729
+ trans1 = transforms .Grayscale (num_output_channels = 1 )
730
+ gray_pil_1 = trans1 (x_pil )
731
+ gray_np_1 = np .array (gray_pil_1 )
732
+ assert gray_pil_1 .mode == 'L' , 'mode should be L'
733
+ assert gray_np_1 .shape == tuple (x_shape [0 :2 ]), 'should be 1 channel'
734
+ np .testing .assert_equal (gray_np , gray_np_1 )
735
+
736
+ # Case 2: RGB -> 3 channel grayscale
737
+ trans2 = transforms .Grayscale (num_output_channels = 3 )
738
+ gray_pil_2 = trans2 (x_pil )
739
+ gray_np_2 = np .array (gray_pil_2 )
740
+ assert gray_pil_2 .mode == 'RGB' , 'mode should be RGB'
741
+ assert gray_np_2 .shape == tuple (x_shape ), 'should be 3 channel'
742
+ np .testing .assert_equal (gray_np_2 [:, :, 0 ], gray_np_2 [:, :, 1 ])
743
+ np .testing .assert_equal (gray_np_2 [:, :, 1 ], gray_np_2 [:, :, 2 ])
744
+ np .testing .assert_equal (gray_np , gray_np_2 [:, :, 0 ])
745
+
746
+ # Case 3: 1 channel grayscale -> 1 channel grayscale
747
+ trans3 = transforms .Grayscale (num_output_channels = 1 )
748
+ gray_pil_3 = trans3 (x_pil_2 )
749
+ gray_np_3 = np .array (gray_pil_3 )
750
+ assert gray_pil_3 .mode == 'L' , 'mode should be L'
751
+ assert gray_np_3 .shape == tuple (x_shape [0 :2 ]), 'should be 1 channel'
752
+ np .testing .assert_equal (gray_np , gray_np_3 )
753
+
754
+ # Case 4: 1 channel grayscale -> 3 channel grayscale
755
+ trans4 = transforms .Grayscale (num_output_channels = 3 )
756
+ gray_pil_4 = trans4 (x_pil_2 )
757
+ gray_np_4 = np .array (gray_pil_4 )
758
+ assert gray_pil_4 .mode == 'RGB' , 'mode should be RGB'
759
+ assert gray_np_4 .shape == tuple (x_shape ), 'should be 3 channel'
760
+ np .testing .assert_equal (gray_np_4 [:, :, 0 ], gray_np_4 [:, :, 1 ])
761
+ np .testing .assert_equal (gray_np_4 [:, :, 1 ], gray_np_4 [:, :, 2 ])
762
+ np .testing .assert_equal (gray_np , gray_np_4 [:, :, 0 ])
763
+
764
+ @unittest .skipIf (stats is None , 'scipy.stats not available' )
765
+ def test_random_grayscale (self ):
766
+ """Unit tests for random grayscale transform"""
767
+
768
+ # Test Set 1: RGB -> 3 channel grayscale
769
+ random_state = random .getstate ()
770
+ random .seed (42 )
771
+ x_shape = [2 , 2 , 3 ]
772
+ x_np = np .random .randint (0 , 256 , x_shape , np .uint8 )
773
+ x_pil = Image .fromarray (x_np , mode = 'RGB' )
774
+ x_pil_2 = x_pil .convert ('L' )
775
+ gray_np = np .array (x_pil_2 )
776
+
777
+ num_samples = 250
778
+ num_gray = 0
779
+ for _ in range (num_samples ):
780
+ gray_pil_2 = transforms .RandomGrayscale (p = 0.5 )(x_pil )
781
+ gray_np_2 = np .array (gray_pil_2 )
782
+ if np .array_equal (gray_np_2 [:, :, 0 ], gray_np_2 [:, :, 1 ]) and \
783
+ np .array_equal (gray_np_2 [:, :, 1 ], gray_np_2 [:, :, 2 ]) and \
784
+ np .array_equal (gray_np , gray_np_2 [:, :, 0 ]):
785
+ num_gray = num_gray + 1
786
+
787
+ p_value = stats .binom_test (num_gray , num_samples , p = 0.5 )
788
+ random .setstate (random_state )
789
+ assert p_value > 0.0001
790
+
791
+ # Test Set 2: grayscale -> 1 channel grayscale
792
+ random_state = random .getstate ()
793
+ random .seed (42 )
794
+ x_shape = [2 , 2 , 3 ]
795
+ x_np = np .random .randint (0 , 256 , x_shape , np .uint8 )
796
+ x_pil = Image .fromarray (x_np , mode = 'RGB' )
797
+ x_pil_2 = x_pil .convert ('L' )
798
+ gray_np = np .array (x_pil_2 )
799
+
800
+ num_samples = 250
801
+ num_gray = 0
802
+ for _ in range (num_samples ):
803
+ gray_pil_3 = transforms .RandomGrayscale (p = 0.5 )(x_pil_2 )
804
+ gray_np_3 = np .array (gray_pil_3 )
805
+ if np .array_equal (gray_np , gray_np_3 ):
806
+ num_gray = num_gray + 1
807
+
808
+ p_value = stats .binom_test (num_gray , num_samples , p = 1.0 ) # Note: grayscale is always unchanged
809
+ random .setstate (random_state )
810
+ assert p_value > 0.0001
811
+
812
+ # Test set 3: Explicit tests
813
+ x_shape = [2 , 2 , 3 ]
814
+ x_data = [0 , 5 , 13 , 54 , 135 , 226 , 37 , 8 , 234 , 90 , 255 , 1 ]
815
+ x_np = np .array (x_data , dtype = np .uint8 ).reshape (x_shape )
816
+ x_pil = Image .fromarray (x_np , mode = 'RGB' )
817
+ x_pil_2 = x_pil .convert ('L' )
818
+ gray_np = np .array (x_pil_2 )
819
+
820
+ # Case 3a: RGB -> 3 channel grayscale (grayscaled)
821
+ trans2 = transforms .RandomGrayscale (p = 1.0 )
822
+ gray_pil_2 = trans2 (x_pil )
823
+ gray_np_2 = np .array (gray_pil_2 )
824
+ assert gray_pil_2 .mode == 'RGB' , 'mode should be RGB'
825
+ assert gray_np_2 .shape == tuple (x_shape ), 'should be 3 channel'
826
+ np .testing .assert_equal (gray_np_2 [:, :, 0 ], gray_np_2 [:, :, 1 ])
827
+ np .testing .assert_equal (gray_np_2 [:, :, 1 ], gray_np_2 [:, :, 2 ])
828
+ np .testing .assert_equal (gray_np , gray_np_2 [:, :, 0 ])
829
+
830
+ # Case 3b: RGB -> 3 channel grayscale (unchanged)
831
+ trans2 = transforms .RandomGrayscale (p = 0.0 )
832
+ gray_pil_2 = trans2 (x_pil )
833
+ gray_np_2 = np .array (gray_pil_2 )
834
+ assert gray_pil_2 .mode == 'RGB' , 'mode should be RGB'
835
+ assert gray_np_2 .shape == tuple (x_shape ), 'should be 3 channel'
836
+ np .testing .assert_equal (x_np , gray_np_2 )
837
+
838
+ # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
839
+ trans3 = transforms .RandomGrayscale (p = 1.0 )
840
+ gray_pil_3 = trans3 (x_pil_2 )
841
+ gray_np_3 = np .array (gray_pil_3 )
842
+ assert gray_pil_3 .mode == 'L' , 'mode should be L'
843
+ assert gray_np_3 .shape == tuple (x_shape [0 :2 ]), 'should be 1 channel'
844
+ np .testing .assert_equal (gray_np , gray_np_3 )
845
+
846
+ # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
847
+ trans3 = transforms .RandomGrayscale (p = 0.0 )
848
+ gray_pil_3 = trans3 (x_pil_2 )
849
+ gray_np_3 = np .array (gray_pil_3 )
850
+ assert gray_pil_3 .mode == 'L' , 'mode should be L'
851
+ assert gray_np_3 .shape == tuple (x_shape [0 :2 ]), 'should be 1 channel'
852
+ np .testing .assert_equal (gray_np , gray_np_3 )
853
+
717
854
718
855
if __name__ == '__main__' :
719
856
unittest .main ()
0 commit comments