@@ -346,6 +346,12 @@ def crop_segmentation_mask():
346
346
)
347
347
348
348
349
+ @register_kernel_info_from_sample_inputs_fn
350
+ def vertical_flip_segmentation_mask ():
351
+ for mask in make_segmentation_masks ():
352
+ yield SampleInput (mask )
353
+
354
+
349
355
@pytest .mark .parametrize (
350
356
"kernel" ,
351
357
[
@@ -915,3 +921,15 @@ def _compute_expected_mask(mask, top_, left_, height_, width_):
915
921
output_mask = F .crop_segmentation_mask (mask , top , left , height , width )
916
922
expected_mask = _compute_expected_mask (mask , top , left , height , width )
917
923
torch .testing .assert_close (output_mask , expected_mask )
924
+
925
+
926
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
927
+ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input (device ):
928
+ mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
929
+ mask [:, 0 , :] = 1
930
+
931
+ out_mask = F .vertical_flip_segmentation_mask (mask )
932
+
933
+ expected_mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
934
+ expected_mask [:, - 1 , :] = 1
935
+ torch .testing .assert_close (out_mask , expected_mask )
0 commit comments