Skip to content

Commit e30f9b0

Browse files
federicopozzi33Federico Pozzivfdev-5
authored
test: add functional vertical flip tests on segmentation mask (#5860)
* test: add functional vertical flip tests on segmentation mask * refactor: improve test readibility * Update test_prototype_transforms_functional.py Co-authored-by: Federico Pozzi <[email protected]> Co-authored-by: vfdev <[email protected]>
1 parent 6f016dd commit e30f9b0

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,12 @@ def crop_segmentation_mask():
346346
)
347347

348348

349+
@register_kernel_info_from_sample_inputs_fn
350+
def vertical_flip_segmentation_mask():
351+
for mask in make_segmentation_masks():
352+
yield SampleInput(mask)
353+
354+
349355
@pytest.mark.parametrize(
350356
"kernel",
351357
[
@@ -915,3 +921,15 @@ def _compute_expected_mask(mask, top_, left_, height_, width_):
915921
output_mask = F.crop_segmentation_mask(mask, top, left, height, width)
916922
expected_mask = _compute_expected_mask(mask, top, left, height, width)
917923
torch.testing.assert_close(output_mask, expected_mask)
924+
925+
926+
@pytest.mark.parametrize("device", cpu_and_gpu())
927+
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
928+
mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
929+
mask[:, 0, :] = 1
930+
931+
out_mask = F.vertical_flip_segmentation_mask(mask)
932+
933+
expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
934+
expected_mask[:, -1, :] = 1
935+
torch.testing.assert_close(out_mask, expected_mask)

0 commit comments

Comments
 (0)