Skip to content

Commit 5f4b579

Browse files
authored
Unified input for F.affine (#2444)
* [WIP] F.affine * [WIP] F.affine + tests * Unified input for F.affine * Removed commented code * Removed unused imports
1 parent 03b1d38 commit 5f4b579

File tree

5 files changed

+301
-87
lines changed

5 files changed

+301
-87
lines changed

test/test_functional_tensor.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,95 @@ def test_resized_crop(self):
348348
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
349349
)
350350

351+
def test_affine(self):
352+
# Tests on square image
353+
tensor, pil_img = self._create_data(26, 26)
354+
355+
scripted_affine = torch.jit.script(F.affine)
356+
# 1) identity map
357+
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
358+
self.assertTrue(
359+
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
360+
)
361+
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
362+
self.assertTrue(
363+
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
364+
)
365+
366+
# 2) Test rotation
367+
test_configs = [
368+
(90, torch.rot90(tensor, k=1, dims=(-1, -2))),
369+
(45, None),
370+
(30, None),
371+
(-30, None),
372+
(-45, None),
373+
(-90, torch.rot90(tensor, k=-1, dims=(-1, -2))),
374+
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
375+
]
376+
for a, true_tensor in test_configs:
377+
for fn in [F.affine, scripted_affine]:
378+
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
379+
if true_tensor is not None:
380+
self.assertTrue(
381+
true_tensor.equal(out_tensor),
382+
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
383+
)
384+
else:
385+
true_tensor = out_tensor
386+
387+
out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
388+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
389+
390+
num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
391+
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
392+
# Tolerance : less than 6% of different pixels
393+
self.assertLess(
394+
ratio_diff_pixels,
395+
0.06,
396+
msg="{}\n{} vs \n{}".format(
397+
ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
398+
)
399+
)
400+
# 3) Test translation
401+
test_configs = [
402+
[10, 12], (12, 13)
403+
]
404+
for t in test_configs:
405+
for fn in [F.affine, scripted_affine]:
406+
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
407+
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
408+
self.compareTensorToPIL(out_tensor, out_pil_img)
409+
410+
# 3) Test rotation + translation + scale + share
411+
test_configs = [
412+
(45, [5, 6], 1.0, [0.0, 0.0]),
413+
(33, (5, -4), 1.0, [0.0, 0.0]),
414+
(45, [5, 4], 1.2, [0.0, 0.0]),
415+
(33, (4, 8), 2.0, [0.0, 0.0]),
416+
(85, (10, -10), 0.7, [0.0, 0.0]),
417+
(0, [0, 0], 1.0, [35.0, ]),
418+
(25, [0, 0], 1.2, [0.0, 15.0]),
419+
(45, [10, 0], 0.7, [2.0, 5.0]),
420+
(45, [10, -10], 1.2, [4.0, 5.0]),
421+
]
422+
for r in [0, ]:
423+
for a, t, s, sh in test_configs:
424+
for fn in [F.affine, scripted_affine]:
425+
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
426+
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
427+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
428+
429+
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
430+
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
431+
# Tolerance : less than 5% of different pixels
432+
self.assertLess(
433+
ratio_diff_pixels,
434+
0.05,
435+
msg="{}: {}\n{} vs \n{}".format(
436+
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
437+
)
438+
)
439+
351440

352441
if __name__ == '__main__':
353442
unittest.main()

test/test_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,8 +1317,8 @@ def test_affine(self):
13171317
for j in range(-5, 5):
13181318
input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55]
13191319

1320-
with self.assertRaises(TypeError):
1321-
F.affine(input_img, 10)
1320+
with self.assertRaises(TypeError, msg="Argument translate should be a sequence"):
1321+
F.affine(input_img, 10, translate=0, scale=1, shear=1)
13221322

13231323
pil_img = F.to_pil_image(input_img)
13241324

torchvision/transforms/functional.py

Lines changed: 79 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import math
22
import numbers
33
import warnings
4-
from typing import Any
4+
from typing import Any, Optional
55

66
import numpy as np
7-
from numpy import sin, cos, tan
8-
from PIL import Image, __version__ as PILLOW_VERSION
7+
from PIL import Image
98

109
import torch
1110
from torch import Tensor
@@ -21,6 +20,7 @@
2120

2221

2322
_is_pil_image = F_pil._is_pil_image
23+
_parse_fill = F_pil._parse_fill
2424

2525

2626
def _get_image_size(img: Tensor) -> List[int]:
@@ -485,43 +485,6 @@ def hflip(img: Tensor) -> Tensor:
485485
return F_t.hflip(img)
486486

487487

488-
def _parse_fill(fill, img, min_pil_version):
489-
"""Helper function to get the fill color for rotate and perspective transforms.
490-
491-
Args:
492-
fill (n-tuple or int or float): Pixel fill value for area outside the transformed
493-
image. If int or float, the value is used for all bands respectively.
494-
Defaults to 0 for all bands.
495-
img (PIL Image): Image to be filled.
496-
min_pil_version (str): The minimum PILLOW version for when the ``fillcolor`` option
497-
was first introduced in the calling function. (e.g. rotate->5.2.0, perspective->5.0.0)
498-
499-
Returns:
500-
dict: kwarg for ``fillcolor``
501-
"""
502-
major_found, minor_found = (int(v) for v in PILLOW_VERSION.split('.')[:2])
503-
major_required, minor_required = (int(v) for v in min_pil_version.split('.')[:2])
504-
if major_found < major_required or (major_found == major_required and minor_found < minor_required):
505-
if fill is None:
506-
return {}
507-
else:
508-
msg = ("The option to fill background area of the transformed image, "
509-
"requires pillow>={}")
510-
raise RuntimeError(msg.format(min_pil_version))
511-
512-
num_bands = len(img.getbands())
513-
if fill is None:
514-
fill = 0
515-
if isinstance(fill, (int, float)) and num_bands > 1:
516-
fill = tuple([fill] * num_bands)
517-
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
518-
msg = ("The number of elements in 'fill' does not match the number of "
519-
"bands of the image ({} != {})")
520-
raise ValueError(msg.format(len(fill), num_bands))
521-
522-
return {"fillcolor": fill}
523-
524-
525488
def _get_perspective_coeffs(startpoints, endpoints):
526489
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
527490
@@ -827,7 +790,9 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
827790
return img.rotate(angle, resample, expand, center, **opts)
828791

829792

830-
def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
793+
def _get_inverse_affine_matrix(
794+
center: List[int], angle: float, translate: List[float], scale: float, shear: List[float]
795+
) -> List[float]:
831796
# Helper method to compute inverse matrix for affine transformation
832797

833798
# As it is explained in PIL.Image.rotate
@@ -847,75 +812,107 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
847812
#
848813
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
849814

850-
if isinstance(shear, numbers.Number):
851-
shear = [shear, 0]
852-
853-
if not isinstance(shear, (tuple, list)) and len(shear) == 2:
854-
raise ValueError(
855-
"Shear should be a single value or a tuple/list containing " +
856-
"two values. Got {}".format(shear))
857-
858815
rot = math.radians(angle)
859816
sx, sy = [math.radians(s) for s in shear]
860817

861818
cx, cy = center
862819
tx, ty = translate
863820

864821
# RSS without scaling
865-
a = cos(rot - sy) / cos(sy)
866-
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
867-
c = sin(rot - sy) / cos(sy)
868-
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
822+
a = math.cos(rot - sy) / math.cos(sy)
823+
b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
824+
c = math.sin(rot - sy) / math.cos(sy)
825+
d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
869826

870827
# Inverted rotation matrix with scale and shear
871828
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
872-
M = [d, -b, 0,
873-
-c, a, 0]
874-
M = [x / scale for x in M]
829+
matrix = [d, -b, 0.0, -c, a, 0.0]
830+
matrix = [x / scale for x in matrix]
875831

876832
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
877-
M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
878-
M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
833+
matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
834+
matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
879835

880836
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
881-
M[2] += cx
882-
M[5] += cy
883-
return M
837+
matrix[2] += cx
838+
matrix[5] += cy
884839

840+
return matrix
885841

886-
def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
887-
"""Apply affine transformation on the image keeping image center invariant
842+
843+
def affine(
844+
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
845+
resample: int = 0, fillcolor: Optional[int] = None
846+
) -> Tensor:
847+
"""Apply affine transformation on the image keeping image center invariant.
848+
The image can be a PIL Image or a Tensor, in which case it is expected
849+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
888850
889851
Args:
890-
img (PIL Image): PIL Image to be rotated.
852+
img (PIL Image or Tensor): image to be rotated.
891853
angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction.
892854
translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
893855
scale (float): overall scale
894856
shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction.
895-
If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while
896-
the second value corresponds to a shear parallel to the y axis.
857+
If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while
858+
the second value corresponds to a shear parallel to the y axis.
897859
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
898-
An optional resampling filter.
899-
See `filters`_ for more information.
900-
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
860+
An optional resampling filter. See `filters`_ for more information.
861+
If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
862+
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
901863
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
864+
865+
Returns:
866+
PIL Image or Tensor: Transformed image.
902867
"""
903-
if not F_pil._is_pil_image(img):
904-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
868+
if not isinstance(angle, (int, float)):
869+
raise TypeError("Argument angle should be int or float")
870+
871+
if not isinstance(translate, (list, tuple)):
872+
raise TypeError("Argument translate should be a sequence")
873+
874+
if len(translate) != 2:
875+
raise ValueError("Argument translate should be a sequence of length 2")
876+
877+
if scale <= 0.0:
878+
raise ValueError("Argument scale should be positive")
879+
880+
if not isinstance(shear, (numbers.Number, (list, tuple))):
881+
raise TypeError("Shear should be either a single value or a sequence of two values")
882+
883+
if isinstance(angle, int):
884+
angle = float(angle)
885+
886+
if isinstance(translate, tuple):
887+
translate = list(translate)
888+
889+
if isinstance(shear, numbers.Number):
890+
shear = [shear, 0.0]
891+
892+
if isinstance(shear, tuple):
893+
shear = list(shear)
894+
895+
if len(shear) == 1:
896+
shear = [shear[0], shear[0]]
897+
898+
if len(shear) != 2:
899+
raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear))
900+
901+
img_size = _get_image_size(img)
902+
if not isinstance(img, torch.Tensor):
903+
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
904+
# it is visually better to estimate the center without 0.5 offset
905+
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
906+
center = [img_size[0] * 0.5, img_size[1] * 0.5]
907+
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
905908

906-
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
907-
"Argument translate should be a list or tuple of length 2"
909+
return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
908910

909-
assert scale > 0.0, "Argument scale should be positive"
911+
# we need to rescale translate by image size / 2 as its values can be between -1 and 1
912+
translate = [2.0 * t / s for s, t in zip(img_size, translate)]
910913

911-
output_size = img.size
912-
# center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
913-
# it is visually better to estimate the center without 0.5 offset
914-
# otherwise image rotated by 90 degrees is shifted 1 pixel
915-
center = (img.size[0] * 0.5, img.size[1] * 0.5)
916-
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
917-
kwargs = {"fillcolor": fillcolor} if int(PILLOW_VERSION.split('.')[0]) >= 5 else {}
918-
return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs)
914+
matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear)
915+
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
919916

920917

921918
def to_grayscale(img, num_output_channels=1):

0 commit comments

Comments
 (0)