Skip to content

Commit b425177

Browse files
author
Federico Pozzi
committed
fix: pr comments
1 parent f0611e5 commit b425177

File tree

1 file changed

+21
-30
lines changed

1 file changed

+21
-30
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -369,26 +369,20 @@ def resized_crop_segmentation_mask():
369369
):
370370
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size)
371371

372+
372373
@register_kernel_info_from_sample_inputs_fn
373374
def pad_segmentation_mask():
374375
for mask, padding, padding_mode in itertools.product(
375376
make_segmentation_masks(),
376377
[[1], [1, 1], [1, 1, 2, 2]], # padding
377378
["constant", "symmetric", "edge", "reflect"], # padding mode,
378379
):
379-
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
380-
continue
381-
if (
382-
padding_mode == "edge"
383-
and len(padding) == 2
384-
and mask.ndim not in [2, 3]
385-
or len(padding) == 4
386-
and mask.ndim not in [4, 3]
387-
or len(padding) == 1
388-
):
380+
if padding_mode == "symmetric" and mask.ndim not in [2, 3, 4]:
389381
continue
390-
if padding_mode == "edge" and mask.ndim not in [2, 3, 4, 5]:
382+
383+
if (padding_mode == "edge" or padding_mode == "reflect") and mask.ndim not in [2, 3, 4]:
391384
continue
385+
392386
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
393387

394388

@@ -1054,6 +1048,7 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
10541048
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size)
10551049
torch.testing.assert_close(output_mask, expected_mask)
10561050

1051+
10571052
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
10581053
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
10591054

@@ -1064,24 +1059,22 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
10641059
torch.testing.assert_close(out_mask, expected_mask)
10651060

10661061

1067-
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, 1.0, [1, 2]])
1062+
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]])
10681063
def test_correctness_pad_segmentation_mask(padding):
1069-
def _parse_padding():
1070-
if isinstance(padding, int):
1071-
return [padding] * 4
1072-
if isinstance(padding, float):
1073-
return [int(padding)] * 4
1074-
if isinstance(padding, list):
1075-
if len(padding) == 1:
1076-
return padding * 4
1077-
if len(padding) == 2:
1078-
return padding * 2 # [left, up, right, down]
1079-
1080-
return padding
1081-
1082-
def _compute_expected_mask(padding):
1064+
def _compute_expected_mask():
1065+
def parse_padding():
1066+
if isinstance(padding, int):
1067+
return [padding] * 4
1068+
if isinstance(padding, list):
1069+
if len(padding) == 1:
1070+
return padding * 4
1071+
if len(padding) == 2:
1072+
return padding * 2 # [left, up, right, down]
1073+
1074+
return padding
1075+
10831076
h, w = mask.shape[-2], mask.shape[-1]
1084-
pad_left, pad_up, pad_right, pad_down = padding
1077+
pad_left, pad_up, pad_right, pad_down = parse_padding()
10851078

10861079
new_h = h + pad_up + pad_down
10871080
new_w = w + pad_left + pad_right
@@ -1092,10 +1085,8 @@ def _compute_expected_mask(padding):
10921085

10931086
return expected_mask
10941087

1095-
padding = _parse_padding()
1096-
10971088
for mask in make_segmentation_masks():
10981089
out_mask = F.pad_segmentation_mask(mask, padding, "constant")
10991090

1100-
expected_mask = _compute_expected_mask(padding)
1091+
expected_mask = _compute_expected_mask()
11011092
torch.testing.assert_close(out_mask, expected_mask)

0 commit comments

Comments
 (0)