diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 57bd53b4c40..2ba6b1115d7 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1,12 +1,16 @@ import functools import itertools +import math +import numpy as np import pytest import torch.testing import torchvision.prototype.transforms.functional as F +from common_utils import cpu_and_gpu from torch import jit from torch.nn.functional import one_hot from torchvision.prototype import features +from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format from torchvision.transforms.functional_tensor import _max_value as get_max_value make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") @@ -205,6 +209,45 @@ def resize_bounding_box(): yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size) +@register_kernel_info_from_sample_inputs_fn +def affine_image_tensor(): + for image, angle, translate, scale, shear in itertools.product( + make_images(extra_dims=()), + [-87, 15, 90], # angle + [5, -5], # translate + [0.77, 1.27], # scale + [0, 12], # shear + ): + yield SampleInput( + image, + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + interpolation=F.InterpolationMode.NEAREST, + ) + + +@register_kernel_info_from_sample_inputs_fn +def affine_bounding_box(): + for bounding_box, angle, translate, scale, shear in itertools.product( + make_bounding_boxes(), + [-87, 15, 90], # angle + [5, -5], # translate + [0.77, 1.27], # scale + [0, 12], # shear + ): + yield SampleInput( + bounding_box, + format=bounding_box.format, + image_size=bounding_box.image_size, + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + ) + + @pytest.mark.parametrize( "kernel", [ @@ -233,3 +276,154 @@ def test_eager_vs_scripted(functional_info, sample_input): scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs) torch.testing.assert_close(eager, scripted) + + +def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): + rot = math.radians(angle_) + cx, cy = center_ + tx, ty = translate_ + sx, sy = [math.radians(sh_) for sh_ in shear_] + + c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) + t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) + c_matrix_inv = np.linalg.inv(c_matrix) + rs_matrix = np.array( + [ + [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0], + [scale_ * math.sin(rot), scale_ * math.cos(rot), 0], + [0, 0, 1], + ] + ) + shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) + shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) + rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix)) + true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv))) + return true_matrix + + +@pytest.mark.parametrize("angle", range(-90, 90, 56)) +@pytest.mark.parametrize("translate", range(-10, 10, 8)) +@pytest.mark.parametrize("scale", [0.77, 1.0, 1.27]) +@pytest.mark.parametrize("shear", range(-15, 15, 8)) +@pytest.mark.parametrize("center", [None, (12, 14)]) +def test_correctness_affine_bounding_box(angle, translate, scale, shear, center): + def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): + affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_) + affine_matrix = affine_matrix[:2, :] + + bbox_xyxy = convert_bounding_box_format( + bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY + ) + points = np.array( + [ + [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], + [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0], + [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0], + [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], + ] + ) + transformed_points = np.matmul(points, affine_matrix.T) + out_bbox = [ + np.min(transformed_points[:, 0]), + np.min(transformed_points[:, 1]), + np.max(transformed_points[:, 0]), + np.max(transformed_points[:, 1]), + ] + out_bbox = features.BoundingBox( + out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=(32, 32), dtype=torch.float32 + ) + out_bbox = convert_bounding_box_format( + out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False + ) + return out_bbox.to(bbox.device) + + image_size = (32, 38) + + for bboxes in make_bounding_boxes( + image_sizes=[ + image_size, + ], + extra_dims=((4,),), + ): + output_bboxes = F.affine_bounding_box( + bboxes, + bboxes.format, + image_size=image_size, + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + center=center, + ) + if center is None: + center = [s // 2 for s in image_size[::-1]] + + bboxes_format = bboxes.format + bboxes_image_size = bboxes.image_size + if bboxes.ndim < 2: + bboxes = [ + bboxes, + ] + + expected_bboxes = [] + for bbox in bboxes: + bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) + expected_bboxes.append( + _compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center) + ) + if len(expected_bboxes) > 1: + expected_bboxes = torch.stack(expected_bboxes) + else: + expected_bboxes = expected_bboxes[0] + torch.testing.assert_close(output_bboxes, expected_bboxes) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_affine_bounding_box_on_fixed_input(device): + # Check transformation against known expected output + image_size = (64, 64) + # xyxy format + in_boxes = [ + [20, 25, 35, 45], + [50, 5, 70, 22], + [image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10], + [1, 1, 5, 5], + ] + in_boxes = features.BoundingBox( + in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64 + ).to(device) + # Tested parameters + angle = 63 + scale = 0.89 + dx = 0.12 + dy = 0.23 + + # Expected bboxes computed using albumentations: + # from albumentations.augmentations.geometric.functional import bbox_shift_scale_rotate + # from albumentations.augmentations.geometric.functional import normalize_bbox, denormalize_bbox + # expected_bboxes = [] + # for in_box in in_boxes: + # n_in_box = normalize_bbox(in_box, *image_size) + # n_out_box = bbox_shift_scale_rotate(n_in_box, -angle, scale, dx, dy, *image_size) + # out_box = denormalize_bbox(n_out_box, *image_size) + # expected_bboxes.append(out_box) + expected_bboxes = [ + (24.522435977922218, 34.375689508290854, 46.443125279998114, 54.3516575015695), + (54.88288587110401, 50.08453280875634, 76.44484547743795, 72.81332520036864), + (27.709526487041554, 34.74952648704156, 51.650473512958435, 58.69047351295844), + (48.56528888843238, 9.611532109828834, 53.35347829361575, 14.39972151501221), + ] + + output_boxes = F.affine_bounding_box( + in_boxes, + in_boxes.format, + in_boxes.image_size, + angle, + (dx * image_size[1], dy * image_size[0]), + scale, + shear=(0, 0), + ) + + assert len(output_boxes) == len(expected_bboxes) + for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()): + np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index fa65051dfac..6a317e87182 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -49,6 +49,7 @@ center_crop_image_pil, resized_crop_image_tensor, resized_crop_image_pil, + affine_bounding_box, affine_image_tensor, affine_image_pil, rotate_image_tensor, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6ee76228fbc..b35db660438 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -178,6 +178,57 @@ def affine_image_pil( return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) +def affine_bounding_box( + bounding_box: torch.Tensor, + format: features.BoundingBoxFormat, + image_size: Tuple[int, int], + angle: float, + translate: List[float], + scale: float, + shear: List[float], + center: Optional[List[float]] = None, +) -> torch.Tensor: + original_shape = bounding_box.shape + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 + device = bounding_box.device + + if center is None: + height, width = image_size + center_f = [width * 0.5, height * 0.5] + else: + center_f = [float(c) for c in center] + + translate_f = [float(t) for t in translate] + affine_matrix = torch.tensor( + _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear, inverted=False), + dtype=dtype, + device=device, + ).view(2, 3) + # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). + # Tensor of points has shape (N * 4, 3), where N is the number of bboxes + # Single point structure is similar to + # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] + points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) + points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) + # 2) Now let's transform the points using affine matrix + transformed_points = torch.matmul(points, affine_matrix.T) + # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] + # and compute bounding box from 4 transformed points: + transformed_points = transformed_points.view(-1, 4, 2) + out_bbox_mins, _ = torch.min(transformed_points, dim=1) + out_bbox_maxs, _ = torch.max(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) + # out_bboxes should be of shape [N boxes, 4] + + return convert_bounding_box_format( + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + ).view(original_shape) + + def rotate_image_tensor( img: torch.Tensor, angle: float, diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5b762ff2975..7234d923fbc 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -931,11 +931,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: def _get_inverse_affine_matrix( - center: List[float], - angle: float, - translate: List[float], - scale: float, - shear: List[float], + center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True ) -> List[float]: # Helper method to compute inverse matrix for affine transformation @@ -970,18 +966,26 @@ def _get_inverse_affine_matrix( c = math.sin(rot - sy) / math.cos(sy) d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot) - # Inverted rotation matrix with scale and shear - # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 - matrix = [d, -b, 0.0, -c, a, 0.0] - matrix = [x / scale for x in matrix] - - # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 - matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) - matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) - - # Apply center translation: C * RSS^-1 * C^-1 * T^-1 - matrix[2] += cx - matrix[5] += cy + if inverted: + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + matrix = [d, -b, 0.0, -c, a, 0.0] + matrix = [x / scale for x in matrix] + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) + matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += cx + matrix[5] += cy + else: + matrix = [a, b, 0.0, c, d, 0.0] + matrix = [x * scale for x in matrix] + # Apply inverse of center translation: RSS * C^-1 + matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy) + matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy) + # Apply translation and center : T * C * RSS * C^-1 + matrix[2] += cx + tx + matrix[5] += cy + ty return matrix