Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 57 additions & 79 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import torch
import sys
import math
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION
try:
import accimage
except ImportError:
accimage = None
import numpy as np
from numpy import sin, cos, tan
import numbers
import collections
import warnings
Expand Down Expand Up @@ -350,61 +349,54 @@ def pad(img, padding, fill=0, padding_mode='constant'):
return Image.fromarray(img)


def crop(img, top, left, height, width):
def crop(img, i, j, h, w):
"""Crop the given PIL Image.

Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
img (PIL Image): Image to be cropped.
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped image.
w (int): Width of the cropped image.

Returns:
PIL Image: Cropped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.crop((left, top, left + width, top + height))
return img.crop((j, i, j + w, i + h))


def center_crop(img, output_size):
"""Crop the given PIL Image and resize it to desired size.

Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
output_size (sequence or int): (height, width) of the crop box. If int,
it is used for both directions
Returns:
PIL Image: Cropped image.
"""
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
image_width, image_height = img.size
crop_height, crop_width = output_size
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
return crop(img, crop_top, crop_left, crop_height, crop_width)
w, h = img.size
th, tw = output_size
i = int(round((h - th) / 2.))
j = int(round((w - tw) / 2.))
return crop(img, i, j, th, tw)


def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR):
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
"""Crop the given PIL Image and resize it to desired size.

Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.

Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
img (PIL Image): Image to be cropped.
i (int): i in (i,j) i.e coordinates of the upper left corner
j (int): j in (i,j) i.e coordinates of the upper left corner
h (int): Height of the cropped image.
w (int): Width of the cropped image.
size (sequence or int): Desired output size. Same semantics as ``resize``.
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``.
Returns:
PIL Image: Cropped image.
"""
assert _is_pil_image(img), 'img should be PIL Image'
img = crop(img, top, left, height, width)
img = crop(img, i, j, h, w)
img = resize(img, size, interpolation)
return img

Expand Down Expand Up @@ -503,18 +495,16 @@ def five_crop(img, size):
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

image_width, image_height = img.size
crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))

tl = img.crop((0, 0, crop_width, crop_height))
tr = img.crop((image_width - crop_width, 0, image_width, crop_height))
bl = img.crop((0, image_height - crop_height, crop_width, image_height))
br = img.crop((image_width - crop_width, image_height - crop_height,
image_width, image_height))
center = center_crop(img, (crop_height, crop_width))
w, h = img.size
crop_h, crop_w = size
if crop_w > w or crop_h > h:
raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
(h, w)))
tl = img.crop((0, 0, crop_w, crop_h))
tr = img.crop((w - crop_w, 0, w, crop_h))
bl = img.crop((0, h - crop_h, crop_w, h))
br = img.crop((w - crop_w, h - crop_h, w, h))
center = center_crop(img, (crop_h, crop_w))
return (tl, tr, bl, br, center)


Expand Down Expand Up @@ -714,7 +704,7 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0):
Origin is the upper left corner.
Default is the center of the image.
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively.
If int, it is used for all channels respectively.

.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

Expand All @@ -723,9 +713,9 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if isinstance(fill, int):
if isinstance(fill, int) and img.mode != 'L':
fill = tuple([fill] * 3)

return img.rotate(angle, resample, expand, center, fillcolor=fill)


Expand All @@ -737,52 +727,40 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
# RSS is rotation with scale and shear matrix
# RSS(a, s, (sx, sy)) =
# = R(a) * S(s) * SHy(sy) * SHx(sx)
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
# [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
# [ 0 , 0 , 1 ]
#
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
# [0, 1 ] [-tan(s), 1]
#
# RSS(a, scale, shear) = [ cos(a + shear_y)*scale -sin(a + shear_x)*scale 0]
# [ sin(a + shear_y)*scale cos(a + shear_x)*scale 0]
# [ 0 0 1]
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1

if isinstance(shear, numbers.Number):
angle = math.radians(angle)
if isinstance(shear, (tuple, list)) and len(shear) == 2:
shear = [math.radians(s) for s in shear]
elif isinstance(shear, numbers.Number):
shear = math.radians(shear)
shear = [shear, 0]

if not isinstance(shear, (tuple, list)) and len(shear) == 2:
else:
raise ValueError(
"Shear should be a single value or a tuple/list containing " +
"two values. Got {}".format(shear))

rot = math.radians(angle)
sx, sy = [math.radians(s) for s in shear]

cx, cy = center
tx, ty = translate

# RSS without scaling
a = cos(rot - sy) / cos(sy)
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
c = sin(rot - sy) / cos(sy)
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
scale = 1.0 / scale

# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
M = [d, -b, 0,
-c, a, 0]
M = [x / scale for x in M]
d = math.cos(angle + shear[0]) * math.cos(angle + shear[1]) + \
math.sin(angle + shear[0]) * math.sin(angle + shear[1])
matrix = [
math.cos(angle + shear[0]), math.sin(angle + shear[0]), 0,
-math.sin(angle + shear[1]), math.cos(angle + shear[1]), 0
]
matrix = [scale / d * m for m in matrix]

# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])

# Apply center translation: C * RSS^-1 * C^-1 * T^-1
M[2] += cx
M[5] += cy
return M
matrix[2] += center[0]
matrix[5] += center[1]
return matrix


def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ class RandomRotation(object):
Origin is the upper left corner.
Default is the center of the image.
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively.
If int, it is used for all channels respectively.

.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

Expand Down Expand Up @@ -999,7 +999,7 @@ def __call__(self, img):
"""

angle = self.get_params(self.degrees)

return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)

def __repr__(self):
Expand Down