Skip to content

Commit cf93d9e

Browse files
committed
Add box_area_center and box_iou_center for cxcywh format
1 parent 124dfa4 commit cf93d9e

File tree

4 files changed

+163
-0
lines changed

4 files changed

+163
-0
lines changed

docs/source/ops.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ These utility functions perform various operations on bounding boxes.
5050
:template: function.rst
5151

5252
box_area
53+
box_area_center
5354
box_convert
5455
box_iou
56+
box_iou_center
5557
clip_boxes_to_image
5658
complete_box_iou
5759
distance_box_iou

test/test_ops.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,6 +1451,41 @@ def test_box_area_jit(self):
14511451
torch.testing.assert_close(scripted_area, expected)
14521452

14531453

1454+
class TestBoxAreaCenter:
1455+
def area_check(self, box, expected, atol=1e-4):
1456+
out = ops.box_area_center(box)
1457+
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)
1458+
1459+
@pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
1460+
def test_int_boxes(self, dtype):
1461+
box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype),
1462+
in_fmt="xyxy", out_fmt="cxcywh")
1463+
expected = torch.tensor([10000, 0], dtype=torch.int32)
1464+
self.area_check(box_tensor, expected)
1465+
1466+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
1467+
def test_float_boxes(self, dtype):
1468+
box_tensor = ops.box_convert(torch.tensor(FLOAT_BOXES, dtype=dtype), in_fmt="xyxy", out_fmt="cxcywh")
1469+
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
1470+
self.area_check(box_tensor, expected)
1471+
1472+
def test_float16_box(self):
1473+
box_tensor = ops.box_convert(torch.tensor(
1474+
[[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
1475+
), in_fmt="xyxy", out_fmt="cxcywh")
1476+
1477+
expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
1478+
self.area_check(box_tensor, expected, atol=0.01)
1479+
1480+
def test_box_area_jit(self):
1481+
box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float),
1482+
in_fmt="xyxy", out_fmt="cxcywh")
1483+
expected = ops.box_area_center(box_tensor)
1484+
scripted_fn = torch.jit.script(ops.box_area_center)
1485+
scripted_area = scripted_fn(box_tensor)
1486+
torch.testing.assert_close(scripted_area, expected)
1487+
1488+
14541489
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
14551490
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
14561491
FLOAT_BOXES = [
@@ -1459,6 +1494,14 @@ def test_box_area_jit(self):
14591494
[279.2440, 197.9812, 1189.4746, 849.2019],
14601495
]
14611496

1497+
INT_BOXES_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100], [10, 10, 20, 20]]
1498+
INT_BOXES2_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100]]
1499+
FLOAT_BOXES_CXCYWH = [
1500+
[739.4324, 518.5154, 908.1572, 665.8793],
1501+
[738.8228, 519.9021, 907.3512, 662.3295],
1502+
[734.3593, 523.5916, 910.2306, 651.2207]
1503+
]
1504+
14621505

14631506
def gen_box(size, dtype=torch.float):
14641507
xy1 = torch.rand((size, 2), dtype=dtype)
@@ -1525,6 +1568,65 @@ def test_iou_cartesian(self):
15251568
self._run_cartesian_test(ops.box_iou)
15261569

15271570

1571+
class TestIouCenterBase:
1572+
@staticmethod
1573+
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
1574+
for dtype in dtypes:
1575+
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
1576+
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
1577+
expected_box = torch.tensor(expected)
1578+
out = target_fn(actual_box1, actual_box2)
1579+
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
1580+
1581+
@staticmethod
1582+
def _run_jit_test(target_fn: Callable, actual_box: List):
1583+
box_tensor = torch.tensor(actual_box, dtype=torch.float)
1584+
expected = target_fn(box_tensor, box_tensor)
1585+
scripted_fn = torch.jit.script(target_fn)
1586+
scripted_out = scripted_fn(box_tensor, box_tensor)
1587+
torch.testing.assert_close(scripted_out, expected)
1588+
1589+
@staticmethod
1590+
def _cartesian_product(boxes1, boxes2, target_fn: Callable):
1591+
N = boxes1.size(0)
1592+
M = boxes2.size(0)
1593+
result = torch.zeros((N, M))
1594+
for i in range(N):
1595+
for j in range(M):
1596+
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
1597+
return result
1598+
1599+
@staticmethod
1600+
def _run_cartesian_test(target_fn: Callable):
1601+
boxes1 = ops.box_convert(gen_box(5), in_fmt="xyxy", out_fmt="cxcywh")
1602+
boxes2 = ops.box_convert(gen_box(7), in_fmt="xyxy", out_fmt="cxcywh")
1603+
a = TestIouCenterBase._cartesian_product(boxes1, boxes2, target_fn)
1604+
b = target_fn(boxes1, boxes2)
1605+
torch.testing.assert_close(a, b)
1606+
1607+
1608+
class TestBoxIouCenter(TestIouBase):
1609+
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.04, 0.16, 0.0]]
1610+
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1611+
1612+
@pytest.mark.parametrize(
1613+
"actual_box1, actual_box2, dtypes, atol, expected",
1614+
[
1615+
pytest.param(INT_BOXES_CXCYWH, INT_BOXES2_CXCYWH, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1616+
pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float16], 0.002, float_expected),
1617+
pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float32, torch.float64], 1e-3, float_expected),
1618+
],
1619+
)
1620+
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1621+
self._run_test(ops.box_iou_center, actual_box1, actual_box2, dtypes, atol, expected)
1622+
1623+
def test_iou_jit(self):
1624+
self._run_jit_test(ops.box_iou_center, INT_BOXES_CXCYWH)
1625+
1626+
def test_iou_cartesian(self):
1627+
self._run_cartesian_test(ops.box_iou_center)
1628+
1629+
15281630
class TestGeneralizedBoxIou(TestIouBase):
15291631
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
15301632
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

torchvision/ops/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from .boxes import (
33
batched_nms,
44
box_area,
5+
box_area_center,
56
box_convert,
67
box_iou,
8+
box_iou_center,
79
clip_boxes_to_image,
810
complete_box_iou,
911
distance_box_iou,
@@ -40,7 +42,9 @@
4042
"clip_boxes_to_image",
4143
"box_convert",
4244
"box_area",
45+
"box_area_center",
4346
"box_iou",
47+
"box_iou_center",
4448
"generalized_box_iou",
4549
"distance_box_iou",
4650
"complete_box_iou",

torchvision/ops/boxes.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,25 @@ def box_area(boxes: Tensor) -> Tensor:
290290
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
291291

292292

293+
def box_area_center(boxes: Tensor) -> Tensor:
294+
"""
295+
Computes the area of a set of bounding boxes, which are specified by their
296+
(cx, cy, w, h) coordinates.
297+
298+
Args:
299+
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
300+
are expected to be in (cx, cy, w, h) format with
301+
``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.
302+
303+
Returns:
304+
Tensor[N]: the area for each box
305+
"""
306+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
307+
_log_api_usage_once(box_area)
308+
boxes = _upcast(boxes)
309+
return boxes[:, 2] * boxes[:, 3]
310+
311+
293312
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
294313
# with slight modifications
295314
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
@@ -328,6 +347,42 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
328347
return iou
329348

330349

350+
def _box_inter_union_center(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
351+
area1 = box_area_center(boxes1)
352+
area2 = box_area_center(boxes2)
353+
354+
lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2]
355+
rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2]
356+
357+
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
358+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
359+
360+
union = area1[:, None] + area2 - inter
361+
362+
return inter, union
363+
364+
365+
def box_iou_center(boxes1: Tensor, boxes2: Tensor) -> Tensor:
366+
"""
367+
Return intersection-over-union (Jaccard index) between two sets of boxes.
368+
369+
Both sets of boxes are expected to be in ``(cx, cy, w, h)`` format with
370+
``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.
371+
372+
Args:
373+
boxes1 (Tensor[N, 4]): first set of boxes
374+
boxes2 (Tensor[M, 4]): second set of boxes
375+
376+
Returns:
377+
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
378+
"""
379+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
380+
_log_api_usage_once(box_iou_center)
381+
inter, union = _box_inter_union_center(boxes1, boxes2)
382+
iou = inter / union
383+
return iou
384+
385+
331386
# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
332387
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
333388
"""

0 commit comments

Comments
 (0)