From dc886585dba8be07e102dec6c5c295d21fe84a16 Mon Sep 17 00:00:00 2001 From: Alykhan Tejani Date: Sat, 28 Oct 2017 15:44:36 +0100 Subject: [PATCH 1/2] refactor transforms and add transforms/functional.py --- test/test_transforms.py | 50 +- torchvision/transforms.py | 1061 -------------------------- torchvision/transforms/__init__.py | 1 + torchvision/transforms/functional.py | 526 +++++++++++++ torchvision/transforms/transforms.py | 543 +++++++++++++ 5 files changed, 1095 insertions(+), 1086 deletions(-) delete mode 100644 torchvision/transforms.py create mode 100644 torchvision/transforms/__init__.py create mode 100644 torchvision/transforms/functional.py create mode 100644 torchvision/transforms/transforms.py diff --git a/test/test_transforms.py b/test/test_transforms.py index 57c6d2ac947..6c12dcb07f0 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1,5 +1,6 @@ import torch import torchvision.transforms as transforms +import torchvision.transforms.functional as F import unittest import random import numpy as np @@ -14,7 +15,6 @@ except ImportError: stats = None - GRACE_HOPPER = 'assets/grace_hopper_517x606.jpg' @@ -347,7 +347,7 @@ def verify_img_data(img_data, expected_output, mode): assert img.mode == mode split = img.split() for i in range(3): - assert np.allclose(expected_output[i].numpy(), transforms.to_tensor(split[i]).numpy()) + assert np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()) img_data = torch.Tensor(3, 4, 4).uniform_() expected_output = img_data.mul(255).int().float().div(255) @@ -391,7 +391,7 @@ def verify_img_data(img_data, expected_output, mode): split = img.split() for i in range(4): - assert np.allclose(expected_output[i].numpy(), transforms.to_tensor(split[i]).numpy()) + assert np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()) img_data = torch.Tensor(4, 4, 4).uniform_() expected_output = img_data.mul(255).int().float().div(255) @@ -491,19 +491,19 @@ def test_adjust_brightness(self): x_pil = Image.fromarray(x_np, mode='RGB') # test 0 - y_pil = transforms.adjust_brightness(x_pil, 1) + y_pil = F.adjust_brightness(x_pil, 1) y_np = np.array(y_pil) assert np.allclose(y_np, x_np) # test 1 - y_pil = transforms.adjust_brightness(x_pil, 0.5) + y_pil = F.adjust_brightness(x_pil, 0.5) y_np = np.array(y_pil) y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) assert np.allclose(y_np, y_ans) # test 2 - y_pil = transforms.adjust_brightness(x_pil, 2) + y_pil = F.adjust_brightness(x_pil, 2) y_np = np.array(y_pil) y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) @@ -516,19 +516,19 @@ def test_adjust_contrast(self): x_pil = Image.fromarray(x_np, mode='RGB') # test 0 - y_pil = transforms.adjust_contrast(x_pil, 1) + y_pil = F.adjust_contrast(x_pil, 1) y_np = np.array(y_pil) assert np.allclose(y_np, x_np) # test 1 - y_pil = transforms.adjust_contrast(x_pil, 0.5) + y_pil = F.adjust_contrast(x_pil, 0.5) y_np = np.array(y_pil) y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) assert np.allclose(y_np, y_ans) # test 2 - y_pil = transforms.adjust_contrast(x_pil, 2) + y_pil = F.adjust_contrast(x_pil, 2) y_np = np.array(y_pil) y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) @@ -541,19 +541,19 @@ def test_adjust_saturation(self): x_pil = Image.fromarray(x_np, mode='RGB') # test 0 - y_pil = transforms.adjust_saturation(x_pil, 1) + y_pil = F.adjust_saturation(x_pil, 1) y_np = np.array(y_pil) assert np.allclose(y_np, x_np) # test 1 - y_pil = transforms.adjust_saturation(x_pil, 0.5) + y_pil = F.adjust_saturation(x_pil, 0.5) y_np = np.array(y_pil) y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) assert np.allclose(y_np, y_ans) # test 2 - y_pil = transforms.adjust_saturation(x_pil, 2) + y_pil = F.adjust_saturation(x_pil, 2) y_np = np.array(y_pil) y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) @@ -566,26 +566,26 @@ def test_adjust_hue(self): x_pil = Image.fromarray(x_np, mode='RGB') with self.assertRaises(ValueError): - transforms.adjust_hue(x_pil, -0.7) - transforms.adjust_hue(x_pil, 1) + F.adjust_hue(x_pil, -0.7) + F.adjust_hue(x_pil, 1) # test 0: almost same as x_data but not exact. # probably because hsv <-> rgb floating point ops - y_pil = transforms.adjust_hue(x_pil, 0) + y_pil = F.adjust_hue(x_pil, 0) y_np = np.array(y_pil) y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) assert np.allclose(y_np, y_ans) # test 1 - y_pil = transforms.adjust_hue(x_pil, 0.25) + y_pil = F.adjust_hue(x_pil, 0.25) y_np = np.array(y_pil) y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) assert np.allclose(y_np, y_ans) # test 2 - y_pil = transforms.adjust_hue(x_pil, -0.25) + y_pil = F.adjust_hue(x_pil, -0.25) y_np = np.array(y_pil) y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) @@ -598,19 +598,19 @@ def test_adjust_gamma(self): x_pil = Image.fromarray(x_np, mode='RGB') # test 0 - y_pil = transforms.adjust_gamma(x_pil, 1) + y_pil = F.adjust_gamma(x_pil, 1) y_np = np.array(y_pil) assert np.allclose(y_np, x_np) # test 1 - y_pil = transforms.adjust_gamma(x_pil, 0.5) + y_pil = F.adjust_gamma(x_pil, 0.5) y_np = np.array(y_pil) y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) assert np.allclose(y_np, y_ans) # test 2 - y_pil = transforms.adjust_gamma(x_pil, 2) + y_pil = F.adjust_gamma(x_pil, 2) y_np = np.array(y_pil) y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) @@ -623,11 +623,11 @@ def test_adjusts_L_mode(self): x_rgb = Image.fromarray(x_np, mode='RGB') x_l = x_rgb.convert('L') - assert transforms.adjust_brightness(x_l, 2).mode == 'L' - assert transforms.adjust_saturation(x_l, 2).mode == 'L' - assert transforms.adjust_contrast(x_l, 2).mode == 'L' - assert transforms.adjust_hue(x_l, 0.4).mode == 'L' - assert transforms.adjust_gamma(x_l, 0.5).mode == 'L' + assert F.adjust_brightness(x_l, 2).mode == 'L' + assert F.adjust_saturation(x_l, 2).mode == 'L' + assert F.adjust_contrast(x_l, 2).mode == 'L' + assert F.adjust_hue(x_l, 0.4).mode == 'L' + assert F.adjust_gamma(x_l, 0.5).mode == 'L' def test_color_jitter(self): color_jitter = transforms.ColorJitter(2, 2, 2, 0.1) diff --git a/torchvision/transforms.py b/torchvision/transforms.py deleted file mode 100644 index 9d52dab7e95..00000000000 --- a/torchvision/transforms.py +++ /dev/null @@ -1,1061 +0,0 @@ -from __future__ import division -import torch -import math -import random -from PIL import Image, ImageOps, ImageEnhance -try: - import accimage -except ImportError: - accimage = None -import numpy as np -import numbers -import types -import collections -import warnings - - -def _is_pil_image(img): - if accimage is not None: - return isinstance(img, (Image.Image, accimage.Image)) - else: - return isinstance(img, Image.Image) - - -def _is_tensor_image(img): - return torch.is_tensor(img) and img.ndimension() == 3 - - -def _is_numpy_image(img): - return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) - - -def to_tensor(pic): - """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. - - See ``ToTensor`` for more details. - - Args: - pic (PIL Image or numpy.ndarray): Image to be converted to tensor. - - Returns: - Tensor: Converted image. - """ - if not(_is_pil_image(pic) or _is_numpy_image(pic)): - raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) - - if isinstance(pic, np.ndarray): - # handle numpy array - img = torch.from_numpy(pic.transpose((2, 0, 1))) - # backward compatibility - return img.float().div(255) - - if accimage is not None and isinstance(pic, accimage.Image): - nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) - pic.copyto(nppic) - return torch.from_numpy(nppic) - - # handle PIL Image - if pic.mode == 'I': - img = torch.from_numpy(np.array(pic, np.int32, copy=False)) - elif pic.mode == 'I;16': - img = torch.from_numpy(np.array(pic, np.int16, copy=False)) - else: - img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) - # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK - if pic.mode == 'YCbCr': - nchannel = 3 - elif pic.mode == 'I;16': - nchannel = 1 - else: - nchannel = len(pic.mode) - img = img.view(pic.size[1], pic.size[0], nchannel) - # put it from HWC to CHW format - # yikes, this transpose takes 80% of the loading time/CPU - img = img.transpose(0, 1).transpose(0, 2).contiguous() - if isinstance(img, torch.ByteTensor): - return img.float().div(255) - else: - return img - - -def to_pil_image(pic, mode=None): - """Convert a tensor or an ndarray to PIL Image. - - See :class:`~torchvision.transforms.ToPIlImage` for more details. - - Args: - pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. - mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). - - .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes - - Returns: - PIL Image: Image converted to PIL Image. - """ - if not(_is_numpy_image(pic) or _is_tensor_image(pic)): - raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) - - npimg = pic - if isinstance(pic, torch.FloatTensor): - pic = pic.mul(255).byte() - if torch.is_tensor(pic): - npimg = np.transpose(pic.numpy(), (1, 2, 0)) - - if not isinstance(npimg, np.ndarray): - raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + - 'not {}'.format(type(npimg))) - - if npimg.shape[2] == 1: - expected_mode = None - npimg = npimg[:, :, 0] - if npimg.dtype == np.uint8: - expected_mode = 'L' - if npimg.dtype == np.int16: - expected_mode = 'I;16' - if npimg.dtype == np.int32: - expected_mode = 'I' - elif npimg.dtype == np.float32: - expected_mode = 'F' - if mode is not None and mode != expected_mode: - raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" - .format(mode, np.dtype, expected_mode)) - mode = expected_mode - - elif npimg.shape[2] == 4: - permitted_4_channel_modes = ['RGBA', 'CMYK'] - if mode is not None and mode not in permitted_4_channel_modes: - raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) - - if mode is None and npimg.dtype == np.uint8: - mode = 'RGBA' - else: - permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] - if mode is not None and mode not in permitted_3_channel_modes: - raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) - if mode is None and npimg.dtype == np.uint8: - mode = 'RGB' - - if mode is None: - raise TypeError('Input type {} is not supported'.format(npimg.dtype)) - - return Image.fromarray(npimg, mode=mode) - - -def normalize(tensor, mean, std): - """Normalize a tensor image with mean and standard deviation. - - See ``Normalize`` for more details. - - Args: - tensor (Tensor): Tensor image of size (C, H, W) to be normalized. - mean (sequence): Sequence of means for each channel. - std (sequence): Sequence of standard deviations for each channely. - - Returns: - Tensor: Normalized Tensor image. - """ - if not _is_tensor_image(tensor): - raise TypeError('tensor is not a torch image.') - # TODO: make efficient - for t, m, s in zip(tensor, mean, std): - t.sub_(m).div_(s) - return tensor - - -def resize(img, size, interpolation=Image.BILINEAR): - """Resize the input PIL Image to the given size. - - Args: - img (PIL Image): Image to be resized. - size (sequence or int): Desired output size. If size is a sequence like - (h, w), the output size will be matched to this. If size is an int, - the smaller edge of the image will be matched to this number maintaing - the aspect ratio. i.e, if height > width, then image will be rescaled to - (size * height / width, size) - interpolation (int, optional): Desired interpolation. Default is - ``PIL.Image.BILINEAR`` - - Returns: - PIL Image: Resized image. - """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): - raise TypeError('Got inappropriate size arg: {}'.format(size)) - - if isinstance(size, int): - w, h = img.size - if (w <= h and w == size) or (h <= w and h == size): - return img - if w < h: - ow = size - oh = int(size * h / w) - return img.resize((ow, oh), interpolation) - else: - oh = size - ow = int(size * w / h) - return img.resize((ow, oh), interpolation) - else: - return img.resize(size[::-1], interpolation) - - -def scale(*args, **kwargs): - warnings.warn("The use of the transforms.Scale transform is deprecated, " + - "please use transforms.Resize instead.") - return resize(*args, **kwargs) - - -def pad(img, padding, fill=0): - """Pad the given PIL Image on all sides with the given "pad" value. - - Args: - img (PIL Image): Image to be padded. - padding (int or tuple): Padding on each border. If a single int is provided this - is used to pad all borders. If tuple of length 2 is provided this is the padding - on left/right and top/bottom respectively. If a tuple of length 4 is provided - this is the padding for the left, top, right and bottom borders - respectively. - fill: Pixel fill value. Default is 0. If a tuple of - length 3, it is used to fill R, G, B channels respectively. - - Returns: - PIL Image: Padded image. - """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - if not isinstance(padding, (numbers.Number, tuple)): - raise TypeError('Got inappropriate padding arg') - if not isinstance(fill, (numbers.Number, str, tuple)): - raise TypeError('Got inappropriate fill arg') - - if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: - raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) - - return ImageOps.expand(img, border=padding, fill=fill) - - -def crop(img, i, j, h, w): - """Crop the given PIL Image. - - Args: - img (PIL Image): Image to be cropped. - i: Upper pixel coordinate. - j: Left pixel coordinate. - h: Height of the cropped image. - w: Width of the cropped image. - - Returns: - PIL Image: Cropped image. - """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - return img.crop((j, i, j + w, i + h)) - - -def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): - """Crop the given PIL Image and resize it to desired size. - - Notably used in RandomResizedCrop. - - Args: - img (PIL Image): Image to be cropped. - i: Upper pixel coordinate. - j: Left pixel coordinate. - h: Height of the cropped image. - w: Width of the cropped image. - size (sequence or int): Desired output size. Same semantics as ``scale``. - interpolation (int, optional): Desired interpolation. Default is - ``PIL.Image.BILINEAR``. - Returns: - PIL Image: Cropped image. - """ - assert _is_pil_image(img), 'img should be PIL Image' - img = crop(img, i, j, h, w) - img = resize(img, size, interpolation) - return img - - -def hflip(img): - """Horizontally flip the given PIL Image. - - Args: - img (PIL Image): Image to be flipped. - - Returns: - PIL Image: Horizontall flipped image. - """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - return img.transpose(Image.FLIP_LEFT_RIGHT) - - -def vflip(img): - """Vertically flip the given PIL Image. - - Args: - img (PIL Image): Image to be flipped. - - Returns: - PIL Image: Vertically flipped image. - """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - return img.transpose(Image.FLIP_TOP_BOTTOM) - - -def five_crop(img, size): - """Crop the given PIL Image into four corners and the central crop. - - .. Note:: - This transform returns a tuple of images and there may be a - mismatch in the number of inputs and targets your ``Dataset`` returns. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. - Returns: - tuple: tuple (tl, tr, bl, br, center) corresponding top left, - top right, bottom left, bottom right and center crop. - """ - if isinstance(size, numbers.Number): - size = (int(size), int(size)) - else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." - - w, h = img.size - crop_h, crop_w = size - if crop_w > w or crop_h > h: - raise ValueError("Requested crop size {} is bigger than input size {}".format(size, - (h, w))) - tl = img.crop((0, 0, crop_w, crop_h)) - tr = img.crop((w - crop_w, 0, w, crop_h)) - bl = img.crop((0, h - crop_h, crop_w, h)) - br = img.crop((w - crop_w, h - crop_h, w, h)) - center = CenterCrop((crop_h, crop_w))(img) - return (tl, tr, bl, br, center) - - -def ten_crop(img, size, vertical_flip=False): - """Crop the given PIL Image into four corners and the central crop plus the - flipped version of these (horizontal flipping is used by default). - - .. Note:: - This transform returns a tuple of images and there may be a - mismatch in the number of inputs and targets your ``Dataset`` returns. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. - vertical_flip (bool): Use vertical flipping instead of horizontal - - Returns: - tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, - br_flip, center_flip) corresponding top left, top right, - bottom left, bottom right and center crop and same for the - flipped image. - """ - if isinstance(size, numbers.Number): - size = (int(size), int(size)) - else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." - - first_five = five_crop(img, size) - - if vertical_flip: - img = vflip(img) - else: - img = hflip(img) - - second_five = five_crop(img, size) - return first_five + second_five - - -def adjust_brightness(img, brightness_factor): - """Adjust brightness of an Image. - - Args: - img (PIL Image): PIL Image to be adjusted. - brightness_factor (float): How much to adjust the brightness. Can be - any non negative number. 0 gives a black image, 1 gives the - original image while 2 increases the brightness by a factor of 2. - - Returns: - PIL Image: Brightness adjusted image. - """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - enhancer = ImageEnhance.Brightness(img) - img = enhancer.enhance(brightness_factor) - return img - - -def adjust_contrast(img, contrast_factor): - """Adjust contrast of an Image. - - Args: - img (PIL Image): PIL Image to be adjusted. - contrast_factor (float): How much to adjust the contrast. Can be any - non negative number. 0 gives a solid gray image, 1 gives the - original image while 2 increases the contrast by a factor of 2. - - Returns: - PIL Image: Contrast adjusted image. - """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - enhancer = ImageEnhance.Contrast(img) - img = enhancer.enhance(contrast_factor) - return img - - -def adjust_saturation(img, saturation_factor): - """Adjust color saturation of an image. - - Args: - img (PIL Image): PIL Image to be adjusted. - saturation_factor (float): How much to adjust the saturation. 0 will - give a black and white image, 1 will give the original image while - 2 will enhance the saturation by a factor of 2. - - Returns: - PIL Image: Saturation adjusted image. - """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - enhancer = ImageEnhance.Color(img) - img = enhancer.enhance(saturation_factor) - return img - - -def adjust_hue(img, hue_factor): - """Adjust hue of an image. - - The image hue is adjusted by converting the image to HSV and - cyclically shifting the intensities in the hue channel (H). - The image is then converted back to original image mode. - - `hue_factor` is the amount of shift in H channel and must be in the - interval `[-0.5, 0.5]`. - - See https://en.wikipedia.org/wiki/Hue for more details on Hue. - - Args: - img (PIL Image): PIL Image to be adjusted. - hue_factor (float): How much to shift the hue channel. Should be in - [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in - HSV space in positive and negative direction respectively. - 0 means no shift. Therefore, both -0.5 and 0.5 will give an image - with complementary colors while 0 gives the original image. - - Returns: - PIL Image: Hue adjusted image. - """ - if not(-0.5 <= hue_factor <= 0.5): - raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) - - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - input_mode = img.mode - if input_mode in {'L', '1', 'I', 'F'}: - return img - - h, s, v = img.convert('HSV').split() - - np_h = np.array(h, dtype=np.uint8) - # uint8 addition take cares of rotation across boundaries - with np.errstate(over='ignore'): - np_h += np.uint8(hue_factor * 255) - h = Image.fromarray(np_h, 'L') - - img = Image.merge('HSV', (h, s, v)).convert(input_mode) - return img - - -def adjust_gamma(img, gamma, gain=1): - """Perform gamma correction on an image. - - Also known as Power Law Transform. Intensities in RGB mode are adjusted - based on the following equation: - - I_out = 255 * gain * ((I_in / 255) ** gamma) - - See https://en.wikipedia.org/wiki/Gamma_correction for more details. - - Args: - img (PIL Image): PIL Image to be adjusted. - gamma (float): Non negative real number. gamma larger than 1 make the - shadows darker, while gamma smaller than 1 make dark regions - lighter. - gain (float): The constant multiplier. - """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - if gamma < 0: - raise ValueError('Gamma should be a non-negative real number') - - input_mode = img.mode - img = img.convert('RGB') - - np_img = np.array(img, dtype=np.float32) - np_img = 255 * gain * ((np_img / 255) ** gamma) - np_img = np.uint8(np.clip(np_img, 0, 255)) - - img = Image.fromarray(np_img, 'RGB').convert(input_mode) - return img - - -class Compose(object): - """Composes several transforms together. - - Args: - transforms (list of ``Transform`` objects): list of transforms to compose. - - Example: - >>> transforms.Compose([ - >>> transforms.CenterCrop(10), - >>> transforms.ToTensor(), - >>> ]) - """ - - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, img): - for t in self.transforms: - img = t(img) - return img - - -class ToTensor(object): - """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. - - Converts a PIL Image or numpy.ndarray (H x W x C) in the range - [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. - """ - - def __call__(self, pic): - """ - Args: - pic (PIL Image or numpy.ndarray): Image to be converted to tensor. - - Returns: - Tensor: Converted image. - """ - return to_tensor(pic) - - -class ToPILImage(object): - """Convert a tensor or an ndarray to PIL Image. - - Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape - H x W x C to a PIL Image while preserving the value range. - - Args: - mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). - If ``mode`` is ``None`` (default) there are some assumptions made about the input data: - 1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. - 2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. - 3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e, - ``int``, ``float``, ``short``). - - .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes - """ - def __init__(self, mode=None): - self.mode = mode - - def __call__(self, pic): - """ - Args: - pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. - - Returns: - PIL Image: Image converted to PIL Image. - - """ - return to_pil_image(pic, self.mode) - - -class Normalize(object): - """Normalize an tensor image with mean and standard deviation. - Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform - will normalize each channel of the input ``torch.*Tensor`` i.e. - ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` - - Args: - mean (sequence): Sequence of means for each channel. - std (sequence): Sequence of standard deviations for each channel. - """ - - def __init__(self, mean, std): - self.mean = mean - self.std = std - - def __call__(self, tensor): - """ - Args: - tensor (Tensor): Tensor image of size (C, H, W) to be normalized. - - Returns: - Tensor: Normalized Tensor image. - """ - return normalize(tensor, self.mean, self.std) - - -class Resize(object): - """Resize the input PIL Image to the given size. - - Args: - size (sequence or int): Desired output size. If size is a sequence like - (h, w), output size will be matched to this. If size is an int, - smaller edge of the image will be matched to this number. - i.e, if height > width, then image will be rescaled to - (size * height / width, size) - interpolation (int, optional): Desired interpolation. Default is - ``PIL.Image.BILINEAR`` - """ - - def __init__(self, size, interpolation=Image.BILINEAR): - assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) - self.size = size - self.interpolation = interpolation - - def __call__(self, img): - """ - Args: - img (PIL Image): Image to be scaled. - - Returns: - PIL Image: Rescaled image. - """ - return resize(img, self.size, self.interpolation) - - -class Scale(Resize): - """ - Note: This transform is deprecated in favor of Resize. - """ - def __init__(self, *args, **kwargs): - warnings.warn("The use of the transforms.Scale transform is deprecated, " + - "please use transforms.Resize instead.") - super(Scale, self).__init__(*args, **kwargs) - - -class CenterCrop(object): - """Crops the given PIL Image at the center. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. - """ - - def __init__(self, size): - if isinstance(size, numbers.Number): - self.size = (int(size), int(size)) - else: - self.size = size - - @staticmethod - def get_params(img, output_size): - """Get parameters for ``crop`` for center crop. - - Args: - img (PIL Image): Image to be cropped. - output_size (tuple): Expected output size of the crop. - - Returns: - tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. - """ - w, h = img.size - th, tw = output_size - i = int(round((h - th) / 2.)) - j = int(round((w - tw) / 2.)) - return i, j, th, tw - - def __call__(self, img): - """ - Args: - img (PIL Image): Image to be cropped. - - Returns: - PIL Image: Cropped image. - """ - i, j, h, w = self.get_params(img, self.size) - return crop(img, i, j, h, w) - - -class Pad(object): - """Pad the given PIL Image on all sides with the given "pad" value. - - Args: - padding (int or tuple): Padding on each border. If a single int is provided this - is used to pad all borders. If tuple of length 2 is provided this is the padding - on left/right and top/bottom respectively. If a tuple of length 4 is provided - this is the padding for the left, top, right and bottom borders - respectively. - fill: Pixel fill value. Default is 0. If a tuple of - length 3, it is used to fill R, G, B channels respectively. - """ - - def __init__(self, padding, fill=0): - assert isinstance(padding, (numbers.Number, tuple)) - assert isinstance(fill, (numbers.Number, str, tuple)) - if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: - raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) - - self.padding = padding - self.fill = fill - - def __call__(self, img): - """ - Args: - img (PIL Image): Image to be padded. - - Returns: - PIL Image: Padded image. - """ - return pad(img, self.padding, self.fill) - - -class Lambda(object): - """Apply a user-defined lambda as a transform. - - Args: - lambd (function): Lambda/function to be used for transform. - """ - - def __init__(self, lambd): - assert isinstance(lambd, types.LambdaType) - self.lambd = lambd - - def __call__(self, img): - return self.lambd(img) - - -class RandomCrop(object): - """Crop the given PIL Image at a random location. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. - padding (int or sequence, optional): Optional padding on each border - of the image. Default is 0, i.e no padding. If a sequence of length - 4 is provided, it is used to pad left, top, right, bottom borders - respectively. - """ - - def __init__(self, size, padding=0): - if isinstance(size, numbers.Number): - self.size = (int(size), int(size)) - else: - self.size = size - self.padding = padding - - @staticmethod - def get_params(img, output_size): - """Get parameters for ``crop`` for a random crop. - - Args: - img (PIL Image): Image to be cropped. - output_size (tuple): Expected output size of the crop. - - Returns: - tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. - """ - w, h = img.size - th, tw = output_size - if w == tw and h == th: - return 0, 0, h, w - - i = random.randint(0, h - th) - j = random.randint(0, w - tw) - return i, j, th, tw - - def __call__(self, img): - """ - Args: - img (PIL Image): Image to be cropped. - - Returns: - PIL Image: Cropped image. - """ - if self.padding > 0: - img = pad(img, self.padding) - - i, j, h, w = self.get_params(img, self.size) - - return crop(img, i, j, h, w) - - -class RandomHorizontalFlip(object): - """Horizontally flip the given PIL Image randomly with a probability of 0.5.""" - - def __call__(self, img): - """ - Args: - img (PIL Image): Image to be flipped. - - Returns: - PIL Image: Randomly flipped image. - """ - if random.random() < 0.5: - return hflip(img) - return img - - -class RandomVerticalFlip(object): - """Vertically flip the given PIL Image randomly with a probability of 0.5.""" - - def __call__(self, img): - """ - Args: - img (PIL Image): Image to be flipped. - - Returns: - PIL Image: Randomly flipped image. - """ - if random.random() < 0.5: - return vflip(img) - return img - - -class RandomResizedCrop(object): - """Crop the given PIL Image to random size and aspect ratio. - - A crop of random size of (0.08 to 1.0) of the original size and a random - aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop - is finally resized to given size. - This is popularly used to train the Inception networks. - - Args: - size: expected output size of each edge - interpolation: Default: PIL.Image.BILINEAR - """ - - def __init__(self, size, interpolation=Image.BILINEAR): - self.size = (size, size) - self.interpolation = interpolation - - @staticmethod - def get_params(img): - """Get parameters for ``crop`` for a random sized crop. - - Args: - img (PIL Image): Image to be cropped. - - Returns: - tuple: params (i, j, h, w) to be passed to ``crop`` for a random - sized crop. - """ - for attempt in range(10): - area = img.size[0] * img.size[1] - target_area = random.uniform(0.08, 1.0) * area - aspect_ratio = random.uniform(3. / 4, 4. / 3) - - w = int(round(math.sqrt(target_area * aspect_ratio))) - h = int(round(math.sqrt(target_area / aspect_ratio))) - - if random.random() < 0.5: - w, h = h, w - - if w <= img.size[0] and h <= img.size[1]: - i = random.randint(0, img.size[1] - h) - j = random.randint(0, img.size[0] - w) - return i, j, h, w - - # Fallback - w = min(img.size[0], img.size[1]) - i = (img.size[1] - w) // 2 - j = (img.size[0] - w) // 2 - return i, j, w, w - - def __call__(self, img): - """ - Args: - img (PIL Image): Image to be flipped. - - Returns: - PIL Image: Randomly cropped and resize image. - """ - i, j, h, w = self.get_params(img) - return resized_crop(img, i, j, h, w, self.size, self.interpolation) - - -class RandomSizedCrop(RandomResizedCrop): - """ - Note: This transform is deprecated in favor of RandomResizedCrop. - """ - def __init__(self, *args, **kwargs): - warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + - "please use transforms.RandomResizedCrop instead.") - super(RandomSizedCrop, self).__init__(*args, **kwargs) - - -class FiveCrop(object): - """Crop the given PIL Image into four corners and the central crop.abs - - Note: this transform returns a tuple of images and there may be a mismatch in the number of - inputs and targets your `Dataset` returns. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. - """ - - def __init__(self, size): - self.size = size - if isinstance(size, numbers.Number): - self.size = (int(size), int(size)) - else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." - self.size = size - - def __call__(self, img): - return five_crop(img, self.size) - - -class TenCrop(object): - """Crop the given PIL Image into four corners and the central crop plus the - flipped version of these (horizontal flipping is used by default) - - Note: this transform returns a tuple of images and there may be a mismatch in the number of - inputs and targets your `Dataset` returns. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. - vertical_flip(bool): Use vertical flipping instead of horizontal - """ - - def __init__(self, size, vertical_flip=False): - self.size = size - if isinstance(size, numbers.Number): - self.size = (int(size), int(size)) - else: - assert len(size) == 2, "Please provide only two dimensions (h, w) for size." - self.size = size - self.vertical_flip = vertical_flip - - def __call__(self, img): - return ten_crop(img, self.size, self.vertical_flip) - - -class LinearTransformation(object): - """Transform a tensor image with a square transformation matrix computed - offline. - - Given transformation_matrix, will flatten the torch.*Tensor, compute the dot - product with the transformation matrix and reshape the tensor to its - original shape. - - Applications: - - whitening: zero-center the data, compute the data covariance matrix - [D x D] with np.dot(X.T, X), perform SVD on this matrix and - pass it as transformation_matrix. - - Args: - transformation_matrix (Tensor): tensor [D x D], D = C x H x W - """ - - def __init__(self, transformation_matrix): - if transformation_matrix.size(0) != transformation_matrix.size(1): - raise ValueError("transformation_matrix should be square. Got " + - "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) - self.transformation_matrix = transformation_matrix - - def __call__(self, tensor): - """ - Args: - tensor (Tensor): Tensor image of size (C, H, W) to be whitened. - - Returns: - Tensor: Transformed image. - """ - if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0): - raise ValueError("tensor and transformation matrix have incompatible shape." + - "[{} x {} x {}] != ".format(*tensor.size()) + - "{}".format(self.transformation_matrix.size(0))) - flat_tensor = tensor.view(1, -1) - transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) - tensor = transformed_tensor.view(tensor.size()) - return tensor - - -class ColorJitter(object): - """Randomly change the brightness, contrast and saturation of an image. - - Args: - brightness (float): How much to jitter brightness. brightness_factor - is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. - contrast (float): How much to jitter contrast. contrast_factor - is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. - saturation (float): How much to jitter saturation. saturation_factor - is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. - hue(float): How much to jitter hue. hue_factor is chosen uniformly from - [-hue, hue]. Should be >=0 and <= 0.5. - """ - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): - self.brightness = brightness - self.contrast = contrast - self.saturation = saturation - self.hue = hue - - @staticmethod - def get_params(brightness, contrast, saturation, hue): - """Get a randomized transform to be applied on image. - - Arguments are same as that of __init__. - - Returns: - Transform which randomly adjusts brightness, contrast and - saturation in a random order. - """ - transforms = [] - if brightness > 0: - brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) - transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) - - if contrast > 0: - contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) - transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) - - if saturation > 0: - saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) - transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) - - if hue > 0: - hue_factor = np.random.uniform(-hue, hue) - transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) - - np.random.shuffle(transforms) - transform = Compose(transforms) - - return transform - - def __call__(self, img): - """ - Args: - img (PIL Image): Input image. - - Returns: - PIL Image: Color jittered image. - """ - transform = self.get_params(self.brightness, self.contrast, - self.saturation, self.hue) - return transform(img) diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py new file mode 100644 index 00000000000..7986cdd6429 --- /dev/null +++ b/torchvision/transforms/__init__.py @@ -0,0 +1 @@ +from .transforms import * diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py new file mode 100644 index 00000000000..966404fd340 --- /dev/null +++ b/torchvision/transforms/functional.py @@ -0,0 +1,526 @@ +from __future__ import division +import torch +import math +import random +from PIL import Image, ImageOps, ImageEnhance +try: + import accimage +except ImportError: + accimage = None +import numpy as np +import numbers +import types +import collections +import warnings + + +def _is_pil_image(img): + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +def _is_tensor_image(img): + return torch.is_tensor(img) and img.ndimension() == 3 + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def to_tensor(pic): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + + See ``ToTensor`` for more details. + + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if not(_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic.transpose((2, 0, 1))) + # backward compatibility + return img.float().div(255) + + if accimage is not None and isinstance(pic, accimage.Image): + nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) + pic.copyto(nppic) + return torch.from_numpy(nppic) + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float().div(255) + else: + return img + + +def to_pil_image(pic, mode=None): + """Convert a tensor or an ndarray to PIL Image. + + See :class:`~torchvision.transforms.ToPIlImage` for more details. + + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + + .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes + + Returns: + PIL Image: Image converted to PIL Image. + """ + if not(_is_numpy_image(pic) or _is_tensor_image(pic)): + raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) + + npimg = pic + if isinstance(pic, torch.FloatTensor): + pic = pic.mul(255).byte() + if torch.is_tensor(pic): + npimg = np.transpose(pic.numpy(), (1, 2, 0)) + + if not isinstance(npimg, np.ndarray): + raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + + 'not {}'.format(type(npimg))) + + if npimg.shape[2] == 1: + expected_mode = None + npimg = npimg[:, :, 0] + if npimg.dtype == np.uint8: + expected_mode = 'L' + if npimg.dtype == np.int16: + expected_mode = 'I;16' + if npimg.dtype == np.int32: + expected_mode = 'I' + elif npimg.dtype == np.float32: + expected_mode = 'F' + if mode is not None and mode != expected_mode: + raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" + .format(mode, np.dtype, expected_mode)) + mode = expected_mode + + elif npimg.shape[2] == 4: + permitted_4_channel_modes = ['RGBA', 'CMYK'] + if mode is not None and mode not in permitted_4_channel_modes: + raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) + + if mode is None and npimg.dtype == np.uint8: + mode = 'RGBA' + else: + permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] + if mode is not None and mode not in permitted_3_channel_modes: + raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) + if mode is None and npimg.dtype == np.uint8: + mode = 'RGB' + + if mode is None: + raise TypeError('Input type {} is not supported'.format(npimg.dtype)) + + return Image.fromarray(npimg, mode=mode) + + +def normalize(tensor, mean, std): + """Normalize a tensor image with mean and standard deviation. + + See ``Normalize`` for more details. + + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channely. + + Returns: + Tensor: Normalized Tensor image. + """ + if not _is_tensor_image(tensor): + raise TypeError('tensor is not a torch image.') + # TODO: make efficient + for t, m, s in zip(tensor, mean, std): + t.sub_(m).div_(s) + return tensor + + +def resize(img, size, interpolation=Image.BILINEAR): + """Resize the input PIL Image to the given size. + + Args: + img (PIL Image): Image to be resized. + size (sequence or int): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaing + the aspect ratio. i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + + Returns: + PIL Image: Resized image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + + if isinstance(size, int): + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + return img.resize((ow, oh), interpolation) + else: + oh = size + ow = int(size * w / h) + return img.resize((ow, oh), interpolation) + else: + return img.resize(size[::-1], interpolation) + + +def scale(*args, **kwargs): + warnings.warn("The use of the transforms.Scale transform is deprecated, " + + "please use transforms.Resize instead.") + return resize(*args, **kwargs) + + +def pad(img, padding, fill=0): + """Pad the given PIL Image on all sides with the given "pad" value. + + Args: + img (PIL Image): Image to be padded. + padding (int or tuple): Padding on each border. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. + fill: Pixel fill value. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + + Returns: + PIL Image: Padded image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if not isinstance(padding, (numbers.Number, tuple)): + raise TypeError('Got inappropriate padding arg') + if not isinstance(fill, (numbers.Number, str, tuple)): + raise TypeError('Got inappropriate fill arg') + + if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: + raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + return ImageOps.expand(img, border=padding, fill=fill) + + +def crop(img, i, j, h, w): + """Crop the given PIL Image. + + Args: + img (PIL Image): Image to be cropped. + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + + Returns: + PIL Image: Cropped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.crop((j, i, j + w, i + h)) + + +def center_crop(img, output_size): + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + w, h = img.size + th, tw = output_size + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + return crop(img, i, j, th, tw) + + +def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): + """Crop the given PIL Image and resize it to desired size. + + Notably used in RandomResizedCrop. + + Args: + img (PIL Image): Image to be cropped. + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + size (sequence or int): Desired output size. Same semantics as ``scale``. + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR``. + Returns: + PIL Image: Cropped image. + """ + assert _is_pil_image(img), 'img should be PIL Image' + img = crop(img, i, j, h, w) + img = resize(img, size, interpolation) + return img + + +def hflip(img): + """Horizontally flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Horizontall flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +def vflip(img): + """Vertically flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Vertically flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_TOP_BOTTOM) + + +def five_crop(img, size): + """Crop the given PIL Image into four corners and the central crop. + + .. Note:: + This transform returns a tuple of images and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + Returns: + tuple: tuple (tl, tr, bl, br, center) corresponding top left, + top right, bottom left, bottom right and center crop. + """ + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + + w, h = img.size + crop_h, crop_w = size + if crop_w > w or crop_h > h: + raise ValueError("Requested crop size {} is bigger than input size {}".format(size, + (h, w))) + tl = img.crop((0, 0, crop_w, crop_h)) + tr = img.crop((w - crop_w, 0, w, crop_h)) + bl = img.crop((0, h - crop_h, crop_w, h)) + br = img.crop((w - crop_w, h - crop_h, w, h)) + center = center_crop(img, (crop_h, crop_w)) + return (tl, tr, bl, br, center) + + +def ten_crop(img, size, vertical_flip=False): + """Crop the given PIL Image into four corners and the central crop plus the + flipped version of these (horizontal flipping is used by default). + + .. Note:: + This transform returns a tuple of images and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + vertical_flip (bool): Use vertical flipping instead of horizontal + + Returns: + tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, + br_flip, center_flip) corresponding top left, top right, + bottom left, bottom right and center crop and same for the + flipped image. + """ + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + + first_five = five_crop(img, size) + + if vertical_flip: + img = vflip(img) + else: + img = hflip(img) + + second_five = five_crop(img, size) + return first_five + second_five + + +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an Image. + + Args: + img (PIL Image): PIL Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL Image: Brightness adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an Image. + + Args: + img (PIL Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + PIL Image: Contrast adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an image. + + Args: + img (PIL Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + PIL Image: Saturation adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See https://en.wikipedia.org/wiki/Hue for more details on Hue. + + Args: + img (PIL Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + PIL Image: Hue adjusted image. + """ + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) + + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert(input_mode) + return img + + +def adjust_gamma(img, gamma, gain=1): + """Perform gamma correction on an image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + I_out = 255 * gain * ((I_in / 255) ** gamma) + + See https://en.wikipedia.org/wiki/Gamma_correction for more details. + + Args: + img (PIL Image): PIL Image to be adjusted. + gamma (float): Non negative real number. gamma larger than 1 make the + shadows darker, while gamma smaller than 1 make dark regions + lighter. + gain (float): The constant multiplier. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if gamma < 0: + raise ValueError('Gamma should be a non-negative real number') + + input_mode = img.mode + img = img.convert('RGB') + + np_img = np.array(img, dtype=np.float32) + np_img = 255 * gain * ((np_img / 255) ** gamma) + np_img = np.uint8(np.clip(np_img, 0, 255)) + + img = Image.fromarray(np_img, 'RGB').convert(input_mode) + return img diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py new file mode 100644 index 00000000000..7882341003f --- /dev/null +++ b/torchvision/transforms/transforms.py @@ -0,0 +1,543 @@ +from __future__ import division +import torch +import math +import random +from PIL import Image, ImageOps, ImageEnhance +try: + import accimage +except ImportError: + accimage = None +import numpy as np +import numbers +import types +import collections +import warnings + +from . import functional as F + + +class Compose(object): + """Composes several transforms together. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + for t in self.transforms: + img = t(img) + return img + + +class ToTensor(object): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + + Converts a PIL Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return F.to_tensor(pic) + + +class ToPILImage(object): + """Convert a tensor or an ndarray to PIL Image. + + Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape + H x W x C to a PIL Image while preserving the value range. + + Args: + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + If ``mode`` is ``None`` (default) there are some assumptions made about the input data: + 1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. + 2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. + 3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e, + ``int``, ``float``, ``short``). + + .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes + """ + def __init__(self, mode=None): + self.mode = mode + + def __call__(self, pic): + """ + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. + + Returns: + PIL Image: Image converted to PIL Image. + + """ + return F.to_pil_image(pic, self.mode) + + +class Normalize(object): + """Normalize an tensor image with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + + Returns: + Tensor: Normalized Tensor image. + """ + return F.normalize(tensor, self.mean, self.std) + + +class Resize(object): + """Resize the input PIL Image to the given size. + + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) + self.size = size + self.interpolation = interpolation + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be scaled. + + Returns: + PIL Image: Rescaled image. + """ + return F.resize(img, self.size, self.interpolation) + + +class Scale(Resize): + """ + Note: This transform is deprecated in favor of Resize. + """ + def __init__(self, *args, **kwargs): + warnings.warn("The use of the transforms.Scale transform is deprecated, " + + "please use transforms.Resize instead.") + super(Scale, self).__init__(*args, **kwargs) + + +class CenterCrop(object): + """Crops the given PIL Image at the center. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped. + + Returns: + PIL Image: Cropped image. + """ + return F.center_crop(img, self.size) + + +class Pad(object): + """Pad the given PIL Image on all sides with the given "pad" value. + + Args: + padding (int or tuple): Padding on each border. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. + fill: Pixel fill value. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + """ + + def __init__(self, padding, fill=0): + assert isinstance(padding, (numbers.Number, tuple)) + assert isinstance(fill, (numbers.Number, str, tuple)) + if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: + raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + self.padding = padding + self.fill = fill + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be padded. + + Returns: + PIL Image: Padded image. + """ + return F.pad(img, self.padding, self.fill) + + +class Lambda(object): + """Apply a user-defined lambda as a transform. + + Args: + lambd (function): Lambda/function to be used for transform. + """ + + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, img): + return self.lambd(img) + + +class RandomCrop(object): + """Crop the given PIL Image at a random location. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + padding (int or sequence, optional): Optional padding on each border + of the image. Default is 0, i.e no padding. If a sequence of length + 4 is provided, it is used to pad left, top, right, bottom borders + respectively. + """ + + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL Image): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + w, h = img.size + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped. + + Returns: + PIL Image: Cropped image. + """ + if self.padding > 0: + img = F.pad(img, self.padding) + + i, j, h, w = self.get_params(img, self.size) + + return F.crop(img, i, j, h, w) + + +class RandomHorizontalFlip(object): + """Horizontally flip the given PIL Image randomly with a probability of 0.5.""" + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Randomly flipped image. + """ + if random.random() < 0.5: + return F.hflip(img) + return img + + +class RandomVerticalFlip(object): + """Vertically flip the given PIL Image randomly with a probability of 0.5.""" + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Randomly flipped image. + """ + if random.random() < 0.5: + return F.vflip(img) + return img + + +class RandomResizedCrop(object): + """Crop the given PIL Image to random size and aspect ratio. + + A crop of random size of (0.08 to 1.0) of the original size and a random + aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = (size, size) + self.interpolation = interpolation + + @staticmethod + def get_params(img): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + for attempt in range(10): + area = img.size[0] * img.size[1] + target_area = random.uniform(0.08, 1.0) * area + aspect_ratio = random.uniform(3. / 4, 4. / 3) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if random.random() < 0.5: + w, h = h, w + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback + w = min(img.size[0], img.size[1]) + i = (img.size[1] - w) // 2 + j = (img.size[0] - w) // 2 + return i, j, w, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Randomly cropped and resize image. + """ + i, j, h, w = self.get_params(img) + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) + + +class RandomSizedCrop(RandomResizedCrop): + """ + Note: This transform is deprecated in favor of RandomResizedCrop. + """ + def __init__(self, *args, **kwargs): + warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + + "please use transforms.RandomResizedCrop instead.") + super(RandomSizedCrop, self).__init__(*args, **kwargs) + + +class FiveCrop(object): + """Crop the given PIL Image into four corners and the central crop.abs + + Note: this transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your `Dataset` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + + def __init__(self, size): + self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + self.size = size + + def __call__(self, img): + return F.five_crop(img, self.size) + + +class TenCrop(object): + """Crop the given PIL Image into four corners and the central crop plus the + flipped version of these (horizontal flipping is used by default) + + Note: this transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your `Dataset` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + vertical_flip(bool): Use vertical flipping instead of horizontal + """ + + def __init__(self, size, vertical_flip=False): + self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + self.size = size + self.vertical_flip = vertical_flip + + def __call__(self, img): + return F.ten_crop(img, self.size, self.vertical_flip) + + +class LinearTransformation(object): + """Transform a tensor image with a square transformation matrix computed + offline. + + Given transformation_matrix, will flatten the torch.*Tensor, compute the dot + product with the transformation matrix and reshape the tensor to its + original shape. + + Applications: + - whitening: zero-center the data, compute the data covariance matrix + [D x D] with np.dot(X.T, X), perform SVD on this matrix and + pass it as transformation_matrix. + + Args: + transformation_matrix (Tensor): tensor [D x D], D = C x H x W + """ + + def __init__(self, transformation_matrix): + if transformation_matrix.size(0) != transformation_matrix.size(1): + raise ValueError("transformation_matrix should be square. Got " + + "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) + self.transformation_matrix = transformation_matrix + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be whitened. + + Returns: + Tensor: Transformed image. + """ + if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0): + raise ValueError("tensor and transformation matrix have incompatible shape." + + "[{} x {} x {}] != ".format(*tensor.size()) + + "{}".format(self.transformation_matrix.size(0))) + flat_tensor = tensor.view(1, -1) + transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) + tensor = transformed_tensor.view(tensor.size()) + return tensor + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation of an image. + + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + + Arguments are same as that of __init__. + + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ + transforms = [] + if brightness > 0: + brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) + transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) + + if contrast > 0: + contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) + transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) + + if saturation > 0: + saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) + transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) + + if hue > 0: + hue_factor = np.random.uniform(-hue, hue) + transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) + + np.random.shuffle(transforms) + transform = Compose(transforms) + + return transform + + def __call__(self, img): + """ + Args: + img (PIL Image): Input image. + + Returns: + PIL Image: Color jittered image. + """ + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + return transform(img) From fc84c1e5b8194c5ef060e58fa3007795f729eb83 Mon Sep 17 00:00:00 2001 From: Alykhan Tejani Date: Wed, 1 Nov 2017 18:22:58 -0400 Subject: [PATCH 2/2] add __all__ to transforms.py --- torchvision/transforms/transforms.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 7882341003f..65e6db87210 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -15,6 +15,10 @@ from . import functional as F +__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", + "Lambda", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", + "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter"] + class Compose(object): """Composes several transforms together.