@@ -382,6 +382,15 @@ def pad_segmentation_mask():
382382 yield SampleInput (mask , padding = padding , padding_mode = padding_mode )
383383
384384
385+ @register_kernel_info_from_sample_inputs_fn
386+ def pad_bounding_box ():
387+ for bounding_box , padding in itertools .product (
388+ make_bounding_boxes (),
389+ [[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]],
390+ ):
391+ yield SampleInput (bounding_box , padding = padding , format = bounding_box .format )
392+
393+
385394@register_kernel_info_from_sample_inputs_fn
386395def perspective_bounding_box ():
387396 for bounding_box , perspective_coeffs in itertools .product (
@@ -1103,22 +1112,67 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
11031112 torch .testing .assert_close (out_mask , expected_mask )
11041113
11051114
1115+ def _parse_padding (padding ):
1116+ if isinstance (padding , int ):
1117+ return [padding ] * 4
1118+ if isinstance (padding , list ):
1119+ if len (padding ) == 1 :
1120+ return padding * 4
1121+ if len (padding ) == 2 :
1122+ return padding * 2 # [left, up, right, down]
1123+
1124+ return padding
1125+
1126+
1127+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1128+ @pytest .mark .parametrize ("padding" , [[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]])
1129+ def test_correctness_pad_bounding_box (device , padding ):
1130+ def _compute_expected_bbox (bbox , padding_ ):
1131+ pad_left , pad_up , _ , _ = _parse_padding (padding_ )
1132+
1133+ bbox_format = bbox .format
1134+ bbox_dtype = bbox .dtype
1135+ bbox = convert_bounding_box_format (bbox , old_format = bbox_format , new_format = features .BoundingBoxFormat .XYXY )
1136+
1137+ bbox [0 ::2 ] += pad_left
1138+ bbox [1 ::2 ] += pad_up
1139+
1140+ bbox = convert_bounding_box_format (
1141+ bbox , old_format = features .BoundingBoxFormat .XYXY , new_format = bbox_format , copy = False
1142+ )
1143+ if bbox .dtype != bbox_dtype :
1144+ # Temporary cast to original dtype
1145+ # e.g. float32 -> int
1146+ bbox = bbox .to (bbox_dtype )
1147+ return bbox
1148+
1149+ for bboxes in make_bounding_boxes ():
1150+ bboxes = bboxes .to (device )
1151+ bboxes_format = bboxes .format
1152+ bboxes_image_size = bboxes .image_size
1153+
1154+ output_boxes = F .pad_bounding_box (bboxes , padding , format = bboxes_format )
1155+
1156+ if bboxes .ndim < 2 :
1157+ bboxes = [bboxes ]
1158+
1159+ expected_bboxes = []
1160+ for bbox in bboxes :
1161+ bbox = features .BoundingBox (bbox , format = bboxes_format , image_size = bboxes_image_size )
1162+ expected_bboxes .append (_compute_expected_bbox (bbox , padding ))
1163+
1164+ if len (expected_bboxes ) > 1 :
1165+ expected_bboxes = torch .stack (expected_bboxes )
1166+ else :
1167+ expected_bboxes = expected_bboxes [0 ]
1168+ torch .testing .assert_close (output_boxes , expected_bboxes )
1169+
1170+
11061171@pytest .mark .parametrize ("padding" , [[1 , 2 , 3 , 4 ], [1 ], 1 , [1 , 2 ]])
11071172def test_correctness_pad_segmentation_mask (padding ):
1108- def _compute_expected_mask ():
1109- def parse_padding ():
1110- if isinstance (padding , int ):
1111- return [padding ] * 4
1112- if isinstance (padding , list ):
1113- if len (padding ) == 1 :
1114- return padding * 4
1115- if len (padding ) == 2 :
1116- return padding * 2 # [left, up, right, down]
1117-
1118- return padding
1119-
1173+ def _compute_expected_mask (mask , padding_ ):
11201174 h , w = mask .shape [- 2 ], mask .shape [- 1 ]
1121- pad_left , pad_up , pad_right , pad_down = parse_padding ( )
1175+ pad_left , pad_up , pad_right , pad_down = _parse_padding ( padding_ )
11221176
11231177 new_h = h + pad_up + pad_down
11241178 new_w = w + pad_left + pad_right
@@ -1132,7 +1186,7 @@ def parse_padding():
11321186 for mask in make_segmentation_masks ():
11331187 out_mask = F .pad_segmentation_mask (mask , padding , "constant" )
11341188
1135- expected_mask = _compute_expected_mask ()
1189+ expected_mask = _compute_expected_mask (mask , padding )
11361190 torch .testing .assert_close (out_mask , expected_mask )
11371191
11381192
0 commit comments