From 5acb309d64de78f7b2f5a71b710c2be5f3968ee0 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 4 Nov 2022 09:54:13 +0000 Subject: [PATCH 1/4] [proto] small optim for perspective op on images, reverted concat trick on bboxes --- .../transforms/functional/_geometry.py | 65 +++++++++++++++---- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index f2a12d6f609..6c59f53062e 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -4,8 +4,7 @@ import PIL.Image import torch -from torch.nn.functional import interpolate, pad as torch_pad - +from torch.nn.functional import interpolate from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms.functional import ( @@ -16,6 +15,7 @@ pil_to_tensor, to_pil_image, ) +from torchvision.transforms.functional_tensor import _parse_pad_padding from ._meta import convert_format_bounding_box, get_spatial_size_image_pil @@ -906,6 +906,36 @@ def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: i return crop_image_pil(inpt, top, left, height, width) +def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ + # src/libImaging/Geometry.c#L394 + + # + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + # + + theta1 = torch.tensor( + [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device + ) + theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device) + + d = 0.5 + base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) + x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + base_grid[..., 2].fill_(1) + + rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device) + output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) + output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) + + output_grid = output_grid1.div_(output_grid2).sub_(1.0) + return output_grid.view(1, oh, ow, 2) + + def perspective_image_tensor( image: torch.Tensor, perspective_coeffs: List[float], @@ -923,7 +953,19 @@ def perspective_image_tensor( else: needs_unsquash = False - output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill) + _FT._assert_grid_transform_inputs( + image, + matrix=None, + interpolation=interpolation.value, + fill=fill, + supported_interpolation_modes=["nearest", "bilinear"], + coeffs=perspective_coeffs, + ) + + ow, oh = image.shape[-1], image.shape[-2] + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device) + output = _FT._apply_grid_transform(image, grid, interpolation.value, fill=fill) if needs_unsquash: output = output.reshape(shape) @@ -988,16 +1030,16 @@ def perspective_bounding_box( (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, ] - theta12_T = torch.tensor( - [ - [inv_coeffs[0], inv_coeffs[3], inv_coeffs[6], inv_coeffs[6]], - [inv_coeffs[1], inv_coeffs[4], inv_coeffs[7], inv_coeffs[7]], - [inv_coeffs[2], inv_coeffs[5], 1.0, 1.0], - ], + theta1 = torch.tensor( + [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], dtype=dtype, device=device, ) + theta2 = torch.tensor( + [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device + ) + # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). # Tensor of points has shape (N * 4, 3), where N is the number of bboxes # Single point structure is similar to @@ -1008,9 +1050,8 @@ def perspective_bounding_box( # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) - numer_denom_points = torch.matmul(points, theta12_T) - numer_points = numer_denom_points[:, :2] - denom_points = numer_denom_points[:, 2:] + numer_points = torch.matmul(points, theta1.T) + denom_points = torch.matmul(points, theta2.T) transformed_points = numer_points.div_(denom_points) # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] From 70e45717a29c87ff462b7badc991df568f67c44c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 4 Nov 2022 10:13:54 +0000 Subject: [PATCH 2/4] revert unrelated changes --- torchvision/prototype/transforms/functional/_geometry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6c59f53062e..2710177e8d5 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -4,7 +4,8 @@ import PIL.Image import torch -from torch.nn.functional import interpolate +from torch.nn.functional import interpolate, pad as torch_pad + from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms.functional import ( @@ -15,7 +16,6 @@ pil_to_tensor, to_pil_image, ) -from torchvision.transforms.functional_tensor import _parse_pad_padding from ._meta import convert_format_bounding_box, get_spatial_size_image_pil From a91b646700635f176fe8d84428501ed4d5e25850 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 4 Nov 2022 10:38:42 +0000 Subject: [PATCH 3/4] PR review updates --- torchvision/prototype/transforms/functional/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 2710177e8d5..3210b97e1b4 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -928,7 +928,7 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, base_grid[..., 1].copy_(y_grid) base_grid[..., 2].fill_(1) - rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device) + rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)) output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) From 2b858e0209cefb1154667586394cc0db6ae165b4 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 4 Nov 2022 12:25:08 +0000 Subject: [PATCH 4/4] PR review change --- torchvision/prototype/transforms/functional/_geometry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index dc6358265b2..adf494b1c42 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -962,11 +962,10 @@ def perspective_image_tensor( fill: features.FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> torch.Tensor: + perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) if image.numel() == 0: return image - perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) - shape = image.shape if image.ndim > 4: