Skip to content

Commit 1214e6f

Browse files
committed
WIP on differentiable F.rotate
1 parent 93b26da commit 1214e6f

File tree

3 files changed

+95
-19
lines changed

3 files changed

+95
-19
lines changed

test/test_functional_tensor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,19 @@ def test_rotate_interpolation_type(self):
154154
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
155155
assert_equal(res1, res2)
156156

157+
@pytest.mark.parametrize("fn", [F.rotate, scripted_rotate])
158+
@pytest.mark.parametrize("center", [None, torch.tensor([0.1, 0.2], requires_grad=True)])
159+
def test_differentiable_rotate(self, fn, center):
160+
alpha = torch.tensor(1.0, requires_grad=True)
161+
x = torch.zeros(1, 3, 10, 10)
162+
163+
y = fn(x, alpha, interpolation=BILINEAR, center=center)
164+
assert y.requires_grad
165+
y.mean().backward()
166+
assert alpha.grad is not None
167+
if center is not None:
168+
assert center.grad is not None
169+
157170

158171
class TestAffine:
159172

torchvision/transforms/functional.py

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numbers
33
import warnings
44
from enum import Enum
5-
from typing import List, Tuple, Any, Optional
5+
from typing import List, Tuple, Any, Optional, Union
66

77
import numpy as np
88
import torch
@@ -948,12 +948,48 @@ def _get_inverse_affine_matrix(
948948
return matrix
949949

950950

951+
def _get_inverse_affine_matrix_tensor(
952+
center: Tensor, angle: Tensor, translate: Tensor, scale: Tensor, shear: Tensor
953+
) -> Tensor:
954+
output = torch.zeros(3, 3)
955+
956+
rot = angle * torch.pi / 180.0
957+
shear_rad = shear * torch.pi / 180.0
958+
959+
m_center = torch.eye(3, 3)
960+
m_center[:2, 2] = center
961+
962+
i_m_center = torch.eye(3, 3)
963+
i_m_center[:2, 2] = -center
964+
965+
i_m_translate = torch.eye(3, 3)
966+
i_m_translate[:2, 2] = -translate
967+
968+
# RSS without scaling
969+
sx, sy = shear_rad[0], shear_rad[1]
970+
a = torch.cos(rot - sy) / torch.cos(sy)
971+
b = torch.cos(rot - sy) * torch.tan(sx) / torch.cos(sy) + torch.sin(rot)
972+
c = -torch.sin(rot - sy) / torch.cos(sy)
973+
d = -torch.sin(rot - sy) * torch.tan(sx) / torch.cos(sy) + torch.cos(rot)
974+
975+
output[0, 0] = d
976+
output[0, 1] = b
977+
output[1, 0] = c
978+
output[1, 1] = a
979+
output = output / scale
980+
output[2, 2] = 1.0
981+
982+
output = torch.chain_matmul(m_center, output, i_m_center, i_m_translate)
983+
output = output[:2, :]
984+
return output
985+
986+
951987
def rotate(
952988
img: Tensor,
953-
angle: float,
989+
angle: Union[float, int, Tensor],
954990
interpolation: InterpolationMode = InterpolationMode.NEAREST,
955991
expand: bool = False,
956-
center: Optional[List[int]] = None,
992+
center: Optional[Union[List[int], Tuple[int, int], Tensor]] = None,
957993
fill: Optional[List[float]] = None,
958994
resample: Optional[int] = None,
959995
) -> Tensor:
@@ -963,7 +999,7 @@ def rotate(
963999
9641000
Args:
9651001
img (PIL Image or Tensor): image to be rotated.
966-
angle (number): rotation angle value in degrees, counter-clockwise.
1002+
angle (number or Tensor): rotation angle value in degrees, counter-clockwise.
9671003
interpolation (InterpolationMode): Desired interpolation enum defined by
9681004
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
9691005
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
@@ -972,7 +1008,7 @@ def rotate(
9721008
If true, expands the output image to make it large enough to hold the entire rotated image.
9731009
If false or omitted, make the output image the same size as the input image.
9741010
Note that the expand flag assumes rotation around the center and no translation.
975-
center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
1011+
center (sequence or Tensor, optional): Optional center of rotation. Origin is the upper left corner.
9761012
Default is the center of the image.
9771013
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
9781014
image. If given a number, the value is used for all bands respectively.
@@ -1001,28 +1037,48 @@ def rotate(
10011037
)
10021038
interpolation = _interpolation_modes_from_int(interpolation)
10031039

1004-
if not isinstance(angle, (int, float)):
1005-
raise TypeError("Argument angle should be int or float")
1040+
if not isinstance(angle, (int, float, Tensor)):
1041+
raise TypeError("Argument angle should be int or float or Tensor")
10061042

1007-
if center is not None and not isinstance(center, (list, tuple)):
1008-
raise TypeError("Argument center should be a sequence")
1043+
if center is not None and not isinstance(center, (list, tuple, Tensor)):
1044+
raise TypeError("Argument center should be a sequence or a Tensor")
10091045

10101046
if not isinstance(interpolation, InterpolationMode):
10111047
raise TypeError("Argument interpolation should be a InterpolationMode")
10121048

10131049
if not isinstance(img, torch.Tensor):
1050+
if not isinstance(angle, (int, float)):
1051+
raise TypeError("Argument angle should be int or float")
1052+
1053+
if center is not None and not isinstance(center, (list, tuple)):
1054+
raise TypeError("Argument center should be a sequence")
1055+
10141056
pil_interpolation = pil_modes_mapping[interpolation]
10151057
return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
10161058

1017-
center_f = [0.0, 0.0]
1059+
if isinstance(angle, torch.Tensor) and angle.requires_grad:
1060+
# assert img.dtype is float
1061+
pass
1062+
1063+
center_t = torch.tensor([0.0, 0.0])
10181064
if center is not None:
1019-
img_size = get_image_size(img)
1065+
# ct = torch.tensor([float(c) for c in list(center)]) if not isinstance(center, Tensor) else center
1066+
# THIS DOES NOT PASS JIT as we mix list/tuple of ints but list/tuple of floats are required
1067+
ct = torch.tensor(center) if not isinstance(center, Tensor) else center
1068+
img_size = torch.tensor(get_image_size(img))
10201069
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1021-
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]
1070+
center_t = 1.0 * (ct - img_size * 0.5)
10221071

10231072
# due to current incoherence of rotation angle direction between affine and rotate implementations
10241073
# we need to set -angle.
1025-
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
1074+
angle_t = torch.tensor(float(angle)) if not isinstance(angle, Tensor) else angle
1075+
matrix = _get_inverse_affine_matrix_tensor(
1076+
center_t,
1077+
-angle_t,
1078+
torch.tensor([0.0, 0.0]),
1079+
torch.tensor(1.0),
1080+
torch.tensor([0.0, 0.0])
1081+
)
10261082
return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
10271083

10281084

torchvision/transforms/functional_tensor.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import warnings
23
from typing import Optional, Tuple, List
34

@@ -586,6 +587,12 @@ def _assert_grid_transform_inputs(
586587
if matrix is not None and len(matrix) != 6:
587588
raise ValueError("Argument matrix should have 6 float values")
588589

590+
# if matrix is not None and not isinstance(matrix, Tensor):
591+
# raise TypeError("Argument matrix should be a Tensor")
592+
593+
# if matrix is not None and list(matrix.shape) != [2, 3]:
594+
# raise ValueError("Argument matrix should have shape [2, 3]")
595+
589596
if coeffs is not None and len(coeffs) != 8:
590597
raise ValueError("Argument coeffs should have 8 float values")
591598

@@ -710,7 +717,7 @@ def affine(
710717
return _apply_grid_transform(img, grid, interpolation, fill=fill)
711718

712719

713-
def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
720+
def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]:
714721

715722
# Inspired of PIL implementation:
716723
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
@@ -724,7 +731,6 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
724731
[0.5 * w, -0.5 * h, 1.0],
725732
]
726733
)
727-
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
728734
new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2)
729735
min_vals, _ = new_pts.min(dim=0)
730736
max_vals, _ = new_pts.max(dim=0)
@@ -739,16 +745,17 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
739745

740746
def rotate(
741747
img: Tensor,
742-
matrix: List[float],
748+
matrix: Tensor,
743749
interpolation: str = "nearest",
744750
expand: bool = False,
745751
fill: Optional[List[float]] = None,
746752
) -> Tensor:
747-
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
753+
# _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
754+
matrix = matrix.unsqueeze(0)
748755
w, h = img.shape[-1], img.shape[-2]
749-
ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
756+
ow, oh = _compute_output_size(matrix.detach(), w, h) if expand else (w, h)
750757
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
751-
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
758+
theta = matrix.to(dtype=dtype, device=img.device)
752759
# grid will be generated on the same device as theta and img
753760
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
754761

0 commit comments

Comments
 (0)