diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 6c99720114a..2da3aa4696a 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -352,6 +352,16 @@ def vertical_flip_segmentation_mask(): yield SampleInput(mask) +@register_kernel_info_from_sample_inputs_fn +def resized_crop_bounding_box(): + for bounding_box, top, left, height, width, size in itertools.product( + make_bounding_boxes(), [-8, 9], [-8, 9], [32, 22], [34, 20], [(32, 32), (16, 18)] + ): + yield SampleInput( + bounding_box, format=bounding_box.format, top=top, left=left, height=height, width=width, size=size + ) + + @pytest.mark.parametrize( "kernel", [ @@ -842,6 +852,10 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device): @pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "format", + [features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH], +) @pytest.mark.parametrize( "top, left, height, width, expected_bboxes", [ @@ -849,7 +863,7 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device): [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]], ], ) -def test_correctness_crop_bounding_box(device, top, left, height, width, expected_bboxes): +def test_correctness_crop_bounding_box(device, format, top, left, height, width, expected_bboxes): # Expected bboxes computed using Albumentations: # import numpy as np @@ -871,14 +885,19 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte [45.0, 46.0, 56.0, 62.0], ] in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=size, device=device) + if format != features.BoundingBoxFormat.XYXY: + in_boxes = convert_bounding_box_format(in_boxes, features.BoundingBoxFormat.XYXY, format) output_boxes = F.crop_bounding_box( in_boxes, - in_boxes.format, + format, top, left, ) + if format != features.BoundingBoxFormat.XYXY: + output_boxes = convert_bounding_box_format(output_boxes, format, features.BoundingBoxFormat.XYXY) + torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) @@ -933,3 +952,49 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) expected_mask[:, -1, :] = 1 torch.testing.assert_close(out_mask, expected_mask) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "format", + [features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH], +) +@pytest.mark.parametrize( + "top, left, height, width, size", + [ + [0, 0, 30, 30, (60, 60)], + [-5, 5, 35, 45, (32, 34)], + ], +) +def test_correctness_resized_crop_bounding_box(device, format, top, left, height, width, size): + def _compute_expected(bbox, top_, left_, height_, width_, size_): + # bbox should be xyxy + bbox[0] = (bbox[0] - left_) * size_[1] / width_ + bbox[1] = (bbox[1] - top_) * size_[0] / height_ + bbox[2] = (bbox[2] - left_) * size_[1] / width_ + bbox[3] = (bbox[3] - top_) * size_[0] / height_ + return bbox + + image_size = (100, 100) + # xyxy format + in_boxes = [ + [10.0, 10.0, 20.0, 20.0], + [5.0, 10.0, 15.0, 20.0], + ] + expected_bboxes = [] + for in_box in in_boxes: + expected_bboxes.append(_compute_expected(list(in_box), top, left, height, width, size)) + expected_bboxes = torch.tensor(expected_bboxes, device=device) + + in_boxes = features.BoundingBox( + in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, device=device + ) + if format != features.BoundingBoxFormat.XYXY: + in_boxes = convert_bounding_box_format(in_boxes, features.BoundingBoxFormat.XYXY, format) + + output_boxes = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size) + + if format != features.BoundingBoxFormat.XYXY: + output_boxes = convert_bounding_box_format(output_boxes, format, features.BoundingBoxFormat.XYXY) + + torch.testing.assert_close(output_boxes, expected_bboxes) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index bbfa9584d88..7069f17c414 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -49,6 +49,7 @@ center_crop_image_pil, resized_crop_image_tensor, resized_crop_image_pil, + resized_crop_bounding_box, affine_bounding_box, affine_image_tensor, affine_image_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 65673203941..fc1eddfd230 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -542,6 +542,19 @@ def resized_crop_image_pil( return resize_image_pil(img, size, interpolation=interpolation) +def resized_crop_bounding_box( + bounding_box: torch.Tensor, + format: features.BoundingBoxFormat, + top: int, + left: int, + height: int, + width: int, + size: List[int], +) -> torch.Tensor: + bounding_box = crop_bounding_box(bounding_box, format, top, left) + return resize_bounding_box(bounding_box, size, (height, width)) + + def _parse_five_crop_size(size: List[int]) -> List[int]: if isinstance(size, numbers.Number): size = [int(size), int(size)]