@@ -279,26 +279,33 @@ def _affine_bounding_box_xyxy(
279
279
bounding_box : torch .Tensor ,
280
280
image_size : Tuple [int , int ],
281
281
angle : float ,
282
- translate : List [float ],
283
- scale : float ,
284
- shear : List [float ],
282
+ translate : Optional [ List [float ]] = None ,
283
+ scale : Optional [ float ] = None ,
284
+ shear : Optional [ List [float ]] = None ,
285
285
center : Optional [List [float ]] = None ,
286
286
expand : bool = False ,
287
287
) -> torch .Tensor :
288
- # This is just a dummy value to avoid raising an error in `_affine_parse_args` although we don't have an
289
- # interpolation mode for bounding boxes.
290
- interpolation = InterpolationMode .NEAREST
291
- angle , translate , shear , center = _affine_parse_args (angle , translate , scale , shear , interpolation , center )
288
+ dtype = bounding_box .dtype if torch .is_floating_point (bounding_box ) else torch .float32
289
+ device = bounding_box .device
290
+
291
+ if translate is None :
292
+ translate = [0.0 , 0.0 ]
293
+
294
+ if scale is None :
295
+ scale = 1.0
296
+
297
+ if shear is None :
298
+ shear = [0.0 , 0.0 ]
292
299
293
300
if center is None :
294
301
height , width = image_size
295
- center = [width * 0.5 , height * 0.5 ]
296
-
297
- dtype = bounding_box .dtype if torch .is_floating_point (bounding_box ) else torch .float32
298
- device = bounding_box .device
302
+ center_f = [width * 0.5 , height * 0.5 ]
303
+ else :
304
+ center_f = [float (c ) for c in center ]
299
305
306
+ translate_f = [float (t ) for t in translate ]
300
307
affine_matrix = torch .tensor (
301
- _get_inverse_affine_matrix (center , angle , translate , scale , shear , inverted = False ),
308
+ _get_inverse_affine_matrix (center_f , angle , translate_f , scale , shear , inverted = False ),
302
309
dtype = dtype ,
303
310
device = device ,
304
311
).view (2 , 3 )
@@ -521,16 +528,7 @@ def rotate_bounding_box(
521
528
bounding_box , old_format = format , new_format = features .BoundingBoxFormat .XYXY
522
529
).view (- 1 , 4 )
523
530
524
- out_bboxes = _affine_bounding_box_xyxy (
525
- bounding_box ,
526
- image_size ,
527
- angle = - angle ,
528
- translate = [0.0 , 0.0 ],
529
- scale = 1.0 ,
530
- shear = [0.0 , 0.0 ],
531
- center = center ,
532
- expand = expand ,
533
- )
531
+ out_bboxes = _affine_bounding_box_xyxy (bounding_box , image_size , angle = - angle , center = center , expand = expand )
534
532
535
533
return convert_format_bounding_box (
536
534
out_bboxes , old_format = features .BoundingBoxFormat .XYXY , new_format = format , copy = False
0 commit comments