diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 361a921b18e..25daf3da59f 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -915,7 +915,7 @@ def sample_inputs_rotate_video(): reference_inputs_fn=reference_inputs_rotate_image_tensor, float32_vs_uint8=True, # TODO: investigate - closeness_kwargs=pil_reference_pixel_difference(100, agg_method="mean"), + closeness_kwargs=pil_reference_pixel_difference(110, agg_method="mean"), test_marks=[ xfail_jit_tuple_instead_of_list("fill"), # TODO: check if this is a regression since it seems that should be supported if `int` is ok diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 0cc52f8b838..d82d9ebea4f 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -401,6 +401,7 @@ def __init__( ArgsKwargs(p=1, distortion_scale=0.1, fill=1), ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)), ], + closeness_kwargs={"atol": None, "rtol": None}, ), ConsistencyConfig( prototype_transforms.RandomRotation, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 98bd7a52712..41262185b5d 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,16 +1,16 @@ +import math import numbers import warnings from typing import List, Optional, Sequence, Tuple, Union import PIL.Image import torch -from torch.nn.functional import interpolate, pad as torch_pad +from torch.nn.functional import grid_sample, interpolate, pad as torch_pad from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms.functional import ( _compute_resized_output_size as __compute_resized_output_size, - _get_inverse_affine_matrix, _get_perspective_coeffs, InterpolationMode, pil_modes_mapping, @@ -272,6 +272,195 @@ def _affine_parse_args( return angle, translate, shear, center +def _get_inverse_affine_matrix( + center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True +) -> List[float]: + # Helper method to compute inverse matrix for affine transformation + + # Pillow requires inverse affine transformation matrix: + # Affine matrix is : M = T * C * RotateScaleShear * C^-1 + # + # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] + # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] + # RotateScaleShear is rotation with scale and shear matrix + # + # RotateScaleShear(a, s, (sx, sy)) = + # = R(a) * S(s) * SHy(sy) * SHx(sx) + # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ] + # [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ] + # [ 0 , 0 , 1 ] + # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: + # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] + # [0, 1 ] [-tan(s), 1] + # + # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1 + + rot = math.radians(angle) + sx = math.radians(shear[0]) + sy = math.radians(shear[1]) + + cx, cy = center + tx, ty = translate + + # Cached results + cos_sy = math.cos(sy) + tan_sx = math.tan(sx) + rot_minus_sy = rot - sy + cx_plus_tx = cx + tx + cy_plus_ty = cy + ty + + # Rotate Scale Shear (RSS) without scaling + a = math.cos(rot_minus_sy) / cos_sy + b = -(a * tan_sx + math.sin(rot)) + c = math.sin(rot_minus_sy) / cos_sy + d = math.cos(rot) - c * tan_sx + + if inverted: + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0] + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + # and then apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty + matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty + else: + matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0] + # Apply inverse of center translation: RSS * C^-1 + # and then apply translation and center : T * C * RSS * C^-1 + matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy + matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy + + return matrix + + +def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: + # Inspired of PIL implementation: + # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 + + # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + # Points are shifted due to affine matrix torch convention about + # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5) + half_w = 0.5 * w + half_h = 0.5 * h + pts = torch.tensor( + [ + [-half_w, -half_h, 1.0], + [-half_w, half_h, 1.0], + [half_w, half_h, 1.0], + [half_w, -half_h, 1.0], + ] + ) + theta = torch.tensor(matrix, dtype=torch.float).view(2, 3) + new_pts = torch.matmul(pts, theta.T) + min_vals, max_vals = new_pts.aminmax(dim=0) + + # shift points to [0, w] and [0, h] interval to match PIL results + halfs = torch.tensor((half_w, half_h)) + min_vals.add_(halfs) + max_vals.add_(halfs) + + # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 + tol = 1e-4 + inv_tol = 1.0 / tol + cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_() + cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_() + size = cmax.sub_(cmin) + return int(size[0]), int(size[1]) # w, h + + +def _apply_grid_transform( + float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: features.FillTypeJIT +) -> torch.Tensor: + + shape = float_img.shape + if shape[0] > 1: + # Apply same grid to a batch of images + grid = grid.expand(shape[0], -1, -1, -1) + + # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice + if fill is not None: + mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device) + float_img = torch.cat((float_img, mask), dim=1) + + float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False) + + # Fill with required color + if fill is not None: + float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3) + mask = mask.expand_as(float_img) + fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] + fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1) + if mode == "nearest": + bool_mask = mask < 0.5 + float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask] + else: # 'bilinear' + # The following is mathematically equivalent to: + # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill + float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img) + + return float_img + + +def _assert_grid_transform_inputs( + image: torch.Tensor, + matrix: Optional[List[float]], + interpolation: str, + fill: features.FillTypeJIT, + supported_interpolation_modes: List[str], + coeffs: Optional[List[float]] = None, +) -> None: + if matrix is not None: + if not isinstance(matrix, list): + raise TypeError("Argument matrix should be a list") + elif len(matrix) != 6: + raise ValueError("Argument matrix should have 6 float values") + + if coeffs is not None and len(coeffs) != 8: + raise ValueError("Argument coeffs should have 8 float values") + + if fill is not None: + if isinstance(fill, (tuple, list)): + length = len(fill) + num_channels = image.shape[-3] + if length > 1 and length != num_channels: + raise ValueError( + "The number of elements in 'fill' cannot broadcast to match the number of " + f"channels of the image ({length} != {num_channels})" + ) + elif not isinstance(fill, (int, float)): + raise ValueError("Argument fill should be either int, float, tuple or list") + + if interpolation not in supported_interpolation_modes: + raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input") + + +def _affine_grid( + theta: torch.Tensor, + w: int, + h: int, + ow: int, + oh: int, +) -> torch.Tensor: + # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ + # AffineGridGenerator.cpp#L18 + # Difference with AffineGridGenerator is that: + # 1) we normalize grid values after applying theta + # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate + dtype = theta.dtype + device = theta.device + + base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) + x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + base_grid[..., 2].fill_(1) + + rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device)) + output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta) + return output_grid.view(1, oh, ow, 2) + + def affine_image_tensor( image: torch.Tensor, angle: Union[int, float], @@ -286,9 +475,19 @@ def affine_image_tensor( return image shape = image.shape - num_channels, height, width = shape[-3:] - image = image.reshape(-1, num_channels, height, width) + ndim = image.ndim + fp = torch.is_floating_point(image) + if ndim > 4: + image = image.reshape((-1,) + shape[-3:]) + needs_unsquash = True + elif ndim == 3: + image = image.unsqueeze(0) + needs_unsquash = True + else: + needs_unsquash = False + + height, width = shape[-2:] angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) center_f = [0.0, 0.0] @@ -299,8 +498,20 @@ def affine_image_tensor( translate_f = [float(t) for t in translate] matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) - output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill) - return output.reshape(shape) + _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) + + dtype = image.dtype if fp else torch.float32 + theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) + grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height) + output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) + + if not fp: + output = output.round_().to(image.dtype) + + if needs_unsquash: + output = output.reshape(shape) + + return output @torch.jit.unused @@ -395,7 +606,7 @@ def _affine_bounding_box_xyxy( out_bboxes.sub_(tr.repeat((1, 2))) # Estimate meta-data for image with inverted=True and with center=[0,0] affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) - new_width, new_height = _FT._compute_affine_output_size(affine_vector, width, height) + new_width, new_height = _compute_affine_output_size(affine_vector, width, height) spatial_size = (new_height, new_width) return out_bboxes.to(bounding_box.dtype), spatial_size @@ -543,18 +754,26 @@ def rotate_image_tensor( matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) if image.numel() > 0: - image = _FT.rotate( - image.reshape(-1, num_channels, height, width), - matrix, - interpolation=interpolation.value, - expand=expand, - fill=fill, - ) - new_height, new_width = image.shape[-2:] + fp = torch.is_floating_point(image) + image = image.reshape(-1, num_channels, height, width) + + _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) + + ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height) + dtype = image.dtype if fp else torch.float32 + theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) + grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh) + output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) + + if not fp: + output = output.round_().to(image.dtype) + + new_height, new_width = output.shape[-2:] else: - new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height) + output = image + new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height) - return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) + return output.reshape(shape[:-3] + (num_channels, new_height, new_width)) @torch.jit.unused @@ -944,7 +1163,6 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) # - # TODO: should we define them transposed? theta1 = torch.tensor( [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device ) @@ -959,8 +1177,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, base_grid[..., 2].fill_(1) rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)) - output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) - output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) + shape = (1, oh * ow, 3) + output_grid1 = base_grid.view(shape).bmm(rescaled_theta1) + output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2)) output_grid = output_grid1.div_(output_grid2).sub_(1.0) return output_grid.view(1, oh, ow, 2) @@ -996,14 +1215,19 @@ def perspective_image_tensor( return image shape = image.shape + ndim = image.ndim + fp = torch.is_floating_point(image) - if image.ndim > 4: + if ndim > 4: image = image.reshape((-1,) + shape[-3:]) needs_unsquash = True + elif ndim == 3: + image = image.unsqueeze(0) + needs_unsquash = True else: needs_unsquash = False - _FT._assert_grid_transform_inputs( + _assert_grid_transform_inputs( image, matrix=None, interpolation=interpolation.value, @@ -1012,10 +1236,13 @@ def perspective_image_tensor( coeffs=perspective_coeffs, ) - ow, oh = image.shape[-1], image.shape[-2] - dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + oh, ow = shape[-2:] + dtype = image.dtype if fp else torch.float32 grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device) - output = _FT._apply_grid_transform(image, grid, interpolation.value, fill=fill) + output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill) + + if not fp: + output = output.round_().to(image.dtype) if needs_unsquash: output = output.reshape(shape) @@ -1086,7 +1313,6 @@ def perspective_bounding_box( (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, ] - # TODO: should we define them transposed? theta1 = torch.tensor( [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], dtype=dtype, @@ -1193,17 +1419,25 @@ def elastic_image_tensor( return image shape = image.shape + ndim = image.ndim device = image.device + fp = torch.is_floating_point(image) - if image.ndim > 4: + if ndim > 4: image = image.reshape((-1,) + shape[-3:]) needs_unsquash = True + elif ndim == 3: + image = image.unsqueeze(0) + needs_unsquash = True else: needs_unsquash = False image_height, image_width = shape[-2:] grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device)) - output = _FT._apply_grid_transform(image, grid, interpolation.value, fill) + output = _apply_grid_transform(image if fp else image.to(torch.float32), grid, interpolation.value, fill=fill) + + if not fp: + output = output.round_().to(image.dtype) if needs_unsquash: output = output.reshape(shape) @@ -1361,7 +1595,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) - image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0) + image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0) image_height, image_width = image.shape[-2:] if crop_width == image_width and crop_height == image_height: