Skip to content

Commit 6d9f0bf

Browse files
YosuaMichaelFederico Pozzi
authored andcommitted
[fbsync] feat: add functional pad on segmentation mask (#5866)
Summary: * feat: add functional pad on segmentation mask * test: add basic correctness test with random masks * test: add all padding options * fix: pr comments * fix: tests * refactor: reshape tensor in 4d, then pad Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095691 fbshipit-source-id: 1e31988216fea1664c1fd48ee39598d28bac8308 Co-authored-by: Federico Pozzi <[email protected]>
1 parent 57284f6 commit 6d9f0bf

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,16 @@ def resized_crop_segmentation_mask():
370370
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size)
371371

372372

373+
@register_kernel_info_from_sample_inputs_fn
374+
def pad_segmentation_mask():
375+
for mask, padding, padding_mode in itertools.product(
376+
make_segmentation_masks(),
377+
[[1], [1, 1], [1, 1, 2, 2]], # padding
378+
["constant", "symmetric", "edge", "reflect"], # padding mode,
379+
):
380+
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
381+
382+
373383
@pytest.mark.parametrize(
374384
"kernel",
375385
[
@@ -1031,3 +1041,47 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
10311041
expected_mask = _compute_expected(in_mask, top, left, height, width, size)
10321042
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size)
10331043
torch.testing.assert_close(output_mask, expected_mask)
1044+
1045+
1046+
@pytest.mark.parametrize("device", cpu_and_gpu())
1047+
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
1048+
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
1049+
1050+
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1])
1051+
1052+
expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
1053+
expected_mask[:, 1:-1, 1:-1] = 1
1054+
torch.testing.assert_close(out_mask, expected_mask)
1055+
1056+
1057+
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]])
1058+
def test_correctness_pad_segmentation_mask(padding):
1059+
def _compute_expected_mask():
1060+
def parse_padding():
1061+
if isinstance(padding, int):
1062+
return [padding] * 4
1063+
if isinstance(padding, list):
1064+
if len(padding) == 1:
1065+
return padding * 4
1066+
if len(padding) == 2:
1067+
return padding * 2 # [left, up, right, down]
1068+
1069+
return padding
1070+
1071+
h, w = mask.shape[-2], mask.shape[-1]
1072+
pad_left, pad_up, pad_right, pad_down = parse_padding()
1073+
1074+
new_h = h + pad_up + pad_down
1075+
new_w = w + pad_left + pad_right
1076+
1077+
new_shape = (*mask.shape[:-2], new_h, new_w) if len(mask.shape) > 2 else (new_h, new_w)
1078+
expected_mask = torch.zeros(new_shape, dtype=torch.long)
1079+
expected_mask[..., pad_up:-pad_down, pad_left:-pad_right] = mask
1080+
1081+
return expected_mask
1082+
1083+
for mask in make_segmentation_masks():
1084+
out_mask = F.pad_segmentation_mask(mask, padding, "constant")
1085+
1086+
expected_mask = _compute_expected_mask()
1087+
torch.testing.assert_close(out_mask, expected_mask)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
pad_bounding_box,
6363
pad_image_tensor,
6464
pad_image_pil,
65+
pad_segmentation_mask,
6566
crop_bounding_box,
6667
crop_image_tensor,
6768
crop_image_pil,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,20 @@ def rotate_segmentation_mask(
396396
pad_image_pil = _FP.pad
397397

398398

399+
def pad_segmentation_mask(
400+
segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant"
401+
) -> torch.Tensor:
402+
num_masks, height, width = segmentation_mask.shape[-3:]
403+
extra_dims = segmentation_mask.shape[:-3]
404+
405+
padded_mask = pad_image_tensor(
406+
img=segmentation_mask.view(-1, num_masks, height, width), padding=padding, fill=0, padding_mode=padding_mode
407+
)
408+
409+
new_height, new_width = padded_mask.shape[-2:]
410+
return padded_mask.view(extra_dims + (num_masks, new_height, new_width))
411+
412+
399413
def pad_bounding_box(
400414
bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat
401415
) -> torch.Tensor:

0 commit comments

Comments
 (0)