Skip to content

Commit 205051e

Browse files
committed
Unified input for F.perspective
- added tests - updated docs
1 parent 2a4cd36 commit 205051e

File tree

3 files changed

+116
-49
lines changed

3 files changed

+116
-49
lines changed

test/test_functional_tensor.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -546,33 +546,37 @@ def test_rotate(self):
546546
)
547547

548548
def test_perspective(self):
549-
tensor, pil_img = self._create_data(26, 34)
550-
551-
scripted_tranform = torch.jit.script(F.perspective)
552-
553-
test_configs = [
554-
([(0, 0), (33, 0), (33, 25), (0, 25)], [(3, 2), (32, 3), (30, 24), (2, 25)]),
555-
([(3, 2), (32, 3), (30, 24), (2, 25)], [(0, 0), (33, 0), (33, 25), (0, 25)]),
556-
([(3, 2), (32, 3), (30, 24), (2, 25)], [(5, 5), (30, 3), (33, 19), (4, 25)]),
557-
]
558-
for r in [0, ]:
559-
for spoints, epoints in test_configs:
560-
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
561-
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
562-
563-
for fn in [F.perspective, scripted_tranform]:
564-
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r)
565-
566-
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
567-
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
568-
# Tolerance : less than 5% of different pixels
569-
self.assertLess(
570-
ratio_diff_pixels,
571-
0.05,
572-
msg="{}: {}\n{} vs \n{}".format(
573-
(r, spoints, epoints), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
549+
550+
for tensor, pil_img in [self._create_data(26, 34), self._create_data(26, 26)]:
551+
552+
scripted_tranform = torch.jit.script(F.perspective)
553+
554+
test_configs = [
555+
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
556+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
557+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
558+
]
559+
for r in [0, ]:
560+
for spoints, epoints in test_configs:
561+
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
562+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
563+
564+
for fn in [F.perspective, scripted_tranform]:
565+
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r)
566+
567+
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
568+
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
569+
# Tolerance : less than 3% of different pixels
570+
self.assertLess(
571+
ratio_diff_pixels,
572+
0.03,
573+
msg="{}: {}\n{} vs \n{}".format(
574+
(r, spoints, epoints),
575+
ratio_diff_pixels,
576+
out_tensor[0, :7, :7],
577+
out_pil_tensor[0, :7, :7]
578+
)
574579
)
575-
)
576580

577581

578582
if __name__ == '__main__':

torchvision/transforms/functional.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -492,38 +492,39 @@ def hflip(img: Tensor) -> Tensor:
492492

493493

494494
def _get_perspective_coeffs(
495-
startpoints: List[Tuple[int, int]], endpoints: List[Tuple[int, int]]
495+
startpoints: List[List[int]], endpoints: List[List[int]]
496496
) -> List[float]:
497497
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
498498
499499
In Perspective Transform each pixel (x, y) in the original image gets transformed as,
500500
(x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
501501
502502
Args:
503-
startpoints (list of tuples): List containing four tuples of two integers corresponding to four corners
503+
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
504504
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
505-
endpoints (list of tuples): List containing four tuples of two integers corresponding to four corners
505+
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
506506
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
507507
508508
Returns:
509509
octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
510510
"""
511-
matrix = []
511+
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)
512512

513-
for p1, p2 in zip(endpoints, startpoints):
514-
matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
515-
matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
513+
for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
514+
a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
515+
a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
516516

517-
A = torch.tensor(matrix, dtype=torch.float)
518-
B = torch.tensor(startpoints, dtype=torch.float).view(8)
519-
res = torch.lstsq(B, A)[0]
520-
return res.squeeze_(1).tolist()
517+
b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
518+
res = torch.lstsq(b_matrix, a_matrix)[0]
519+
# We have to explicitly produce the list of floats, otherwise torch.jit.script does recognize output type
520+
# RuntimeError: Expected type hint for result of tolist()
521+
return [float(i.item()) for i in res[:, 0]]
521522

522523

523524
def perspective(
524525
img: Tensor,
525-
startpoints: List[Tuple[int, int]],
526-
endpoints: List[Tuple[int, int]],
526+
startpoints: List[List[int]],
527+
endpoints: List[List[int]],
527528
interpolation: int = 3,
528529
fill: Optional[int] = None
529530
) -> Tensor:
@@ -533,9 +534,9 @@ def perspective(
533534
534535
Args:
535536
img (PIL Image or Tensor): Image to be transformed.
536-
startpoints (list of tuples): List containing four tuples of two integers corresponding to four corners
537+
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
537538
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
538-
endpoints (list of tuples): List containing four tuples of two integers corresponding to four corners
539+
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
539540
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
540541
interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and
541542
``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BICUBIC`` for PIL images and
@@ -546,15 +547,20 @@ def perspective(
546547
input. Fill value for the area outside the transform in the output image is always 0.
547548
548549
Returns:
549-
PIL Image or Tensor: Perspectively transformed Image.
550+
PIL Image or Tensor: transformed Image.
550551
"""
551552

552553
coeffs = _get_perspective_coeffs(startpoints, endpoints)
553554

554555
if not isinstance(img, torch.Tensor):
555556
return F_pil.perspective(img, coeffs, interpolation=interpolation, fill=fill)
556557

557-
return F_t.perspective()
558+
if interpolation == Image.BICUBIC:
559+
# bicubic is not supported by pytorch
560+
# set to bilinear interpolation
561+
interpolation = 2
562+
563+
return F_t.perspective(img, coeffs, interpolation=interpolation, fill=fill)
558564

559565

560566
def vflip(img: Tensor) -> Tensor:

torchvision/transforms/functional_tensor.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -620,17 +620,25 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
620620

621621

622622
def _assert_grid_transform_inputs(
623-
img: Tensor, matrix: List[float], resample: int, fillcolor: Optional[int], _interpolation_modes: Dict[int, str]
623+
img: Tensor,
624+
matrix: Optional[List[float]],
625+
resample: int,
626+
fillcolor: Optional[int],
627+
_interpolation_modes: Dict[int, str],
628+
coeffs: Optional[List[float]] = None,
624629
):
625630
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
626631
raise TypeError("img should be Tensor Image. Got {}".format(type(img)))
627632

628-
if not isinstance(matrix, list):
633+
if matrix is not None and not isinstance(matrix, list):
629634
raise TypeError("Argument matrix should be a list. Got {}".format(type(matrix)))
630635

631-
if len(matrix) != 6:
636+
if matrix is not None and len(matrix) != 6:
632637
raise ValueError("Argument matrix should have 6 float values")
633638

639+
if coeffs is not None and len(coeffs) != 8:
640+
raise ValueError("Argument coeffs should have 8 float values")
641+
634642
if fillcolor is not None:
635643
warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero")
636644

@@ -775,6 +783,37 @@ def rotate(
775783
return _apply_grid_transform(img, grid, mode)
776784

777785

786+
def _perspective_grid(coeffs: List[float], ow: int, oh: int):
787+
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
788+
# src/libImaging/Geometry.c#L394
789+
790+
#
791+
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
792+
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
793+
#
794+
795+
theta1 = torch.tensor([[
796+
[coeffs[0], coeffs[1], coeffs[2]],
797+
[coeffs[3], coeffs[4], coeffs[5]]
798+
]])
799+
theta2 = torch.tensor([[
800+
[coeffs[6], coeffs[7], 1.0],
801+
[coeffs[6], coeffs[7], 1.0]
802+
]])
803+
804+
d = 0.5
805+
base_grid = torch.empty(1, oh, ow, 3)
806+
base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow))
807+
base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1))
808+
base_grid[..., 2].fill_(1)
809+
810+
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh]))
811+
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
812+
813+
output_grid = output_grid1 / output_grid2 - 1.0
814+
return output_grid.view(1, oh, ow, 2)
815+
816+
778817
def perspective(
779818
img: Tensor, perspective_coeffs: List[float], interpolation: int = 2, fill: Optional[int] = None
780819
) -> Tensor:
@@ -783,14 +822,32 @@ def perspective(
783822
Args:
784823
img (Tensor): Image to be transformed.
785824
perspective_coeffs (list of float): perspective transformation coefficients.
786-
interpolation (int): Interpolation type. Default, ``Image.BICUBIC``.
825+
interpolation (int): Interpolation type. Default, ``PIL.Image.BILINEAR``.
787826
fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area
788827
outside the transform in the output image is always 0.
789828
790829
Returns:
791-
Tensor: Perspectively transformed Image.
830+
Tensor: transformed image.
792831
"""
793832
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
794833
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
795834

796-
return None
835+
_interpolation_modes = {
836+
0: "nearest",
837+
2: "bilinear",
838+
}
839+
840+
_assert_grid_transform_inputs(
841+
img,
842+
matrix=None,
843+
resample=interpolation,
844+
fillcolor=fill,
845+
_interpolation_modes=_interpolation_modes,
846+
coeffs=perspective_coeffs
847+
)
848+
849+
ow, oh = img.shape[-1], img.shape[-2]
850+
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh)
851+
mode = _interpolation_modes[interpolation]
852+
853+
return _apply_grid_transform(img, grid, mode)

0 commit comments

Comments
 (0)