From 98d4b3c942b758436afde8dbacf41b491ac0089f Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 18 Jan 2022 19:39:49 +0000 Subject: [PATCH 1/2] Added center option to F.affine and RandomAffine ops --- test/test_functional_tensor.py | 9 +++-- test/test_transforms.py | 30 ++++++++++++++-- torchvision/transforms/functional.py | 51 ++++++++++++++++++++++------ torchvision/transforms/transforms.py | 12 ++++++- 4 files changed, 84 insertions(+), 18 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 8f923475664..b4807c11f51 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -232,7 +232,8 @@ def test_square_rotations(self, device, height, width, dt, angle, config, fn): @pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120]) @pytest.mark.parametrize("fn", [F.affine, scripted_affine]) - def test_rect_rotations(self, device, height, width, dt, angle, fn): + @pytest.mark.parametrize("center", [None, [0, 0]]) + def test_rect_rotations(self, device, height, width, dt, angle, fn, center): # Tests on rectangular images tensor, pil_img = _create_data(height, width, device=device) @@ -244,11 +245,13 @@ def test_rect_rotations(self, device, height, width, dt, angle, fn): tensor = tensor.to(dtype=dt) out_pil_img = F.affine( - pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST + pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST).cpu() + out_tensor = fn( + tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center + ).cpu() if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) diff --git a/test/test_transforms.py b/test/test_transforms.py index 512a343ee59..41aafa5e271 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1983,11 +1983,11 @@ def _to_3x3_inv(self, inv_result_matrix): result_matrix[2, 2] = 1 return np.linalg.inv(result_matrix) - def _test_transformation(self, angle, translate, scale, shear, pil_image, input_img): + def _test_transformation(self, angle, translate, scale, shear, pil_image, input_img, center=None): a_rad = math.radians(angle) s_rad = [math.radians(sh_) for sh_ in shear] - cnt = [20, 20] + cnt = [20, 20] if center is None else center cx, cy = cnt tx, ty = translate sx, sy = s_rad @@ -2032,7 +2032,7 @@ def _test_transformation(self, angle, translate, scale, shear, pil_image, input_ if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]: true_result[y, x, :] = input_img[_y, _x, :] - result = F.affine(pil_image, angle=angle, translate=translate, scale=scale, shear=shear) + result = F.affine(pil_image, angle=angle, translate=translate, scale=scale, shear=shear, center=center) assert result.size == pil_image.size # Compute number of different pixels: np_result = np.array(result) @@ -2050,6 +2050,18 @@ def test_transformation_discrete(self, pil_image, input_img): angle=angle, translate=(0, 0), scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img ) + # Test rotation + angle = 45 + self._test_transformation( + angle=angle, + translate=(0, 0), + scale=1.0, + shear=(0.0, 0.0), + pil_image=pil_image, + input_img=input_img, + center=[0, 0], + ) + # Test translation translate = [10, 15] self._test_transformation( @@ -2068,6 +2080,18 @@ def test_transformation_discrete(self, pil_image, input_img): angle=0.0, translate=(0.0, 0.0), scale=1.0, shear=shear, pil_image=pil_image, input_img=input_img ) + # Test shear with top-left as center + shear = [45.0, 25.0] + self._test_transformation( + angle=0.0, + translate=(0.0, 0.0), + scale=1.0, + shear=shear, + pil_image=pil_image, + input_img=input_img, + center=[0, 0], + ) + @pytest.mark.parametrize("angle", range(-90, 90, 36)) @pytest.mark.parametrize("translate", range(-10, 10, 5)) @pytest.mark.parametrize("scale", [0.77, 1.0, 1.27]) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7d7d5382291..4484a2ff9cf 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -945,26 +945,42 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: def _get_inverse_affine_matrix( - center: List[float], angle: float, translate: List[float], scale: float, shear: List[float] + center: List[float], + angle: float, + translate: List[float], + scale: float, + shear: List[float], + centered_shear: bool = True, ) -> List[float]: # Helper method to compute inverse matrix for affine transformation - # As it is explained in PIL.Image.rotate - # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 + # Pillow requires inverse affine transformation matrix: + # option 1 (centered_shear=True) curr : M = T * C * RotateScaleShear * C^-1 + # option 2 (centered_shear=False) new : M = T * C * RotateScale * C^-1 * Shear + # # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] - # RSS is rotation with scale and shear matrix - # RSS(a, s, (sx, sy)) = + # RotateScaleShear is rotation with scale and shear matrix + # RotateScale is rotation with scale matrix + # + # RotateScaleShear(a, s, (sx, sy)) = # = R(a) * S(s) * SHy(sy) * SHx(sx) - # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ] - # [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ] + # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ] + # [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ] # [ 0 , 0 , 1 ] # + # RotateScale(a, s) = + # = R(a) * S(s) + # = [ s*cos(a), -s*sin(a), 0 ] + # [ s*sin(a), s*cos(a), 0 ] + # [ 0 , 0 , 1 ] + # # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] # [0, 1 ] [-tan(s), 1] - # - # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 + + # TODO: implement the option + assert centered_shear rot = math.radians(angle) sx = math.radians(shear[0]) @@ -1085,6 +1101,7 @@ def affine( fill: Optional[List[float]] = None, resample: Optional[int] = None, fillcolor: Optional[List[float]] = None, + center: Optional[List[int]] = None, ) -> Tensor: """Apply affine transformation on the image keeping image center invariant. If the image is torch Tensor, it is expected @@ -1112,6 +1129,8 @@ def affine( Please use the ``fill`` parameter instead. resample (int, optional): deprecated argument and will be removed since v0.10.0. Please use the ``interpolation`` parameter instead. + center (sequence, optional): Optional center of rotation. Origin is the upper left corner. + Default is the center of the image. Returns: PIL Image or Tensor: Transformed image. @@ -1172,18 +1191,28 @@ def affine( if len(shear) != 2: raise ValueError(f"Shear should be a sequence containing two values. Got {shear}") + if center is not None and not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + img_size = get_image_size(img) if not isinstance(img, torch.Tensor): # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine - center = [img_size[0] * 0.5, img_size[1] * 0.5] + if center is None: + center = [img_size[0] * 0.5, img_size[1] * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) pil_interpolation = pil_modes_mapping[interpolation] return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) + center_f = [0.0, 0.0] + if center is not None: + img_size = get_image_size(img) + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] + translate_f = [1.0 * t for t in translate] - matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear) + matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 5cacecab625..660f1a6f2dd 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1414,6 +1414,8 @@ class RandomAffine(torch.nn.Module): Please use the ``fill`` parameter instead. resample (int, optional): deprecated argument and will be removed since v0.10.0. Please use the ``interpolation`` parameter instead. + center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. + Default is the center of the image. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters @@ -1429,6 +1431,7 @@ def __init__( fill=0, fillcolor=None, resample=None, + center=None, ): super().__init__() _log_api_usage_once(self) @@ -1482,6 +1485,11 @@ def __init__( self.fillcolor = self.fill = fill + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + @staticmethod def get_params( degrees: List[float], @@ -1538,7 +1546,7 @@ def forward(self, img): ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) - return F.affine(img, *ret, interpolation=self.interpolation, fill=fill) + return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center) def __repr__(self): s = "{name}(degrees={degrees}" @@ -1552,6 +1560,8 @@ def __repr__(self): s += ", interpolation={interpolation}" if self.fill != 0: s += ", fill={fill}" + if self.center is not None: + s += ", center={center}" s += ")" d = dict(self.__dict__) d["interpolation"] = self.interpolation.value From f75bfb1c9b98ce22468e63c49cb819b767ccf104 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 24 Jan 2022 15:33:58 +0000 Subject: [PATCH 2/2] Updates according to the review --- torchvision/transforms/functional.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 4484a2ff9cf..dc51877f821 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -950,37 +950,26 @@ def _get_inverse_affine_matrix( translate: List[float], scale: float, shear: List[float], - centered_shear: bool = True, ) -> List[float]: # Helper method to compute inverse matrix for affine transformation # Pillow requires inverse affine transformation matrix: - # option 1 (centered_shear=True) curr : M = T * C * RotateScaleShear * C^-1 - # option 2 (centered_shear=False) new : M = T * C * RotateScale * C^-1 * Shear + # Affine matrix is : M = T * C * RotateScaleShear * C^-1 # # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] # RotateScaleShear is rotation with scale and shear matrix - # RotateScale is rotation with scale matrix # # RotateScaleShear(a, s, (sx, sy)) = # = R(a) * S(s) * SHy(sy) * SHx(sx) # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ] # [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ] # [ 0 , 0 , 1 ] - # - # RotateScale(a, s) = - # = R(a) * S(s) - # = [ s*cos(a), -s*sin(a), 0 ] - # [ s*sin(a), s*cos(a), 0 ] - # [ 0 , 0 , 1 ] - # # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] # [0, 1 ] [-tan(s), 1] - - # TODO: implement the option - assert centered_shear + # + # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1 rot = math.radians(angle) sx = math.radians(shear[0])