Skip to content

Commit 33ec8c1

Browse files
author
Federico Pozzi
committed
fix: tests
1 parent b425177 commit 33ec8c1

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,10 +377,10 @@ def pad_segmentation_mask():
377377
[[1], [1, 1], [1, 1, 2, 2]], # padding
378378
["constant", "symmetric", "edge", "reflect"], # padding mode,
379379
):
380-
if padding_mode == "symmetric" and mask.ndim not in [2, 3, 4]:
380+
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
381381
continue
382382

383-
if (padding_mode == "edge" or padding_mode == "reflect") and mask.ndim not in [2, 3, 4]:
383+
if (padding_mode == "edge" or padding_mode == "reflect") and mask.ndim not in [3, 4]:
384384
continue
385385

386386
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
@@ -1049,6 +1049,7 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
10491049
torch.testing.assert_close(output_mask, expected_mask)
10501050

10511051

1052+
@pytest.mark.parametrize("device", cpu_and_gpu())
10521053
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
10531054
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
10541055

0 commit comments

Comments
 (0)