@@ -379,12 +379,15 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
379
379
np .max (transformed_points [:, 1 ]),
380
380
]
381
381
out_bbox = features .BoundingBox (
382
- out_bbox , format = features .BoundingBoxFormat .XYXY , image_size = bbox .image_size , dtype = torch .float32
382
+ out_bbox ,
383
+ format = features .BoundingBoxFormat .XYXY ,
384
+ image_size = bbox .image_size ,
385
+ dtype = torch .float32 ,
386
+ device = bbox .device ,
383
387
)
384
- out_bbox = convert_bounding_box_format (
388
+ return convert_bounding_box_format (
385
389
out_bbox , old_format = features .BoundingBoxFormat .XYXY , new_format = bbox .format , copy = False
386
390
)
387
- return out_bbox .to (bbox .device )
388
391
389
392
image_size = (32 , 38 )
390
393
@@ -439,8 +442,8 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
439
442
[1 , 1 , 5 , 5 ],
440
443
]
441
444
in_boxes = features .BoundingBox (
442
- in_boxes , format = features .BoundingBoxFormat .XYXY , image_size = image_size , dtype = torch .float64
443
- ). to ( device )
445
+ in_boxes , format = features .BoundingBoxFormat .XYXY , image_size = image_size , dtype = torch .float64 , device = device
446
+ )
444
447
# Tested parameters
445
448
angle = 63
446
449
scale = 0.89
@@ -473,9 +476,7 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
473
476
shear = (0 , 0 ),
474
477
)
475
478
476
- assert len (output_boxes ) == len (expected_bboxes )
477
- for a_out_box , out_box in zip (expected_bboxes , output_boxes .cpu ()):
478
- np .testing .assert_allclose (out_box .cpu ().numpy (), a_out_box )
479
+ torch .testing .assert_close (output_boxes .tolist (), expected_bboxes )
479
480
480
481
481
482
@pytest .mark .parametrize ("angle" , [- 54 , 56 ])
@@ -589,12 +590,15 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
589
590
out_bbox [3 ] -= tr_y
590
591
591
592
out_bbox = features .BoundingBox (
592
- out_bbox , format = features .BoundingBoxFormat .XYXY , image_size = image_size , dtype = torch .float32
593
+ out_bbox ,
594
+ format = features .BoundingBoxFormat .XYXY ,
595
+ image_size = image_size ,
596
+ dtype = torch .float32 ,
597
+ device = bbox .device ,
593
598
)
594
- out_bbox = convert_bounding_box_format (
599
+ return convert_bounding_box_format (
595
600
out_bbox , old_format = features .BoundingBoxFormat .XYXY , new_format = bbox .format , copy = False
596
601
)
597
- return out_bbox .to (bbox .device )
598
602
599
603
image_size = (32 , 38 )
600
604
@@ -630,9 +634,6 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
630
634
expected_bboxes = torch .stack (expected_bboxes )
631
635
else :
632
636
expected_bboxes = expected_bboxes [0 ]
633
- print ("input:" , bboxes )
634
- print ("output_bboxes:" , output_bboxes )
635
- print ("expected_bboxes:" , expected_bboxes )
636
637
torch .testing .assert_close (output_bboxes , expected_bboxes )
637
638
638
639
@@ -649,8 +650,8 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
649
650
[image_size [1 ] // 2 - 10 , image_size [0 ] // 2 - 10 , image_size [1 ] // 2 + 10 , image_size [0 ] // 2 + 10 ],
650
651
]
651
652
in_boxes = features .BoundingBox (
652
- in_boxes , format = features .BoundingBoxFormat .XYXY , image_size = image_size , dtype = torch .float64
653
- ). to ( device )
653
+ in_boxes , format = features .BoundingBoxFormat .XYXY , image_size = image_size , dtype = torch .float64 , device = device
654
+ )
654
655
# Tested parameters
655
656
angle = 45
656
657
center = None if expand else [12 , 23 ]
@@ -687,6 +688,4 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
687
688
center = center ,
688
689
)
689
690
690
- assert len (output_boxes ) == len (expected_bboxes )
691
- for a_out_box , out_box in zip (expected_bboxes , output_boxes .cpu ()):
692
- np .testing .assert_allclose (out_box .cpu ().numpy (), a_out_box )
691
+ torch .testing .assert_close (output_boxes .tolist (), expected_bboxes )
0 commit comments