Skip to content

Commit 4839804

Browse files
committed
Unified input for F.affine
1 parent 37edd94 commit 4839804

File tree

4 files changed

+110
-46
lines changed

4 files changed

+110
-46
lines changed

test/test_functional_tensor.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,20 @@ def test_resized_crop(self):
349349
)
350350

351351
def test_affine(self):
352-
# Let's do some tests on square image at first
352+
# Tests on square image
353353
tensor, pil_img = self._create_data(26, 26)
354+
355+
scripted_affine = torch.jit.script(F.affine)
354356
# 1) identity map
355357
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
356358
self.assertTrue(
357359
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
358360
)
361+
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
362+
self.assertTrue(
363+
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
364+
)
365+
359366
# 2) Test rotation
360367
test_configs = [
361368
(90, torch.rot90(tensor, k=1, dims=(-1, -2))),
@@ -367,29 +374,68 @@ def test_affine(self):
367374
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
368375
]
369376
for a, true_tensor in test_configs:
370-
371-
out_tensor = F.affine(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
372-
if true_tensor is not None:
373-
self.assertTrue(
374-
true_tensor.equal(out_tensor),
375-
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
376-
)
377-
else:
378-
true_tensor = out_tensor
379-
380-
out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
381-
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
382-
383-
num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
384-
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
385-
# Tolerence : 6% of different pixels
386-
self.assertLess(
387-
ratio_diff_pixels,
388-
0.06,
389-
msg="{}\n{} vs \n{}".format(
390-
ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
377+
for fn in [F.affine, scripted_affine]:
378+
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
379+
if true_tensor is not None:
380+
self.assertTrue(
381+
true_tensor.equal(out_tensor),
382+
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
383+
)
384+
else:
385+
true_tensor = out_tensor
386+
387+
out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
388+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
389+
390+
num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
391+
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
392+
# Tolerance : less than 6% of different pixels
393+
self.assertLess(
394+
ratio_diff_pixels,
395+
0.06,
396+
msg="{}\n{} vs \n{}".format(
397+
ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
398+
)
391399
)
392-
)
400+
# 3) Test translation
401+
test_configs = [
402+
[10, 12], (12, 13)
403+
]
404+
for t in test_configs:
405+
for fn in [F.affine, scripted_affine]:
406+
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
407+
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
408+
self.compareTensorToPIL(out_tensor, out_pil_img)
409+
410+
# 3) Test rotation + translation + scale + share
411+
test_configs = [
412+
(45, [5, 6], 1.0, [0.0, 0.0]),
413+
(33, (5, -4), 1.0, [0.0, 0.0]),
414+
(45, [5, 4], 1.2, [0.0, 0.0]),
415+
(33, (4, 8), 2.0, [0.0, 0.0]),
416+
(85, (10, -10), 0.7, [0.0, 0.0]),
417+
(0, [0, 0], 1.0, [35.0, ]),
418+
(25, [0, 0], 1.2, [0.0, 15.0]),
419+
(45, [10, 0], 0.7, [2.0, 5.0]),
420+
(45, [10, -10], 1.2, [4.0, 5.0]),
421+
]
422+
for r in [0, ]:
423+
for a, t, s, sh in test_configs:
424+
for fn in [F.affine, scripted_affine]:
425+
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
426+
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
427+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
428+
429+
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
430+
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
431+
# Tolerance : less than 5% of different pixels
432+
self.assertLess(
433+
ratio_diff_pixels,
434+
0.05,
435+
msg="{}: {}\n{} vs \n{}".format(
436+
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
437+
)
438+
)
393439

394440

395441
if __name__ == '__main__':

test/test_transforms.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,9 +1373,10 @@ def _test_transformation(a, t, s, sh):
13731373
inv_true_matrix = np.linalg.inv(true_matrix)
13741374
for y in range(true_result.shape[0]):
13751375
for x in range(true_result.shape[1]):
1376-
res = np.dot(inv_true_matrix, [x, y, 1])
1377-
_x = int(res[0] + 0.5)
1378-
_y = int(res[1] + 0.5)
1376+
# transform pixel's center instead of pixel's TL corner
1377+
res = np.dot(inv_true_matrix, [x + 0.5, y + 0.5, 1])
1378+
_x = int(res[0])
1379+
_y = int(res[1])
13791380
if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]:
13801381
true_result[y, x, :] = input_img[_y, _x, :]
13811382

@@ -1384,8 +1385,8 @@ def _test_transformation(a, t, s, sh):
13841385
# Compute number of different pixels:
13851386
np_result = np.array(result)
13861387
n_diff_pixels = np.sum(np_result != true_result) / 3
1387-
# Accept 3 wrong pixels
1388-
self.assertLess(n_diff_pixels, 3,
1388+
# Accept 7 wrong pixels
1389+
self.assertLess(n_diff_pixels, 7,
13891390
"a={}, t={}, s={}, sh={}\n".format(a, t, s, sh) +
13901391
"n diff pixels={}\n".format(np.sum(np.array(result)[:, :, 0] != true_result[:, :, 0])))
13911392

torchvision/transforms/functional.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,9 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
791791
return img.rotate(angle, resample, expand, center, **opts)
792792

793793

794-
def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
794+
def _get_inverse_affine_matrix(
795+
center: List[int], angle: float, translate: List[float], scale: float, shear: List[float]
796+
) -> List[float]:
795797
# Helper method to compute inverse matrix for affine transformation
796798

797799
# As it is explained in PIL.Image.rotate
@@ -818,14 +820,14 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
818820
tx, ty = translate
819821

820822
# RSS without scaling
821-
a = cos(rot - sy) / cos(sy)
822-
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
823-
c = sin(rot - sy) / cos(sy)
824-
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
823+
a = math.cos(rot - sy) / math.cos(sy)
824+
b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
825+
c = math.sin(rot - sy) / math.cos(sy)
826+
d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
825827

826828
# Inverted rotation matrix with scale and shear
827829
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
828-
matrix = [d, -b, 0, -c, a, 0]
830+
matrix = [d, -b, 0.0, -c, a, 0.0]
829831
matrix = [x / scale for x in matrix]
830832

831833
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
@@ -835,11 +837,12 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
835837
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
836838
matrix[2] += cx
837839
matrix[5] += cy
840+
838841
return matrix
839842

840843

841844
def affine(
842-
img: Tensor, angle: int, translate: List[int], scale: float, shear: List[float],
845+
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
843846
resample: int = 0, fillcolor: Optional[int] = None
844847
) -> Tensor:
845848
"""Apply affine transformation on the image keeping image center invariant.
@@ -863,7 +866,10 @@ def affine(
863866
Returns:
864867
PIL Image or Tensor: Transformed image.
865868
"""
866-
if not isinstance(translate, Sequence):
869+
if not isinstance(angle, (int, float)):
870+
raise TypeError("Argument angle should be int or float")
871+
872+
if not isinstance(translate, (list, tuple)):
867873
raise TypeError("Argument translate should be a sequence")
868874

869875
if len(translate) != 2:
@@ -872,30 +878,41 @@ def affine(
872878
if scale <= 0.0:
873879
raise ValueError("Argument scale should be positive")
874880

875-
if not isinstance(shear, (numbers.Number, Sequence)):
881+
if not isinstance(shear, (numbers.Number, (list, tuple))):
876882
raise TypeError("Shear should be either a single value or a sequence of two values")
877883

884+
if isinstance(angle, int):
885+
angle = float(angle)
886+
887+
if isinstance(translate, tuple):
888+
translate = list(translate)
889+
878890
if isinstance(shear, numbers.Number):
879-
shear = [shear, 0]
891+
shear = [shear, 0.0]
892+
893+
if isinstance(shear, tuple):
894+
shear = list(shear)
895+
896+
if len(shear) == 1:
897+
shear = [shear[0], shear[0]]
880898

881899
if len(shear) != 2:
882900
raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear))
883901

902+
img_size = _get_image_size(img)
884903
if not isinstance(img, torch.Tensor):
885-
img_size = _get_image_size(img)
886904
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
887905
# it is visually better to estimate the center without 0.5 offset
888906
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
889-
center = (img_size[0] * 0.5, img_size[1] * 0.5)
907+
center = [img_size[0] * 0.5, img_size[1] * 0.5]
890908
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
891909

892910
return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
893911

894-
# compute affine matrix (not inversed)
895-
# matrix = _get_inverse_affine_matrix(
896-
# (0, 0), -angle, [-t for t in translate], 1.0 / scale, [-s for s in shear]
897-
# )
898-
matrix = _get_inverse_affine_matrix((0, 0), angle, translate, scale, shear)
912+
# we need to rescale translate by image size / 2 as its values can be between -1 and 1
913+
translate = [2.0 * t / s for s, t in zip(img_size, translate)]
914+
915+
matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear)
899916
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
900917

901918

torchvision/transforms/functional_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
579579

580580

581581
def affine(
582-
img: Tensor, matrix: List[int], resample: int = 0, fillcolor: Optional[int] = None
582+
img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None
583583
) -> Tensor:
584584
"""Apply affine transformation on the Tensor image keeping image center invariant.
585585

0 commit comments

Comments
 (0)