Skip to content

Commit 2a9b597

Browse files
author
Federico Pozzi
committed
test: add basic correctness test with random masks
1 parent 8500e03 commit 2a9b597

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -371,16 +371,13 @@ def resized_crop_segmentation_mask():
371371

372372
@register_kernel_info_from_sample_inputs_fn
373373
def pad_segmentation_mask():
374-
for mask, padding, fill, padding_mode in itertools.product(
374+
for mask, padding, padding_mode in itertools.product(
375375
make_segmentation_masks(),
376376
[[1], [1, 1], [1, 1, 2, 2]], # padding
377-
[0, 1], # fill
378377
["constant", "symmetric", "edge"], # padding mode,
379378
):
380379
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
381380
continue
382-
if padding_mode == "edge" and fill != 0:
383-
continue
384381
if (
385382
padding_mode == "edge"
386383
and len(padding) == 2
@@ -392,7 +389,7 @@ def pad_segmentation_mask():
392389
continue
393390
if padding_mode == "edge" and mask.ndim not in [2, 3, 4, 5]:
394391
continue
395-
yield SampleInput(mask, padding=padding, fill=fill, padding_mode=padding_mode)
392+
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
396393

397394

398395
@pytest.mark.parametrize(
@@ -1059,10 +1056,35 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
10591056

10601057
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
10611058
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
1062-
mask[:, 1, 1] = 0
10631059

1064-
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1], fill=1)
1060+
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1])
10651061

1066-
expected_mask = torch.ones((1, 3 + 1 + 1, 3 + 1 + 1), dtype=torch.long, device=device)
1067-
expected_mask[:, 2, 2] = 0
1062+
expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
1063+
expected_mask[:, 1:-1, 1:-1] = 1
10681064
torch.testing.assert_close(out_mask, expected_mask)
1065+
1066+
1067+
@pytest.mark.parametrize("padding,padding_mode", [([1, 2, 3, 4], "constant")])
1068+
def test_correctness_pad_segmentation_mask(padding, padding_mode):
1069+
def compute_expected_mask():
1070+
h, w = mask.shape[-2], mask.shape[-1]
1071+
1072+
pad_left = padding[0]
1073+
pad_up = padding[1]
1074+
pad_right = padding[2]
1075+
pad_down = padding[3]
1076+
1077+
new_h = h + pad_up + pad_down
1078+
new_w = w + pad_left + pad_right
1079+
1080+
new_shape = (*mask.shape[:-2], new_h, new_w) if len(mask.shape) > 2 else (new_h, new_w)
1081+
expected_mask = torch.zeros(new_shape, dtype=torch.long)
1082+
expected_mask[..., pad_up:-pad_down, pad_left:-pad_right] = mask
1083+
1084+
return expected_mask
1085+
1086+
for mask in make_segmentation_masks():
1087+
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode)
1088+
1089+
expected_mask = compute_expected_mask()
1090+
torch.testing.assert_close(out_mask, expected_mask)

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,9 @@ def rotate_segmentation_mask(
397397

398398

399399
def pad_segmentation_mask(
400-
segmentation_mask: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant"
400+
segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant"
401401
) -> torch.Tensor:
402-
return pad_image_tensor(img=segmentation_mask, padding=padding, fill=fill, padding_mode=padding_mode)
402+
return pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode)
403403

404404

405405
def pad_bounding_box(

0 commit comments

Comments
 (0)