diff --git a/torchvision/transforms.py b/torchvision/transforms.py index da58aa12b9a..75e8a64012e 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -13,6 +13,259 @@ import collections +def _is_pil_image(img): + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +def _is_tensor_image(img): + return torch.is_tensor(img) and img.ndimension() == 3 + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def to_tensor(pic): + """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. + + See ``ToTensor`` for more details. + + Args: + pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if not(_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic.transpose((2, 0, 1))) + # backward compatibility + return img.float().div(255) + + if accimage is not None and isinstance(pic, accimage.Image): + nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) + pic.copyto(nppic) + return torch.from_numpy(nppic) + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float().div(255) + else: + return img + + +def to_pil_image(pic): + """Convert a tensor or an ndarray to PIL Image. + + See ``ToPIlImage`` for more details. + + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image. + + Returns: + PIL.Image: Image converted to PIL.Image. + """ + if not(_is_numpy_image(pic) or _is_tensor_image(pic)): + raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) + + npimg = pic + mode = None + if isinstance(pic, torch.FloatTensor): + pic = pic.mul(255).byte() + if torch.is_tensor(pic): + npimg = np.transpose(pic.numpy(), (1, 2, 0)) + assert isinstance(npimg, np.ndarray) + if npimg.shape[2] == 1: + npimg = npimg[:, :, 0] + + if npimg.dtype == np.uint8: + mode = 'L' + if npimg.dtype == np.int16: + mode = 'I;16' + if npimg.dtype == np.int32: + mode = 'I' + elif npimg.dtype == np.float32: + mode = 'F' + elif npimg.shape[2] == 4: + if npimg.dtype == np.uint8: + mode = 'RGBA' + else: + if npimg.dtype == np.uint8: + mode = 'RGB' + assert mode is not None, '{} is not supported'.format(npimg.dtype) + return Image.fromarray(npimg, mode=mode) + + +def normalize(tensor, mean, std): + """Normalize an tensor image with mean and standard deviation. + + See ``Normalize`` for more details. + + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + mean (sequence): Sequence of means for R, G, B channels respecitvely. + std (sequence): Sequence of standard deviations for R, G, B channels + respecitvely. + + Returns: + Tensor: Normalized image. + """ + if not _is_tensor_image(tensor): + raise TypeError('tensor is not a torch image.') + # TODO: make efficient + for t, m, s in zip(tensor, mean, std): + t.sub_(m).div_(s) + return tensor + + +def scale(img, size, interpolation=Image.BILINEAR): + """Rescale the input PIL.Image to the given size. + + Args: + img (PIL.Image): Image to be scaled. + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + + Returns: + PIL.Image: Rescaled image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + + if isinstance(size, int): + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + return img.resize((ow, oh), interpolation) + else: + oh = size + ow = int(size * w / h) + return img.resize((ow, oh), interpolation) + else: + return img.resize(size[::-1], interpolation) + + +def pad(img, padding, fill=0): + """Pad the given PIL.Image on all sides with the given "pad" value. + + Args: + img (PIL.Image): Image to be padded. + padding (int or tuple): Padding on each border. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. + fill: Pixel fill value. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + + Returns: + PIL.Image: Padded image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if not isinstance(padding, (numbers.Number, tuple)): + raise TypeError('Got inappropriate padding arg') + if not isinstance(fill, (numbers.Number, str, tuple)): + raise TypeError('Got inappropriate fill arg') + + if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: + raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + return ImageOps.expand(img, border=padding, fill=fill) + + +def crop(img, i, j, h, w): + """Crop the given PIL.Image. + + Args: + img (PIL.Image): Image to be cropped. + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + + Returns: + PIL.Image: Cropped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.crop((j, i, j + w, i + h)) + + +def scaled_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): + """Crop the given PIL.Image and scale it to desired size. + + Notably used in RandomSizedCrop. + + Args: + img (PIL.Image): Image to be cropped. + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + size (sequence or int): Desired output size. Same semantics as ``scale``. + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR``. + Returns: + PIL.Image: Cropped image. + """ + assert _is_pil_image(img), 'img should be PIL Image' + img = crop(img, i, j, h, w) + img = scale(img, size, interpolation) + return img + + +def hflip(img): + """Horizontally flip the given PIL.Image. + + Args: + img (PIL.Image): Image to be flipped. + + Returns: + PIL.Image: Horizontall flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_LEFT_RIGHT) + + class Compose(object): """Composes several transforms together. @@ -50,43 +303,11 @@ def __call__(self, pic): Returns: Tensor: Converted image. """ - if isinstance(pic, np.ndarray): - # handle numpy array - img = torch.from_numpy(pic.transpose((2, 0, 1))) - # backward compatibility - return img.float().div(255) - - if accimage is not None and isinstance(pic, accimage.Image): - nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) - pic.copyto(nppic) - return torch.from_numpy(nppic) - - # handle PIL Image - if pic.mode == 'I': - img = torch.from_numpy(np.array(pic, np.int32, copy=False)) - elif pic.mode == 'I;16': - img = torch.from_numpy(np.array(pic, np.int16, copy=False)) - else: - img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) - # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK - if pic.mode == 'YCbCr': - nchannel = 3 - elif pic.mode == 'I;16': - nchannel = 1 - else: - nchannel = len(pic.mode) - img = img.view(pic.size[1], pic.size[0], nchannel) - # put it from HWC to CHW format - # yikes, this transpose takes 80% of the loading time/CPU - img = img.transpose(0, 1).transpose(0, 2).contiguous() - if isinstance(img, torch.ByteTensor): - return img.float().div(255) - else: - return img + return to_tensor(pic) class ToPILImage(object): - """Convert a tensor to PIL Image. + """Convert a tensor or an ndarray to PIL Image. Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape H x W x C to a PIL.Image while preserving the value range. @@ -101,32 +322,7 @@ def __call__(self, pic): PIL.Image: Image converted to PIL.Image. """ - npimg = pic - mode = None - if isinstance(pic, torch.FloatTensor): - pic = pic.mul(255).byte() - if torch.is_tensor(pic): - npimg = np.transpose(pic.numpy(), (1, 2, 0)) - assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray' - if npimg.shape[2] == 1: - npimg = npimg[:, :, 0] - - if npimg.dtype == np.uint8: - mode = 'L' - if npimg.dtype == np.int16: - mode = 'I;16' - if npimg.dtype == np.int32: - mode = 'I' - elif npimg.dtype == np.float32: - mode = 'F' - elif npimg.shape[2] == 4: - if npimg.dtype == np.uint8: - mode = 'RGBA' - else: - if npimg.dtype == np.uint8: - mode = 'RGB' - assert mode is not None, '{} is not supported'.format(npimg.dtype) - return Image.fromarray(npimg, mode=mode) + return to_pil_image(pic) class Normalize(object): @@ -154,10 +350,7 @@ def __call__(self, tensor): Returns: Tensor: Normalized image. """ - # TODO: make efficient - for t, m, s in zip(tensor, self.mean, self.std): - t.sub_(m).div_(s) - return tensor + return normalize(tensor, self.mean, self.std) class Scale(object): @@ -186,20 +379,7 @@ def __call__(self, img): Returns: PIL.Image: Rescaled image. """ - if isinstance(self.size, int): - w, h = img.size - if (w <= h and w == self.size) or (h <= w and h == self.size): - return img - if w < h: - ow = self.size - oh = int(self.size * h / w) - return img.resize((ow, oh), self.interpolation) - else: - oh = self.size - ow = int(self.size * w / h) - return img.resize((ow, oh), self.interpolation) - else: - return img.resize(self.size[::-1], self.interpolation) + return scale(img, self.size, self.interpolation) class CenterCrop(object): @@ -217,6 +397,23 @@ def __init__(self, size): else: self.size = size + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for center crop. + + Args: + img (PIL.Image): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. + """ + w, h = img.size + th, tw = output_size + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + return i, j, th, tw + def __call__(self, img): """ Args: @@ -225,11 +422,8 @@ def __call__(self, img): Returns: PIL.Image: Cropped image. """ - w, h = img.size - th, tw = self.size - x1 = int(round((w - tw) / 2.)) - y1 = int(round((h - th) / 2.)) - return img.crop((x1, y1, x1 + tw, y1 + th)) + i, j, h, w = self.get_params(img, self.size) + return crop(img, i, j, h, w) class Pad(object): @@ -263,7 +457,7 @@ def __call__(self, img): Returns: PIL.Image: Padded image. """ - return ImageOps.expand(img, border=self.padding, fill=self.fill) + return pad(img, self.padding, self.fill) class Lambda(object): @@ -301,6 +495,26 @@ def __init__(self, size, padding=0): self.size = size self.padding = padding + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL.Image): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + w, h = img.size + th, tw = output_size + if w == tw and h == th: + return img + + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + def __call__(self, img): """ Args: @@ -310,16 +524,11 @@ def __call__(self, img): PIL.Image: Cropped image. """ if self.padding > 0: - img = ImageOps.expand(img, border=self.padding, fill=0) + img = pad(img, self.padding) - w, h = img.size - th, tw = self.size - if w == tw and h == th: - return img + i, j, h, w = self.get_params(img, self.size) - x1 = random.randint(0, w - tw) - y1 = random.randint(0, h - th) - return img.crop((x1, y1, x1 + tw, y1 + th)) + return crop(img, i, j, h, w) class RandomHorizontalFlip(object): @@ -334,7 +543,7 @@ def __call__(self, img): PIL.Image: Randomly flipped image. """ if random.random() < 0.5: - return img.transpose(Image.FLIP_LEFT_RIGHT) + return hflip(img) return img @@ -347,15 +556,25 @@ class RandomSizedCrop(object): This is popularly used to train the Inception networks. Args: - size: size of the smaller edge + size: expected output size of each edge interpolation: Default: PIL.Image.BILINEAR """ def __init__(self, size, interpolation=Image.BILINEAR): - self.size = size + self.size = (size, size) self.interpolation = interpolation - def __call__(self, img): + @staticmethod + def get_params(img): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL.Image): Image to be cropped. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ for attempt in range(10): area = img.size[0] * img.size[1] target_area = random.uniform(0.08, 1.0) * area @@ -368,15 +587,23 @@ def __call__(self, img): w, h = h, w if w <= img.size[0] and h <= img.size[1]: - x1 = random.randint(0, img.size[0] - w) - y1 = random.randint(0, img.size[1] - h) + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w - img = img.crop((x1, y1, x1 + w, y1 + h)) - assert(img.size == (w, h)) + # Fallback + w = min(img.size[0], img.shape[1]) + i = (img.shape[1] - w) // 2 + j = (img.shape[0] - w) // 2 + return i, j, w, w - return img.resize((self.size, self.size), self.interpolation) + def __call__(self, img): + """ + Args: + img (PIL.Image): Image to be flipped. - # Fallback - scale = Scale(self.size, interpolation=self.interpolation) - crop = CenterCrop(self.size) - return crop(scale(img)) + Returns: + PIL.Image: Randomly cropped and scaled image. + """ + i, j, h, w = self.get_params(img) + return scaled_crop(img, i, j, h, w, self.size, self.interpolation)