Skip to content

Commit 3577009

Browse files
author
Federico Pozzi
committed
test: add functional vertical flip tests on segmentation mask
1 parent a64c674 commit 3577009

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,12 @@ def crop_bounding_box():
332332
)
333333

334334

335+
@register_kernel_info_from_sample_inputs_fn
336+
def vertical_flip_segmentation_mask():
337+
for mask in make_segmentation_masks(extra_dims=((), (4,))):
338+
yield SampleInput(mask)
339+
340+
335341
@pytest.mark.parametrize(
336342
"kernel",
337343
[
@@ -860,3 +866,26 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
860866
)
861867

862868
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
869+
870+
871+
@pytest.mark.parametrize("device", cpu_and_gpu())
872+
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
873+
mask = torch.tensor(
874+
[
875+
[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
876+
[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
877+
[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
878+
],
879+
device=device,
880+
)
881+
882+
expected_mask = torch.tensor(
883+
[
884+
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]],
885+
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]],
886+
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]],
887+
],
888+
device=device,
889+
)
890+
out_mask = F.vertical_flip_segmentation_mask(mask)
891+
torch.testing.assert_close(out_mask, expected_mask)

0 commit comments

Comments
 (0)