Skip to content

[WIP] Start unification of PIL / Tensor transforms #1532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 0 additions & 57 deletions torchvision/transforms/_functional_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,52 +11,6 @@ def _is_tensor_video_clip(clip):
return True


def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
"""
assert len(clip.size()) == 4, "clip should be a 4D tensor"
return clip[..., i:i + h, j:j + w]


def resize(clip, target_size, interpolation_mode):
assert len(target_size) == 2, "target size should be tuple (height, width)"
return torch.nn.functional.interpolate(
clip, size=target_size, mode=interpolation_mode
)


def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip


def center_crop(clip, crop_size):
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
assert h >= th and w >= tw, "height and width must be no smaller than crop_size"

i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)


def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
Expand Down Expand Up @@ -88,14 +42,3 @@ def normalize(clip, mean, std, inplace=False):
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip


def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
return clip.flip((-1))
10 changes: 6 additions & 4 deletions torchvision/transforms/_transforms_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import numbers
import random
from PIL import Image

from torchvision.transforms import (
RandomCrop,
RandomResizedCrop,
)

from . import _functional_video as F
from . import _functional_video as Fv
from . import functional_tensor as F


__all__ = [
Expand Down Expand Up @@ -49,7 +51,7 @@ def __init__(
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation_mode="bilinear",
interpolation_mode=Image.BILINEAR,
):
if isinstance(size, tuple):
assert len(size) == 2, "size should be tuple (height, width)"
Expand Down Expand Up @@ -119,7 +121,7 @@ def __call__(self, clip):
Args:
clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
"""
return F.normalize(clip, self.mean, self.std, self.inplace)
return Fv.normalize(clip, self.mean, self.std, self.inplace)

def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1}, inplace={2})'.format(
Expand All @@ -142,7 +144,7 @@ def __call__(self, clip):
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
return F.to_tensor(clip)
return Fv.to_tensor(clip)

def __repr__(self):
return self.__class__.__name__
Expand Down
25 changes: 5 additions & 20 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ def normalize(tensor, mean, std, inplace=False):


def resize(img, size, interpolation=Image.BILINEAR):
r"""Resize the input PIL Image to the given size.
r"""Resize the input Image to the given size.

Args:
img (PIL Image): Image to be resized.
img (Tensor or PIL Image): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaing
Expand All @@ -235,25 +235,10 @@ def resize(img, size, interpolation=Image.BILINEAR):
Returns:
PIL Image: Resized image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
raise TypeError('Got inappropriate size arg: {}'.format(size))

if isinstance(size, int):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
if _is_pil_image(img):
return F_p.resize(img, size, interpolation)
else:
return img.resize(size[::-1], interpolation)
return F_t.resize(img, size, interpolation)


def scale(*args, **kwargs):
Expand Down
82 changes: 82 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import division
import torch
import sys
import math
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
try:
import accimage
except ImportError:
accimage = None
import numpy as np
from numpy import sin, cos, tan
import numbers
import collections
import warnings

if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable


def _is_pil_image(img):
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)


def resize(img, size, interpolation=Image.BILINEAR):
r"""Resize the input PIL Image to the given size.

Args:
img (PIL Image): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaing
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``

Returns:
PIL Image: Resized image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
raise TypeError('Got inappropriate size arg: {}'.format(size))

if isinstance(size, int):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
else:
return img.resize(size[::-1], interpolation)


def crop(img, top, left, height, width):
"""Crop the given PIL Image.
Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
Returns:
PIL Image: Cropped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.crop((left, top, left + width, top + height))
Loading