diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index d23930e7313..aaecffe6c5b 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -348,6 +348,95 @@ def test_resized_crop(self): msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]) ) + def test_affine(self): + # Tests on square image + tensor, pil_img = self._create_data(26, 26) + + scripted_affine = torch.jit.script(F.affine) + # 1) identity map + out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + self.assertTrue( + tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) + ) + out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + self.assertTrue( + tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) + ) + + # 2) Test rotation + test_configs = [ + (90, torch.rot90(tensor, k=1, dims=(-1, -2))), + (45, None), + (30, None), + (-30, None), + (-45, None), + (-90, torch.rot90(tensor, k=-1, dims=(-1, -2))), + (180, torch.rot90(tensor, k=2, dims=(-1, -2))), + ] + for a, true_tensor in test_configs: + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + if true_tensor is not None: + self.assertTrue( + true_tensor.equal(out_tensor), + msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]) + ) + else: + true_tensor = out_tensor + + out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2] + # Tolerance : less than 6% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.06, + msg="{}\n{} vs \n{}".format( + ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) + ) + # 3) Test translation + test_configs = [ + [10, 12], (12, 13) + ] + for t in test_configs: + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + self.compareTensorToPIL(out_tensor, out_pil_img) + + # 3) Test rotation + translation + scale + share + test_configs = [ + (45, [5, 6], 1.0, [0.0, 0.0]), + (33, (5, -4), 1.0, [0.0, 0.0]), + (45, [5, 4], 1.2, [0.0, 0.0]), + (33, (4, 8), 2.0, [0.0, 0.0]), + (85, (10, -10), 0.7, [0.0, 0.0]), + (0, [0, 0], 1.0, [35.0, ]), + (25, [0, 0], 1.2, [0.0, 15.0]), + (45, [10, 0], 0.7, [2.0, 5.0]), + (45, [10, -10], 1.2, [4.0, 5.0]), + ] + for r in [0, ]: + for a, t, s, sh in test_configs: + for fn in [F.affine, scripted_affine]: + out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 5% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.05, + msg="{}: {}\n{} vs \n{}".format( + (r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + ) + ) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms.py b/test/test_transforms.py index d583881b472..125502a3ad5 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1317,8 +1317,8 @@ def test_affine(self): for j in range(-5, 5): input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55] - with self.assertRaises(TypeError): - F.affine(input_img, 10) + with self.assertRaises(TypeError, msg="Argument translate should be a sequence"): + F.affine(input_img, 10, translate=0, scale=1, shear=1) pil_img = F.to_pil_image(input_img) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 35cd222acd9..340592a01f0 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,11 +1,10 @@ import math import numbers import warnings -from typing import Any +from typing import Any, Optional import numpy as np -from numpy import sin, cos, tan -from PIL import Image, __version__ as PILLOW_VERSION +from PIL import Image import torch from torch import Tensor @@ -21,6 +20,7 @@ _is_pil_image = F_pil._is_pil_image +_parse_fill = F_pil._parse_fill def _get_image_size(img: Tensor) -> List[int]: @@ -485,43 +485,6 @@ def hflip(img: Tensor) -> Tensor: return F_t.hflip(img) -def _parse_fill(fill, img, min_pil_version): - """Helper function to get the fill color for rotate and perspective transforms. - - Args: - fill (n-tuple or int or float): Pixel fill value for area outside the transformed - image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. - img (PIL Image): Image to be filled. - min_pil_version (str): The minimum PILLOW version for when the ``fillcolor`` option - was first introduced in the calling function. (e.g. rotate->5.2.0, perspective->5.0.0) - - Returns: - dict: kwarg for ``fillcolor`` - """ - major_found, minor_found = (int(v) for v in PILLOW_VERSION.split('.')[:2]) - major_required, minor_required = (int(v) for v in min_pil_version.split('.')[:2]) - if major_found < major_required or (major_found == major_required and minor_found < minor_required): - if fill is None: - return {} - else: - msg = ("The option to fill background area of the transformed image, " - "requires pillow>={}") - raise RuntimeError(msg.format(min_pil_version)) - - num_bands = len(img.getbands()) - if fill is None: - fill = 0 - if isinstance(fill, (int, float)) and num_bands > 1: - fill = tuple([fill] * num_bands) - if not isinstance(fill, (int, float)) and len(fill) != num_bands: - msg = ("The number of elements in 'fill' does not match the number of " - "bands of the image ({} != {})") - raise ValueError(msg.format(len(fill), num_bands)) - - return {"fillcolor": fill} - - def _get_perspective_coeffs(startpoints, endpoints): """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms. @@ -827,7 +790,9 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None): return img.rotate(angle, resample, expand, center, **opts) -def _get_inverse_affine_matrix(center, angle, translate, scale, shear): +def _get_inverse_affine_matrix( + center: List[int], angle: float, translate: List[float], scale: float, shear: List[float] +) -> List[float]: # Helper method to compute inverse matrix for affine transformation # As it is explained in PIL.Image.rotate @@ -847,14 +812,6 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear): # # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 - if isinstance(shear, numbers.Number): - shear = [shear, 0] - - if not isinstance(shear, (tuple, list)) and len(shear) == 2: - raise ValueError( - "Shear should be a single value or a tuple/list containing " + - "two values. Got {}".format(shear)) - rot = math.radians(angle) sx, sy = [math.radians(s) for s in shear] @@ -862,60 +819,100 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear): tx, ty = translate # RSS without scaling - a = cos(rot - sy) / cos(sy) - b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot) - c = sin(rot - sy) / cos(sy) - d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot) + a = math.cos(rot - sy) / math.cos(sy) + b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot) + c = math.sin(rot - sy) / math.cos(sy) + d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot) # Inverted rotation matrix with scale and shear # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 - M = [d, -b, 0, - -c, a, 0] - M = [x / scale for x in M] + matrix = [d, -b, 0.0, -c, a, 0.0] + matrix = [x / scale for x in matrix] # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 - M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty) - M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty) + matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) + matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) # Apply center translation: C * RSS^-1 * C^-1 * T^-1 - M[2] += cx - M[5] += cy - return M + matrix[2] += cx + matrix[5] += cy + return matrix -def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): - """Apply affine transformation on the image keeping image center invariant + +def affine( + img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], + resample: int = 0, fillcolor: Optional[int] = None +) -> Tensor: + """Apply affine transformation on the image keeping image center invariant. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: - img (PIL Image): PIL Image to be rotated. + img (PIL Image or Tensor): image to be rotated. angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction. translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) scale (float): overall scale shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction. - If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while - the second value corresponds to a shear parallel to the y axis. + If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while + the second value corresponds to a shear parallel to the y axis. resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): - An optional resampling filter. - See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) + + Returns: + PIL Image or Tensor: Transformed image. """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if not isinstance(translate, (list, tuple)): + raise TypeError("Argument translate should be a sequence") + + if len(translate) != 2: + raise ValueError("Argument translate should be a sequence of length 2") + + if scale <= 0.0: + raise ValueError("Argument scale should be positive") + + if not isinstance(shear, (numbers.Number, (list, tuple))): + raise TypeError("Shear should be either a single value or a sequence of two values") + + if isinstance(angle, int): + angle = float(angle) + + if isinstance(translate, tuple): + translate = list(translate) + + if isinstance(shear, numbers.Number): + shear = [shear, 0.0] + + if isinstance(shear, tuple): + shear = list(shear) + + if len(shear) == 1: + shear = [shear[0], shear[0]] + + if len(shear) != 2: + raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear)) + + img_size = _get_image_size(img) + if not isinstance(img, torch.Tensor): + # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) + # it is visually better to estimate the center without 0.5 offset + # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine + center = [img_size[0] * 0.5, img_size[1] * 0.5] + matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) - assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ - "Argument translate should be a list or tuple of length 2" + return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) - assert scale > 0.0, "Argument scale should be positive" + # we need to rescale translate by image size / 2 as its values can be between -1 and 1 + translate = [2.0 * t / s for s, t in zip(img_size, translate)] - output_size = img.size - # center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) - # it is visually better to estimate the center without 0.5 offset - # otherwise image rotated by 90 degrees is shifted 1 pixel - center = (img.size[0] * 0.5, img.size[1] * 0.5) - matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) - kwargs = {"fillcolor": fillcolor} if int(PILLOW_VERSION.split('.')[0]) >= 5 else {} - return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs) + matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear) + return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) def to_grayscale(img, num_output_channels=1): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 994988ce1f6..f165b65f8d8 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -1,13 +1,14 @@ import numbers from typing import Any, List, Sequence +import numpy as np import torch +from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION + try: import accimage except ImportError: accimage = None -from PIL import Image, ImageOps, ImageEnhance -import numpy as np @torch.jit.unused @@ -327,3 +328,65 @@ def resize(img, size, interpolation=Image.BILINEAR): return img.resize((ow, oh), interpolation) else: return img.resize(size[::-1], interpolation) + + +@torch.jit.unused +def _parse_fill(fill, img, min_pil_version): + """Helper function to get the fill color for rotate and perspective transforms. + + Args: + fill (n-tuple or int or float): Pixel fill value for area outside the transformed + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. + img (PIL Image): Image to be filled. + min_pil_version (str): The minimum PILLOW version for when the ``fillcolor`` option + was first introduced in the calling function. (e.g. rotate->5.2.0, perspective->5.0.0) + + Returns: + dict: kwarg for ``fillcolor`` + """ + major_found, minor_found = (int(v) for v in PILLOW_VERSION.split('.')[:2]) + major_required, minor_required = (int(v) for v in min_pil_version.split('.')[:2]) + if major_found < major_required or (major_found == major_required and minor_found < minor_required): + if fill is None: + return {} + else: + msg = ("The option to fill background area of the transformed image, " + "requires pillow>={}") + raise RuntimeError(msg.format(min_pil_version)) + + num_bands = len(img.getbands()) + if fill is None: + fill = 0 + if isinstance(fill, (int, float)) and num_bands > 1: + fill = tuple([fill] * num_bands) + if not isinstance(fill, (int, float)) and len(fill) != num_bands: + msg = ("The number of elements in 'fill' does not match the number of " + "bands of the image ({} != {})") + raise ValueError(msg.format(len(fill), num_bands)) + + return {"fillcolor": fill} + + +@torch.jit.unused +def affine(img, matrix, resample=0, fillcolor=None): + """Apply affine transformation on the PIL Image keeping image center invariant. + + Args: + img (PIL Image): image to be rotated. + matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. + See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) + + Returns: + PIL Image: Transformed image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + output_size = img.size + opts = _parse_fill(fillcolor, img, '5.0.0') + return img.transform(output_size, Image.AFFINE, matrix, resample, **opts) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 59cf6bc2764..2bd4549059e 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,9 @@ +import warnings +from typing import Optional + import torch from torch import Tensor +from torch.nn.functional import affine_grid, grid_sample from torch.jit.annotations import List, BroadcastingList2 @@ -496,7 +500,8 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. In torchscript mode padding as a single int is not supported, use a tuple or list of length 1: ``[size, ]``. - interpolation (int, optional): Desired interpolation. Default is bilinear. + interpolation (int, optional): Desired interpolation. Default is bilinear (=2). Other supported values: + nearest(=0) and bicubic(=3). Returns: Tensor: Resized image. @@ -571,3 +576,63 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: img = img.to(out_dtype) return img + + +def affine( + img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None +) -> Tensor: + """Apply affine transformation on the Tensor image keeping image center invariant. + + Args: + img (Tensor): image to be rotated. + matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. + resample (int, optional): An optional resampling filter. Default is nearest (=2). Other supported values: + bilinear(=2). + fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the + transform in the output image is always 0. + + Returns: + Tensor: Transformed image. + """ + if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): + raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) + + if fillcolor is not None: + warnings.warn("Argument fillcolor is not supported for Tensor input. Fill value is zero") + + _interpolation_modes = { + 0: "nearest", + 2: "bilinear", + } + + if resample not in _interpolation_modes: + raise ValueError("This resampling mode is unsupported with Tensor input") + + theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) + shape = img.shape + grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False) + + # make image NCHW + need_squeeze = False + if img.ndim < 4: + img = img.unsqueeze(dim=0) + need_squeeze = True + + mode = _interpolation_modes[resample] + + out_dtype = img.dtype + need_cast = False + if img.dtype not in (torch.float32, torch.float64): + need_cast = True + img = img.to(torch.float32) + + img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) + + if need_squeeze: + img = img.squeeze(dim=0) + + if need_cast: + # it is better to round before cast + img = torch.round(img).to(out_dtype) + + return img