Skip to content

Commit 38d4d67

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Adding ciou and diou support in _box_loss() (#5984)
Summary: * Adding ciou and diou support in `_box_loss()` * Fix linter * Addressing comments for nits Reviewed By: datumbox Differential Revision: D36413352 fbshipit-source-id: 8a53c68ff2a58966f02edae7f210c8db5240a4d2
1 parent f1d007d commit 38d4d67

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

torchvision/models/detection/_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch import Tensor, nn
77
from torch.nn import functional as F
8-
from torchvision.ops import FrozenBatchNorm2d, generalized_box_iou_loss
8+
from torchvision.ops import FrozenBatchNorm2d, complete_box_iou_loss, distance_box_iou_loss, generalized_box_iou_loss
99

1010

1111
class BalancedPositiveNegativeSampler:
@@ -518,7 +518,7 @@ def _box_loss(
518518
bbox_regression_per_image: Tensor,
519519
cnf: Optional[Dict[str, float]] = None,
520520
) -> Tensor:
521-
torch._assert(type in ["l1", "smooth_l1", "giou"], f"Unsupported loss: {type}")
521+
torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
522522

523523
if type == "l1":
524524
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
@@ -527,7 +527,12 @@ def _box_loss(
527527
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
528528
beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
529529
return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
530-
else: # giou
530+
else:
531531
bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
532532
eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
533+
if type == "ciou":
534+
return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
535+
if type == "diou":
536+
return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
537+
# otherwise giou
533538
return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)

0 commit comments

Comments
 (0)