@@ -371,16 +371,13 @@ def resized_crop_segmentation_mask():
371
371
372
372
@register_kernel_info_from_sample_inputs_fn
373
373
def pad_segmentation_mask ():
374
- for mask , padding , fill , padding_mode in itertools .product (
374
+ for mask , padding , padding_mode in itertools .product (
375
375
make_segmentation_masks (),
376
376
[[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]], # padding
377
- [0 , 1 ], # fill
378
377
["constant" , "symmetric" , "edge" ], # padding mode,
379
378
):
380
379
if padding_mode == "symmetric" and mask .ndim not in [3 , 4 ]:
381
380
continue
382
- if padding_mode == "edge" and fill != 0 :
383
- continue
384
381
if (
385
382
padding_mode == "edge"
386
383
and len (padding ) == 2
@@ -392,7 +389,7 @@ def pad_segmentation_mask():
392
389
continue
393
390
if padding_mode == "edge" and mask .ndim not in [2 , 3 , 4 , 5 ]:
394
391
continue
395
- yield SampleInput (mask , padding = padding , fill = fill , padding_mode = padding_mode )
392
+ yield SampleInput (mask , padding = padding , padding_mode = padding_mode )
396
393
397
394
398
395
@pytest .mark .parametrize (
@@ -1059,10 +1056,35 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
1059
1056
1060
1057
def test_correctness_pad_segmentation_mask_on_fixed_input (device ):
1061
1058
mask = torch .ones ((1 , 3 , 3 ), dtype = torch .long , device = device )
1062
- mask [:, 1 , 1 ] = 0
1063
1059
1064
- out_mask = F .pad_segmentation_mask (mask , padding = [1 , 1 , 1 , 1 ], fill = 1 )
1060
+ out_mask = F .pad_segmentation_mask (mask , padding = [1 , 1 , 1 , 1 ])
1065
1061
1066
- expected_mask = torch .ones ((1 , 3 + 1 + 1 , 3 + 1 + 1 ), dtype = torch .long , device = device )
1067
- expected_mask [:, 2 , 2 ] = 0
1062
+ expected_mask = torch .zeros ((1 , 5 , 5 ), dtype = torch .long , device = device )
1063
+ expected_mask [:, 1 : - 1 , 1 : - 1 ] = 1
1068
1064
torch .testing .assert_close (out_mask , expected_mask )
1065
+
1066
+
1067
+ @pytest .mark .parametrize ("padding,padding_mode" , [([1 , 2 , 3 , 4 ], "constant" )])
1068
+ def test_correctness_pad_segmentation_mask (padding , padding_mode ):
1069
+ def compute_expected_mask ():
1070
+ h , w = mask .shape [- 2 ], mask .shape [- 1 ]
1071
+
1072
+ pad_left = padding [0 ]
1073
+ pad_up = padding [1 ]
1074
+ pad_right = padding [2 ]
1075
+ pad_down = padding [3 ]
1076
+
1077
+ new_h = h + pad_up + pad_down
1078
+ new_w = w + pad_left + pad_right
1079
+
1080
+ new_shape = (* mask .shape [:- 2 ], new_h , new_w ) if len (mask .shape ) > 2 else (new_h , new_w )
1081
+ expected_mask = torch .zeros (new_shape , dtype = torch .long )
1082
+ expected_mask [..., pad_up :- pad_down , pad_left :- pad_right ] = mask
1083
+
1084
+ return expected_mask
1085
+
1086
+ for mask in make_segmentation_masks ():
1087
+ out_mask = F .pad_segmentation_mask (mask , padding , padding_mode )
1088
+
1089
+ expected_mask = compute_expected_mask ()
1090
+ torch .testing .assert_close (out_mask , expected_mask )
0 commit comments