diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index dac43717d30..0735eff6575 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -11,8 +11,10 @@ from torch.nn.functional import one_hot from torchvision.prototype import features from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format +from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional_tensor import _max_value as get_max_value + make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") @@ -380,6 +382,37 @@ def pad_segmentation_mask(): yield SampleInput(mask, padding=padding, padding_mode=padding_mode) +@register_kernel_info_from_sample_inputs_fn +def perspective_bounding_box(): + for bounding_box, perspective_coeffs in itertools.product( + make_bounding_boxes(), + [ + [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], + [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], + ], + ): + yield SampleInput( + bounding_box, + format=bounding_box.format, + perspective_coeffs=perspective_coeffs, + ) + + +@register_kernel_info_from_sample_inputs_fn +def perspective_segmentation_mask(): + for mask, perspective_coeffs in itertools.product( + make_segmentation_masks(extra_dims=((), (4,))), + [ + [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], + [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], + ], + ): + yield SampleInput( + mask, + perspective_coeffs=perspective_coeffs, + ) + + @pytest.mark.parametrize( "kernel", [ @@ -985,7 +1018,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): ], ) def test_correctness_resized_crop_bounding_box(device, format, top, left, height, width, size): - def _compute_expected(bbox, top_, left_, height_, width_, size_): + def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_): # bbox should be xyxy bbox[0] = (bbox[0] - left_) * size_[1] / width_ bbox[1] = (bbox[1] - top_) * size_[0] / height_ @@ -1001,7 +1034,7 @@ def _compute_expected(bbox, top_, left_, height_, width_, size_): ] expected_bboxes = [] for in_box in in_boxes: - expected_bboxes.append(_compute_expected(list(in_box), top, left, height, width, size)) + expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size)) expected_bboxes = torch.tensor(expected_bboxes, device=device) in_boxes = features.BoundingBox( @@ -1027,7 +1060,7 @@ def _compute_expected(bbox, top_, left_, height_, width_, size_): ], ) def test_correctness_resized_crop_segmentation_mask(device, top, left, height, width, size): - def _compute_expected(mask, top_, left_, height_, width_, size_): + def _compute_expected_mask(mask, top_, left_, height_, width_, size_): output = mask.clone() output = output[:, top_ : top_ + height_, left_ : left_ + width_] output = torch.nn.functional.interpolate(output[None, :].float(), size=size_, mode="nearest") @@ -1038,7 +1071,7 @@ def _compute_expected(mask, top_, left_, height_, width_, size_): in_mask[0, 10:20, 10:20] = 1 in_mask[0, 5:15, 12:23] = 2 - expected_mask = _compute_expected(in_mask, top, left, height, width, size) + expected_mask = _compute_expected_mask(in_mask, top, left, height, width, size) output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size) torch.testing.assert_close(output_mask, expected_mask) @@ -1085,3 +1118,158 @@ def parse_padding(): expected_mask = _compute_expected_mask() torch.testing.assert_close(out_mask, expected_mask) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "startpoints, endpoints", + [ + [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]], + [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]], + [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]], + ], +) +def test_correctness_perspective_bounding_box(device, startpoints, endpoints): + def _compute_expected_bbox(bbox, pcoeffs_): + m1 = np.array( + [ + [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]], + [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]], + ] + ) + m2 = np.array( + [ + [pcoeffs_[6], pcoeffs_[7], 1.0], + [pcoeffs_[6], pcoeffs_[7], 1.0], + ] + ) + + bbox_xyxy = convert_bounding_box_format( + bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY + ) + points = np.array( + [ + [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], + [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0], + [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0], + [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], + ] + ) + numer = np.matmul(points, m1.T) + denom = np.matmul(points, m2.T) + transformed_points = numer / denom + out_bbox = [ + np.min(transformed_points[:, 0]), + np.min(transformed_points[:, 1]), + np.max(transformed_points[:, 0]), + np.max(transformed_points[:, 1]), + ] + out_bbox = features.BoundingBox( + out_bbox, + format=features.BoundingBoxFormat.XYXY, + image_size=bbox.image_size, + dtype=torch.float32, + device=bbox.device, + ) + return convert_bounding_box_format( + out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False + ) + + image_size = (32, 38) + + pcoeffs = _get_perspective_coeffs(startpoints, endpoints) + inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints) + + for bboxes in make_bounding_boxes( + image_sizes=[ + image_size, + ], + extra_dims=((4,),), + ): + bboxes = bboxes.to(device) + bboxes_format = bboxes.format + bboxes_image_size = bboxes.image_size + + output_bboxes = F.perspective_bounding_box( + bboxes, + bboxes_format, + perspective_coeffs=pcoeffs, + ) + + if bboxes.ndim < 2: + bboxes = [bboxes] + + expected_bboxes = [] + for bbox in bboxes: + bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) + expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) + if len(expected_bboxes) > 1: + expected_bboxes = torch.stack(expected_bboxes) + else: + expected_bboxes = expected_bboxes[0] + torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "startpoints, endpoints", + [ + [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]], + [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]], + [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]], + ], +) +def test_correctness_perspective_segmentation_mask(device, startpoints, endpoints): + def _compute_expected_mask(mask, pcoeffs_): + assert mask.ndim == 3 and mask.shape[0] == 1 + m1 = np.array( + [ + [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]], + [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]], + ] + ) + m2 = np.array( + [ + [pcoeffs_[6], pcoeffs_[7], 1.0], + [pcoeffs_[6], pcoeffs_[7], 1.0], + ] + ) + + expected_mask = torch.zeros_like(mask.cpu()) + for out_y in range(expected_mask.shape[1]): + for out_x in range(expected_mask.shape[2]): + output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0]) + + numer = np.matmul(output_pt, m1.T) + denom = np.matmul(output_pt, m2.T) + input_pt = np.floor(numer / denom).astype(np.int32) + + in_x, in_y = input_pt[:2] + if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]: + expected_mask[0, out_y, out_x] = mask[0, in_y, in_x] + return expected_mask.to(mask.device) + + pcoeffs = _get_perspective_coeffs(startpoints, endpoints) + + for mask in make_segmentation_masks(extra_dims=((), (4,))): + mask = mask.to(device) + + output_mask = F.perspective_segmentation_mask( + mask, + perspective_coeffs=pcoeffs, + ) + + if mask.ndim < 4: + masks = [mask] + else: + masks = [m for m in mask] + + expected_masks = [] + for mask in masks: + expected_mask = _compute_expected_mask(mask, pcoeffs) + expected_masks.append(expected_mask) + if len(expected_masks) > 1: + expected_masks = torch.stack(expected_masks) + else: + expected_masks = expected_masks[0] + torch.testing.assert_close(output_mask, expected_masks) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index c13a94035ea..59c53e70c26 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -67,8 +67,10 @@ crop_image_tensor, crop_image_pil, crop_segmentation_mask, + perspective_bounding_box, perspective_image_tensor, perspective_image_pil, + perspective_segmentation_mask, vertical_flip_image_tensor, vertical_flip_image_pil, vertical_flip_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 602f865f724..ae28bd84874 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -472,6 +472,95 @@ def perspective_image_pil( return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) +def perspective_bounding_box( + bounding_box: torch.Tensor, + format: features.BoundingBoxFormat, + perspective_coeffs: List[float], +) -> torch.Tensor: + + if len(perspective_coeffs) != 8: + raise ValueError("Argument perspective_coeffs should have 8 float values") + + original_shape = bounding_box.shape + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 + device = bounding_box.device + + # perspective_coeffs are computed as endpoint -> start point + # We have to invert perspective_coeffs for bboxes: + # (x, y) - end point and (x_out, y_out) - start point + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + # and we would like to get: + # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2]) + # / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1) + # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5]) + # / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1) + # and compute inv_coeffs in terms of coeffs + + denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3] + if denom == 0: + raise RuntimeError( + f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. " + f"Denominator is zero, denom={denom}" + ) + + inv_coeffs = [ + (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom, + (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom, + (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom, + (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom, + (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom, + (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom, + (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom, + (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, + ] + + theta1 = torch.tensor( + [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], + dtype=dtype, + device=device, + ) + + theta2 = torch.tensor( + [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device + ) + + # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). + # Tensor of points has shape (N * 4, 3), where N is the number of bboxes + # Single point structure is similar to + # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] + points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) + points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) + # 2) Now let's transform the points using perspective matrices + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + + numer_points = torch.matmul(points, theta1.T) + denom_points = torch.matmul(points, theta2.T) + transformed_points = numer_points / denom_points + + # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] + # and compute bounding box from 4 transformed points: + transformed_points = transformed_points.view(-1, 4, 2) + out_bbox_mins, _ = torch.min(transformed_points, dim=1) + out_bbox_maxs, _ = torch.max(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) + + # out_bboxes should be of shape [N boxes, 4] + + return convert_bounding_box_format( + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + ).view(original_shape) + + +def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: + return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST) + + def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: if isinstance(output_size, numbers.Number): return [int(output_size), int(output_size)]