@@ -332,6 +332,12 @@ def crop_bounding_box():
332
332
)
333
333
334
334
335
+ @register_kernel_info_from_sample_inputs_fn
336
+ def vertical_flip_segmentation_mask ():
337
+ for mask in make_segmentation_masks (extra_dims = ((), (4 ,))):
338
+ yield SampleInput (mask )
339
+
340
+
335
341
@pytest .mark .parametrize (
336
342
"kernel" ,
337
343
[
@@ -860,3 +866,26 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
860
866
)
861
867
862
868
torch .testing .assert_close (output_boxes .tolist (), expected_bboxes )
869
+
870
+
871
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
872
+ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input (device ):
873
+ mask = torch .tensor (
874
+ [
875
+ [[1 , 1 , 1 , 1 , 1 ], [0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ]],
876
+ [[1 , 1 , 1 , 1 , 1 ], [0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ]],
877
+ [[1 , 1 , 1 , 1 , 1 ], [0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ]],
878
+ ],
879
+ device = device ,
880
+ )
881
+
882
+ expected_mask = torch .tensor (
883
+ [
884
+ [[0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 , 1 ]],
885
+ [[0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 , 1 ]],
886
+ [[0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 , 1 ]],
887
+ ],
888
+ device = device ,
889
+ )
890
+ out_mask = F .vertical_flip_segmentation_mask (mask )
891
+ torch .testing .assert_close (out_mask , expected_mask )
0 commit comments