Skip to content

Commit db13442

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] [proto] Added functional crop_bounding_box op (#5781)
Summary: * [proto] Added crop_bounding_box op * Removed "pass" * Updated comment * Removed unused args from signature (Note: this ignores all push blocking failures!) Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095685 fbshipit-source-id: a4222d23aa60ea0cb0713a389646adbed55f289d
1 parent e581dd0 commit db13442

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,17 @@ def rotate_segmentation_mask():
321321
)
322322

323323

324+
@register_kernel_info_from_sample_inputs_fn
325+
def crop_bounding_box():
326+
for bounding_box, top, left in itertools.product(make_bounding_boxes(), [-8, 0, 9], [-8, 0, 9]):
327+
yield SampleInput(
328+
bounding_box,
329+
format=bounding_box.format,
330+
top=top,
331+
left=left,
332+
)
333+
334+
324335
@pytest.mark.parametrize(
325336
"kernel",
326337
[
@@ -808,3 +819,44 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
808819
expected_mask = torch.rot90(mask, k=1, dims=(-2, -1))
809820
out_mask = F.rotate_segmentation_mask(mask, 90, expand=False)
810821
torch.testing.assert_close(out_mask, expected_mask)
822+
823+
824+
@pytest.mark.parametrize("device", cpu_and_gpu())
825+
@pytest.mark.parametrize(
826+
"top, left, height, width, expected_bboxes",
827+
[
828+
[8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
829+
[-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
830+
],
831+
)
832+
def test_correctness_crop_bounding_box(device, top, left, height, width, expected_bboxes):
833+
834+
# Expected bboxes computed using Albumentations:
835+
# import numpy as np
836+
# from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox
837+
# expected_bboxes = []
838+
# for in_box in in_boxes:
839+
# n_in_box = normalize_bbox(in_box, *size)
840+
# n_out_box = crop_bbox_by_coords(
841+
# n_in_box, (left, top, left + width, top + height), height, width, *size
842+
# )
843+
# out_box = denormalize_bbox(n_out_box, height, width)
844+
# expected_bboxes.append(out_box)
845+
846+
size = (64, 76)
847+
# xyxy format
848+
in_boxes = [
849+
[10.0, 15.0, 25.0, 35.0],
850+
[50.0, 5.0, 70.0, 22.0],
851+
[45.0, 46.0, 56.0, 62.0],
852+
]
853+
in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=size, device=device)
854+
855+
output_boxes = F.crop_bounding_box(
856+
in_boxes,
857+
in_boxes.format,
858+
top,
859+
left,
860+
)
861+
862+
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@
5757
rotate_image_tensor,
5858
rotate_image_pil,
5959
rotate_segmentation_mask,
60+
pad_bounding_box,
6061
pad_image_tensor,
6162
pad_image_pil,
62-
pad_bounding_box,
63+
crop_bounding_box,
6364
crop_image_tensor,
6465
crop_image_pil,
6566
perspective_image_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,27 @@ def pad_bounding_box(
419419
crop_image_pil = _FP.crop
420420

421421

422+
def crop_bounding_box(
423+
bounding_box: torch.Tensor,
424+
format: features.BoundingBoxFormat,
425+
top: int,
426+
left: int,
427+
) -> torch.Tensor:
428+
shape = bounding_box.shape
429+
430+
bounding_box = convert_bounding_box_format(
431+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
432+
).view(-1, 4)
433+
434+
# Crop or implicit pad if left and/or top have negative values:
435+
bounding_box[:, 0::2] -= left
436+
bounding_box[:, 1::2] -= top
437+
438+
return convert_bounding_box_format(
439+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
440+
).view(shape)
441+
442+
422443
def perspective_image_tensor(
423444
img: torch.Tensor,
424445
perspective_coeffs: List[float],

0 commit comments

Comments
 (0)