@@ -1451,6 +1451,41 @@ def test_box_area_jit(self):
1451
1451
torch .testing .assert_close (scripted_area , expected )
1452
1452
1453
1453
1454
+ class TestBoxAreaCenter :
1455
+ def area_check (self , box , expected , atol = 1e-4 ):
1456
+ out = ops .box_area_center (box )
1457
+ torch .testing .assert_close (out , expected , rtol = 0.0 , check_dtype = False , atol = atol )
1458
+
1459
+ @pytest .mark .parametrize ("dtype" , [torch .int8 , torch .int16 , torch .int32 , torch .int64 ])
1460
+ def test_int_boxes (self , dtype ):
1461
+ box_tensor = ops .box_convert (torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = dtype ),
1462
+ in_fmt = "xyxy" , out_fmt = "cxcywh" )
1463
+ expected = torch .tensor ([10000 , 0 ], dtype = torch .int32 )
1464
+ self .area_check (box_tensor , expected )
1465
+
1466
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float64 ])
1467
+ def test_float_boxes (self , dtype ):
1468
+ box_tensor = ops .box_convert (torch .tensor (FLOAT_BOXES , dtype = dtype ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1469
+ expected = torch .tensor ([604723.0806 , 600965.4666 , 592761.0085 ], dtype = dtype )
1470
+ self .area_check (box_tensor , expected )
1471
+
1472
+ def test_float16_box (self ):
1473
+ box_tensor = ops .box_convert (torch .tensor (
1474
+ [[2.825 , 1.8625 , 3.90 , 4.85 ], [2.825 , 4.875 , 19.20 , 5.10 ], [2.925 , 1.80 , 8.90 , 4.90 ]], dtype = torch .float16
1475
+ ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1476
+
1477
+ expected = torch .tensor ([3.2170 , 3.7108 , 18.5071 ], dtype = torch .float16 )
1478
+ self .area_check (box_tensor , expected , atol = 0.01 )
1479
+
1480
+ def test_box_area_jit (self ):
1481
+ box_tensor = ops .box_convert (torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float ),
1482
+ in_fmt = "xyxy" , out_fmt = "cxcywh" )
1483
+ expected = ops .box_area_center (box_tensor )
1484
+ scripted_fn = torch .jit .script (ops .box_area_center )
1485
+ scripted_area = scripted_fn (box_tensor )
1486
+ torch .testing .assert_close (scripted_area , expected )
1487
+
1488
+
1454
1489
INT_BOXES = [[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ], [0 , 0 , 25 , 25 ]]
1455
1490
INT_BOXES2 = [[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]]
1456
1491
FLOAT_BOXES = [
@@ -1459,6 +1494,14 @@ def test_box_area_jit(self):
1459
1494
[279.2440 , 197.9812 , 1189.4746 , 849.2019 ],
1460
1495
]
1461
1496
1497
+ INT_BOXES_CXCYWH = [[50 , 50 , 100 , 100 ], [25 , 25 , 50 , 50 ], [250 , 250 , 100 , 100 ], [10 , 10 , 20 , 20 ]]
1498
+ INT_BOXES2_CXCYWH = [[50 , 50 , 100 , 100 ], [25 , 25 , 50 , 50 ], [250 , 250 , 100 , 100 ]]
1499
+ FLOAT_BOXES_CXCYWH = [
1500
+ [739.4324 , 518.5154 , 908.1572 , 665.8793 ],
1501
+ [738.8228 , 519.9021 , 907.3512 , 662.3295 ],
1502
+ [734.3593 , 523.5916 , 910.2306 , 651.2207 ]
1503
+ ]
1504
+
1462
1505
1463
1506
def gen_box (size , dtype = torch .float ):
1464
1507
xy1 = torch .rand ((size , 2 ), dtype = dtype )
@@ -1525,6 +1568,65 @@ def test_iou_cartesian(self):
1525
1568
self ._run_cartesian_test (ops .box_iou )
1526
1569
1527
1570
1571
+ class TestIouCenterBase :
1572
+ @staticmethod
1573
+ def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
1574
+ for dtype in dtypes :
1575
+ actual_box1 = torch .tensor (actual_box1 , dtype = dtype )
1576
+ actual_box2 = torch .tensor (actual_box2 , dtype = dtype )
1577
+ expected_box = torch .tensor (expected )
1578
+ out = target_fn (actual_box1 , actual_box2 )
1579
+ torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
1580
+
1581
+ @staticmethod
1582
+ def _run_jit_test (target_fn : Callable , actual_box : List ):
1583
+ box_tensor = torch .tensor (actual_box , dtype = torch .float )
1584
+ expected = target_fn (box_tensor , box_tensor )
1585
+ scripted_fn = torch .jit .script (target_fn )
1586
+ scripted_out = scripted_fn (box_tensor , box_tensor )
1587
+ torch .testing .assert_close (scripted_out , expected )
1588
+
1589
+ @staticmethod
1590
+ def _cartesian_product (boxes1 , boxes2 , target_fn : Callable ):
1591
+ N = boxes1 .size (0 )
1592
+ M = boxes2 .size (0 )
1593
+ result = torch .zeros ((N , M ))
1594
+ for i in range (N ):
1595
+ for j in range (M ):
1596
+ result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ))
1597
+ return result
1598
+
1599
+ @staticmethod
1600
+ def _run_cartesian_test (target_fn : Callable ):
1601
+ boxes1 = ops .box_convert (gen_box (5 ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1602
+ boxes2 = ops .box_convert (gen_box (7 ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1603
+ a = TestIouCenterBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1604
+ b = target_fn (boxes1 , boxes2 )
1605
+ torch .testing .assert_close (a , b )
1606
+
1607
+
1608
+ class TestBoxIouCenter (TestIouBase ):
1609
+ int_expected = [[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [0.04 , 0.16 , 0.0 ]]
1610
+ float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
1611
+
1612
+ @pytest .mark .parametrize (
1613
+ "actual_box1, actual_box2, dtypes, atol, expected" ,
1614
+ [
1615
+ pytest .param (INT_BOXES_CXCYWH , INT_BOXES2_CXCYWH , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected ),
1616
+ pytest .param (FLOAT_BOXES_CXCYWH , FLOAT_BOXES_CXCYWH , [torch .float16 ], 0.002 , float_expected ),
1617
+ pytest .param (FLOAT_BOXES_CXCYWH , FLOAT_BOXES_CXCYWH , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
1618
+ ],
1619
+ )
1620
+ def test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected ):
1621
+ self ._run_test (ops .box_iou_center , actual_box1 , actual_box2 , dtypes , atol , expected )
1622
+
1623
+ def test_iou_jit (self ):
1624
+ self ._run_jit_test (ops .box_iou_center , INT_BOXES_CXCYWH )
1625
+
1626
+ def test_iou_cartesian (self ):
1627
+ self ._run_cartesian_test (ops .box_iou_center )
1628
+
1629
+
1528
1630
class TestGeneralizedBoxIou (TestIouBase ):
1529
1631
int_expected = [[1.0 , 0.25 , - 0.7778 ], [0.25 , 1.0 , - 0.8611 ], [- 0.7778 , - 0.8611 , 1.0 ], [0.0625 , 0.25 , - 0.8819 ]]
1530
1632
float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
0 commit comments