Skip to content

Commit 2e04aa2

Browse files
committed
Added center option to F.affine and RandomAffine ops
1 parent 4946827 commit 2e04aa2

File tree

4 files changed

+84
-18
lines changed

4 files changed

+84
-18
lines changed

test/test_functional_tensor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def test_square_rotations(self, device, height, width, dt, angle, config, fn):
232232
@pytest.mark.parametrize("dt", ALL_DTYPES)
233233
@pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120])
234234
@pytest.mark.parametrize("fn", [F.affine, scripted_affine])
235-
def test_rect_rotations(self, device, height, width, dt, angle, fn):
235+
@pytest.mark.parametrize("center", [None, [0, 0]])
236+
def test_rect_rotations(self, device, height, width, dt, angle, fn, center):
236237
# Tests on rectangular images
237238
tensor, pil_img = _create_data(height, width, device=device)
238239

@@ -244,11 +245,13 @@ def test_rect_rotations(self, device, height, width, dt, angle, fn):
244245
tensor = tensor.to(dtype=dt)
245246

246247
out_pil_img = F.affine(
247-
pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
248+
pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center
248249
)
249250
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
250251

251-
out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST).cpu()
252+
out_tensor = fn(
253+
tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center
254+
).cpu()
252255

253256
if out_tensor.dtype != torch.uint8:
254257
out_tensor = out_tensor.to(torch.uint8)

test/test_transforms.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,11 +1983,11 @@ def _to_3x3_inv(self, inv_result_matrix):
19831983
result_matrix[2, 2] = 1
19841984
return np.linalg.inv(result_matrix)
19851985

1986-
def _test_transformation(self, angle, translate, scale, shear, pil_image, input_img):
1986+
def _test_transformation(self, angle, translate, scale, shear, pil_image, input_img, center=None):
19871987

19881988
a_rad = math.radians(angle)
19891989
s_rad = [math.radians(sh_) for sh_ in shear]
1990-
cnt = [20, 20]
1990+
cnt = [20, 20] if center is None else center
19911991
cx, cy = cnt
19921992
tx, ty = translate
19931993
sx, sy = s_rad
@@ -2032,7 +2032,7 @@ def _test_transformation(self, angle, translate, scale, shear, pil_image, input_
20322032
if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]:
20332033
true_result[y, x, :] = input_img[_y, _x, :]
20342034

2035-
result = F.affine(pil_image, angle=angle, translate=translate, scale=scale, shear=shear)
2035+
result = F.affine(pil_image, angle=angle, translate=translate, scale=scale, shear=shear, center=center)
20362036
assert result.size == pil_image.size
20372037
# Compute number of different pixels:
20382038
np_result = np.array(result)
@@ -2050,6 +2050,18 @@ def test_transformation_discrete(self, pil_image, input_img):
20502050
angle=angle, translate=(0, 0), scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
20512051
)
20522052

2053+
# Test rotation
2054+
angle = 45
2055+
self._test_transformation(
2056+
angle=angle,
2057+
translate=(0, 0),
2058+
scale=1.0,
2059+
shear=(0.0, 0.0),
2060+
pil_image=pil_image,
2061+
input_img=input_img,
2062+
center=[0, 0],
2063+
)
2064+
20532065
# Test translation
20542066
translate = [10, 15]
20552067
self._test_transformation(
@@ -2068,6 +2080,18 @@ def test_transformation_discrete(self, pil_image, input_img):
20682080
angle=0.0, translate=(0.0, 0.0), scale=1.0, shear=shear, pil_image=pil_image, input_img=input_img
20692081
)
20702082

2083+
# Test shear with top-left as center
2084+
shear = [45.0, 25.0]
2085+
self._test_transformation(
2086+
angle=0.0,
2087+
translate=(0.0, 0.0),
2088+
scale=1.0,
2089+
shear=shear,
2090+
pil_image=pil_image,
2091+
input_img=input_img,
2092+
center=[0, 0],
2093+
)
2094+
20712095
@pytest.mark.parametrize("angle", range(-90, 90, 36))
20722096
@pytest.mark.parametrize("translate", range(-10, 10, 5))
20732097
@pytest.mark.parametrize("scale", [0.77, 1.0, 1.27])

torchvision/transforms/functional.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -945,26 +945,42 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
945945

946946

947947
def _get_inverse_affine_matrix(
948-
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float]
948+
center: List[float],
949+
angle: float,
950+
translate: List[float],
951+
scale: float,
952+
shear: List[float],
953+
centered_shear: bool = True,
949954
) -> List[float]:
950955
# Helper method to compute inverse matrix for affine transformation
951956

952-
# As it is explained in PIL.Image.rotate
953-
# We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
957+
# Pillow requires inverse affine transformation matrix:
958+
# option 1 (centered_shear=True) curr : M = T * C * RotateScaleShear * C^-1
959+
# option 2 (centered_shear=False) new : M = T * C * RotateScale * C^-1 * Shear
960+
#
954961
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
955962
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
956-
# RSS is rotation with scale and shear matrix
957-
# RSS(a, s, (sx, sy)) =
963+
# RotateScaleShear is rotation with scale and shear matrix
964+
# RotateScale is rotation with scale matrix
965+
#
966+
# RotateScaleShear(a, s, (sx, sy)) =
958967
# = R(a) * S(s) * SHy(sy) * SHx(sx)
959-
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
960-
# [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
968+
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
969+
# [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
961970
# [ 0 , 0 , 1 ]
962971
#
972+
# RotateScale(a, s) =
973+
# = R(a) * S(s)
974+
# = [ s*cos(a), -s*sin(a), 0 ]
975+
# [ s*sin(a), s*cos(a), 0 ]
976+
# [ 0 , 0 , 1 ]
977+
#
963978
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
964979
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
965980
# [0, 1 ] [-tan(s), 1]
966-
#
967-
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
981+
982+
# TODO: implement the option
983+
assert centered_shear
968984

969985
rot = math.radians(angle)
970986
sx = math.radians(shear[0])
@@ -1085,6 +1101,7 @@ def affine(
10851101
fill: Optional[List[float]] = None,
10861102
resample: Optional[int] = None,
10871103
fillcolor: Optional[List[float]] = None,
1104+
center: Optional[List[int]] = None,
10881105
) -> Tensor:
10891106
"""Apply affine transformation on the image keeping image center invariant.
10901107
If the image is torch Tensor, it is expected
@@ -1112,6 +1129,8 @@ def affine(
11121129
Please use the ``fill`` parameter instead.
11131130
resample (int, optional): deprecated argument and will be removed since v0.10.0.
11141131
Please use the ``interpolation`` parameter instead.
1132+
center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
1133+
Default is the center of the image.
11151134
11161135
Returns:
11171136
PIL Image or Tensor: Transformed image.
@@ -1172,18 +1191,28 @@ def affine(
11721191
if len(shear) != 2:
11731192
raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
11741193

1194+
if center is not None and not isinstance(center, (list, tuple)):
1195+
raise TypeError("Argument center should be a sequence")
1196+
11751197
img_size = get_image_size(img)
11761198
if not isinstance(img, torch.Tensor):
11771199
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
11781200
# it is visually better to estimate the center without 0.5 offset
11791201
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
1180-
center = [img_size[0] * 0.5, img_size[1] * 0.5]
1202+
if center is None:
1203+
center = [img_size[0] * 0.5, img_size[1] * 0.5]
11811204
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
11821205
pil_interpolation = pil_modes_mapping[interpolation]
11831206
return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
11841207

1208+
center_f = [0.0, 0.0]
1209+
if center is not None:
1210+
img_size = get_image_size(img)
1211+
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1212+
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]
1213+
11851214
translate_f = [1.0 * t for t in translate]
1186-
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear)
1215+
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
11871216
return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
11881217

11891218

torchvision/transforms/transforms.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,8 @@ class RandomAffine(torch.nn.Module):
14141414
Please use the ``fill`` parameter instead.
14151415
resample (int, optional): deprecated argument and will be removed since v0.10.0.
14161416
Please use the ``interpolation`` parameter instead.
1417+
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
1418+
Default is the center of the image.
14171419
14181420
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
14191421
@@ -1429,6 +1431,7 @@ def __init__(
14291431
fill=0,
14301432
fillcolor=None,
14311433
resample=None,
1434+
center=None,
14321435
):
14331436
super().__init__()
14341437
_log_api_usage_once(self)
@@ -1482,6 +1485,11 @@ def __init__(
14821485

14831486
self.fillcolor = self.fill = fill
14841487

1488+
if center is not None:
1489+
_check_sequence_input(center, "center", req_sizes=(2,))
1490+
1491+
self.center = center
1492+
14851493
@staticmethod
14861494
def get_params(
14871495
degrees: List[float],
@@ -1538,7 +1546,7 @@ def forward(self, img):
15381546

15391547
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
15401548

1541-
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill)
1549+
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
15421550

15431551
def __repr__(self):
15441552
s = "{name}(degrees={degrees}"
@@ -1552,6 +1560,8 @@ def __repr__(self):
15521560
s += ", interpolation={interpolation}"
15531561
if self.fill != 0:
15541562
s += ", fill={fill}"
1563+
if self.center is not None:
1564+
s += ", center={center}"
15551565
s += ")"
15561566
d = dict(self.__dict__)
15571567
d["interpolation"] = self.interpolation.value

0 commit comments

Comments
 (0)