@@ -374,7 +374,7 @@ def pad_segmentation_mask():
374
374
for mask , padding , padding_mode in itertools .product (
375
375
make_segmentation_masks (),
376
376
[[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]], # padding
377
- ["constant" , "symmetric" , "edge" ], # padding mode,
377
+ ["constant" , "symmetric" , "edge" , "reflect" ], # padding mode,
378
378
):
379
379
if padding_mode == "symmetric" and mask .ndim not in [3 , 4 ]:
380
380
continue
@@ -1064,15 +1064,24 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
1064
1064
torch .testing .assert_close (out_mask , expected_mask )
1065
1065
1066
1066
1067
- @pytest .mark .parametrize ("padding,padding_mode" , [([1 , 2 , 3 , 4 ], "constant" )])
1068
- def test_correctness_pad_segmentation_mask (padding , padding_mode ):
1069
- def compute_expected_mask ():
1070
- h , w = mask .shape [- 2 ], mask .shape [- 1 ]
1067
+ @pytest .mark .parametrize ("padding" , [[1 , 2 , 3 , 4 ], [1 ], 1 , 1.0 , [1 , 2 ]])
1068
+ def test_correctness_pad_segmentation_mask (padding ):
1069
+ def _parse_padding ():
1070
+ if isinstance (padding , int ):
1071
+ return [padding ] * 4
1072
+ if isinstance (padding , float ):
1073
+ return [int (padding )] * 4
1074
+ if isinstance (padding , list ):
1075
+ if len (padding ) == 1 :
1076
+ return padding * 4
1077
+ if len (padding ) == 2 :
1078
+ return padding * 2 # [left, up, right, down]
1079
+
1080
+ return padding
1071
1081
1072
- pad_left = padding [0 ]
1073
- pad_up = padding [1 ]
1074
- pad_right = padding [2 ]
1075
- pad_down = padding [3 ]
1082
+ def _compute_expected_mask (padding ):
1083
+ h , w = mask .shape [- 2 ], mask .shape [- 1 ]
1084
+ pad_left , pad_up , pad_right , pad_down = padding
1076
1085
1077
1086
new_h = h + pad_up + pad_down
1078
1087
new_w = w + pad_left + pad_right
@@ -1083,8 +1092,10 @@ def compute_expected_mask():
1083
1092
1084
1093
return expected_mask
1085
1094
1095
+ padding = _parse_padding ()
1096
+
1086
1097
for mask in make_segmentation_masks ():
1087
- out_mask = F .pad_segmentation_mask (mask , padding , padding_mode )
1098
+ out_mask = F .pad_segmentation_mask (mask , padding , "constant" )
1088
1099
1089
- expected_mask = compute_expected_mask ( )
1100
+ expected_mask = _compute_expected_mask ( padding )
1090
1101
torch .testing .assert_close (out_mask , expected_mask )
0 commit comments