diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 4ffa8cf280e..2675226d3b7 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -545,6 +545,46 @@ def test_rotate(self): ) ) + def test_perspective(self): + + from torchvision.transforms import RandomPerspective + + for tensor, pil_img in [self._create_data(26, 34), self._create_data(26, 26)]: + + scripted_tranform = torch.jit.script(F.perspective) + + test_configs = [ + [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]], + [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]], + [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]], + ] + n = 10 + test_configs += [ + RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n) + ] + + for r in [0, ]: + for spoints, epoints in test_configs: + out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + + for fn in [F.perspective, scripted_tranform]: + out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r) + + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 3% of different pixels + self.assertLess( + ratio_diff_pixels, + 0.03, + msg="{}: {}\n{} vs \n{}".format( + (r, spoints, epoints), + ratio_diff_pixels, + out_tensor[0, :7, :7], + out_pil_tensor[0, :7, :7] + ) + ) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 689137b44cb..0cc2c7d97bd 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -491,53 +491,70 @@ def hflip(img: Tensor) -> Tensor: return F_t.hflip(img) -def _get_perspective_coeffs(startpoints, endpoints): +def _get_perspective_coeffs( + startpoints: List[List[int]], endpoints: List[List[int]] +) -> List[float]: """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms. In Perspective Transform each pixel (x, y) in the original image gets transformed as, (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) ) Args: - List containing [top-left, top-right, bottom-right, bottom-left] of the original image, - List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image + startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners + ``[top-left, top-right, bottom-right, bottom-left]`` of the original image. + endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners + ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image. + Returns: octuple (a, b, c, d, e, f, g, h) for transforming each pixel. """ - matrix = [] + a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float) + + for i, (p1, p2) in enumerate(zip(endpoints, startpoints)): + a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]) + a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) - for p1, p2 in zip(endpoints, startpoints): - matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]) - matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) + b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8) + res = torch.lstsq(b_matrix, a_matrix)[0] - A = torch.tensor(matrix, dtype=torch.float) - B = torch.tensor(startpoints, dtype=torch.float).view(8) - res = torch.lstsq(B, A)[0] - return res.squeeze_(1).tolist() + output: List[float] = res.squeeze(1).tolist() + return output -def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=None): - """Perform perspective transform of the given PIL Image. +def perspective( + img: Tensor, + startpoints: List[List[int]], + endpoints: List[List[int]], + interpolation: int = 2, + fill: Optional[int] = None +) -> Tensor: + """Perform perspective transform of the given image. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: - img (PIL Image): Image to be transformed. - startpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the original image - endpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image - interpolation: Default- Image.BICUBIC + img (PIL Image or Tensor): Image to be transformed. + startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners + ``[top-left, top-right, bottom-right, bottom-left]`` of the original image. + endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners + ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image. + interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and + ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. fill (n-tuple or int or float): Pixel fill value for area outside the rotated image. If int or float, the value is used for all bands respectively. - This option is only available for ``pillow>=5.0.0``. + This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor + input. Fill value for the area outside the transform in the output image is always 0. Returns: - PIL Image: Perspectively transformed Image. + PIL Image or Tensor: transformed Image. """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + coeffs = _get_perspective_coeffs(startpoints, endpoints) - opts = _parse_fill(fill, img, '5.0.0') + if not isinstance(img, torch.Tensor): + return F_pil.perspective(img, coeffs, interpolation=interpolation, fill=fill) - coeffs = _get_perspective_coeffs(startpoints, endpoints) - return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts) + return F_t.perspective(img, coeffs, interpolation=interpolation, fill=fill) def vflip(img: Tensor) -> Tensor: diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 0b76463b24d..f1e8504f874 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -456,3 +456,27 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None): opts = _parse_fill(fill, img, '5.2.0') return img.rotate(angle, resample, expand, center, **opts) + + +@torch.jit.unused +def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None): + """Perform perspective transform of the given PIL Image. + + Args: + img (PIL Image): Image to be transformed. + perspective_coeffs (list of float): perspective transformation coefficients. + interpolation (int): Interpolation type. Default, ``Image.BICUBIC``. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + This option is only available for ``pillow>=5.0.0``. + + Returns: + PIL Image: Perspectively transformed Image. + """ + + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + opts = _parse_fill(fill, img, '5.0.0') + + return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 452e5d37ed8..433575ac6e2 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -620,22 +620,30 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: def _assert_grid_transform_inputs( - img: Tensor, matrix: List[float], resample: int, fillcolor: Optional[int], _interpolation_modes: Dict[int, str] + img: Tensor, + matrix: Optional[List[float]], + resample: int, + fillcolor: Optional[int], + _interpolation_modes: Dict[int, str], + coeffs: Optional[List[float]] = None, ): if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): raise TypeError("img should be Tensor Image. Got {}".format(type(img))) - if not isinstance(matrix, list): + if matrix is not None and not isinstance(matrix, list): raise TypeError("Argument matrix should be a list. Got {}".format(type(matrix))) - if len(matrix) != 6: + if matrix is not None and len(matrix) != 6: raise ValueError("Argument matrix should have 6 float values") + if coeffs is not None and len(coeffs) != 8: + raise ValueError("Argument coeffs should have 8 float values") + if fillcolor is not None: warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero") if resample not in _interpolation_modes: - raise ValueError("This resampling mode is unsupported with Tensor input") + raise ValueError("Resampling mode '{}' is unsupported with Tensor input".format(resample)) def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: @@ -773,3 +781,73 @@ def rotate( mode = _interpolation_modes[resample] return _apply_grid_transform(img, grid, mode) + + +def _perspective_grid(coeffs: List[float], ow: int, oh: int): + # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ + # src/libImaging/Geometry.c#L394 + + # + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + # + + theta1 = torch.tensor([[ + [coeffs[0], coeffs[1], coeffs[2]], + [coeffs[3], coeffs[4], coeffs[5]] + ]]) + theta2 = torch.tensor([[ + [coeffs[6], coeffs[7], 1.0], + [coeffs[6], coeffs[7], 1.0] + ]]) + + d = 0.5 + base_grid = torch.empty(1, oh, ow, 3) + base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow)) + base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1)) + base_grid[..., 2].fill_(1) + + output_grid1 = base_grid.view(1, oh * ow, 3).bmm(theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh])) + output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) + + output_grid = output_grid1 / output_grid2 - 1.0 + return output_grid.view(1, oh, ow, 2) + + +def perspective( + img: Tensor, perspective_coeffs: List[float], interpolation: int = 2, fill: Optional[int] = None +) -> Tensor: + """Perform perspective transform of the given Tensor image. + + Args: + img (Tensor): Image to be transformed. + perspective_coeffs (list of float): perspective transformation coefficients. + interpolation (int): Interpolation type. Default, ``PIL.Image.BILINEAR``. + fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area + outside the transform in the output image is always 0. + + Returns: + Tensor: transformed image. + """ + if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): + raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) + + _interpolation_modes = { + 0: "nearest", + 2: "bilinear", + } + + _assert_grid_transform_inputs( + img, + matrix=None, + resample=interpolation, + fillcolor=fill, + _interpolation_modes=_interpolation_modes, + coeffs=perspective_coeffs + ) + + ow, oh = img.shape[-1], img.shape[-2] + grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh) + mode = _interpolation_modes[interpolation] + + return _apply_grid_transform(img, grid, mode)