Skip to content

Commit 6f3f37e

Browse files
author
Federico Pozzi
committed
test: add all padding options
1 parent 01caaa3 commit 6f3f37e

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def pad_segmentation_mask():
357357
for mask, padding, padding_mode in itertools.product(
358358
make_segmentation_masks(),
359359
[[1], [1, 1], [1, 1, 2, 2]], # padding
360-
["constant", "symmetric", "edge"], # padding mode,
360+
["constant", "symmetric", "edge", "reflect"], # padding mode,
361361
):
362362
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
363363
continue
@@ -969,15 +969,24 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
969969
torch.testing.assert_close(out_mask, expected_mask)
970970

971971

972-
@pytest.mark.parametrize("padding,padding_mode", [([1, 2, 3, 4], "constant")])
973-
def test_correctness_pad_segmentation_mask(padding, padding_mode):
974-
def compute_expected_mask():
975-
h, w = mask.shape[-2], mask.shape[-1]
972+
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, 1.0, [1, 2]])
973+
def test_correctness_pad_segmentation_mask(padding):
974+
def _parse_padding():
975+
if isinstance(padding, int):
976+
return [padding] * 4
977+
if isinstance(padding, float):
978+
return [int(padding)] * 4
979+
if isinstance(padding, list):
980+
if len(padding) == 1:
981+
return padding * 4
982+
if len(padding) == 2:
983+
return padding * 2 # [left, up, right, down]
984+
985+
return padding
976986

977-
pad_left = padding[0]
978-
pad_up = padding[1]
979-
pad_right = padding[2]
980-
pad_down = padding[3]
987+
def _compute_expected_mask(padding):
988+
h, w = mask.shape[-2], mask.shape[-1]
989+
pad_left, pad_up, pad_right, pad_down = padding
981990

982991
new_h = h + pad_up + pad_down
983992
new_w = w + pad_left + pad_right
@@ -988,8 +997,10 @@ def compute_expected_mask():
988997

989998
return expected_mask
990999

1000+
padding = _parse_padding()
1001+
9911002
for mask in make_segmentation_masks():
992-
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode)
1003+
out_mask = F.pad_segmentation_mask(mask, padding, "constant")
9931004

994-
expected_mask = compute_expected_mask()
1005+
expected_mask = _compute_expected_mask(padding)
9951006
torch.testing.assert_close(out_mask, expected_mask)

0 commit comments

Comments
 (0)