Skip to content

[proto] Added resized_crop_bounding_box op #5853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 25, 2022
69 changes: 67 additions & 2 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,16 @@ def vertical_flip_segmentation_mask():
yield SampleInput(mask)


@register_kernel_info_from_sample_inputs_fn
def resized_crop_bounding_box():
for bounding_box, top, left, height, width, size in itertools.product(
make_bounding_boxes(), [-8, 9], [-8, 9], [32, 22], [34, 20], [(32, 32), (16, 18)]
):
yield SampleInput(
bounding_box, format=bounding_box.format, top=top, left=left, height=height, width=width, size=size
)


@pytest.mark.parametrize(
"kernel",
[
Expand Down Expand Up @@ -842,14 +852,18 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device):


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"format",
[features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH],
)
@pytest.mark.parametrize(
"top, left, height, width, expected_bboxes",
[
[8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
[-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
],
)
def test_correctness_crop_bounding_box(device, top, left, height, width, expected_bboxes):
def test_correctness_crop_bounding_box(device, format, top, left, height, width, expected_bboxes):

# Expected bboxes computed using Albumentations:
# import numpy as np
Expand All @@ -871,14 +885,19 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
[45.0, 46.0, 56.0, 62.0],
]
in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=size, device=device)
if format != features.BoundingBoxFormat.XYXY:
in_boxes = convert_bounding_box_format(in_boxes, features.BoundingBoxFormat.XYXY, format)

output_boxes = F.crop_bounding_box(
in_boxes,
in_boxes.format,
format,
top,
left,
)

if format != features.BoundingBoxFormat.XYXY:
output_boxes = convert_bounding_box_format(output_boxes, format, features.BoundingBoxFormat.XYXY)

torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)


Expand Down Expand Up @@ -933,3 +952,49 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
expected_mask[:, -1, :] = 1
torch.testing.assert_close(out_mask, expected_mask)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"format",
[features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH],
)
@pytest.mark.parametrize(
"top, left, height, width, size",
[
[0, 0, 30, 30, (60, 60)],
[-5, 5, 35, 45, (32, 34)],
],
)
def test_correctness_resized_crop_bounding_box(device, format, top, left, height, width, size):
def _compute_expected(bbox, top_, left_, height_, width_, size_):
# bbox should be xyxy
bbox[0] = (bbox[0] - left_) * size_[1] / width_
bbox[1] = (bbox[1] - top_) * size_[0] / height_
bbox[2] = (bbox[2] - left_) * size_[1] / width_
bbox[3] = (bbox[3] - top_) * size_[0] / height_
return bbox

image_size = (100, 100)
# xyxy format
in_boxes = [
[10.0, 10.0, 20.0, 20.0],
[5.0, 10.0, 15.0, 20.0],
]
expected_bboxes = []
for in_box in in_boxes:
expected_bboxes.append(_compute_expected(list(in_box), top, left, height, width, size))
expected_bboxes = torch.tensor(expected_bboxes, device=device)

in_boxes = features.BoundingBox(
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, device=device
)
if format != features.BoundingBoxFormat.XYXY:
in_boxes = convert_bounding_box_format(in_boxes, features.BoundingBoxFormat.XYXY, format)

output_boxes = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size)

if format != features.BoundingBoxFormat.XYXY:
output_boxes = convert_bounding_box_format(output_boxes, format, features.BoundingBoxFormat.XYXY)

torch.testing.assert_close(output_boxes, expected_bboxes)
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
center_crop_image_pil,
resized_crop_image_tensor,
resized_crop_image_pil,
resized_crop_bounding_box,
affine_bounding_box,
affine_image_tensor,
affine_image_pil,
Expand Down
13 changes: 13 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,19 @@ def resized_crop_image_pil(
return resize_image_pil(img, size, interpolation=interpolation)


def resized_crop_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
top: int,
left: int,
height: int,
width: int,
size: List[int],
) -> torch.Tensor:
bounding_box = crop_bounding_box(bounding_box, format, top, left)
return resize_bounding_box(bounding_box, size, (height, width))


def _parse_five_crop_size(size: List[int]) -> List[int]:
if isinstance(size, numbers.Number):
size = [int(size), int(size)]
Expand Down