Skip to content

Commit 576793d

Browse files
author
Federico Pozzi
committed
feat: add functional pad on segmentation mask
1 parent 184b136 commit 576793d

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,32 @@ def vertical_flip_segmentation_mask():
352352
yield SampleInput(mask)
353353

354354

355+
@register_kernel_info_from_sample_inputs_fn
356+
def pad_segmentation_mask():
357+
for mask, padding, fill, padding_mode in itertools.product(
358+
make_segmentation_masks(),
359+
[[1], [1, 1], [1, 1, 2, 2]], # padding
360+
[0, 1], # fill
361+
["constant", "symmetric", "edge"], # padding mode,
362+
):
363+
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
364+
continue
365+
if padding_mode == "edge" and fill != 0:
366+
continue
367+
if (
368+
padding_mode == "edge"
369+
and len(padding) == 2
370+
and mask.ndim not in [2, 3]
371+
or len(padding) == 4
372+
and mask.ndim not in [4, 3]
373+
or len(padding) == 1
374+
):
375+
continue
376+
if padding_mode == "edge" and mask.ndim not in [2, 3, 4, 5]:
377+
continue
378+
yield SampleInput(mask, padding=padding, fill=fill, padding_mode=padding_mode)
379+
380+
355381
@pytest.mark.parametrize(
356382
"kernel",
357383
[
@@ -933,3 +959,15 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
933959
expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
934960
expected_mask[:, -1, :] = 1
935961
torch.testing.assert_close(out_mask, expected_mask)
962+
963+
964+
@pytest.mark.parametrize("device", cpu_and_gpu())
965+
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
966+
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
967+
mask[:, 1, 1] = 0
968+
969+
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1], fill=1)
970+
971+
expected_mask = torch.ones((1, 3 + 1 + 1, 3 + 1 + 1), dtype=torch.long, device=device)
972+
expected_mask[:, 2, 2] = 0
973+
torch.testing.assert_close(out_mask, expected_mask)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
pad_bounding_box,
6161
pad_image_tensor,
6262
pad_image_pil,
63+
pad_segmentation_mask,
6364
crop_bounding_box,
6465
crop_image_tensor,
6566
crop_image_pil,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,12 @@ def rotate_segmentation_mask(
396396
pad_image_pil = _FP.pad
397397

398398

399+
def pad_segmentation_mask(
400+
segmentation_mask: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant"
401+
) -> torch.Tensor:
402+
return pad_image_tensor(img=segmentation_mask, padding=padding, fill=fill, padding_mode=padding_mode)
403+
404+
399405
def pad_bounding_box(
400406
bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat
401407
) -> torch.Tensor:

0 commit comments

Comments
 (0)