Skip to content

Commit 6b6d361

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Added resized_crop_segmentation_mask op (#5855)
Summary: * [proto] Added crop_bounding_box op * Added `crop_segmentation_mask` op * Fixed failed mypy * Added tests for resized_crop_bounding_box * Fixed code formatting * Added resized_crop_segmentation_mask op * Added tests Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095695 fbshipit-source-id: 808c8c2216084999a3baa6f65d49287039f3cf84
1 parent b53d942 commit 6b6d361

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,14 @@ def resized_crop_bounding_box():
362362
)
363363

364364

365+
@register_kernel_info_from_sample_inputs_fn
366+
def resized_crop_segmentation_mask():
367+
for mask, top, left, height, width, size in itertools.product(
368+
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
369+
):
370+
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size)
371+
372+
365373
@pytest.mark.parametrize(
366374
"kernel",
367375
[
@@ -998,3 +1006,28 @@ def _compute_expected(bbox, top_, left_, height_, width_, size_):
9981006
output_boxes = convert_bounding_box_format(output_boxes, format, features.BoundingBoxFormat.XYXY)
9991007

10001008
torch.testing.assert_close(output_boxes, expected_bboxes)
1009+
1010+
1011+
@pytest.mark.parametrize("device", cpu_and_gpu())
1012+
@pytest.mark.parametrize(
1013+
"top, left, height, width, size",
1014+
[
1015+
[0, 0, 30, 30, (60, 60)],
1016+
[5, 5, 35, 45, (32, 34)],
1017+
],
1018+
)
1019+
def test_correctness_resized_crop_segmentation_mask(device, top, left, height, width, size):
1020+
def _compute_expected(mask, top_, left_, height_, width_, size_):
1021+
output = mask.clone()
1022+
output = output[:, top_ : top_ + height_, left_ : left_ + width_]
1023+
output = torch.nn.functional.interpolate(output[None, :].float(), size=size_, mode="nearest")
1024+
output = output[0, :].long()
1025+
return output
1026+
1027+
in_mask = torch.zeros(1, 100, 100, dtype=torch.long, device=device)
1028+
in_mask[0, 10:20, 10:20] = 1
1029+
in_mask[0, 5:15, 12:23] = 2
1030+
1031+
expected_mask = _compute_expected(in_mask, top, left, height, width, size)
1032+
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size)
1033+
torch.testing.assert_close(output_mask, expected_mask)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@
4747
resize_segmentation_mask,
4848
center_crop_image_tensor,
4949
center_crop_image_pil,
50+
resized_crop_bounding_box,
5051
resized_crop_image_tensor,
5152
resized_crop_image_pil,
52-
resized_crop_bounding_box,
53+
resized_crop_segmentation_mask,
5354
affine_bounding_box,
5455
affine_image_tensor,
5556
affine_image_pil,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,18 @@ def resized_crop_bounding_box(
555555
return resize_bounding_box(bounding_box, size, (height, width))
556556

557557

558+
def resized_crop_segmentation_mask(
559+
mask: torch.Tensor,
560+
top: int,
561+
left: int,
562+
height: int,
563+
width: int,
564+
size: List[int],
565+
) -> torch.Tensor:
566+
mask = crop_segmentation_mask(mask, top, left, height, width)
567+
return resize_segmentation_mask(mask, size)
568+
569+
558570
def _parse_five_crop_size(size: List[int]) -> List[int]:
559571
if isinstance(size, numbers.Number):
560572
size = [int(size), int(size)]

0 commit comments

Comments
 (0)