From 234f113ec9911877a60848c521cfe19fd8ddd3d1 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 8 Mar 2022 13:58:14 +0000 Subject: [PATCH 01/14] Added functional affine_bounding_box op with tests --- test/test_prototype_transforms_functional.py | 138 ++++++++++++++++++ .../transforms/functional/__init__.py | 1 + .../transforms/functional/_geometry.py | 47 ++++++ torchvision/transforms/functional.py | 38 ++--- 4 files changed, 207 insertions(+), 17 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 57bd53b4c40..76ca7ee50e9 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1,14 +1,18 @@ import functools import itertools +import math +import numpy as np import pytest import torch.testing import torchvision.prototype.transforms.functional as F from torch import jit from torch.nn.functional import one_hot from torchvision.prototype import features +from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format from torchvision.transforms.functional_tensor import _max_value as get_max_value + make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") @@ -205,6 +209,45 @@ def resize_bounding_box(): yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size) +@register_kernel_info_from_sample_inputs_fn +def affine_image_tensor(): + for image, angle, translate, scale, shear in itertools.product( + make_images(extra_dims=()), + [-87, 15, 90], # angle + [5, -5], # translate + [0.77, 1.27], # scale + [0, 12], # shear + ): + yield SampleInput( + image, + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + interpolation=F.InterpolationMode.NEAREST, + ) + + +@register_kernel_info_from_sample_inputs_fn +def affine_bounding_box(): + for bounding_box, angle, translate, scale, shear in itertools.product( + make_bounding_boxes(), + [-87, 15, 90], # angle + [5, -5], # translate + [0.77, 1.27], # scale + [0, 12], # shear + ): + yield SampleInput( + bounding_box, + format=bounding_box.format, + image_size=bounding_box.image_size, + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + ) + + @pytest.mark.parametrize( "kernel", [ @@ -233,3 +276,98 @@ def test_eager_vs_scripted(functional_info, sample_input): scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs) torch.testing.assert_close(eager, scripted) + + +@pytest.mark.parametrize("angle", range(-90, 90, 36)) +@pytest.mark.parametrize("translate", range(-10, 10, 5)) +@pytest.mark.parametrize("scale", [0.77, 1.0, 1.27]) +@pytest.mark.parametrize("shear", range(-15, 15, 5)) +@pytest.mark.parametrize("center", [None, (12, 14)]) +def test_correctness_affine_bounding_box(angle, translate, scale, shear, center): + def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): + rot = math.radians(angle_) + cx, cy = center_ + tx, ty = translate_ + sx, sy = [math.radians(sh_) for sh_ in shear_] + + c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) + t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) + c_matrix_inv = np.linalg.inv(c_matrix) + rs_matrix = np.array( + [ + [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0], + [scale_ * math.sin(rot), scale_ * math.cos(rot), 0], + [0, 0, 1], + ] + ) + shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) + shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) + rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix)) + true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv))) + true_matrix = true_matrix[:2, :] + + bbox_xyxy = convert_bounding_box_format( + bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY + ) + points = np.array( + [ + [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], + [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0], + [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0], + [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], + ] + ) + transformed_points = points @ true_matrix.T + out_bbox = [ + np.min(transformed_points[:, 0]), + np.min(transformed_points[:, 1]), + np.max(transformed_points[:, 0]), + np.max(transformed_points[:, 1]), + ] + out_bbox = features.BoundingBox( + out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=(32, 32), dtype=torch.float32 + ) + out_bbox = convert_bounding_box_format( + out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format + ) + return out_bbox + + image_size = (32, 32) + + for bboxes in make_bounding_boxes( + image_sizes=[ + image_size, + ], + extra_dims=((4,),), + ): + output_bboxes = F.affine_bounding_box( + bboxes, + bboxes.format, + image_size=image_size, + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + center=center, + ) + if center is None: + center = [s // 2 for s in image_size] + + bboxes_format = bboxes.format + bboxes_image_size = bboxes.image_size + if bboxes.ndim < 2: + bboxes = [ + bboxes, + ] + + expected_bboxes = [] + for bbox in bboxes: + bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) + expected_bboxes.append( + _compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center) + ) + expected_bboxes = torch.stack(expected_bboxes) + if expected_bboxes.shape[0] < 2: + expected_bboxes = expected_bboxes.squeeze(0) + + torch.testing.assert_close(output_bboxes, expected_bboxes) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index c0825784f66..5e023e4885b 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -48,6 +48,7 @@ center_crop_image_pil, resized_crop_image_tensor, resized_crop_image_pil, + affine_bounding_box, affine_image_tensor, affine_image_pil, rotate_image_tensor, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 84d1fe963c9..7b3c19702c0 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -174,6 +174,53 @@ def affine_image_pil( return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) +def affine_bounding_box( + bounding_box: torch.Tensor, + format: features.BoundingBoxFormat, + image_size: Tuple[int, int], + angle: float, + translate: List[float], + scale: float, + shear: List[float], + center: Optional[List[float]] = None, +) -> torch.Tensor: + original_shape = bounding_box.shape + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 + device = bounding_box.device + + if center is None: + height, width = image_size + center_f = [width * 0.5, height * 0.5] + else: + center_f = [float(c) for c in center] + + translate_f = [float(t) for t in translate] + affine_matrix = torch.tensor( + _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear, inverted=False), + dtype=dtype, + device=device, + ).view(2, 3) + # bboxes to 4 points like: + # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1), ...] + points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) + points = torch.cat([points, torch.ones(points.shape[0], 1)], dim=-1) + transformed_points = points @ affine_matrix.T + # reshape transformed points to [N boxes, 4 points, x/y coords] + transformed_points = transformed_points.view(-1, 4, 2) + # compute bounding box from 4 transformed points: + out_bbox_mins, _ = torch.min(transformed_points, dim=1) + out_bbox_maxs, _ = torch.max(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) + # out_bboxes should be of shape [N boxes, 4] + return convert_bounding_box_format(out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format).view( + original_shape + ) + + def rotate_image_tensor( img: torch.Tensor, angle: float, diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5b762ff2975..7234d923fbc 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -931,11 +931,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: def _get_inverse_affine_matrix( - center: List[float], - angle: float, - translate: List[float], - scale: float, - shear: List[float], + center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True ) -> List[float]: # Helper method to compute inverse matrix for affine transformation @@ -970,18 +966,26 @@ def _get_inverse_affine_matrix( c = math.sin(rot - sy) / math.cos(sy) d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot) - # Inverted rotation matrix with scale and shear - # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 - matrix = [d, -b, 0.0, -c, a, 0.0] - matrix = [x / scale for x in matrix] - - # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 - matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) - matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) - - # Apply center translation: C * RSS^-1 * C^-1 * T^-1 - matrix[2] += cx - matrix[5] += cy + if inverted: + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + matrix = [d, -b, 0.0, -c, a, 0.0] + matrix = [x / scale for x in matrix] + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) + matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += cx + matrix[5] += cy + else: + matrix = [a, b, 0.0, c, d, 0.0] + matrix = [x * scale for x in matrix] + # Apply inverse of center translation: RSS * C^-1 + matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy) + matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy) + # Apply translation and center : T * C * RSS * C^-1 + matrix[2] += cx + tx + matrix[5] += cy + ty return matrix From a24fca7a24cd6877380de592ca603888228acc2d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 14 Mar 2022 14:27:28 +0000 Subject: [PATCH 02/14] Updated comments and added another test case --- test/test_prototype_transforms_functional.py | 52 ++++++++++++++++++- .../transforms/functional/_geometry.py | 14 +++-- 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 76ca7ee50e9..2b0ee947284 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -317,7 +317,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], ] ) - transformed_points = points @ true_matrix.T + transformed_points = np.matmul(points, true_matrix.T) out_bbox = [ np.min(transformed_points[:, 0]), np.min(transformed_points[:, 1]), @@ -371,3 +371,53 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): expected_bboxes = expected_bboxes.squeeze(0) torch.testing.assert_close(output_bboxes, expected_bboxes) + + +def test_correctness_affine_bounding_box_on_fixed_input(): + # Check transformation against known expected output + image_size = (64, 64) + # xyxy format + in_boxes = [ + [20, 25, 35, 45], + [50, 5, 70, 22], + [image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10], + [1, 1, 5, 5], + ] + in_boxes = features.BoundingBox( + in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64 + ) + # Tested parameters + angle = 63 + scale = 0.89 + dx = 0.12 + dy = 0.23 + + # Expected bboxes computed using albumentations: + # from albumentations.augmentations.geometric.functional import bbox_shift_scale_rotate + # from albumentations.augmentations.geometric.functional import normalize_bbox, denormalize_bbox + # expected_bboxes = [] + # for in_box in in_boxes: + # n_in_box = normalize_bbox(in_box, *image_size) + # n_out_box = bbox_shift_scale_rotate(n_in_box, -angle, scale, dx, dy, *image_size) + # out_box = denormalize_bbox(n_out_box, *image_size) + # expected_bboxes.append(out_box) + expected_bboxes = [ + (24.522435977922218, 34.375689508290854, 46.443125279998114, 54.3516575015695), + (54.88288587110401, 50.08453280875634, 76.44484547743795, 72.81332520036864), + (27.709526487041554, 34.74952648704156, 51.650473512958435, 58.69047351295844), + (48.56528888843238, 9.611532109828834, 53.35347829361575, 14.39972151501221), + ] + + output_boxes = F.affine_bounding_box( + in_boxes, + in_boxes.format, + in_boxes.image_size, + angle, + (dx * image_size[1], dy * image_size[0]), + scale, + shear=(0, 0), + ) + + assert len(output_boxes) == len(expected_bboxes) + for a_out_box, out_box in zip(expected_bboxes, output_boxes): + np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 7b3c19702c0..0c2bb868968 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -204,18 +204,22 @@ def affine_bounding_box( dtype=dtype, device=device, ).view(2, 3) - # bboxes to 4 points like: - # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1), ...] + # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). + # Tensor of points has shape (N * 4, 3), where N is the number of bboxes + # Single point structure is similar to + # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) points = torch.cat([points, torch.ones(points.shape[0], 1)], dim=-1) - transformed_points = points @ affine_matrix.T - # reshape transformed points to [N boxes, 4 points, x/y coords] + # 2) Now let's transform the points using affine matrix + transformed_points = torch.matmul(points, affine_matrix.T) + # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] + # and compute bounding box from 4 transformed points: transformed_points = transformed_points.view(-1, 4, 2) - # compute bounding box from 4 transformed points: out_bbox_mins, _ = torch.min(transformed_points, dim=1) out_bbox_maxs, _ = torch.max(transformed_points, dim=1) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) # out_bboxes should be of shape [N boxes, 4] + return convert_bounding_box_format(out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format).view( original_shape ) From a872483c1bd8f99b4b4a5cdf47a4123b6fd87670 Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 14 Mar 2022 16:15:41 +0100 Subject: [PATCH 03/14] Update _geometry.py --- torchvision/prototype/transforms/functional/_geometry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 94c70296761..5f9f5c54d0e 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -220,9 +220,9 @@ def affine_bounding_box( out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) # out_bboxes should be of shape [N boxes, 4] - return convert_bounding_box_format(out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format).view( - original_shape - ) + return convert_bounding_box_format( + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + ).view(original_shape) def rotate_image_tensor( From 7ab7d8adc77b132686aae6ca6d556cbd33760b14 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 14 Mar 2022 17:56:57 +0000 Subject: [PATCH 04/14] Added affine_segmentation_mask with tests --- test/test_prototype_transforms_functional.py | 164 ++++++++++++++---- .../transforms/functional/__init__.py | 1 + .../transforms/functional/_geometry.py | 23 ++- 3 files changed, 151 insertions(+), 37 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 2b0ee947284..ea8a2ac3512 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -6,13 +6,13 @@ import pytest import torch.testing import torchvision.prototype.transforms.functional as F +from common_utils import cpu_and_gpu from torch import jit from torch.nn.functional import one_hot from torchvision.prototype import features from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format from torchvision.transforms.functional_tensor import _max_value as get_max_value - make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") @@ -138,6 +138,22 @@ def make_one_hot_labels( yield make_one_hot_label(extra_dims_) +def make_segmentation_mask(size=None, *, max_value=80, extra_dims=(), dtype=torch.uint8): + size = size or torch.randint(16, 33, (2,)).tolist() + shape = (*extra_dims, 1, *size) + data = make_tensor(shape, low=0, high=max_value, dtype=dtype) + return features.SegmentationMask(data) + + +def make_segmentation_masks( + image_sizes=((32, 32), (32, 42)), + dtypes=(torch.long,), + extra_dims=((), (4,)), +): + for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims): + yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_) + + class SampleInput: def __init__(self, *args, **kwargs): self.args = args @@ -248,6 +264,24 @@ def affine_bounding_box(): ) +@register_kernel_info_from_sample_inputs_fn +def affine_segmentation_mask(): + for image, angle, translate, scale, shear in itertools.product( + make_images(extra_dims=()), + [-87, 15, 90], # angle + [5, -5], # translate + [0.77, 1.27], # scale + [0, 12], # shear + ): + yield SampleInput( + image, + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + ) + + @pytest.mark.parametrize( "kernel", [ @@ -278,33 +312,38 @@ def test_eager_vs_scripted(functional_info, sample_input): torch.testing.assert_close(eager, scripted) -@pytest.mark.parametrize("angle", range(-90, 90, 36)) -@pytest.mark.parametrize("translate", range(-10, 10, 5)) +def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): + rot = math.radians(angle_) + cx, cy = center_ + tx, ty = translate_ + sx, sy = [math.radians(sh_) for sh_ in shear_] + + c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) + t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) + c_matrix_inv = np.linalg.inv(c_matrix) + rs_matrix = np.array( + [ + [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0], + [scale_ * math.sin(rot), scale_ * math.cos(rot), 0], + [0, 0, 1], + ] + ) + shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) + shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) + rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix)) + true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv))) + return true_matrix + + +@pytest.mark.parametrize("angle", range(-90, 90, 56)) +@pytest.mark.parametrize("translate", range(-10, 10, 8)) @pytest.mark.parametrize("scale", [0.77, 1.0, 1.27]) -@pytest.mark.parametrize("shear", range(-15, 15, 5)) +@pytest.mark.parametrize("shear", range(-15, 15, 8)) @pytest.mark.parametrize("center", [None, (12, 14)]) def test_correctness_affine_bounding_box(angle, translate, scale, shear, center): def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): - rot = math.radians(angle_) - cx, cy = center_ - tx, ty = translate_ - sx, sy = [math.radians(sh_) for sh_ in shear_] - - c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) - t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) - c_matrix_inv = np.linalg.inv(c_matrix) - rs_matrix = np.array( - [ - [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0], - [scale_ * math.sin(rot), scale_ * math.cos(rot), 0], - [0, 0, 1], - ] - ) - shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) - shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) - rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix)) - true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv))) - true_matrix = true_matrix[:2, :] + affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_) + affine_matrix = affine_matrix[:2, :] bbox_xyxy = convert_bounding_box_format( bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY @@ -317,7 +356,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], ] ) - transformed_points = np.matmul(points, true_matrix.T) + transformed_points = np.matmul(points, affine_matrix.T) out_bbox = [ np.min(transformed_points[:, 0]), np.min(transformed_points[:, 1]), @@ -328,11 +367,11 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=(32, 32), dtype=torch.float32 ) out_bbox = convert_bounding_box_format( - out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format + out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False ) - return out_bbox + return out_bbox.to(bbox.device) - image_size = (32, 32) + image_size = (32, 38) for bboxes in make_bounding_boxes( image_sizes=[ @@ -351,7 +390,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): center=center, ) if center is None: - center = [s // 2 for s in image_size] + center = [s // 2 for s in image_size[::-1]] bboxes_format = bboxes.format bboxes_image_size = bboxes.image_size @@ -366,14 +405,15 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): expected_bboxes.append( _compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center) ) - expected_bboxes = torch.stack(expected_bboxes) - if expected_bboxes.shape[0] < 2: - expected_bboxes = expected_bboxes.squeeze(0) - + if len(expected_bboxes) > 1: + expected_bboxes = torch.stack(expected_bboxes) + else: + expected_bboxes = expected_bboxes[0] torch.testing.assert_close(output_bboxes, expected_bboxes) -def test_correctness_affine_bounding_box_on_fixed_input(): +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_affine_bounding_box_on_fixed_input(device): # Check transformation against known expected output image_size = (64, 64) # xyxy format @@ -385,7 +425,7 @@ def test_correctness_affine_bounding_box_on_fixed_input(): ] in_boxes = features.BoundingBox( in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64 - ) + ).to(device) # Tested parameters angle = 63 scale = 0.89 @@ -419,5 +459,57 @@ def test_correctness_affine_bounding_box_on_fixed_input(): ) assert len(output_boxes) == len(expected_bboxes) - for a_out_box, out_box in zip(expected_bboxes, output_boxes): + for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()): np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box) + + +@pytest.mark.parametrize("angle", [-54, 56]) +@pytest.mark.parametrize("translate", [-7, 8]) +@pytest.mark.parametrize("scale", [0.89, 1.12]) +@pytest.mark.parametrize("shear", [4]) +@pytest.mark.parametrize("center", [None, (12, 14)]) +def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, center): + def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): + assert mask.ndim == 3 and mask.shape[0] == 1 + affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_) + inv_affine_matrix = np.linalg.inv(affine_matrix) + inv_affine_matrix = inv_affine_matrix[:2, :] + + expected_mask = torch.zeros_like(mask.cpu()) + for out_y in range(expected_mask.shape[1]): + for out_x in range(expected_mask.shape[2]): + output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0]) + input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32) + in_x, in_y = input_pt[:2] + if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]: + expected_mask[0, out_y, out_x] = mask[0, in_y, in_x] + return expected_mask.to(mask.device) + + for mask in make_segmentation_masks(): + output_mask = F.affine_segmentation_mask( + mask, + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + center=center, + ) + if center is None: + center = [s // 2 for s in mask.shape[-2:][::-1]] + + if mask.ndim < 4: + masks = [ + mask, + ] + else: + masks = [m for m in mask] + + expected_masks = [] + for mask in masks: + expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center) + expected_masks.append(expected_mask) + if len(expected_masks) > 1: + expected_masks = torch.stack(expected_masks) + else: + expected_masks = expected_masks[0] + torch.testing.assert_close(output_mask, expected_masks) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 105f3ab95c3..de8c54fd682 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -51,6 +51,7 @@ affine_bounding_box, affine_image_tensor, affine_image_pil, + affine_segmentation_mask, rotate_image_tensor, rotate_image_pil, pad_image_tensor, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 5f9f5c54d0e..8c37184b126 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -209,7 +209,7 @@ def affine_bounding_box( # Single point structure is similar to # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) - points = torch.cat([points, torch.ones(points.shape[0], 1)], dim=-1) + points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) # 2) Now let's transform the points using affine matrix transformed_points = torch.matmul(points, affine_matrix.T) # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] @@ -225,6 +225,27 @@ def affine_bounding_box( ).view(original_shape) +def affine_segmentation_mask( + img: torch.Tensor, + angle: float, + translate: List[float], + scale: float, + shear: List[float], + center: Optional[List[float]] = None, +) -> torch.Tensor: + angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, center=center) + + center_f = [0.0, 0.0] + if center is not None: + _, height, width = get_dimensions_image_tensor(img) + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] + + translate_f = [1.0 * t for t in translate] + matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) + return _FT.affine(img, matrix, interpolation=InterpolationMode.NEAREST.value, fill=None) + + def rotate_image_tensor( img: torch.Tensor, angle: float, From 36ed30a6d19cbaa0859a4ec25b7e3e07215eb198 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 14 Mar 2022 18:00:33 +0000 Subject: [PATCH 05/14] Fixed device mismatch issue Added a cude/cpu test Reduced the number of test samples --- test/test_prototype_transforms_functional.py | 78 ++++++++++--------- .../transforms/functional/_geometry.py | 2 +- 2 files changed, 43 insertions(+), 37 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 2b0ee947284..2ba6b1115d7 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -6,13 +6,13 @@ import pytest import torch.testing import torchvision.prototype.transforms.functional as F +from common_utils import cpu_and_gpu from torch import jit from torch.nn.functional import one_hot from torchvision.prototype import features from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format from torchvision.transforms.functional_tensor import _max_value as get_max_value - make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") @@ -278,33 +278,38 @@ def test_eager_vs_scripted(functional_info, sample_input): torch.testing.assert_close(eager, scripted) -@pytest.mark.parametrize("angle", range(-90, 90, 36)) -@pytest.mark.parametrize("translate", range(-10, 10, 5)) +def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): + rot = math.radians(angle_) + cx, cy = center_ + tx, ty = translate_ + sx, sy = [math.radians(sh_) for sh_ in shear_] + + c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) + t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) + c_matrix_inv = np.linalg.inv(c_matrix) + rs_matrix = np.array( + [ + [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0], + [scale_ * math.sin(rot), scale_ * math.cos(rot), 0], + [0, 0, 1], + ] + ) + shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) + shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) + rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix)) + true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv))) + return true_matrix + + +@pytest.mark.parametrize("angle", range(-90, 90, 56)) +@pytest.mark.parametrize("translate", range(-10, 10, 8)) @pytest.mark.parametrize("scale", [0.77, 1.0, 1.27]) -@pytest.mark.parametrize("shear", range(-15, 15, 5)) +@pytest.mark.parametrize("shear", range(-15, 15, 8)) @pytest.mark.parametrize("center", [None, (12, 14)]) def test_correctness_affine_bounding_box(angle, translate, scale, shear, center): def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): - rot = math.radians(angle_) - cx, cy = center_ - tx, ty = translate_ - sx, sy = [math.radians(sh_) for sh_ in shear_] - - c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) - t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) - c_matrix_inv = np.linalg.inv(c_matrix) - rs_matrix = np.array( - [ - [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0], - [scale_ * math.sin(rot), scale_ * math.cos(rot), 0], - [0, 0, 1], - ] - ) - shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) - shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) - rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix)) - true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv))) - true_matrix = true_matrix[:2, :] + affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_) + affine_matrix = affine_matrix[:2, :] bbox_xyxy = convert_bounding_box_format( bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY @@ -317,7 +322,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], ] ) - transformed_points = np.matmul(points, true_matrix.T) + transformed_points = np.matmul(points, affine_matrix.T) out_bbox = [ np.min(transformed_points[:, 0]), np.min(transformed_points[:, 1]), @@ -328,11 +333,11 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=(32, 32), dtype=torch.float32 ) out_bbox = convert_bounding_box_format( - out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format + out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False ) - return out_bbox + return out_bbox.to(bbox.device) - image_size = (32, 32) + image_size = (32, 38) for bboxes in make_bounding_boxes( image_sizes=[ @@ -351,7 +356,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): center=center, ) if center is None: - center = [s // 2 for s in image_size] + center = [s // 2 for s in image_size[::-1]] bboxes_format = bboxes.format bboxes_image_size = bboxes.image_size @@ -366,14 +371,15 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): expected_bboxes.append( _compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center) ) - expected_bboxes = torch.stack(expected_bboxes) - if expected_bboxes.shape[0] < 2: - expected_bboxes = expected_bboxes.squeeze(0) - + if len(expected_bboxes) > 1: + expected_bboxes = torch.stack(expected_bboxes) + else: + expected_bboxes = expected_bboxes[0] torch.testing.assert_close(output_bboxes, expected_bboxes) -def test_correctness_affine_bounding_box_on_fixed_input(): +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_affine_bounding_box_on_fixed_input(device): # Check transformation against known expected output image_size = (64, 64) # xyxy format @@ -385,7 +391,7 @@ def test_correctness_affine_bounding_box_on_fixed_input(): ] in_boxes = features.BoundingBox( in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64 - ) + ).to(device) # Tested parameters angle = 63 scale = 0.89 @@ -419,5 +425,5 @@ def test_correctness_affine_bounding_box_on_fixed_input(): ) assert len(output_boxes) == len(expected_bboxes) - for a_out_box, out_box in zip(expected_bboxes, output_boxes): + for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()): np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 5f9f5c54d0e..a0ca51785be 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -209,7 +209,7 @@ def affine_bounding_box( # Single point structure is similar to # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) - points = torch.cat([points, torch.ones(points.shape[0], 1)], dim=-1) + points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) # 2) Now let's transform the points using affine matrix transformed_points = torch.matmul(points, affine_matrix.T) # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] From d0030513aa8b342fa743fae78f60111cd402781d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 15 Mar 2022 22:02:44 +0000 Subject: [PATCH 06/14] Added test_correctness_affine_segmentation_mask_on_fixed_input --- test/test_prototype_transforms_functional.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index ea8a2ac3512..8e98b7eb5fa 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -513,3 +513,22 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): else: expected_masks = expected_masks[0] torch.testing.assert_close(output_mask, expected_masks) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_affine_segmentation_mask_on_fixed_input(device): + # Check transformation against known expected output + # Rotate 90 degrees and scale + mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device) + mask[0, 2:10, 2:10] = 1 + mask[0, 32 - 9 : 32 - 3, 3:9] = 2 + mask[0, 1:11, 32 - 11 : 32 - 1] = 3 + mask[0, 16 - 4 : 16 + 4, 16 - 4 : 16 + 4] = 4 + + expected_mask = torch.rot90(mask, k=-1, dims=(-2, -1)) + expected_mask = torch.nn.functional.interpolate(expected_mask[None, ...].float(), size=(64, 64), mode="nearest") + expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long() + + out_mask = F.affine_segmentation_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0]) + + assert out_mask.allclose(expected_mask) From 7e89062bc57bca3f1943489ee6fa5cb658c1f009 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 16 Mar 2022 09:51:32 +0000 Subject: [PATCH 07/14] Updates according to the review --- test/test_prototype_transforms_functional.py | 13 ++++++------ .../transforms/functional/_geometry.py | 20 +++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 8e98b7eb5fa..b88ab2feb72 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -146,7 +146,7 @@ def make_segmentation_mask(size=None, *, max_value=80, extra_dims=(), dtype=torc def make_segmentation_masks( - image_sizes=((32, 32), (32, 42)), + image_sizes=((32, 32), (32, 42), (38, 24)), dtypes=(torch.long,), extra_dims=((), (4,)), ): @@ -498,9 +498,7 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): center = [s // 2 for s in mask.shape[-2:][::-1]] if mask.ndim < 4: - masks = [ - mask, - ] + masks = [mask] else: masks = [m for m in mask] @@ -518,17 +516,20 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): @pytest.mark.parametrize("device", cpu_and_gpu()) def test_correctness_affine_segmentation_mask_on_fixed_input(device): # Check transformation against known expected output - # Rotate 90 degrees and scale + + # Create a fixed input segmentation mask with 4 square masks + # in top-left, top-right, bottom-right corners and in the center mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device) mask[0, 2:10, 2:10] = 1 mask[0, 32 - 9 : 32 - 3, 3:9] = 2 mask[0, 1:11, 32 - 11 : 32 - 1] = 3 mask[0, 16 - 4 : 16 + 4, 16 - 4 : 16 + 4] = 4 + # Rotate 90 degrees and scale expected_mask = torch.rot90(mask, k=-1, dims=(-2, -1)) expected_mask = torch.nn.functional.interpolate(expected_mask[None, ...].float(), size=(64, 64), mode="nearest") expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long() out_mask = F.affine_segmentation_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0]) - assert out_mask.allclose(expected_mask) + torch.testing.assert_close(out_mask, expected_mask) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index d8b187a20e5..fe10239d242 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -237,17 +237,15 @@ def affine_segmentation_mask( shear: List[float], center: Optional[List[float]] = None, ) -> torch.Tensor: - angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, center=center) - - center_f = [0.0, 0.0] - if center is not None: - _, height, width = get_dimensions_image_tensor(img) - # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] - - translate_f = [1.0 * t for t in translate] - matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) - return _FT.affine(img, matrix, interpolation=InterpolationMode.NEAREST.value, fill=None) + return affine_image_tensor( + img, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=InterpolationMode.NEAREST, + center=center, + ) def rotate_image_tensor( From a2be66655de23c7d9382b1b0aad3cbb718212e13 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 21 Mar 2022 12:06:49 +0000 Subject: [PATCH 08/14] Replaced [None, ...] by [None, :] --- test/test_prototype_transforms_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index b88ab2feb72..1809824d3ff 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -527,7 +527,7 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device): # Rotate 90 degrees and scale expected_mask = torch.rot90(mask, k=-1, dims=(-2, -1)) - expected_mask = torch.nn.functional.interpolate(expected_mask[None, ...].float(), size=(64, 64), mode="nearest") + expected_mask = torch.nn.functional.interpolate(expected_mask[None, :].float(), size=(64, 64), mode="nearest") expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long() out_mask = F.affine_segmentation_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0]) From 9d6ac7497362139bff953c42f470608ed60ad33d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 23 Mar 2022 10:42:51 +0000 Subject: [PATCH 09/14] Adressed review comments --- test/test_prototype_transforms_functional.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 1809824d3ff..722162aff20 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -138,17 +138,17 @@ def make_one_hot_labels( yield make_one_hot_label(extra_dims_) -def make_segmentation_mask(size=None, *, max_value=80, extra_dims=(), dtype=torch.uint8): +def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype=torch.long): size = size or torch.randint(16, 33, (2,)).tolist() shape = (*extra_dims, 1, *size) - data = make_tensor(shape, low=0, high=max_value, dtype=dtype) + data = make_tensor(shape, low=0, high=num_categories, dtype=dtype) return features.SegmentationMask(data) def make_segmentation_masks( image_sizes=((32, 32), (32, 42), (38, 24)), dtypes=(torch.long,), - extra_dims=((), (4,)), + extra_dims=((), (4,), (2, 3)), ): for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims): yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_) @@ -228,7 +228,7 @@ def resize_bounding_box(): @register_kernel_info_from_sample_inputs_fn def affine_image_tensor(): for image, angle, translate, scale, shear in itertools.product( - make_images(extra_dims=()), + make_segmentation_masks(extra_dims=((), (4,))), [-87, 15, 90], # angle [5, -5], # translate [0.77, 1.27], # scale @@ -485,7 +485,7 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): expected_mask[0, out_y, out_x] = mask[0, in_y, in_x] return expected_mask.to(mask.device) - for mask in make_segmentation_masks(): + for mask in make_segmentation_masks(extra_dims=((), (4, ))): output_mask = F.affine_segmentation_mask( mask, angle=angle, From d17decb73bc55c32b2fa0093c9efc2e4cb5854fa Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 23 Mar 2022 10:59:28 +0000 Subject: [PATCH 10/14] Fixed formatting and more updates according to the review --- test/test_prototype_transforms_functional.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 722162aff20..955627670f7 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -146,7 +146,7 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype def make_segmentation_masks( - image_sizes=((32, 32), (32, 42), (38, 24)), + image_sizes=((16, 16), (7, 33), (31, 9)), dtypes=(torch.long,), extra_dims=((), (4,), (2, 3)), ): @@ -485,7 +485,7 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): expected_mask[0, out_y, out_x] = mask[0, in_y, in_x] return expected_mask.to(mask.device) - for mask in make_segmentation_masks(extra_dims=((), (4, ))): + for mask in make_segmentation_masks(extra_dims=((), (4,))): output_mask = F.affine_segmentation_mask( mask, angle=angle, @@ -515,15 +515,13 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): @pytest.mark.parametrize("device", cpu_and_gpu()) def test_correctness_affine_segmentation_mask_on_fixed_input(device): - # Check transformation against known expected output + # Check transformation against known expected output and CPU/CUDA devices - # Create a fixed input segmentation mask with 4 square masks - # in top-left, top-right, bottom-right corners and in the center + # Create a fixed input segmentation mask with 2 square masks + # in top-left, bottom-left corners mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device) mask[0, 2:10, 2:10] = 1 mask[0, 32 - 9 : 32 - 3, 3:9] = 2 - mask[0, 1:11, 32 - 11 : 32 - 1] = 3 - mask[0, 16 - 4 : 16 + 4, 16 - 4 : 16 + 4] = 4 # Rotate 90 degrees and scale expected_mask = torch.rot90(mask, k=-1, dims=(-2, -1)) From f4c22430a6f730f7e5f707d88f662c192016454d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 23 Mar 2022 11:34:29 +0000 Subject: [PATCH 11/14] Fixed bad merge --- test/test_prototype_transforms_functional.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index b081f1eb5a4..7057fdaa2a9 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -228,7 +228,7 @@ def resize_bounding_box(): @register_kernel_info_from_sample_inputs_fn def affine_image_tensor(): for image, angle, translate, scale, shear in itertools.product( - make_segmentation_masks(extra_dims=((), (4,))), + make_images(extra_dims=((), (4,))), [-87, 15, 90], # angle [5, -5], # translate [0.77, 1.27], # scale @@ -267,7 +267,7 @@ def affine_bounding_box(): @register_kernel_info_from_sample_inputs_fn def affine_segmentation_mask(): for image, angle, translate, scale, shear in itertools.product( - make_images(extra_dims=()), + make_segmentation_masks(extra_dims=((), (4,))), [-87, 15, 90], # angle [5, -5], # translate [0.77, 1.27], # scale @@ -279,6 +279,7 @@ def affine_segmentation_mask(): translate=(translate, translate), scale=scale, shear=(shear, shear), + ) @register_kernel_info_from_sample_inputs_fn From 60376109550559af1781f3be8c7e7ea06e9fe496 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 25 Mar 2022 11:34:46 +0000 Subject: [PATCH 12/14] WIP --- test/test_prototype_transforms_functional.py | 132 +++++++++++++++--- .../transforms/functional/__init__.py | 1 + .../transforms/functional/_geometry.py | 19 +++ torchvision/transforms/functional_tensor.py | 4 + 4 files changed, 138 insertions(+), 18 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 7057fdaa2a9..9458987cd55 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -266,7 +266,7 @@ def affine_bounding_box(): @register_kernel_info_from_sample_inputs_fn def affine_segmentation_mask(): - for image, angle, translate, scale, shear in itertools.product( + for mask, angle, translate, scale, shear in itertools.product( make_segmentation_masks(extra_dims=((), (4,))), [-87, 15, 90], # angle [5, -5], # translate @@ -274,7 +274,7 @@ def affine_segmentation_mask(): [0, 12], # shear ): yield SampleInput( - image, + mask, angle=angle, translate=(translate, translate), scale=scale, @@ -285,8 +285,12 @@ def affine_segmentation_mask(): @register_kernel_info_from_sample_inputs_fn def rotate_bounding_box(): for bounding_box, angle, expand, center in itertools.product( - make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] # angle # expand # center + make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] ): + if center is not None and expand: + # Skip warning: The provided center argument is ignored if expand is True + continue + yield SampleInput( bounding_box, format=bounding_box.format, @@ -297,6 +301,26 @@ def rotate_bounding_box(): ) +@register_kernel_info_from_sample_inputs_fn +def rotate_segmentation_mask(): + for mask, angle, expand, center in itertools.product( + make_segmentation_masks(extra_dims=((), (4,))), + [-87, 15, 90], # angle + [True, False], # expand + [None, [12, 23]] # center + ): + if center is not None and expand: + # Skip warning: The provided center argument is ignored if expand is True + continue + + yield SampleInput( + mask, + angle=angle, + expand=expand, + center=center, + ) + + @pytest.mark.parametrize( "kernel", [ @@ -408,8 +432,9 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): center=center, ) - if center is None: - center = [s // 2 for s in bboxes_image_size[::-1]] + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in bboxes_image_size[::-1]] if bboxes.ndim < 2: bboxes = [bboxes] @@ -418,7 +443,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): for bbox in bboxes: bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) expected_bboxes.append( - _compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center) + _compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center_) ) if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) @@ -509,8 +534,10 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): shear=(shear, shear), center=center, ) - if center is None: - center = [s // 2 for s in mask.shape[-2:][::-1]] + + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in mask.shape[-2:][::-1]] if mask.ndim < 4: masks = [mask] @@ -519,7 +546,7 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): expected_masks = [] for mask in masks: - expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center) + expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center_) expected_masks.append(expected_mask) if len(expected_masks) > 1: expected_masks = torch.stack(expected_masks) @@ -549,9 +576,9 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device): @pytest.mark.parametrize("angle", range(-90, 90, 56)) -@pytest.mark.parametrize("expand", [True, False]) -@pytest.mark.parametrize("center", [None, (12, 14)]) +@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) def test_correctness_rotate_bounding_box(angle, expand, center): + def _compute_expected_bbox(bbox, angle_, expand_, center_): affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) affine_matrix = affine_matrix[:2, :] @@ -616,8 +643,9 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): center=center, ) - if center is None: - center = [s // 2 for s in bboxes_image_size[::-1]] + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in bboxes_image_size[::-1]] if bboxes.ndim < 2: bboxes = [bboxes] @@ -625,19 +653,16 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): expected_bboxes = [] for bbox in bboxes: bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) - expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center)) + expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center_)) if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) else: expected_bboxes = expected_bboxes[0] - print("input:", bboxes) - print("output_bboxes:", output_bboxes) - print("expected_bboxes:", expected_bboxes) torch.testing.assert_close(output_bboxes, expected_bboxes) @pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2, analysis in progress +@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2 def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): # Check transformation against known expected output image_size = (64, 64) @@ -690,3 +715,74 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): assert len(output_boxes) == len(expected_bboxes) for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()): np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box) + + +@pytest.mark.parametrize("angle", range(-90, 90, 56)) +@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) +def test_correctness_rotate_segmentation_mask(angle, expand, center): + + def _compute_expected_mask(mask, angle_, expand_, center_): + assert mask.ndim == 3 and mask.shape[0] == 1 + + image_size = mask.shape[-2:] + print(image_size) + if expand_: + rot = math.radians(angle_) + h, w = image_size + print(h, w) + abs_cos, abs_sin = (abs(math.cos(rot)), abs(math.sin(rot))) + new_size_f = [ + (h * abs_sin + w * abs_cos), + (h * abs_cos + w * abs_sin) + ] + center_ = [s * 0.5 for s in new_size_f] + print(center_) + image_size = (int(new_size_f[1]), int(new_size_f[0])) + print(image_size) + + affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) + + inv_affine_matrix = np.linalg.inv(affine_matrix) + inv_affine_matrix = inv_affine_matrix[:2, :] + + print("image size:", image_size) + expected_mask = torch.zeros(1, *image_size, dtype=mask.dtype) + + for out_y in range(expected_mask.shape[1]): + for out_x in range(expected_mask.shape[2]): + output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0]) + input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32) + in_x, in_y = input_pt[:2] + if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]: + expected_mask[0, out_y, out_x] = mask[0, in_y, in_x] + return expected_mask.to(mask.device) + + for mask in make_segmentation_masks(extra_dims=((), (4,))): + print("\n-- mask:", mask.shape, center) + output_mask = F.rotate_segmentation_mask( + mask, + angle=angle, + expand=expand, + center=center, + ) + + center_ = center + if center_ is None: + center_ = [s * 0.5 for s in mask.shape[-2:][::-1]] + + if mask.ndim < 4: + masks = [mask] + else: + masks = [m for m in mask] + + expected_masks = [] + for mask in masks: + expected_mask = _compute_expected_mask(mask, -angle, expand, center_) + expected_masks.append(expected_mask) + if len(expected_masks) > 1: + expected_masks = torch.stack(expected_masks) + else: + expected_masks = expected_masks[0] + print("output_mask:", output_mask.shape) + print("expected_masks:", expected_masks.shape) + torch.testing.assert_close(output_mask, expected_masks) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 51bf73a18f7..e8f25342a18 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -56,6 +56,7 @@ rotate_bounding_box, rotate_image_tensor, rotate_image_pil, + rotate_segmentation_mask, pad_image_tensor, pad_image_pil, pad_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 71882f06270..638504c1eeb 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -361,6 +361,10 @@ def rotate_bounding_box( expand: bool = False, center: Optional[List[float]] = None, ) -> torch.Tensor: + if center is not None and expand: + warnings.warn("The provided center argument is ignored if expand is True") + center = None + original_shape = bounding_box.shape bounding_box = convert_bounding_box_format( bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY @@ -373,6 +377,21 @@ def rotate_bounding_box( ).view(original_shape) +def rotate_segmentation_mask( + img: torch.Tensor, + angle: float, + expand: bool = False, + center: Optional[List[float]] = None, +) -> torch.Tensor: + return rotate_image_tensor( + img, + angle=angle, + expand=expand, + interpolation=InterpolationMode.NEAREST, + center=center, + ) + + pad_image_tensor = _FT.pad pad_image_pil = _FP.pad diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 6bcd1ea85da..2e5b15dbb2b 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -663,6 +663,10 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] min_vals, _ = new_pts.min(dim=0) max_vals, _ = new_pts.max(dim=0) + # shift points to [0, w] and [0, h] interval to match PIL results + min_vals += torch.tensor((w * 0.5, h * 0.5)) + max_vals += torch.tensor((w * 0.5, h * 0.5)) + # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 tol = 1e-4 cmax = torch.ceil((max_vals / tol).trunc_() * tol) From 8cb3510f3bada7ffafdf372b2d463c3a8a30708d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 28 Mar 2022 11:44:53 +0000 Subject: [PATCH 13/14] Fixed tests --- test/test_prototype_transforms_functional.py | 48 +++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 0b3cbb7b1bc..3876beea5c4 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -307,7 +307,7 @@ def rotate_segmentation_mask(): make_segmentation_masks(extra_dims=((), (4,))), [-87, 15, 90], # angle [True, False], # expand - [None, [12, 23]] # center + [None, [12, 23]], # center ): if center is not None and expand: # Skip warning: The provided center argument is ignored if expand is True @@ -579,7 +579,6 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device): @pytest.mark.parametrize("angle", range(-90, 90, 56)) @pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) def test_correctness_rotate_bounding_box(angle, expand, center): - def _compute_expected_bbox(bbox, angle_, expand_, center_): affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) affine_matrix = affine_matrix[:2, :] @@ -719,17 +718,38 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) -@pytest.mark.parametrize("angle", range(-90, 90, 56)) +@pytest.mark.parametrize("angle", range(-90, 90, 37)) @pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) def test_correctness_rotate_segmentation_mask(angle, expand, center): - def _compute_expected_mask(mask, angle_, expand_, center_): assert mask.ndim == 3 and mask.shape[0] == 1 - image_size = mask.shape[-2:] affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) - inv_affine_matrix = np.linalg.inv(affine_matrix) + + if expand_: + # Pillow implementation on how to perform expand: + # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054-L2069 + height, width = image_size + points = np.array( + [ + [0.0, 0.0, 1.0], + [0.0, 1.0 * height, 1.0], + [1.0 * width, 1.0 * height, 1.0], + [1.0 * width, 0.0, 1.0], + ] + ) + new_points = points @ inv_affine_matrix.T + min_vals = np.min(new_points, axis=0)[:2] + max_vals = np.max(new_points, axis=0)[:2] + cmax = np.ceil(np.trunc(max_vals * 1e4) * 1e-4) + cmin = np.floor(np.trunc((min_vals + 1e-8) * 1e4) * 1e-4) + new_width, new_height = (cmax - cmin).astype("int32").tolist() + tr = np.array([-(new_width - width) / 2.0, -(new_height - height) / 2.0, 1.0]) @ inv_affine_matrix.T + + inv_affine_matrix[:2, 2] = tr[:2] + image_size = [new_height, new_width] + inv_affine_matrix = inv_affine_matrix[:2, :] expected_mask = torch.zeros(1, *image_size, dtype=mask.dtype) @@ -768,3 +788,19 @@ def _compute_expected_mask(mask, angle_, expand_, center_): else: expected_masks = expected_masks[0] torch.testing.assert_close(output_mask, expected_masks) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_rotate_segmentation_mask_on_fixed_input(device): + # Check transformation against known expected output and CPU/CUDA devices + + # Create a fixed input segmentation mask with 2 square masks + # in top-left, bottom-left corners + mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device) + mask[0, 2:10, 2:10] = 1 + mask[0, 32 - 9 : 32 - 3, 3:9] = 2 + + # Rotate 90 degrees + expected_mask = torch.rot90(mask, k=1, dims=(-2, -1)) + out_mask = F.rotate_segmentation_mask(mask, 90, expand=False) + torch.testing.assert_close(out_mask, expected_mask) From ee9f3d6cb60357ee77761aaa3fb2f8bf07485a57 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 4 Apr 2022 10:49:51 +0000 Subject: [PATCH 14/14] Updated warning message --- torchvision/prototype/transforms/functional/_geometry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 638504c1eeb..7629766c0e2 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -324,7 +324,7 @@ def rotate_image_tensor( center_f = [0.0, 0.0] if center is not None: if expand: - warnings.warn("The provided center argument is ignored if expand is True") + warnings.warn("The provided center argument has no effect on the result if expand is True") else: _, height, width = get_dimensions_image_tensor(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. @@ -345,7 +345,7 @@ def rotate_image_pil( center: Optional[List[float]] = None, ) -> PIL.Image.Image: if center is not None and expand: - warnings.warn("The provided center argument is ignored if expand is True") + warnings.warn("The provided center argument has no effect on the result if expand is True") center = None return _FP.rotate( @@ -362,7 +362,7 @@ def rotate_bounding_box( center: Optional[List[float]] = None, ) -> torch.Tensor: if center is not None and expand: - warnings.warn("The provided center argument is ignored if expand is True") + warnings.warn("The provided center argument has no effect on the result if expand is True") center = None original_shape = bounding_box.shape