Skip to content

Commit 8c7e7bb

Browse files
authored
[BC-breaking] Unified input for F.perspective (#2558)
* [WIP] Added unified input perspective transformation code * Unified input for F.perspective - added tests - updated docs * Added more random test configs * Fixed the code according to PR's review
1 parent 08af5cb commit 8c7e7bb

File tree

4 files changed

+187
-28
lines changed

4 files changed

+187
-28
lines changed

test/test_functional_tensor.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,46 @@ def test_rotate(self):
545545
)
546546
)
547547

548+
def test_perspective(self):
549+
550+
from torchvision.transforms import RandomPerspective
551+
552+
for tensor, pil_img in [self._create_data(26, 34), self._create_data(26, 26)]:
553+
554+
scripted_tranform = torch.jit.script(F.perspective)
555+
556+
test_configs = [
557+
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
558+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
559+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
560+
]
561+
n = 10
562+
test_configs += [
563+
RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n)
564+
]
565+
566+
for r in [0, ]:
567+
for spoints, epoints in test_configs:
568+
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
569+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
570+
571+
for fn in [F.perspective, scripted_tranform]:
572+
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r)
573+
574+
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
575+
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
576+
# Tolerance : less than 3% of different pixels
577+
self.assertLess(
578+
ratio_diff_pixels,
579+
0.03,
580+
msg="{}: {}\n{} vs \n{}".format(
581+
(r, spoints, epoints),
582+
ratio_diff_pixels,
583+
out_tensor[0, :7, :7],
584+
out_pil_tensor[0, :7, :7]
585+
)
586+
)
587+
548588

549589
if __name__ == '__main__':
550590
unittest.main()

torchvision/transforms/functional.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -491,53 +491,70 @@ def hflip(img: Tensor) -> Tensor:
491491
return F_t.hflip(img)
492492

493493

494-
def _get_perspective_coeffs(startpoints, endpoints):
494+
def _get_perspective_coeffs(
495+
startpoints: List[List[int]], endpoints: List[List[int]]
496+
) -> List[float]:
495497
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
496498
497499
In Perspective Transform each pixel (x, y) in the original image gets transformed as,
498500
(x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
499501
500502
Args:
501-
List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
502-
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image
503+
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
504+
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
505+
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
506+
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
507+
503508
Returns:
504509
octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
505510
"""
506-
matrix = []
511+
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)
512+
513+
for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
514+
a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
515+
a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
507516

508-
for p1, p2 in zip(endpoints, startpoints):
509-
matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
510-
matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
517+
b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
518+
res = torch.lstsq(b_matrix, a_matrix)[0]
511519

512-
A = torch.tensor(matrix, dtype=torch.float)
513-
B = torch.tensor(startpoints, dtype=torch.float).view(8)
514-
res = torch.lstsq(B, A)[0]
515-
return res.squeeze_(1).tolist()
520+
output: List[float] = res.squeeze(1).tolist()
521+
return output
516522

517523

518-
def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=None):
519-
"""Perform perspective transform of the given PIL Image.
524+
def perspective(
525+
img: Tensor,
526+
startpoints: List[List[int]],
527+
endpoints: List[List[int]],
528+
interpolation: int = 2,
529+
fill: Optional[int] = None
530+
) -> Tensor:
531+
"""Perform perspective transform of the given image.
532+
The image can be a PIL Image or a Tensor, in which case it is expected
533+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
520534
521535
Args:
522-
img (PIL Image): Image to be transformed.
523-
startpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the original image
524-
endpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image
525-
interpolation: Default- Image.BICUBIC
536+
img (PIL Image or Tensor): Image to be transformed.
537+
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
538+
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
539+
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
540+
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
541+
interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and
542+
``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors.
526543
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
527544
image. If int or float, the value is used for all bands respectively.
528-
This option is only available for ``pillow>=5.0.0``.
545+
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor
546+
input. Fill value for the area outside the transform in the output image is always 0.
529547
530548
Returns:
531-
PIL Image: Perspectively transformed Image.
549+
PIL Image or Tensor: transformed Image.
532550
"""
533551

534-
if not F_pil._is_pil_image(img):
535-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
552+
coeffs = _get_perspective_coeffs(startpoints, endpoints)
536553

537-
opts = _parse_fill(fill, img, '5.0.0')
554+
if not isinstance(img, torch.Tensor):
555+
return F_pil.perspective(img, coeffs, interpolation=interpolation, fill=fill)
538556

539-
coeffs = _get_perspective_coeffs(startpoints, endpoints)
540-
return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts)
557+
return F_t.perspective(img, coeffs, interpolation=interpolation, fill=fill)
541558

542559

543560
def vflip(img: Tensor) -> Tensor:

torchvision/transforms/functional_pil.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,27 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None):
456456

457457
opts = _parse_fill(fill, img, '5.2.0')
458458
return img.rotate(angle, resample, expand, center, **opts)
459+
460+
461+
@torch.jit.unused
462+
def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None):
463+
"""Perform perspective transform of the given PIL Image.
464+
465+
Args:
466+
img (PIL Image): Image to be transformed.
467+
perspective_coeffs (list of float): perspective transformation coefficients.
468+
interpolation (int): Interpolation type. Default, ``Image.BICUBIC``.
469+
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
470+
image. If int or float, the value is used for all bands respectively.
471+
This option is only available for ``pillow>=5.0.0``.
472+
473+
Returns:
474+
PIL Image: Perspectively transformed Image.
475+
"""
476+
477+
if not _is_pil_image(img):
478+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
479+
480+
opts = _parse_fill(fill, img, '5.0.0')
481+
482+
return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)

torchvision/transforms/functional_tensor.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -620,22 +620,30 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
620620

621621

622622
def _assert_grid_transform_inputs(
623-
img: Tensor, matrix: List[float], resample: int, fillcolor: Optional[int], _interpolation_modes: Dict[int, str]
623+
img: Tensor,
624+
matrix: Optional[List[float]],
625+
resample: int,
626+
fillcolor: Optional[int],
627+
_interpolation_modes: Dict[int, str],
628+
coeffs: Optional[List[float]] = None,
624629
):
625630
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
626631
raise TypeError("img should be Tensor Image. Got {}".format(type(img)))
627632

628-
if not isinstance(matrix, list):
633+
if matrix is not None and not isinstance(matrix, list):
629634
raise TypeError("Argument matrix should be a list. Got {}".format(type(matrix)))
630635

631-
if len(matrix) != 6:
636+
if matrix is not None and len(matrix) != 6:
632637
raise ValueError("Argument matrix should have 6 float values")
633638

639+
if coeffs is not None and len(coeffs) != 8:
640+
raise ValueError("Argument coeffs should have 8 float values")
641+
634642
if fillcolor is not None:
635643
warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero")
636644

637645
if resample not in _interpolation_modes:
638-
raise ValueError("This resampling mode is unsupported with Tensor input")
646+
raise ValueError("Resampling mode '{}' is unsupported with Tensor input".format(resample))
639647

640648

641649
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
@@ -773,3 +781,73 @@ def rotate(
773781
mode = _interpolation_modes[resample]
774782

775783
return _apply_grid_transform(img, grid, mode)
784+
785+
786+
def _perspective_grid(coeffs: List[float], ow: int, oh: int):
787+
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
788+
# src/libImaging/Geometry.c#L394
789+
790+
#
791+
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
792+
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
793+
#
794+
795+
theta1 = torch.tensor([[
796+
[coeffs[0], coeffs[1], coeffs[2]],
797+
[coeffs[3], coeffs[4], coeffs[5]]
798+
]])
799+
theta2 = torch.tensor([[
800+
[coeffs[6], coeffs[7], 1.0],
801+
[coeffs[6], coeffs[7], 1.0]
802+
]])
803+
804+
d = 0.5
805+
base_grid = torch.empty(1, oh, ow, 3)
806+
base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow))
807+
base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1))
808+
base_grid[..., 2].fill_(1)
809+
810+
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh]))
811+
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
812+
813+
output_grid = output_grid1 / output_grid2 - 1.0
814+
return output_grid.view(1, oh, ow, 2)
815+
816+
817+
def perspective(
818+
img: Tensor, perspective_coeffs: List[float], interpolation: int = 2, fill: Optional[int] = None
819+
) -> Tensor:
820+
"""Perform perspective transform of the given Tensor image.
821+
822+
Args:
823+
img (Tensor): Image to be transformed.
824+
perspective_coeffs (list of float): perspective transformation coefficients.
825+
interpolation (int): Interpolation type. Default, ``PIL.Image.BILINEAR``.
826+
fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area
827+
outside the transform in the output image is always 0.
828+
829+
Returns:
830+
Tensor: transformed image.
831+
"""
832+
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
833+
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
834+
835+
_interpolation_modes = {
836+
0: "nearest",
837+
2: "bilinear",
838+
}
839+
840+
_assert_grid_transform_inputs(
841+
img,
842+
matrix=None,
843+
resample=interpolation,
844+
fillcolor=fill,
845+
_interpolation_modes=_interpolation_modes,
846+
coeffs=perspective_coeffs
847+
)
848+
849+
ow, oh = img.shape[-1], img.shape[-2]
850+
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh)
851+
mode = _interpolation_modes[interpolation]
852+
853+
return _apply_grid_transform(img, grid, mode)

0 commit comments

Comments
 (0)