diff --git a/test/test_transforms.py b/test/test_transforms.py index 9f54002a5b0..e67d8f0d6c5 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -714,6 +714,143 @@ def test_random_rotation(self): angle = t.get_params(t.degrees) assert angle > -10 and angle < 10 + def test_to_grayscale(self): + """Unit tests for grayscale transform""" + + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + x_pil_2 = x_pil.convert('L') + gray_np = np.array(x_pil_2) + + # Test Set: Grayscale an image with desired number of output channels + # Case 1: RGB -> 1 channel grayscale + trans1 = transforms.Grayscale(num_output_channels=1) + gray_pil_1 = trans1(x_pil) + gray_np_1 = np.array(gray_pil_1) + assert gray_pil_1.mode == 'L', 'mode should be L' + assert gray_np_1.shape == tuple(x_shape[0:2]), 'should be 1 channel' + np.testing.assert_equal(gray_np, gray_np_1) + + # Case 2: RGB -> 3 channel grayscale + trans2 = transforms.Grayscale(num_output_channels=3) + gray_pil_2 = trans2(x_pil) + gray_np_2 = np.array(gray_pil_2) + assert gray_pil_2.mode == 'RGB', 'mode should be RGB' + assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' + np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) + np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) + np.testing.assert_equal(gray_np, gray_np_2[:, :, 0]) + + # Case 3: 1 channel grayscale -> 1 channel grayscale + trans3 = transforms.Grayscale(num_output_channels=1) + gray_pil_3 = trans3(x_pil_2) + gray_np_3 = np.array(gray_pil_3) + assert gray_pil_3.mode == 'L', 'mode should be L' + assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' + np.testing.assert_equal(gray_np, gray_np_3) + + # Case 4: 1 channel grayscale -> 3 channel grayscale + trans4 = transforms.Grayscale(num_output_channels=3) + gray_pil_4 = trans4(x_pil_2) + gray_np_4 = np.array(gray_pil_4) + assert gray_pil_4.mode == 'RGB', 'mode should be RGB' + assert gray_np_4.shape == tuple(x_shape), 'should be 3 channel' + np.testing.assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1]) + np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2]) + np.testing.assert_equal(gray_np, gray_np_4[:, :, 0]) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_grayscale(self): + """Unit tests for random grayscale transform""" + + # Test Set 1: RGB -> 3 channel grayscale + random_state = random.getstate() + random.seed(42) + x_shape = [2, 2, 3] + x_np = np.random.randint(0, 256, x_shape, np.uint8) + x_pil = Image.fromarray(x_np, mode='RGB') + x_pil_2 = x_pil.convert('L') + gray_np = np.array(x_pil_2) + + num_samples = 250 + num_gray = 0 + for _ in range(num_samples): + gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil) + gray_np_2 = np.array(gray_pil_2) + if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \ + np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \ + np.array_equal(gray_np, gray_np_2[:, :, 0]): + num_gray = num_gray + 1 + + p_value = stats.binom_test(num_gray, num_samples, p=0.5) + random.setstate(random_state) + assert p_value > 0.0001 + + # Test Set 2: grayscale -> 1 channel grayscale + random_state = random.getstate() + random.seed(42) + x_shape = [2, 2, 3] + x_np = np.random.randint(0, 256, x_shape, np.uint8) + x_pil = Image.fromarray(x_np, mode='RGB') + x_pil_2 = x_pil.convert('L') + gray_np = np.array(x_pil_2) + + num_samples = 250 + num_gray = 0 + for _ in range(num_samples): + gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2) + gray_np_3 = np.array(gray_pil_3) + if np.array_equal(gray_np, gray_np_3): + num_gray = num_gray + 1 + + p_value = stats.binom_test(num_gray, num_samples, p=1.0) # Note: grayscale is always unchanged + random.setstate(random_state) + assert p_value > 0.0001 + + # Test set 3: Explicit tests + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + x_pil_2 = x_pil.convert('L') + gray_np = np.array(x_pil_2) + + # Case 3a: RGB -> 3 channel grayscale (grayscaled) + trans2 = transforms.RandomGrayscale(p=1.0) + gray_pil_2 = trans2(x_pil) + gray_np_2 = np.array(gray_pil_2) + assert gray_pil_2.mode == 'RGB', 'mode should be RGB' + assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' + np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) + np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) + np.testing.assert_equal(gray_np, gray_np_2[:, :, 0]) + + # Case 3b: RGB -> 3 channel grayscale (unchanged) + trans2 = transforms.RandomGrayscale(p=0.0) + gray_pil_2 = trans2(x_pil) + gray_np_2 = np.array(gray_pil_2) + assert gray_pil_2.mode == 'RGB', 'mode should be RGB' + assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' + np.testing.assert_equal(x_np, gray_np_2) + + # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled) + trans3 = transforms.RandomGrayscale(p=1.0) + gray_pil_3 = trans3(x_pil_2) + gray_np_3 = np.array(gray_pil_3) + assert gray_pil_3.mode == 'L', 'mode should be L' + assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' + np.testing.assert_equal(gray_np, gray_np_3) + + # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged) + trans3 = transforms.RandomGrayscale(p=0.0) + gray_pil_3 = trans3(x_pil_2) + gray_np_3 = np.array(gray_pil_3) + assert gray_pil_3.mode == 'L', 'mode should be L' + assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' + np.testing.assert_equal(gray_np, gray_np_3) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 57f37b71e23..85beb38b934 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -545,7 +545,35 @@ def rotate(img, angle, resample=False, expand=False, center=None): Origin is the upper left corner. Default is the center of the image. """ + if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.rotate(angle, resample, expand, center) + + +def to_grayscale(img, num_output_channels=1): + """Convert image to grayscale version of image. + + Args: + img (PIL Image): Image to be converted to grayscale. + + Returns: + PIL Image: Grayscale version of the image. + if num_output_channels == 1 : returned image is single channel + if num_output_channels == 3 : returned image is 3 channel with r == g == b + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if num_output_channels == 1: + img = img.convert('L') + elif num_output_channels == 3: + img = img.convert('L') + np_img = np.array(img, dtype=np.uint8) + np_img = np.dstack([np_img, np_img, np_img]) + img = Image.fromarray(np_img, 'RGB') + else: + raise ValueError('num_output_channels should be either 1 or 3') + + return img diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 3fd8e43f898..c1ff34b07c5 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -17,7 +17,8 @@ __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", - "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation"] + "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", + "Grayscale", "RandomGrayscale"] class Compose(object): @@ -619,7 +620,6 @@ def get_params(degrees): def __call__(self, img): """ - Args: img (PIL Image): Image to be rotated. Returns: @@ -629,3 +629,61 @@ def __call__(self, img): angle = self.get_params(self.degrees) return F.rotate(img, angle, self.resample, self.expand, self.center) + + +class Grayscale(object): + """Convert image to grayscale. + Args: + num_output_channels (int): (1 or 3) number of channels desired for output image + + Returns: + PIL Image: grayscale version of the input + if num_output_channels == 1 : returned image is single channel + if num_output_channels == 3 : returned image is 3 channel with r == g == b + + """ + + def __init__(self, num_output_channels=1): + self.num_output_channels = num_output_channels + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be converted to grayscale. + + Returns: + PIL Image: Randomly grayscaled image. + """ + return F.to_grayscale(img, num_output_channels=self.num_output_channels) + + +class RandomGrayscale(object): + """Randomly convert image to grayscale with a probability of p (default 0.1). + Args: + p (float): probability that image should be converted to grayscale. + + Returns: + PIL Image: grayscale version of the input image with probability p + and unchanged with probability (1-p) + - if input image is 1 channel: + grayscale version is 1 channel + - if input image is 3 channel: + grayscale version is 3 channel with r == g == b + + """ + + def __init__(self, p=0.1): + self.p = p + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be converted to grayscale. + + Returns: + PIL Image: Randomly grayscaled image. + """ + num_output_channels = 1 if img.mode == 'L' else 3 + if random.random() < self.p: + return F.to_grayscale(img, num_output_channels=num_output_channels) + return img