Skip to content

Unified input for F.affine #2444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,95 @@ def test_resized_crop(self):
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
)

def test_affine(self):
# Tests on square image
tensor, pil_img = self._create_data(26, 26)

scripted_affine = torch.jit.script(F.affine)
# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)

# 2) Test rotation
test_configs = [
(90, torch.rot90(tensor, k=1, dims=(-1, -2))),
(45, None),
(30, None),
(-30, None),
(-45, None),
(-90, torch.rot90(tensor, k=-1, dims=(-1, -2))),
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
]
for a, true_tensor in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
if true_tensor is not None:
self.assertTrue(
true_tensor.equal(out_tensor),
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
)
else:
true_tensor = out_tensor

out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
# Tolerance : less than 6% of different pixels
self.assertLess(
ratio_diff_pixels,
0.06,
msg="{}\n{} vs \n{}".format(
ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
# 3) Test translation
test_configs = [
[10, 12], (12, 13)
]
for t in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
self.compareTensorToPIL(out_tensor, out_pil_img)

# 3) Test rotation + translation + scale + share
test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]),
(33, (5, -4), 1.0, [0.0, 0.0]),
(45, [5, 4], 1.2, [0.0, 0.0]),
(33, (4, 8), 2.0, [0.0, 0.0]),
(85, (10, -10), 0.7, [0.0, 0.0]),
(0, [0, 0], 1.0, [35.0, ]),
(25, [0, 0], 1.2, [0.0, 15.0]),
(45, [10, 0], 0.7, [2.0, 5.0]),
(45, [10, -10], 1.2, [4.0, 5.0]),
]
for r in [0, ]:
for a, t, s, sh in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels,
0.05,
msg="{}: {}\n{} vs \n{}".format(
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)


if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,8 +1317,8 @@ def test_affine(self):
for j in range(-5, 5):
input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55]

with self.assertRaises(TypeError):
F.affine(input_img, 10)
with self.assertRaises(TypeError, msg="Argument translate should be a sequence"):
F.affine(input_img, 10, translate=0, scale=1, shear=1)

pil_img = F.to_pil_image(input_img)

Expand Down
161 changes: 79 additions & 82 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import math
import numbers
import warnings
from typing import Any
from typing import Any, Optional

import numpy as np
from numpy import sin, cos, tan
from PIL import Image, __version__ as PILLOW_VERSION
from PIL import Image

import torch
from torch import Tensor
Expand All @@ -21,6 +20,7 @@


_is_pil_image = F_pil._is_pil_image
_parse_fill = F_pil._parse_fill


def _get_image_size(img: Tensor) -> List[int]:
Expand Down Expand Up @@ -485,43 +485,6 @@ def hflip(img: Tensor) -> Tensor:
return F_t.hflip(img)


def _parse_fill(fill, img, min_pil_version):
"""Helper function to get the fill color for rotate and perspective transforms.

Args:
fill (n-tuple or int or float): Pixel fill value for area outside the transformed
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands.
img (PIL Image): Image to be filled.
min_pil_version (str): The minimum PILLOW version for when the ``fillcolor`` option
was first introduced in the calling function. (e.g. rotate->5.2.0, perspective->5.0.0)

Returns:
dict: kwarg for ``fillcolor``
"""
major_found, minor_found = (int(v) for v in PILLOW_VERSION.split('.')[:2])
major_required, minor_required = (int(v) for v in min_pil_version.split('.')[:2])
if major_found < major_required or (major_found == major_required and minor_found < minor_required):
if fill is None:
return {}
else:
msg = ("The option to fill background area of the transformed image, "
"requires pillow>={}")
raise RuntimeError(msg.format(min_pil_version))

num_bands = len(img.getbands())
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))

return {"fillcolor": fill}


def _get_perspective_coeffs(startpoints, endpoints):
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.

Expand Down Expand Up @@ -827,7 +790,9 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
return img.rotate(angle, resample, expand, center, **opts)


def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
def _get_inverse_affine_matrix(
center: List[int], angle: float, translate: List[float], scale: float, shear: List[float]
) -> List[float]:
# Helper method to compute inverse matrix for affine transformation

# As it is explained in PIL.Image.rotate
Expand All @@ -847,75 +812,107 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
#
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1

if isinstance(shear, numbers.Number):
shear = [shear, 0]

if not isinstance(shear, (tuple, list)) and len(shear) == 2:
raise ValueError(
"Shear should be a single value or a tuple/list containing " +
"two values. Got {}".format(shear))

rot = math.radians(angle)
sx, sy = [math.radians(s) for s in shear]

cx, cy = center
tx, ty = translate

# RSS without scaling
a = cos(rot - sy) / cos(sy)
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
c = sin(rot - sy) / cos(sy)
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
a = math.cos(rot - sy) / math.cos(sy)
b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
c = math.sin(rot - sy) / math.cos(sy)
d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)

# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
M = [d, -b, 0,
-c, a, 0]
M = [x / scale for x in M]
matrix = [d, -b, 0.0, -c, a, 0.0]
matrix = [x / scale for x in matrix]

# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)

# Apply center translation: C * RSS^-1 * C^-1 * T^-1
M[2] += cx
M[5] += cy
return M
matrix[2] += cx
matrix[5] += cy

return matrix

def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
"""Apply affine transformation on the image keeping image center invariant

def affine(
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
resample: int = 0, fillcolor: Optional[int] = None
) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.

Args:
img (PIL Image): PIL Image to be rotated.
img (PIL Image or Tensor): image to be rotated.
angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction.
translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
scale (float): overall scale
shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction.
If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while
the second value corresponds to a shear parallel to the y axis.
If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while
the second value corresponds to a shear parallel to the y axis.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter.
See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)

Returns:
PIL Image or Tensor: Transformed image.
"""
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")

if not isinstance(translate, (list, tuple)):
raise TypeError("Argument translate should be a sequence")

if len(translate) != 2:
raise ValueError("Argument translate should be a sequence of length 2")

if scale <= 0.0:
raise ValueError("Argument scale should be positive")

if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError("Shear should be either a single value or a sequence of two values")

if isinstance(angle, int):
angle = float(angle)

if isinstance(translate, tuple):
translate = list(translate)

if isinstance(shear, numbers.Number):
shear = [shear, 0.0]

if isinstance(shear, tuple):
shear = list(shear)

if len(shear) == 1:
shear = [shear[0], shear[0]]

if len(shear) != 2:
raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear))

img_size = _get_image_size(img)
if not isinstance(img, torch.Tensor):
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
center = [img_size[0] * 0.5, img_size[1] * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"Argument translate should be a list or tuple of length 2"
return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)

assert scale > 0.0, "Argument scale should be positive"
# we need to rescale translate by image size / 2 as its values can be between -1 and 1
translate = [2.0 * t / s for s, t in zip(img_size, translate)]

output_size = img.size
# center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted 1 pixel
center = (img.size[0] * 0.5, img.size[1] * 0.5)
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
kwargs = {"fillcolor": fillcolor} if int(PILLOW_VERSION.split('.')[0]) >= 5 else {}
return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs)
matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear)
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)


def to_grayscale(img, num_output_channels=1):
Expand Down
Loading