Skip to content

Commit 0f826a1

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Added resized_crop_bounding_box op (#5853)
Summary: * [proto] Added crop_bounding_box op * Added tests for resized_crop_bounding_box * Fixed code formatting Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095714 fbshipit-source-id: 5896896af6e6f2cd656b4f157f2828dc93664f85
1 parent 34ccb8d commit 0f826a1

File tree

3 files changed

+81
-2
lines changed

3 files changed

+81
-2
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,16 @@ def vertical_flip_segmentation_mask():
352352
yield SampleInput(mask)
353353

354354

355+
@register_kernel_info_from_sample_inputs_fn
356+
def resized_crop_bounding_box():
357+
for bounding_box, top, left, height, width, size in itertools.product(
358+
make_bounding_boxes(), [-8, 9], [-8, 9], [32, 22], [34, 20], [(32, 32), (16, 18)]
359+
):
360+
yield SampleInput(
361+
bounding_box, format=bounding_box.format, top=top, left=left, height=height, width=width, size=size
362+
)
363+
364+
355365
@pytest.mark.parametrize(
356366
"kernel",
357367
[
@@ -842,14 +852,18 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
842852

843853

844854
@pytest.mark.parametrize("device", cpu_and_gpu())
855+
@pytest.mark.parametrize(
856+
"format",
857+
[features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH],
858+
)
845859
@pytest.mark.parametrize(
846860
"top, left, height, width, expected_bboxes",
847861
[
848862
[8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
849863
[-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
850864
],
851865
)
852-
def test_correctness_crop_bounding_box(device, top, left, height, width, expected_bboxes):
866+
def test_correctness_crop_bounding_box(device, format, top, left, height, width, expected_bboxes):
853867

854868
# Expected bboxes computed using Albumentations:
855869
# import numpy as np
@@ -871,14 +885,19 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
871885
[45.0, 46.0, 56.0, 62.0],
872886
]
873887
in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=size, device=device)
888+
if format != features.BoundingBoxFormat.XYXY:
889+
in_boxes = convert_bounding_box_format(in_boxes, features.BoundingBoxFormat.XYXY, format)
874890

875891
output_boxes = F.crop_bounding_box(
876892
in_boxes,
877-
in_boxes.format,
893+
format,
878894
top,
879895
left,
880896
)
881897

898+
if format != features.BoundingBoxFormat.XYXY:
899+
output_boxes = convert_bounding_box_format(output_boxes, format, features.BoundingBoxFormat.XYXY)
900+
882901
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
883902

884903

@@ -933,3 +952,49 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
933952
expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
934953
expected_mask[:, -1, :] = 1
935954
torch.testing.assert_close(out_mask, expected_mask)
955+
956+
957+
@pytest.mark.parametrize("device", cpu_and_gpu())
958+
@pytest.mark.parametrize(
959+
"format",
960+
[features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH],
961+
)
962+
@pytest.mark.parametrize(
963+
"top, left, height, width, size",
964+
[
965+
[0, 0, 30, 30, (60, 60)],
966+
[-5, 5, 35, 45, (32, 34)],
967+
],
968+
)
969+
def test_correctness_resized_crop_bounding_box(device, format, top, left, height, width, size):
970+
def _compute_expected(bbox, top_, left_, height_, width_, size_):
971+
# bbox should be xyxy
972+
bbox[0] = (bbox[0] - left_) * size_[1] / width_
973+
bbox[1] = (bbox[1] - top_) * size_[0] / height_
974+
bbox[2] = (bbox[2] - left_) * size_[1] / width_
975+
bbox[3] = (bbox[3] - top_) * size_[0] / height_
976+
return bbox
977+
978+
image_size = (100, 100)
979+
# xyxy format
980+
in_boxes = [
981+
[10.0, 10.0, 20.0, 20.0],
982+
[5.0, 10.0, 15.0, 20.0],
983+
]
984+
expected_bboxes = []
985+
for in_box in in_boxes:
986+
expected_bboxes.append(_compute_expected(list(in_box), top, left, height, width, size))
987+
expected_bboxes = torch.tensor(expected_bboxes, device=device)
988+
989+
in_boxes = features.BoundingBox(
990+
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, device=device
991+
)
992+
if format != features.BoundingBoxFormat.XYXY:
993+
in_boxes = convert_bounding_box_format(in_boxes, features.BoundingBoxFormat.XYXY, format)
994+
995+
output_boxes = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size)
996+
997+
if format != features.BoundingBoxFormat.XYXY:
998+
output_boxes = convert_bounding_box_format(output_boxes, format, features.BoundingBoxFormat.XYXY)
999+
1000+
torch.testing.assert_close(output_boxes, expected_bboxes)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
center_crop_image_pil,
5050
resized_crop_image_tensor,
5151
resized_crop_image_pil,
52+
resized_crop_bounding_box,
5253
affine_bounding_box,
5354
affine_image_tensor,
5455
affine_image_pil,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,19 @@ def resized_crop_image_pil(
542542
return resize_image_pil(img, size, interpolation=interpolation)
543543

544544

545+
def resized_crop_bounding_box(
546+
bounding_box: torch.Tensor,
547+
format: features.BoundingBoxFormat,
548+
top: int,
549+
left: int,
550+
height: int,
551+
width: int,
552+
size: List[int],
553+
) -> torch.Tensor:
554+
bounding_box = crop_bounding_box(bounding_box, format, top, left)
555+
return resize_bounding_box(bounding_box, size, (height, width))
556+
557+
545558
def _parse_five_crop_size(size: List[int]) -> List[int]:
546559
if isinstance(size, numbers.Number):
547560
size = [int(size), int(size)]

0 commit comments

Comments
 (0)