@@ -369,26 +369,20 @@ def resized_crop_segmentation_mask():
369
369
):
370
370
yield SampleInput (mask , top = top , left = left , height = height , width = width , size = size )
371
371
372
+
372
373
@register_kernel_info_from_sample_inputs_fn
373
374
def pad_segmentation_mask ():
374
375
for mask , padding , padding_mode in itertools .product (
375
376
make_segmentation_masks (),
376
377
[[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]], # padding
377
378
["constant" , "symmetric" , "edge" , "reflect" ], # padding mode,
378
379
):
379
- if padding_mode == "symmetric" and mask .ndim not in [3 , 4 ]:
380
- continue
381
- if (
382
- padding_mode == "edge"
383
- and len (padding ) == 2
384
- and mask .ndim not in [2 , 3 ]
385
- or len (padding ) == 4
386
- and mask .ndim not in [4 , 3 ]
387
- or len (padding ) == 1
388
- ):
380
+ if padding_mode == "symmetric" and mask .ndim not in [2 , 3 , 4 ]:
389
381
continue
390
- if padding_mode == "edge" and mask .ndim not in [2 , 3 , 4 , 5 ]:
382
+
383
+ if (padding_mode == "edge" or padding_mode == "reflect" ) and mask .ndim not in [2 , 3 , 4 ]:
391
384
continue
385
+
392
386
yield SampleInput (mask , padding = padding , padding_mode = padding_mode )
393
387
394
388
@@ -1054,6 +1048,7 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
1054
1048
output_mask = F .resized_crop_segmentation_mask (in_mask , top , left , height , width , size )
1055
1049
torch .testing .assert_close (output_mask , expected_mask )
1056
1050
1051
+
1057
1052
def test_correctness_pad_segmentation_mask_on_fixed_input (device ):
1058
1053
mask = torch .ones ((1 , 3 , 3 ), dtype = torch .long , device = device )
1059
1054
@@ -1064,24 +1059,22 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
1064
1059
torch .testing .assert_close (out_mask , expected_mask )
1065
1060
1066
1061
1067
- @pytest .mark .parametrize ("padding" , [[1 , 2 , 3 , 4 ], [1 ], 1 , 1.0 , [1 , 2 ]])
1062
+ @pytest .mark .parametrize ("padding" , [[1 , 2 , 3 , 4 ], [1 ], 1 , [1 , 2 ]])
1068
1063
def test_correctness_pad_segmentation_mask (padding ):
1069
- def _parse_padding ():
1070
- if isinstance (padding , int ):
1071
- return [padding ] * 4
1072
- if isinstance (padding , float ):
1073
- return [int (padding )] * 4
1074
- if isinstance (padding , list ):
1075
- if len (padding ) == 1 :
1076
- return padding * 4
1077
- if len (padding ) == 2 :
1078
- return padding * 2 # [left, up, right, down]
1079
-
1080
- return padding
1081
-
1082
- def _compute_expected_mask (padding ):
1064
+ def _compute_expected_mask ():
1065
+ def parse_padding ():
1066
+ if isinstance (padding , int ):
1067
+ return [padding ] * 4
1068
+ if isinstance (padding , list ):
1069
+ if len (padding ) == 1 :
1070
+ return padding * 4
1071
+ if len (padding ) == 2 :
1072
+ return padding * 2 # [left, up, right, down]
1073
+
1074
+ return padding
1075
+
1083
1076
h , w = mask .shape [- 2 ], mask .shape [- 1 ]
1084
- pad_left , pad_up , pad_right , pad_down = padding
1077
+ pad_left , pad_up , pad_right , pad_down = parse_padding ()
1085
1078
1086
1079
new_h = h + pad_up + pad_down
1087
1080
new_w = w + pad_left + pad_right
@@ -1092,10 +1085,8 @@ def _compute_expected_mask(padding):
1092
1085
1093
1086
return expected_mask
1094
1087
1095
- padding = _parse_padding ()
1096
-
1097
1088
for mask in make_segmentation_masks ():
1098
1089
out_mask = F .pad_segmentation_mask (mask , padding , "constant" )
1099
1090
1100
- expected_mask = _compute_expected_mask (padding )
1091
+ expected_mask = _compute_expected_mask ()
1101
1092
torch .testing .assert_close (out_mask , expected_mask )
0 commit comments