@@ -357,7 +357,7 @@ def pad_segmentation_mask():
357
357
for mask , padding , padding_mode in itertools .product (
358
358
make_segmentation_masks (),
359
359
[[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]], # padding
360
- ["constant" , "symmetric" , "edge" ], # padding mode,
360
+ ["constant" , "symmetric" , "edge" , "reflect" ], # padding mode,
361
361
):
362
362
if padding_mode == "symmetric" and mask .ndim not in [3 , 4 ]:
363
363
continue
@@ -969,15 +969,24 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
969
969
torch .testing .assert_close (out_mask , expected_mask )
970
970
971
971
972
- @pytest .mark .parametrize ("padding,padding_mode" , [([1 , 2 , 3 , 4 ], "constant" )])
973
- def test_correctness_pad_segmentation_mask (padding , padding_mode ):
974
- def compute_expected_mask ():
975
- h , w = mask .shape [- 2 ], mask .shape [- 1 ]
972
+ @pytest .mark .parametrize ("padding" , [[1 , 2 , 3 , 4 ], [1 ], 1 , 1.0 , [1 , 2 ]])
973
+ def test_correctness_pad_segmentation_mask (padding ):
974
+ def _parse_padding ():
975
+ if isinstance (padding , int ):
976
+ return [padding ] * 4
977
+ if isinstance (padding , float ):
978
+ return [int (padding )] * 4
979
+ if isinstance (padding , list ):
980
+ if len (padding ) == 1 :
981
+ return padding * 4
982
+ if len (padding ) == 2 :
983
+ return padding * 2 # [left, up, right, down]
984
+
985
+ return padding
976
986
977
- pad_left = padding [0 ]
978
- pad_up = padding [1 ]
979
- pad_right = padding [2 ]
980
- pad_down = padding [3 ]
987
+ def _compute_expected_mask (padding ):
988
+ h , w = mask .shape [- 2 ], mask .shape [- 1 ]
989
+ pad_left , pad_up , pad_right , pad_down = padding
981
990
982
991
new_h = h + pad_up + pad_down
983
992
new_w = w + pad_left + pad_right
@@ -988,8 +997,10 @@ def compute_expected_mask():
988
997
989
998
return expected_mask
990
999
1000
+ padding = _parse_padding ()
1001
+
991
1002
for mask in make_segmentation_masks ():
992
- out_mask = F .pad_segmentation_mask (mask , padding , padding_mode )
1003
+ out_mask = F .pad_segmentation_mask (mask , padding , "constant" )
993
1004
994
- expected_mask = compute_expected_mask ( )
1005
+ expected_mask = _compute_expected_mask ( padding )
995
1006
torch .testing .assert_close (out_mask , expected_mask )
0 commit comments