Skip to content

Added center arg to F.affine and RandomAffine ops #5208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def test_square_rotations(self, device, height, width, dt, angle, config, fn):
@pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120])
@pytest.mark.parametrize("fn", [F.affine, scripted_affine])
def test_rect_rotations(self, device, height, width, dt, angle, fn):
@pytest.mark.parametrize("center", [None, [0, 0]])
def test_rect_rotations(self, device, height, width, dt, angle, fn, center):
# Tests on rectangular images
tensor, pil_img = _create_data(height, width, device=device)

Expand All @@ -244,11 +245,13 @@ def test_rect_rotations(self, device, height, width, dt, angle, fn):
tensor = tensor.to(dtype=dt)

out_pil_img = F.affine(
pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST).cpu()
out_tensor = fn(
tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center
).cpu()

if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
Expand Down
30 changes: 27 additions & 3 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,11 +1983,11 @@ def _to_3x3_inv(self, inv_result_matrix):
result_matrix[2, 2] = 1
return np.linalg.inv(result_matrix)

def _test_transformation(self, angle, translate, scale, shear, pil_image, input_img):
def _test_transformation(self, angle, translate, scale, shear, pil_image, input_img, center=None):

a_rad = math.radians(angle)
s_rad = [math.radians(sh_) for sh_ in shear]
cnt = [20, 20]
cnt = [20, 20] if center is None else center
cx, cy = cnt
tx, ty = translate
sx, sy = s_rad
Expand Down Expand Up @@ -2032,7 +2032,7 @@ def _test_transformation(self, angle, translate, scale, shear, pil_image, input_
if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]:
true_result[y, x, :] = input_img[_y, _x, :]

result = F.affine(pil_image, angle=angle, translate=translate, scale=scale, shear=shear)
result = F.affine(pil_image, angle=angle, translate=translate, scale=scale, shear=shear, center=center)
assert result.size == pil_image.size
# Compute number of different pixels:
np_result = np.array(result)
Expand All @@ -2050,6 +2050,18 @@ def test_transformation_discrete(self, pil_image, input_img):
angle=angle, translate=(0, 0), scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
)

# Test rotation
angle = 45
self._test_transformation(
angle=angle,
translate=(0, 0),
scale=1.0,
shear=(0.0, 0.0),
pil_image=pil_image,
input_img=input_img,
center=[0, 0],
)

# Test translation
translate = [10, 15]
self._test_transformation(
Expand All @@ -2068,6 +2080,18 @@ def test_transformation_discrete(self, pil_image, input_img):
angle=0.0, translate=(0.0, 0.0), scale=1.0, shear=shear, pil_image=pil_image, input_img=input_img
)

# Test shear with top-left as center
shear = [45.0, 25.0]
self._test_transformation(
angle=0.0,
translate=(0.0, 0.0),
scale=1.0,
shear=shear,
pil_image=pil_image,
input_img=input_img,
center=[0, 0],
)

@pytest.mark.parametrize("angle", range(-90, 90, 36))
@pytest.mark.parametrize("translate", range(-10, 10, 5))
@pytest.mark.parametrize("scale", [0.77, 1.0, 1.27])
Expand Down
40 changes: 29 additions & 11 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,26 +945,31 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:


def _get_inverse_affine_matrix(
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float]
center: List[float],
angle: float,
translate: List[float],
scale: float,
shear: List[float],
) -> List[float]:
# Helper method to compute inverse matrix for affine transformation

# As it is explained in PIL.Image.rotate
# We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
# Pillow requires inverse affine transformation matrix:
# Affine matrix is : M = T * C * RotateScaleShear * C^-1
#
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
# RSS is rotation with scale and shear matrix
# RSS(a, s, (sx, sy)) =
# RotateScaleShear is rotation with scale and shear matrix
#
# RotateScaleShear(a, s, (sx, sy)) =
# = R(a) * S(s) * SHy(sy) * SHx(sx)
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
# [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
# [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
# [ 0 , 0 , 1 ]
#
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
# [0, 1 ] [-tan(s), 1]
#
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
# Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1

rot = math.radians(angle)
sx = math.radians(shear[0])
Expand Down Expand Up @@ -1085,6 +1090,7 @@ def affine(
fill: Optional[List[float]] = None,
resample: Optional[int] = None,
fillcolor: Optional[List[float]] = None,
center: Optional[List[int]] = None,
) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant.
If the image is torch Tensor, it is expected
Expand Down Expand Up @@ -1112,6 +1118,8 @@ def affine(
Please use the ``fill`` parameter instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use the ``interpolation`` parameter instead.
center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
Default is the center of the image.

Returns:
PIL Image or Tensor: Transformed image.
Expand Down Expand Up @@ -1172,18 +1180,28 @@ def affine(
if len(shear) != 2:
raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")

if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")

img_size = get_image_size(img)
if not isinstance(img, torch.Tensor):
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
center = [img_size[0] * 0.5, img_size[1] * 0.5]
if center is None:
center = [img_size[0] * 0.5, img_size[1] * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)

center_f = [0.0, 0.0]
if center is not None:
img_size = get_image_size(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]

translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear)
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)


Expand Down
12 changes: 11 additions & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,8 @@ class RandomAffine(torch.nn.Module):
Please use the ``fill`` parameter instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use the ``interpolation`` parameter instead.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image.

.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

Expand All @@ -1429,6 +1431,7 @@ def __init__(
fill=0,
fillcolor=None,
resample=None,
center=None,
):
super().__init__()
_log_api_usage_once(self)
Expand Down Expand Up @@ -1482,6 +1485,11 @@ def __init__(

self.fillcolor = self.fill = fill

if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))

self.center = center

@staticmethod
def get_params(
degrees: List[float],
Expand Down Expand Up @@ -1538,7 +1546,7 @@ def forward(self, img):

ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)

return F.affine(img, *ret, interpolation=self.interpolation, fill=fill)
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)

def __repr__(self):
s = "{name}(degrees={degrees}"
Expand All @@ -1552,6 +1560,8 @@ def __repr__(self):
s += ", interpolation={interpolation}"
if self.fill != 0:
s += ", fill={fill}"
if self.center is not None:
s += ", center={center}"
s += ")"
d = dict(self.__dict__)
d["interpolation"] = self.interpolation.value
Expand Down