@@ -370,6 +370,16 @@ def resized_crop_segmentation_mask():
370
370
yield SampleInput (mask , top = top , left = left , height = height , width = width , size = size )
371
371
372
372
373
+ @register_kernel_info_from_sample_inputs_fn
374
+ def pad_segmentation_mask ():
375
+ for mask , padding , padding_mode in itertools .product (
376
+ make_segmentation_masks (),
377
+ [[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]], # padding
378
+ ["constant" , "symmetric" , "edge" , "reflect" ], # padding mode,
379
+ ):
380
+ yield SampleInput (mask , padding = padding , padding_mode = padding_mode )
381
+
382
+
373
383
@pytest .mark .parametrize (
374
384
"kernel" ,
375
385
[
@@ -1031,3 +1041,47 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
1031
1041
expected_mask = _compute_expected (in_mask , top , left , height , width , size )
1032
1042
output_mask = F .resized_crop_segmentation_mask (in_mask , top , left , height , width , size )
1033
1043
torch .testing .assert_close (output_mask , expected_mask )
1044
+
1045
+
1046
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1047
+ def test_correctness_pad_segmentation_mask_on_fixed_input (device ):
1048
+ mask = torch .ones ((1 , 3 , 3 ), dtype = torch .long , device = device )
1049
+
1050
+ out_mask = F .pad_segmentation_mask (mask , padding = [1 , 1 , 1 , 1 ])
1051
+
1052
+ expected_mask = torch .zeros ((1 , 5 , 5 ), dtype = torch .long , device = device )
1053
+ expected_mask [:, 1 :- 1 , 1 :- 1 ] = 1
1054
+ torch .testing .assert_close (out_mask , expected_mask )
1055
+
1056
+
1057
+ @pytest .mark .parametrize ("padding" , [[1 , 2 , 3 , 4 ], [1 ], 1 , [1 , 2 ]])
1058
+ def test_correctness_pad_segmentation_mask (padding ):
1059
+ def _compute_expected_mask ():
1060
+ def parse_padding ():
1061
+ if isinstance (padding , int ):
1062
+ return [padding ] * 4
1063
+ if isinstance (padding , list ):
1064
+ if len (padding ) == 1 :
1065
+ return padding * 4
1066
+ if len (padding ) == 2 :
1067
+ return padding * 2 # [left, up, right, down]
1068
+
1069
+ return padding
1070
+
1071
+ h , w = mask .shape [- 2 ], mask .shape [- 1 ]
1072
+ pad_left , pad_up , pad_right , pad_down = parse_padding ()
1073
+
1074
+ new_h = h + pad_up + pad_down
1075
+ new_w = w + pad_left + pad_right
1076
+
1077
+ new_shape = (* mask .shape [:- 2 ], new_h , new_w ) if len (mask .shape ) > 2 else (new_h , new_w )
1078
+ expected_mask = torch .zeros (new_shape , dtype = torch .long )
1079
+ expected_mask [..., pad_up :- pad_down , pad_left :- pad_right ] = mask
1080
+
1081
+ return expected_mask
1082
+
1083
+ for mask in make_segmentation_masks ():
1084
+ out_mask = F .pad_segmentation_mask (mask , padding , "constant" )
1085
+
1086
+ expected_mask = _compute_expected_mask ()
1087
+ torch .testing .assert_close (out_mask , expected_mask )
0 commit comments