diff --git a/torchvision/transforms/_functional_video.py b/torchvision/transforms/_functional_video.py index 06c30716908..aa7b1211f8b 100644 --- a/torchvision/transforms/_functional_video.py +++ b/torchvision/transforms/_functional_video.py @@ -11,52 +11,6 @@ def _is_tensor_video_clip(clip): return True -def crop(clip, i, j, h, w): - """ - Args: - clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) - """ - assert len(clip.size()) == 4, "clip should be a 4D tensor" - return clip[..., i:i + h, j:j + w] - - -def resize(clip, target_size, interpolation_mode): - assert len(target_size) == 2, "target size should be tuple (height, width)" - return torch.nn.functional.interpolate( - clip, size=target_size, mode=interpolation_mode - ) - - -def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): - """ - Do spatial cropping and resizing to the video clip - Args: - clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) - i (int): i in (i,j) i.e coordinates of the upper left corner. - j (int): j in (i,j) i.e coordinates of the upper left corner. - h (int): Height of the cropped region. - w (int): Width of the cropped region. - size (tuple(int, int)): height and width of resized clip - Returns: - clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) - """ - assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" - clip = crop(clip, i, j, h, w) - clip = resize(clip, size, interpolation_mode) - return clip - - -def center_crop(clip, crop_size): - assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" - h, w = clip.size(-2), clip.size(-1) - th, tw = crop_size - assert h >= th and w >= tw, "height and width must be no smaller than crop_size" - - i = int(round((h - th) / 2.0)) - j = int(round((w - tw) / 2.0)) - return crop(clip, i, j, th, tw) - - def to_tensor(clip): """ Convert tensor data type from uint8 to float, divide value by 255.0 and @@ -88,14 +42,3 @@ def normalize(clip, mean, std, inplace=False): std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) return clip - - -def hflip(clip): - """ - Args: - clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) - Returns: - flipped clip (torch.tensor): Size is (C, T, H, W) - """ - assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" - return clip.flip((-1)) diff --git a/torchvision/transforms/_transforms_video.py b/torchvision/transforms/_transforms_video.py index aa1a4b05314..b3552e5bd8e 100644 --- a/torchvision/transforms/_transforms_video.py +++ b/torchvision/transforms/_transforms_video.py @@ -2,13 +2,15 @@ import numbers import random +from PIL import Image from torchvision.transforms import ( RandomCrop, RandomResizedCrop, ) -from . import _functional_video as F +from . import _functional_video as Fv +from . import functional_tensor as F __all__ = [ @@ -49,7 +51,7 @@ def __init__( size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), - interpolation_mode="bilinear", + interpolation_mode=Image.BILINEAR, ): if isinstance(size, tuple): assert len(size) == 2, "size should be tuple (height, width)" @@ -119,7 +121,7 @@ def __call__(self, clip): Args: clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W) """ - return F.normalize(clip, self.mean, self.std, self.inplace) + return Fv.normalize(clip, self.mean, self.std, self.inplace) def __repr__(self): return self.__class__.__name__ + '(mean={0}, std={1}, inplace={2})'.format( @@ -142,7 +144,7 @@ def __call__(self, clip): Return: clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) """ - return F.to_tensor(clip) + return Fv.to_tensor(clip) def __repr__(self): return self.__class__.__name__ diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8ae75f84c5b..c8020a7a9f1 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -220,10 +220,10 @@ def normalize(tensor, mean, std, inplace=False): def resize(img, size, interpolation=Image.BILINEAR): - r"""Resize the input PIL Image to the given size. + r"""Resize the input Image to the given size. Args: - img (PIL Image): Image to be resized. + img (Tensor or PIL Image): Image to be resized. size (sequence or int): Desired output size. If size is a sequence like (h, w), the output size will be matched to this. If size is an int, the smaller edge of the image will be matched to this number maintaing @@ -235,25 +235,10 @@ def resize(img, size, interpolation=Image.BILINEAR): Returns: PIL Image: Resized image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): - raise TypeError('Got inappropriate size arg: {}'.format(size)) - - if isinstance(size, int): - w, h = img.size - if (w <= h and w == size) or (h <= w and h == size): - return img - if w < h: - ow = size - oh = int(size * h / w) - return img.resize((ow, oh), interpolation) - else: - oh = size - ow = int(size * w / h) - return img.resize((ow, oh), interpolation) + if _is_pil_image(img): + return F_p.resize(img, size, interpolation) else: - return img.resize(size[::-1], interpolation) + return F_t.resize(img, size, interpolation) def scale(*args, **kwargs): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py new file mode 100644 index 00000000000..b5d6efb6997 --- /dev/null +++ b/torchvision/transforms/functional_pil.py @@ -0,0 +1,82 @@ +from __future__ import division +import torch +import sys +import math +from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION +try: + import accimage +except ImportError: + accimage = None +import numpy as np +from numpy import sin, cos, tan +import numbers +import collections +import warnings + +if sys.version_info < (3, 3): + Sequence = collections.Sequence + Iterable = collections.Iterable +else: + Sequence = collections.abc.Sequence + Iterable = collections.abc.Iterable + + +def _is_pil_image(img): + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +def resize(img, size, interpolation=Image.BILINEAR): + r"""Resize the input PIL Image to the given size. + + Args: + img (PIL Image): Image to be resized. + size (sequence or int): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaing + the aspect ratio. i.e, if height > width, then image will be rescaled to + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + + Returns: + PIL Image: Resized image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + + if isinstance(size, int): + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + return img.resize((ow, oh), interpolation) + else: + oh = size + ow = int(size * w / h) + return img.resize((ow, oh), interpolation) + else: + return img.resize(size[::-1], interpolation) + + +def crop(img, top, left, height, width): + """Crop the given PIL Image. + Args: + img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + Returns: + PIL Image: Cropped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.crop((left, top, left + width, top + height)) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 7ef83c1086b..6746dc381a0 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,18 @@ +import numbers import torch -import torchvision.transforms.functional as F +from PIL import Image + + +def _is_tensor_image(img): + if not isinstance(img, torch.Tensor): + raise TypeError("expected torch.Tensor, got {}".format(type(img))) + return img.ndimension() in [3, 4] + + +_PIL_TO_TORCH_INTERP_MODE = { + Image.NEAREST: "nearest", + Image.BILINEAR: "bilinear" +} def vflip(img_tensor): @@ -11,7 +24,7 @@ def vflip(img_tensor): Returns: Tensor: Vertically flipped image Tensor. """ - if not F._is_tensor_image(img_tensor): + if not _is_tensor_image(img_tensor): raise TypeError('tensor is not a torch image.') return img_tensor.flip(-2) @@ -27,12 +40,65 @@ def hflip(img_tensor): Tensor: Horizontally flipped image Tensor. """ - if not F._is_tensor_image(img_tensor): + if not _is_tensor_image(img_tensor): raise TypeError('tensor is not a torch image.') return img_tensor.flip(-1) +def resize(img, size, interpolation=None): + r"""Resize the input Image to the given size. + + Args: + img (torch.Tensor): Image to be resized. Can be 3d or 4d (for batches of images or videos) + size (sequence or int): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaing + the aspect ratio. i.e, if height > width, then image will be rescaled to + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + + Returns: + torch.Tensor: Resized image. + """ + + if interpolation is None: + interpolation = Image.BILINEAR + + if not _is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + if not (isinstance(size, int) or (isinstance(size, (tuple, list)) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + + if isinstance(size, int): + h, w = img.shape[-2:] + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + size = (oh, ow) + + interpolation_mode = _PIL_TO_TORCH_INTERP_MODE[interpolation] + + # interpolate expects batch of images for now, so should adapt input to 4D if necessary + should_squeeze = False + if img.ndim == 3: + should_squeeze = True + img = image[None] + res = torch.nn.functional.interpolate( + img, size=size, mode=interpolation_mode, align_corners=False + ) + if should_squeeze: + res = res[0] + + return res + + def crop(img, top, left, height, width): """Crop the given Image Tensor. Args: @@ -44,12 +110,53 @@ def crop(img, top, left, height, width): Returns: Tensor: Cropped image. """ - if not F._is_tensor_image(img): + if not _is_tensor_image(img): raise TypeError('tensor is not a torch image.') return img[..., top:top + height, left:left + width] +def center_crop(img, output_size): + """Crop the given Image Tensor and resize it to desired size. + + Args: + img (Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. + output_size (sequence or int): (height, width) of the crop box. If int, + it is used for both directions + Returns: + Tensor: Cropped image. + """ + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + image_height, image_width = img.shape[-2:] + crop_height, crop_width = output_size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, crop_top, crop_left, crop_height, crop_width) + + +def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR): + """Crop the given Image Tensor and resize it to desired size. + + Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. + + Args: + img (Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + size (sequence or int): Desired output size. Same semantics as ``resize``. + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR``. + Returns: + Tensor: Cropped image. + """ + img = crop(img, top, left, height, width) + img = resize(img, size, interpolation) + return img + + def adjust_brightness(img, brightness_factor): """Adjust brightness of an RGB image. @@ -62,7 +169,7 @@ def adjust_brightness(img, brightness_factor): Returns: Tensor: Brightness adjusted image. """ - if not F._is_tensor_image(img): + if not _is_tensor_image(img): raise TypeError('tensor is not a torch image.') return _blend(img, 0, brightness_factor) @@ -80,7 +187,7 @@ def adjust_contrast(img, contrast_factor): Returns: Tensor: Contrast adjusted image. """ - if not F._is_tensor_image(img): + if not _is_tensor_image(img): raise TypeError('tensor is not a torch image.') mean = torch.mean(_rgb_to_grayscale(img).to(torch.float)) @@ -100,7 +207,7 @@ def adjust_saturation(img, saturation_factor): Returns: Tensor: Saturation adjusted image. """ - if not F._is_tensor_image(img): + if not _is_tensor_image(img): raise TypeError('tensor is not a torch image.') return _blend(img, _rgb_to_grayscale(img), saturation_factor)