diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index d01a357d7b5..0532f171471 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -435,7 +435,7 @@ def test_affine(self): ) # 3) Test translation test_configs = [ - [10, 12], (12, 13) + [10, 12], (-12, -13) ] for t in test_configs: for fn in [F.affine, scripted_affine]: @@ -447,21 +447,21 @@ def test_affine(self): test_configs = [ (45, [5, 6], 1.0, [0.0, 0.0]), (33, (5, -4), 1.0, [0.0, 0.0]), - (45, [5, 4], 1.2, [0.0, 0.0]), - (33, (4, 8), 2.0, [0.0, 0.0]), + (45, [-5, 4], 1.2, [0.0, 0.0]), + (33, (-4, -8), 2.0, [0.0, 0.0]), (85, (10, -10), 0.7, [0.0, 0.0]), (0, [0, 0], 1.0, [35.0, ]), (25, [0, 0], 1.2, [0.0, 15.0]), - (45, [10, 0], 0.7, [2.0, 5.0]), - (45, [10, -10], 1.2, [4.0, 5.0]), + (45, [-10, 0], 0.7, [2.0, 5.0]), + (45, [-10, -10], 1.2, [4.0, 5.0]), ] for r in [0, ]: for a, t, s, sh in test_configs: + out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for fn in [F.affine, scripted_affine]: out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r) - out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] # Tolerance : less than 5% of different pixels @@ -473,6 +473,47 @@ def test_affine(self): ) ) + def test_rotate(self): + # Tests on square image + tensor, pil_img = self._create_data(26, 26) + scripted_rotate = torch.jit.script(F.rotate) + + img_size = pil_img.size + + centers = [ + None, + (int(img_size[0] * 0.3), int(img_size[0] * 0.4)), + [int(img_size[0] * 0.5), int(img_size[0] * 0.6)] + ] + + for r in [0, ]: + for a in range(-120, 120, 23): + for e in [True, False]: + for c in centers: + + out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for fn in [F.rotate, scripted_rotate]: + out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c) + + self.assertEqual( + out_tensor.shape, + out_pil_tensor.shape, + msg="{}: {} vs {}".format( + (r, a, e, c), out_tensor.shape, out_pil_tensor.shape + ) + ) + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 2% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.02, + msg="{}: {}\n{} vs \n{}".format( + (r, a, e, c), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) + ) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms.py b/test/test_transforms.py index 2d3109b06b0..f3d0f48e4a2 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1266,7 +1266,7 @@ def test_rotate(self): x = np.zeros((100, 100, 3), dtype=np.uint8) x[40, 40] = [255, 255, 255] - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, r"img should be PIL Image"): F.rotate(x, 10) img = F.to_pil_image(x) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 74c92d61bbd..374c131cf44 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -756,40 +756,8 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: return F_t.adjust_gamma(img, gamma, gain) -def rotate(img, angle, resample=False, expand=False, center=None, fill=None): - """Rotate the image by angle. - - - Args: - img (PIL Image): PIL Image to be rotated. - angle (float or int): In degrees degrees counter clockwise order. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): - An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. - expand (bool, optional): Optional expansion flag. - If true, expands the output image to make it large enough to hold the entire rotated image. - If false or omitted, make the output image the same size as the input image. - Note that the expand flag assumes rotation around the center and no translation. - center (2-tuple, optional): Optional center of rotation. - Origin is the upper left corner. - Default is the center of the image. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated - image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. - - .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters - - """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - opts = _parse_fill(fill, img, '5.2.0') - - return img.rotate(angle, resample, expand, center, **opts) - - def _get_inverse_affine_matrix( - center: List[int], angle: float, translate: List[float], scale: float, shear: List[float] + center: List[float], angle: float, translate: List[float], scale: float, shear: List[float] ) -> List[float]: # Helper method to compute inverse matrix for affine transformation @@ -838,6 +806,56 @@ def _get_inverse_affine_matrix( return matrix +def rotate( + img: Tensor, angle: float, resample: int = 0, expand: bool = False, + center: Optional[List[int]] = None, fill: Optional[int] = None +) -> Tensor: + """Rotate the image by angle. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + img (PIL Image or Tensor): image to be rotated. + angle (float or int): rotation angle value in degrees, counter-clockwise. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner. + Default is the center of the image. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. + + Returns: + PIL Image or Tensor: Rotated image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if center is not None and not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + + if not isinstance(img, torch.Tensor): + return F_pil.rotate(img, angle=angle, resample=resample, expand=expand, center=center, fill=fill) + + center_f = [0.0, 0.0] + if center is not None: + img_size = _get_image_size(img) + # Center is normalized to [-1, +1] + center_f = [2.0 * t / s - 1.0 for s, t in zip(img_size, center)] + # due to current incoherence of rotation angle direction between affine and rotate implementations + # we need to set -angle. + matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill) + + def affine( img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], resample: int = 0, fillcolor: Optional[int] = None @@ -847,7 +865,7 @@ def affine( to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: - img (PIL Image or Tensor): image to be rotated. + img (PIL Image or Tensor): image to transform. angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction. translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) scale (float): overall scale @@ -911,7 +929,7 @@ def affine( # we need to rescale translate by image size / 2 as its values can be between -1 and 1 translate = [2.0 * t / s for s, t in zip(img_size, translate)] - matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear) + matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index a0e9467700e..0b76463b24d 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -422,3 +422,37 @@ def affine(img, matrix, resample=0, fillcolor=None): output_size = img.size opts = _parse_fill(fillcolor, img, '5.0.0') return img.transform(output_size, Image.AFFINE, matrix, resample, **opts) + + +@torch.jit.unused +def rotate(img, angle, resample=0, expand=False, center=None, fill=None): + """Rotate PIL image by angle. + + Args: + img (PIL Image): image to be rotated. + angle (float or int): rotation angle value in degrees, counter-clockwise. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (2-tuple, optional): Optional center of rotation. + Origin is the upper left corner. + Default is the center of the image. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. + + Returns: + PIL Image: Rotated image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + opts = _parse_fill(fill, img, '5.2.0') + return img.rotate(angle, resample, expand, center, **opts) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 357f23b88fc..3641d722730 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional +from typing import Optional, Dict, Tuple import torch from torch import Tensor @@ -619,48 +619,32 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: return img -def affine( - img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None -) -> Tensor: - """Apply affine transformation on the Tensor image keeping image center invariant. +def _assert_grid_transform_inputs( + img: Tensor, matrix: List[float], resample: int, fillcolor: Optional[int], _interpolation_modes: Dict[int, str] +): + if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): + raise TypeError("img should be Tensor Image. Got {}".format(type(img))) - Args: - img (Tensor): image to be rotated. - matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. - resample (int, optional): An optional resampling filter. Default is nearest (=2). Other supported values: - bilinear(=2). - fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the - transform in the output image is always 0. + if not isinstance(matrix, list): + raise TypeError("Argument matrix should be a list. Got {}".format(type(matrix))) - Returns: - Tensor: Transformed image. - """ - if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): - raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) + if len(matrix) != 6: + raise ValueError("Argument matrix should have 6 float values") if fillcolor is not None: - warnings.warn("Argument fillcolor is not supported for Tensor input. Fill value is zero") - - _interpolation_modes = { - 0: "nearest", - 2: "bilinear", - } + warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero") if resample not in _interpolation_modes: raise ValueError("This resampling mode is unsupported with Tensor input") - theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) - shape = img.shape - grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) +def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: # make image NCHW need_squeeze = False if img.ndim < 4: img = img.unsqueeze(dim=0) need_squeeze = True - mode = _interpolation_modes[resample] - out_dtype = img.dtype need_cast = False if img.dtype not in (torch.float32, torch.float64): @@ -677,3 +661,106 @@ def affine( img = torch.round(img).to(out_dtype) return img + + +def affine( + img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None +) -> Tensor: + """Apply affine transformation on the Tensor image keeping image center invariant. + + Args: + img (Tensor): image to be rotated. + matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. + resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: + bilinear(=2). + fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the + transform in the output image is always 0. + + Returns: + Tensor: Transformed image. + """ + _interpolation_modes = { + 0: "nearest", + 2: "bilinear", + } + + _assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes) + + theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) + shape = img.shape + grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) + mode = _interpolation_modes[resample] + + return _apply_grid_transform(img, grid, mode) + + +def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]: + + # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + # we need to normalize coordinates according to + # [0, s] is mapped [-1, +1] as theta translation parameters are normalized like that + pts = torch.tensor([ + [-1.0, -1.0, 1.0], + [-1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, -1.0, 1.0], + ]) + # denormalize back to w, h: + new_pts = (torch.matmul(pts, theta.t()) + 1.0) * torch.tensor([w, h]) / 2.0 + min_vals, _ = new_pts.min(dim=0) + max_vals, _ = new_pts.max(dim=0) + size = torch.ceil(max_vals) - torch.floor(min_vals) + return int(size[0]), int(size[1]) + + +def _expanded_affine_grid(theta: Tensor, w: int, h: int, expand: bool = False) -> Tensor: + if expand: + ow, oh = _compute_output_size(theta, w, h) + else: + ow, oh = w, h + d = 0.5 # if not align_corners + + x = (torch.arange(ow) + d - ow * 0.5) / (0.5 * w) + y = (torch.arange(oh) + d - oh * 0.5) / (0.5 * h) + y, x = torch.meshgrid(y, x) + pts = torch.stack([x, y, torch.ones_like(x)], dim=-1) + output_grid = torch.matmul(pts, theta.t()) + return output_grid.unsqueeze(dim=0) + + +def rotate( + img: Tensor, matrix: List[float], resample: int = 0, expand: bool = False, fill: Optional[int] = None +) -> Tensor: + """Rotate the Tensor image by angle. + + Args: + img (Tensor): image to be rotated. + matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation. + resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: + bilinear(=2). + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + fill (n-tuple or int or float): this option is not supported for Tensor input. + Fill value for the area outside the transform in the output image is always 0. + + Returns: + Tensor: Rotated image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + _interpolation_modes = { + 0: "nearest", + 2: "bilinear", + } + + _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes) + + theta = torch.tensor(matrix).reshape(2, 3) + shape = img.shape + grid = _expanded_affine_grid(theta, shape[-1], shape[-2], expand=expand) + mode = _interpolation_modes[resample] + + return _apply_grid_transform(img, grid, mode)