@@ -354,16 +354,13 @@ def vertical_flip_segmentation_mask():
354
354
355
355
@register_kernel_info_from_sample_inputs_fn
356
356
def pad_segmentation_mask ():
357
- for mask , padding , fill , padding_mode in itertools .product (
357
+ for mask , padding , padding_mode in itertools .product (
358
358
make_segmentation_masks (),
359
359
[[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]], # padding
360
- [0 , 1 ], # fill
361
360
["constant" , "symmetric" , "edge" ], # padding mode,
362
361
):
363
362
if padding_mode == "symmetric" and mask .ndim not in [3 , 4 ]:
364
363
continue
365
- if padding_mode == "edge" and fill != 0 :
366
- continue
367
364
if (
368
365
padding_mode == "edge"
369
366
and len (padding ) == 2
@@ -375,7 +372,7 @@ def pad_segmentation_mask():
375
372
continue
376
373
if padding_mode == "edge" and mask .ndim not in [2 , 3 , 4 , 5 ]:
377
374
continue
378
- yield SampleInput (mask , padding = padding , fill = fill , padding_mode = padding_mode )
375
+ yield SampleInput (mask , padding = padding , padding_mode = padding_mode )
379
376
380
377
381
378
@pytest .mark .parametrize (
@@ -964,10 +961,35 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
964
961
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
965
962
def test_correctness_pad_segmentation_mask_on_fixed_input (device ):
966
963
mask = torch .ones ((1 , 3 , 3 ), dtype = torch .long , device = device )
967
- mask [:, 1 , 1 ] = 0
968
964
969
- out_mask = F .pad_segmentation_mask (mask , padding = [1 , 1 , 1 , 1 ], fill = 1 )
965
+ out_mask = F .pad_segmentation_mask (mask , padding = [1 , 1 , 1 , 1 ])
970
966
971
- expected_mask = torch .ones ((1 , 3 + 1 + 1 , 3 + 1 + 1 ), dtype = torch .long , device = device )
972
- expected_mask [:, 2 , 2 ] = 0
967
+ expected_mask = torch .zeros ((1 , 5 , 5 ), dtype = torch .long , device = device )
968
+ expected_mask [:, 1 : - 1 , 1 : - 1 ] = 1
973
969
torch .testing .assert_close (out_mask , expected_mask )
970
+
971
+
972
+ @pytest .mark .parametrize ("padding,padding_mode" , [([1 , 2 , 3 , 4 ], "constant" )])
973
+ def test_correctness_pad_segmentation_mask (padding , padding_mode ):
974
+ def compute_expected_mask ():
975
+ h , w = mask .shape [- 2 ], mask .shape [- 1 ]
976
+
977
+ pad_left = padding [0 ]
978
+ pad_up = padding [1 ]
979
+ pad_right = padding [2 ]
980
+ pad_down = padding [3 ]
981
+
982
+ new_h = h + pad_up + pad_down
983
+ new_w = w + pad_left + pad_right
984
+
985
+ new_shape = (* mask .shape [:- 2 ], new_h , new_w ) if len (mask .shape ) > 2 else (new_h , new_w )
986
+ expected_mask = torch .zeros (new_shape , dtype = torch .long )
987
+ expected_mask [..., pad_up :- pad_down , pad_left :- pad_right ] = mask
988
+
989
+ return expected_mask
990
+
991
+ for mask in make_segmentation_masks ():
992
+ out_mask = F .pad_segmentation_mask (mask , padding , padding_mode )
993
+
994
+ expected_mask = compute_expected_mask ()
995
+ torch .testing .assert_close (out_mask , expected_mask )
0 commit comments