diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 3e996ecd6c8..82d8d0a84f0 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -19,6 +19,8 @@ Operators box_convert box_iou clip_boxes_to_image + complete_box_iou + complete_box_iou_loss deform_conv2d drop_block2d drop_block3d diff --git a/test/test_ops.py b/test/test_ops.py index c61922204a3..c546710b271 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1258,6 +1258,43 @@ def test_giou_jit(self) -> None: self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) +class TestCompleteBoxIou(BoxTestBase): + def _target_fn(self) -> Tuple[bool, Callable]: + return (True, ops.complete_box_iou) + + def _generate_int_input() -> List[List[int]]: + return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] + + def _generate_int_expected() -> List[List[float]]: + return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + + def _generate_float_input() -> List[List[float]]: + return [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] + + def _generate_float_expected() -> List[List[float]]: + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), + ], + ) + def test_complete_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: + self._run_test(test_input, dtypes, tolerance, expected) + + def test_ciou_jit(self) -> None: + self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + + class TestMasksToBoxes: def test_masks_box(self): def masks_box_check(masks, expected, tolerance=1e-4): @@ -1578,6 +1615,7 @@ def test_giou_loss(self, dtype, device) -> None: box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) + box1s = torch.stack([box2, box2], dim=0) box2s = torch.stack([box3, box4], dim=0) @@ -1623,5 +1661,53 @@ def test_empty_inputs(self, dtype, device) -> None: assert loss.numel() == 0, "giou_loss for two empty box should be empty" +class TestCIOULoss: + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_ciou_loss(self, dtype, device): + box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) + box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) + box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) + box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) + + box1s = torch.stack([box2, box2], dim=0) + box2s = torch.stack([box3, box4], dim=0) + + def assert_ciou_loss(box1, box2, expected_output, reduction="none"): + + output = ops.complete_box_iou_loss(box1, box2, reduction=reduction) + expected_output = torch.tensor(expected_output, device=device) + tol = 1e-5 if dtype != torch.half else 1e-3 + torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) + + assert_ciou_loss(box1, box1, 0.0) + + assert_ciou_loss(box1, box2, 0.8125) + + assert_ciou_loss(box1, box3, 1.1923) + + assert_ciou_loss(box1, box4, 1.2500) + + assert_ciou_loss(box1s, box2s, 1.2250, reduction="mean") + assert_ciou_loss(box1s, box2s, 2.4500, reduction="sum") + + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_inputs(self, dtype, device) -> None: + box1 = torch.randn([0, 4], dtype=dtype).requires_grad_() + box2 = torch.randn([0, 4], dtype=dtype).requires_grad_() + + loss = ops.complete_box_iou_loss(box1, box2, reduction="mean") + loss.backward() + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) + assert box1.grad is not None, "box1.grad should not be None after backward is called" + assert box2.grad is not None, "box2.grad should not be None after backward is called" + + loss = ops.complete_box_iou_loss(box1, box2, reduction="none") + assert loss.numel() == 0, "ciou_loss for two empty box should be empty" + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index ceb78250415..9d99db7125c 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -7,9 +7,11 @@ box_area, box_iou, generalized_box_iou, + complete_box_iou, masks_to_boxes, ) from .boxes import box_convert +from .ciou_loss import complete_box_iou_loss from .deform_conv import deform_conv2d, DeformConv2d from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d from .feature_pyramid_network import FeaturePyramidNetwork diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 23c1001438c..3239ba0d60a 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -311,6 +311,54 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: return iou - (areai - union) / areai +def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor: + """ + Return complete intersection-over-union (Jaccard index) between two sets of boxes. + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + eps (float, optional): small number to prevent division by zero. Default: 1e-7 + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values + for every element in boxes1 and boxes2 + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(complete_box_iou) + + boxes1 = _upcast(boxes1) + boxes2 = _upcast(boxes2) + + inter, union = _box_inter_union(boxes1, boxes2) + iou = inter / union + + lti = torch.min(boxes1[:, None, :2], boxes2[:, None, :2]) + rbi = torch.max(boxes1[:, None, 2:], boxes2[:, None, 2:]) + + whi = (rbi - lti).clamp(min=0) # [N,M,2] + diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps + + # centers of boxes + x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2 + y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2 + x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2 + y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 + # The distance between boxes' centers squared. + centers_distance_squared = (x_p - x_g) ** 2 + (y_p - y_g) ** 2 + + w_pred = boxes1[:, 2] - boxes1[:, 0] + h_pred = boxes1[:, 3] - boxes1[:, 1] + + w_gt = boxes2[:, 2] - boxes2[:, 0] + h_gt = boxes2[:, 3] - boxes2[:, 1] + + v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) + with torch.no_grad(): + alpha = v / (1 - iou + v + eps) + return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v + + def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: """ Compute the bounding boxes around the provided masks. diff --git a/torchvision/ops/ciou_loss.py b/torchvision/ops/ciou_loss.py new file mode 100644 index 00000000000..d53e2d6af2a --- /dev/null +++ b/torchvision/ops/ciou_loss.py @@ -0,0 +1,92 @@ +import torch + +from ..utils import _log_api_usage_once +from .giou_loss import _upcast + + +def complete_box_iou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + reduction: str = "none", + eps: float = 1e-7, +) -> torch.Tensor: + + """ + Gradient-friendly IoU loss with an additional penalty that is non-zero when the + boxes do not overlap overlap area, This loss function considers important geometrical + factors such as overlap area, normalized central point distance and aspect ratio. + This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the + same dimensions. + + Args: + boxes1 : (Tensor[N, 4] or Tensor[4]) first set of boxes + boxes2 : (Tensor[N, 4] or Tensor[4]) second set of boxes + reduction : (string, optional) Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be + applied to the output. ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'`` + eps : (float): small number to prevent division by zero. Default: 1e-7 + + Reference: + + Complete Intersection over Union Loss (Zhaohui Zheng et. al) + https://arxiv.org/abs/1911.08287 + + """ + + # Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(complete_box_iou_loss) + + boxes1 = _upcast(boxes1) + boxes2 = _upcast(boxes2) + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsct = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps + iou = intsct / union + + # smallest enclosing box + xc1 = torch.min(x1, x1g) + yc1 = torch.min(y1, y1g) + xc2 = torch.max(x2, x2g) + yc2 = torch.max(y2, y2g) + diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps + + # centers of boxes + x_p = (x2 + x1) / 2 + y_p = (y2 + y1) / 2 + x_g = (x1g + x2g) / 2 + y_g = (y1g + y2g) / 2 + distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) + + # width and height of boxes + w_pred = x2 - x1 + h_pred = y2 - y1 + w_gt = x2g - x1g + h_gt = y2g - y1g + v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) + with torch.no_grad(): + alpha = v / (1 - iou + v + eps) + + loss = 1 - iou + (distance / diag_len) + alpha * v + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + + return loss