Skip to content

Commit 57284f6

Browse files
YosuaMichaeldatumboxpmeier
authored andcommitted
[fbsync] Added CIOU loss function (#5776)
Summary: * added ciou loss * "formatting with flake8 and ufmt" * formatting with ufmt and flake8 * minor changes * changes as per the suggestions * added reference in torchvision/ops/__init__.py * sample test * tests formatted * added description * formatting * edited tests * changes in tests * added tests for multiple boxes * minor edits * minor edit * doc added * minor edits * Update test_ops.py * formatting test file * changes as per the suggestions * formatting and adding some more tests * bounding box added * removed unnecessary comment * added docstring * added type annotations * removed potential bug * Update torchvision/ops/boxes.py * Update torchvision/ops/boxes.py * Update test/test_ops.py Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095722 fbshipit-source-id: be54569f72c54380a832e0798b83a0cfe6558c86 Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Philip Meier <[email protected]>
1 parent ccc0a92 commit 57284f6

File tree

5 files changed

+230
-0
lines changed

5 files changed

+230
-0
lines changed

docs/source/ops.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ Operators
1919
box_convert
2020
box_iou
2121
clip_boxes_to_image
22+
complete_box_iou
23+
complete_box_iou_loss
2224
deform_conv2d
2325
drop_block2d
2426
drop_block3d

test/test_ops.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,43 @@ def test_giou_jit(self) -> None:
12581258
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])
12591259

12601260

1261+
class TestCompleteBoxIou(BoxTestBase):
1262+
def _target_fn(self) -> Tuple[bool, Callable]:
1263+
return (True, ops.complete_box_iou)
1264+
1265+
def _generate_int_input() -> List[List[int]]:
1266+
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
1267+
1268+
def _generate_int_expected() -> List[List[float]]:
1269+
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
1270+
1271+
def _generate_float_input() -> List[List[float]]:
1272+
return [
1273+
[285.3538, 185.5758, 1193.5110, 851.4551],
1274+
[285.1472, 188.7374, 1192.4984, 851.0669],
1275+
[279.2440, 197.9812, 1189.4746, 849.2019],
1276+
]
1277+
1278+
def _generate_float_expected() -> List[List[float]]:
1279+
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1280+
1281+
@pytest.mark.parametrize(
1282+
"test_input, dtypes, tolerance, expected",
1283+
[
1284+
pytest.param(
1285+
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
1286+
),
1287+
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.002, _generate_float_expected()),
1288+
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()),
1289+
],
1290+
)
1291+
def test_complete_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None:
1292+
self._run_test(test_input, dtypes, tolerance, expected)
1293+
1294+
def test_ciou_jit(self) -> None:
1295+
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])
1296+
1297+
12611298
class TestMasksToBoxes:
12621299
def test_masks_box(self):
12631300
def masks_box_check(masks, expected, tolerance=1e-4):
@@ -1578,6 +1615,7 @@ def test_giou_loss(self, dtype, device) -> None:
15781615
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
15791616
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
15801617
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
1618+
15811619
box1s = torch.stack([box2, box2], dim=0)
15821620
box2s = torch.stack([box3, box4], dim=0)
15831621

@@ -1623,5 +1661,53 @@ def test_empty_inputs(self, dtype, device) -> None:
16231661
assert loss.numel() == 0, "giou_loss for two empty box should be empty"
16241662

16251663

1664+
class TestCIOULoss:
1665+
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1666+
@pytest.mark.parametrize("device", cpu_and_gpu())
1667+
def test_ciou_loss(self, dtype, device):
1668+
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
1669+
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
1670+
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
1671+
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
1672+
1673+
box1s = torch.stack([box2, box2], dim=0)
1674+
box2s = torch.stack([box3, box4], dim=0)
1675+
1676+
def assert_ciou_loss(box1, box2, expected_output, reduction="none"):
1677+
1678+
output = ops.complete_box_iou_loss(box1, box2, reduction=reduction)
1679+
expected_output = torch.tensor(expected_output, device=device)
1680+
tol = 1e-5 if dtype != torch.half else 1e-3
1681+
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol)
1682+
1683+
assert_ciou_loss(box1, box1, 0.0)
1684+
1685+
assert_ciou_loss(box1, box2, 0.8125)
1686+
1687+
assert_ciou_loss(box1, box3, 1.1923)
1688+
1689+
assert_ciou_loss(box1, box4, 1.2500)
1690+
1691+
assert_ciou_loss(box1s, box2s, 1.2250, reduction="mean")
1692+
assert_ciou_loss(box1s, box2s, 2.4500, reduction="sum")
1693+
1694+
@pytest.mark.parametrize("device", cpu_and_gpu())
1695+
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1696+
def test_empty_inputs(self, dtype, device) -> None:
1697+
box1 = torch.randn([0, 4], dtype=dtype).requires_grad_()
1698+
box2 = torch.randn([0, 4], dtype=dtype).requires_grad_()
1699+
1700+
loss = ops.complete_box_iou_loss(box1, box2, reduction="mean")
1701+
loss.backward()
1702+
1703+
tol = 1e-3 if dtype is torch.half else 1e-5
1704+
torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol)
1705+
assert box1.grad is not None, "box1.grad should not be None after backward is called"
1706+
assert box2.grad is not None, "box2.grad should not be None after backward is called"
1707+
1708+
loss = ops.complete_box_iou_loss(box1, box2, reduction="none")
1709+
assert loss.numel() == 0, "ciou_loss for two empty box should be empty"
1710+
1711+
16261712
if __name__ == "__main__":
16271713
pytest.main([__file__])

torchvision/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
box_area,
88
box_iou,
99
generalized_box_iou,
10+
complete_box_iou,
1011
masks_to_boxes,
1112
)
1213
from .boxes import box_convert
14+
from .ciou_loss import complete_box_iou_loss
1315
from .deform_conv import deform_conv2d, DeformConv2d
1416
from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d
1517
from .feature_pyramid_network import FeaturePyramidNetwork

torchvision/ops/boxes.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,54 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
311311
return iou - (areai - union) / areai
312312

313313

314+
def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
315+
"""
316+
Return complete intersection-over-union (Jaccard index) between two sets of boxes.
317+
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
318+
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
319+
Args:
320+
boxes1 (Tensor[N, 4]): first set of boxes
321+
boxes2 (Tensor[M, 4]): second set of boxes
322+
eps (float, optional): small number to prevent division by zero. Default: 1e-7
323+
Returns:
324+
Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values
325+
for every element in boxes1 and boxes2
326+
"""
327+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
328+
_log_api_usage_once(complete_box_iou)
329+
330+
boxes1 = _upcast(boxes1)
331+
boxes2 = _upcast(boxes2)
332+
333+
inter, union = _box_inter_union(boxes1, boxes2)
334+
iou = inter / union
335+
336+
lti = torch.min(boxes1[:, None, :2], boxes2[:, None, :2])
337+
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, None, 2:])
338+
339+
whi = (rbi - lti).clamp(min=0) # [N,M,2]
340+
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
341+
342+
# centers of boxes
343+
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
344+
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
345+
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
346+
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
347+
# The distance between boxes' centers squared.
348+
centers_distance_squared = (x_p - x_g) ** 2 + (y_p - y_g) ** 2
349+
350+
w_pred = boxes1[:, 2] - boxes1[:, 0]
351+
h_pred = boxes1[:, 3] - boxes1[:, 1]
352+
353+
w_gt = boxes2[:, 2] - boxes2[:, 0]
354+
h_gt = boxes2[:, 3] - boxes2[:, 1]
355+
356+
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
357+
with torch.no_grad():
358+
alpha = v / (1 - iou + v + eps)
359+
return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v
360+
361+
314362
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
315363
"""
316364
Compute the bounding boxes around the provided masks.

torchvision/ops/ciou_loss.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
3+
from ..utils import _log_api_usage_once
4+
from .giou_loss import _upcast
5+
6+
7+
def complete_box_iou_loss(
8+
boxes1: torch.Tensor,
9+
boxes2: torch.Tensor,
10+
reduction: str = "none",
11+
eps: float = 1e-7,
12+
) -> torch.Tensor:
13+
14+
"""
15+
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
16+
boxes do not overlap overlap area, This loss function considers important geometrical
17+
factors such as overlap area, normalized central point distance and aspect ratio.
18+
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
19+
20+
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
21+
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the
22+
same dimensions.
23+
24+
Args:
25+
boxes1 : (Tensor[N, 4] or Tensor[4]) first set of boxes
26+
boxes2 : (Tensor[N, 4] or Tensor[4]) second set of boxes
27+
reduction : (string, optional) Specifies the reduction to apply to the output:
28+
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
29+
applied to the output. ``'mean'``: The output will be averaged.
30+
``'sum'``: The output will be summed. Default: ``'none'``
31+
eps : (float): small number to prevent division by zero. Default: 1e-7
32+
33+
Reference:
34+
35+
Complete Intersection over Union Loss (Zhaohui Zheng et. al)
36+
https://arxiv.org/abs/1911.08287
37+
38+
"""
39+
40+
# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
41+
42+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
43+
_log_api_usage_once(complete_box_iou_loss)
44+
45+
boxes1 = _upcast(boxes1)
46+
boxes2 = _upcast(boxes2)
47+
48+
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
49+
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
50+
51+
# Intersection keypoints
52+
xkis1 = torch.max(x1, x1g)
53+
ykis1 = torch.max(y1, y1g)
54+
xkis2 = torch.min(x2, x2g)
55+
ykis2 = torch.min(y2, y2g)
56+
57+
intsct = torch.zeros_like(x1)
58+
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
59+
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
60+
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
61+
iou = intsct / union
62+
63+
# smallest enclosing box
64+
xc1 = torch.min(x1, x1g)
65+
yc1 = torch.min(y1, y1g)
66+
xc2 = torch.max(x2, x2g)
67+
yc2 = torch.max(y2, y2g)
68+
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
69+
70+
# centers of boxes
71+
x_p = (x2 + x1) / 2
72+
y_p = (y2 + y1) / 2
73+
x_g = (x1g + x2g) / 2
74+
y_g = (y1g + y2g) / 2
75+
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
76+
77+
# width and height of boxes
78+
w_pred = x2 - x1
79+
h_pred = y2 - y1
80+
w_gt = x2g - x1g
81+
h_gt = y2g - y1g
82+
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
83+
with torch.no_grad():
84+
alpha = v / (1 - iou + v + eps)
85+
86+
loss = 1 - iou + (distance / diag_len) + alpha * v
87+
if reduction == "mean":
88+
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
89+
elif reduction == "sum":
90+
loss = loss.sum()
91+
92+
return loss

0 commit comments

Comments
 (0)