Skip to content

Commit f0611e5

Browse files
author
Federico Pozzi
committed
test: add all padding options
1 parent 2a9b597 commit f0611e5

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def pad_segmentation_mask():
374374
for mask, padding, padding_mode in itertools.product(
375375
make_segmentation_masks(),
376376
[[1], [1, 1], [1, 1, 2, 2]], # padding
377-
["constant", "symmetric", "edge"], # padding mode,
377+
["constant", "symmetric", "edge", "reflect"], # padding mode,
378378
):
379379
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
380380
continue
@@ -1064,15 +1064,24 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
10641064
torch.testing.assert_close(out_mask, expected_mask)
10651065

10661066

1067-
@pytest.mark.parametrize("padding,padding_mode", [([1, 2, 3, 4], "constant")])
1068-
def test_correctness_pad_segmentation_mask(padding, padding_mode):
1069-
def compute_expected_mask():
1070-
h, w = mask.shape[-2], mask.shape[-1]
1067+
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, 1.0, [1, 2]])
1068+
def test_correctness_pad_segmentation_mask(padding):
1069+
def _parse_padding():
1070+
if isinstance(padding, int):
1071+
return [padding] * 4
1072+
if isinstance(padding, float):
1073+
return [int(padding)] * 4
1074+
if isinstance(padding, list):
1075+
if len(padding) == 1:
1076+
return padding * 4
1077+
if len(padding) == 2:
1078+
return padding * 2 # [left, up, right, down]
1079+
1080+
return padding
10711081

1072-
pad_left = padding[0]
1073-
pad_up = padding[1]
1074-
pad_right = padding[2]
1075-
pad_down = padding[3]
1082+
def _compute_expected_mask(padding):
1083+
h, w = mask.shape[-2], mask.shape[-1]
1084+
pad_left, pad_up, pad_right, pad_down = padding
10761085

10771086
new_h = h + pad_up + pad_down
10781087
new_w = w + pad_left + pad_right
@@ -1083,8 +1092,10 @@ def compute_expected_mask():
10831092

10841093
return expected_mask
10851094

1095+
padding = _parse_padding()
1096+
10861097
for mask in make_segmentation_masks():
1087-
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode)
1098+
out_mask = F.pad_segmentation_mask(mask, padding, "constant")
10881099

1089-
expected_mask = compute_expected_mask()
1100+
expected_mask = _compute_expected_mask(padding)
10901101
torch.testing.assert_close(out_mask, expected_mask)

0 commit comments

Comments
 (0)