Skip to content

transforms: randomly grayscaling an image #325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,143 @@ def test_random_rotation(self):
angle = t.get_params(t.degrees)
assert angle > -10 and angle < 10

def test_to_grayscale(self):
"""Unit tests for grayscale transform"""

x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
x_pil_2 = x_pil.convert('L')
gray_np = np.array(x_pil_2)

# Test Set: Grayscale an image with desired number of output channels
# Case 1: RGB -> 1 channel grayscale
trans1 = transforms.Grayscale(num_output_channels=1)
gray_pil_1 = trans1(x_pil)
gray_np_1 = np.array(gray_pil_1)
assert gray_pil_1.mode == 'L', 'mode should be L'
assert gray_np_1.shape == tuple(x_shape[0:2]), 'should be 1 channel'
np.testing.assert_equal(gray_np, gray_np_1)

# Case 2: RGB -> 3 channel grayscale
trans2 = transforms.Grayscale(num_output_channels=3)
gray_pil_2 = trans2(x_pil)
gray_np_2 = np.array(gray_pil_2)
assert gray_pil_2.mode == 'RGB', 'mode should be RGB'
assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel'
np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])

# Case 3: 1 channel grayscale -> 1 channel grayscale
trans3 = transforms.Grayscale(num_output_channels=1)
gray_pil_3 = trans3(x_pil_2)
gray_np_3 = np.array(gray_pil_3)
assert gray_pil_3.mode == 'L', 'mode should be L'
assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel'
np.testing.assert_equal(gray_np, gray_np_3)

# Case 4: 1 channel grayscale -> 3 channel grayscale
trans4 = transforms.Grayscale(num_output_channels=3)
gray_pil_4 = trans4(x_pil_2)
gray_np_4 = np.array(gray_pil_4)
assert gray_pil_4.mode == 'RGB', 'mode should be RGB'
assert gray_np_4.shape == tuple(x_shape), 'should be 3 channel'
np.testing.assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1])
np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
np.testing.assert_equal(gray_np, gray_np_4[:, :, 0])

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_grayscale(self):
"""Unit tests for random grayscale transform"""

# Test Set 1: RGB -> 3 channel grayscale
random_state = random.getstate()
random.seed(42)
x_shape = [2, 2, 3]
x_np = np.random.randint(0, 256, x_shape, np.uint8)
x_pil = Image.fromarray(x_np, mode='RGB')
x_pil_2 = x_pil.convert('L')
gray_np = np.array(x_pil_2)

num_samples = 250
num_gray = 0
for _ in range(num_samples):
gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil)
gray_np_2 = np.array(gray_pil_2)
if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \
np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \
np.array_equal(gray_np, gray_np_2[:, :, 0]):
num_gray = num_gray + 1

p_value = stats.binom_test(num_gray, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001

# Test Set 2: grayscale -> 1 channel grayscale
random_state = random.getstate()
random.seed(42)
x_shape = [2, 2, 3]
x_np = np.random.randint(0, 256, x_shape, np.uint8)
x_pil = Image.fromarray(x_np, mode='RGB')
x_pil_2 = x_pil.convert('L')
gray_np = np.array(x_pil_2)

num_samples = 250
num_gray = 0
for _ in range(num_samples):
gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2)
gray_np_3 = np.array(gray_pil_3)
if np.array_equal(gray_np, gray_np_3):
num_gray = num_gray + 1

p_value = stats.binom_test(num_gray, num_samples, p=1.0) # Note: grayscale is always unchanged
random.setstate(random_state)
assert p_value > 0.0001

# Test set 3: Explicit tests
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
x_pil_2 = x_pil.convert('L')
gray_np = np.array(x_pil_2)

# Case 3a: RGB -> 3 channel grayscale (grayscaled)
trans2 = transforms.RandomGrayscale(p=1.0)
gray_pil_2 = trans2(x_pil)
gray_np_2 = np.array(gray_pil_2)
assert gray_pil_2.mode == 'RGB', 'mode should be RGB'
assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel'
np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])

# Case 3b: RGB -> 3 channel grayscale (unchanged)
trans2 = transforms.RandomGrayscale(p=0.0)
gray_pil_2 = trans2(x_pil)
gray_np_2 = np.array(gray_pil_2)
assert gray_pil_2.mode == 'RGB', 'mode should be RGB'
assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel'
np.testing.assert_equal(x_np, gray_np_2)

# Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
trans3 = transforms.RandomGrayscale(p=1.0)
gray_pil_3 = trans3(x_pil_2)
gray_np_3 = np.array(gray_pil_3)
assert gray_pil_3.mode == 'L', 'mode should be L'
assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel'
np.testing.assert_equal(gray_np, gray_np_3)

# Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
trans3 = transforms.RandomGrayscale(p=0.0)
gray_pil_3 = trans3(x_pil_2)
gray_np_3 = np.array(gray_pil_3)
assert gray_pil_3.mode == 'L', 'mode should be L'
assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel'
np.testing.assert_equal(gray_np, gray_np_3)


if __name__ == '__main__':
unittest.main()
28 changes: 28 additions & 0 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,35 @@ def rotate(img, angle, resample=False, expand=False, center=None):
Origin is the upper left corner.
Default is the center of the image.
"""

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.rotate(angle, resample, expand, center)


def to_grayscale(img, num_output_channels=1):
"""Convert image to grayscale version of image.

Args:
img (PIL Image): Image to be converted to grayscale.

Returns:
PIL Image: Grayscale version of the image.
if num_output_channels == 1 : returned image is single channel
if num_output_channels == 3 : returned image is 3 channel with r == g == b
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')

return img
62 changes: 60 additions & 2 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation"]
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation",
"Grayscale", "RandomGrayscale"]


class Compose(object):
Expand Down Expand Up @@ -619,7 +620,6 @@ def get_params(degrees):

def __call__(self, img):
"""
Args:
img (PIL Image): Image to be rotated.

Returns:
Expand All @@ -629,3 +629,61 @@ def __call__(self, img):
angle = self.get_params(self.degrees)

return F.rotate(img, angle, self.resample, self.expand, self.center)


class Grayscale(object):
"""Convert image to grayscale.
Args:
num_output_channels (int): (1 or 3) number of channels desired for output image

Returns:
PIL Image: grayscale version of the input
if num_output_channels == 1 : returned image is single channel
if num_output_channels == 3 : returned image is 3 channel with r == g == b

"""

def __init__(self, num_output_channels=1):
self.num_output_channels = num_output_channels

def __call__(self, img):
"""
Args:
img (PIL Image): Image to be converted to grayscale.

Returns:
PIL Image: Randomly grayscaled image.
"""
return F.to_grayscale(img, num_output_channels=self.num_output_channels)


class RandomGrayscale(object):
"""Randomly convert image to grayscale with a probability of p (default 0.1).
Args:
p (float): probability that image should be converted to grayscale.

Returns:
PIL Image: grayscale version of the input image with probability p
and unchanged with probability (1-p)
- if input image is 1 channel:
grayscale version is 1 channel
- if input image is 3 channel:
grayscale version is 3 channel with r == g == b

"""

def __init__(self, p=0.1):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

self.p = p

def __call__(self, img):
"""
Args:
img (PIL Image): Image to be converted to grayscale.

Returns:
PIL Image: Randomly grayscaled image.
"""
num_output_channels = 1 if img.mode == 'L' else 3
if random.random() < self.p:
return F.to_grayscale(img, num_output_channels=num_output_channels)
return img