Skip to content

Commit 618072b

Browse files
chsasankfmassa
authored andcommitted
Add Color transforms (#275)
* Add adjust_hue and adjust_saturation * Add adjust_brightness, adjust_contrast Also * Change adjust_saturation to use pillow implementation * Documentation made clear * Add adjust_gamma * Add ColorJitter * Address review comments * Fix documentation for ColorJitter * Address review comments 2 * Fallback to adjust_hue in case of BW images * Add tests * fix dtype
1 parent 88e81ce commit 618072b

File tree

2 files changed

+365
-1
lines changed

2 files changed

+365
-1
lines changed

test/test_transforms.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,167 @@ def test_random_horizontal_flip(self):
422422
p_value = stats.binom_test(num_horizontal, 100, p=0.5)
423423
assert p_value > 0.05
424424

425+
def test_adjust_brightness(self):
426+
x_shape = [2, 2, 3]
427+
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
428+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
429+
x_pil = Image.fromarray(x_np, mode='RGB')
430+
431+
# test 0
432+
y_pil = transforms.adjust_brightness(x_pil, 1)
433+
y_np = np.array(y_pil)
434+
assert np.allclose(y_np, x_np)
435+
436+
# test 1
437+
y_pil = transforms.adjust_brightness(x_pil, 0.5)
438+
y_np = np.array(y_pil)
439+
y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0]
440+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
441+
assert np.allclose(y_np, y_ans)
442+
443+
# test 2
444+
y_pil = transforms.adjust_brightness(x_pil, 2)
445+
y_np = np.array(y_pil)
446+
y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2]
447+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
448+
assert np.allclose(y_np, y_ans)
449+
450+
def test_adjust_contrast(self):
451+
x_shape = [2, 2, 3]
452+
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
453+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
454+
x_pil = Image.fromarray(x_np, mode='RGB')
455+
456+
# test 0
457+
y_pil = transforms.adjust_contrast(x_pil, 1)
458+
y_np = np.array(y_pil)
459+
assert np.allclose(y_np, x_np)
460+
461+
# test 1
462+
y_pil = transforms.adjust_contrast(x_pil, 0.5)
463+
y_np = np.array(y_pil)
464+
y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43]
465+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
466+
assert np.allclose(y_np, y_ans)
467+
468+
# test 2
469+
y_pil = transforms.adjust_contrast(x_pil, 2)
470+
y_np = np.array(y_pil)
471+
y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0]
472+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
473+
assert np.allclose(y_np, y_ans)
474+
475+
def test_adjust_saturation(self):
476+
x_shape = [2, 2, 3]
477+
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
478+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
479+
x_pil = Image.fromarray(x_np, mode='RGB')
480+
481+
# test 0
482+
y_pil = transforms.adjust_saturation(x_pil, 1)
483+
y_np = np.array(y_pil)
484+
assert np.allclose(y_np, x_np)
485+
486+
# test 1
487+
y_pil = transforms.adjust_saturation(x_pil, 0.5)
488+
y_np = np.array(y_pil)
489+
y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88]
490+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
491+
assert np.allclose(y_np, y_ans)
492+
493+
# test 2
494+
y_pil = transforms.adjust_saturation(x_pil, 2)
495+
y_np = np.array(y_pil)
496+
y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0]
497+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
498+
assert np.allclose(y_np, y_ans)
499+
500+
def test_adjust_hue(self):
501+
x_shape = [2, 2, 3]
502+
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
503+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
504+
x_pil = Image.fromarray(x_np, mode='RGB')
505+
506+
with self.assertRaises(ValueError):
507+
transforms.adjust_hue(x_pil, -0.7)
508+
transforms.adjust_hue(x_pil, 1)
509+
510+
# test 0: almost same as x_data but not exact.
511+
# probably because hsv <-> rgb floating point ops
512+
y_pil = transforms.adjust_hue(x_pil, 0)
513+
y_np = np.array(y_pil)
514+
y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1]
515+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
516+
assert np.allclose(y_np, y_ans)
517+
518+
# test 1
519+
y_pil = transforms.adjust_hue(x_pil, 0.25)
520+
y_np = np.array(y_pil)
521+
y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255]
522+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
523+
assert np.allclose(y_np, y_ans)
524+
525+
# test 2
526+
y_pil = transforms.adjust_hue(x_pil, -0.25)
527+
y_np = np.array(y_pil)
528+
y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1]
529+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
530+
assert np.allclose(y_np, y_ans)
531+
532+
def test_adjust_gamma(self):
533+
x_shape = [2, 2, 3]
534+
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
535+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
536+
x_pil = Image.fromarray(x_np, mode='RGB')
537+
538+
# test 0
539+
y_pil = transforms.adjust_gamma(x_pil, 1)
540+
y_np = np.array(y_pil)
541+
assert np.allclose(y_np, x_np)
542+
543+
# test 1
544+
y_pil = transforms.adjust_gamma(x_pil, 0.5)
545+
y_np = np.array(y_pil)
546+
y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
547+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
548+
assert np.allclose(y_np, y_ans)
549+
550+
# test 2
551+
y_pil = transforms.adjust_gamma(x_pil, 2)
552+
y_np = np.array(y_pil)
553+
y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
554+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
555+
assert np.allclose(y_np, y_ans)
556+
557+
def test_adjusts_L_mode(self):
558+
x_shape = [2, 2, 3]
559+
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
560+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
561+
x_rgb = Image.fromarray(x_np, mode='RGB')
562+
563+
x_l = x_rgb.convert('L')
564+
assert transforms.adjust_brightness(x_l, 2).mode == 'L'
565+
assert transforms.adjust_saturation(x_l, 2).mode == 'L'
566+
assert transforms.adjust_contrast(x_l, 2).mode == 'L'
567+
assert transforms.adjust_hue(x_l, 0.4).mode == 'L'
568+
assert transforms.adjust_gamma(x_l, 0.5).mode == 'L'
569+
570+
def test_color_jitter(self):
571+
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)
572+
573+
x_shape = [2, 2, 3]
574+
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
575+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
576+
x_pil = Image.fromarray(x_np, mode='RGB')
577+
x_pil_2 = x_pil.convert('L')
578+
579+
for i in range(10):
580+
y_pil = color_jitter(x_pil)
581+
assert y_pil.mode == x_pil.mode
582+
583+
y_pil_2 = color_jitter(x_pil_2)
584+
assert y_pil_2.mode == x_pil_2.mode
585+
425586

426587
if __name__ == '__main__':
427588
unittest.main()

torchvision/transforms.py

Lines changed: 204 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import math
44
import random
5-
from PIL import Image, ImageOps
5+
from PIL import Image, ImageOps, ImageEnhance
66
try:
77
import accimage
88
except ImportError:
@@ -355,6 +355,145 @@ def ten_crop(img, size, vertical_flip=False):
355355
return first_five + second_five
356356

357357

358+
def adjust_brightness(img, brightness_factor):
359+
"""Adjust brightness of an Image.
360+
361+
Args:
362+
img (PIL.Image): PIL Image to be adjusted.
363+
brightness_factor (float): How much to adjust the brightness. Can be
364+
any non negative number. 0 gives a black image, 1 gives the
365+
original image while 2 increases the brightness by a factor of 2.
366+
367+
Returns:
368+
PIL.Image: Brightness adjusted image.
369+
"""
370+
if not _is_pil_image(img):
371+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
372+
373+
enhancer = ImageEnhance.Brightness(img)
374+
img = enhancer.enhance(brightness_factor)
375+
return img
376+
377+
378+
def adjust_contrast(img, contrast_factor):
379+
"""Adjust contrast of an Image.
380+
381+
Args:
382+
img (PIL.Image): PIL Image to be adjusted.
383+
contrast_factor (float): How much to adjust the contrast. Can be any
384+
non negative number. 0 gives a solid gray image, 1 gives the
385+
original image while 2 increases the contrast by a factor of 2.
386+
387+
Returns:
388+
PIL.Image: Contrast adjusted image.
389+
"""
390+
if not _is_pil_image(img):
391+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
392+
393+
enhancer = ImageEnhance.Contrast(img)
394+
img = enhancer.enhance(contrast_factor)
395+
return img
396+
397+
398+
def adjust_saturation(img, saturation_factor):
399+
"""Adjust color saturation of an image.
400+
401+
Args:
402+
img (PIL.Image): PIL Image to be adjusted.
403+
saturation_factor (float): How much to adjust the saturation. 0 will
404+
give a black and white image, 1 will give the original image while
405+
2 will enhance the saturation by a factor of 2.
406+
407+
Returns:
408+
PIL.Image: Saturation adjusted image.
409+
"""
410+
if not _is_pil_image(img):
411+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
412+
413+
enhancer = ImageEnhance.Color(img)
414+
img = enhancer.enhance(saturation_factor)
415+
return img
416+
417+
418+
def adjust_hue(img, hue_factor):
419+
"""Adjust hue of an image.
420+
421+
The image hue is adjusted by converting the image to HSV and
422+
cyclically shifting the intensities in the hue channel (H).
423+
The image is then converted back to original image mode.
424+
425+
`hue_factor` is the amount of shift in H channel and must be in the
426+
interval `[-0.5, 0.5]`.
427+
428+
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
429+
430+
Args:
431+
img (PIL.Image): PIL Image to be adjusted.
432+
hue_factor (float): How much to shift the hue channel. Should be in
433+
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
434+
HSV space in positive and negative direction respectively.
435+
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
436+
with complementary colors while 0 gives the original image.
437+
438+
Returns:
439+
PIL.Image: Hue adjusted image.
440+
"""
441+
if not(-0.5 <= hue_factor <= 0.5):
442+
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
443+
444+
if not _is_pil_image(img):
445+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
446+
447+
input_mode = img.mode
448+
if input_mode in {'L', '1', 'I', 'F'}:
449+
return img
450+
451+
h, s, v = img.convert('HSV').split()
452+
453+
np_h = np.array(h, dtype=np.uint8)
454+
# uint8 addition take cares of rotation across boundaries
455+
with np.errstate(over='ignore'):
456+
np_h += np.uint8(hue_factor * 255)
457+
h = Image.fromarray(np_h, 'L')
458+
459+
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
460+
return img
461+
462+
463+
def adjust_gamma(img, gamma, gain=1):
464+
"""Perform gamma correction on an image.
465+
466+
Also known as Power Law Transform. Intensities in RGB mode are adjusted
467+
based on the following equation:
468+
469+
I_out = 255 * gain * ((I_in / 255) ** gamma)
470+
471+
See https://en.wikipedia.org/wiki/Gamma_correction for more details.
472+
473+
Args:
474+
img (PIL.Image): PIL Image to be adjusted.
475+
gamma (float): Non negative real number. gamma larger than 1 make the
476+
shadows darker, while gamma smaller than 1 make dark regions
477+
lighter.
478+
gain (float): The constant multiplier.
479+
"""
480+
if not _is_pil_image(img):
481+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
482+
483+
if gamma < 0:
484+
raise ValueError('Gamma should be a non-negative real number')
485+
486+
input_mode = img.mode
487+
img = img.convert('RGB')
488+
489+
np_img = np.array(img, dtype=np.float32)
490+
np_img = 255 * gain * ((np_img / 255) ** gamma)
491+
np_img = np.uint8(np.clip(np_img, 0, 255))
492+
493+
img = Image.fromarray(np_img, 'RGB').convert(input_mode)
494+
return img
495+
496+
358497
class Compose(object):
359498
"""Composes several transforms together.
360499
@@ -777,3 +916,67 @@ def __init__(self, size, vertical_flip=False):
777916

778917
def __call__(self, img):
779918
return ten_crop(img, self.size, self.vertical_flip)
919+
920+
921+
class ColorJitter(object):
922+
"""Randomly change the brightness, contrast and saturation of an image.
923+
924+
Args:
925+
brightness (float): How much to jitter brightness. brightness_factor
926+
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
927+
contrast (float): How much to jitter contrast. contrast_factor
928+
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
929+
saturation (float): How much to jitter saturation. saturation_factor
930+
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
931+
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
932+
[-hue, hue]. Should be >=0 and <= 0.5.
933+
"""
934+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
935+
self.brightness = brightness
936+
self.contrast = contrast
937+
self.saturation = saturation
938+
self.hue = hue
939+
940+
@staticmethod
941+
def get_params(brightness, contrast, saturation, hue):
942+
"""Get a randomized transform to be applied on image.
943+
944+
Arguments are same as that of __init__.
945+
946+
Returns:
947+
Transform which randomly adjusts brightness, contrast and
948+
saturation in a random order.
949+
"""
950+
transforms = []
951+
if brightness > 0:
952+
brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
953+
transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor)))
954+
955+
if contrast > 0:
956+
contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
957+
transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))
958+
959+
if saturation > 0:
960+
saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
961+
transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor)))
962+
963+
if hue > 0:
964+
hue_factor = np.random.uniform(-hue, hue)
965+
transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor)))
966+
967+
np.random.shuffle(transforms)
968+
transform = Compose(transforms)
969+
970+
return transform
971+
972+
def __call__(self, img):
973+
"""
974+
Args:
975+
img (PIL.Image): Input image.
976+
977+
Returns:
978+
PIL.Image: Color jittered image.
979+
"""
980+
transform = self.get_params(self.brightness, self.contrast,
981+
self.saturation, self.hue)
982+
return transform(img)

0 commit comments

Comments
 (0)