From 969ea03427e804f3ff8a319ff6fce8bb12d63a0d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 10 May 2022 15:25:48 +0100 Subject: [PATCH 1/3] Adding ciou and diou support in `_box_loss()` --- torchvision/models/detection/_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index f4c426691c0..865a31ca797 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -5,7 +5,7 @@ import torch from torch import Tensor, nn from torch.nn import functional as F -from torchvision.ops import FrozenBatchNorm2d, generalized_box_iou_loss +from torchvision.ops import FrozenBatchNorm2d, complete_box_iou_loss, distance_box_iou_loss, generalized_box_iou_loss class BalancedPositiveNegativeSampler: @@ -518,7 +518,7 @@ def _box_loss( bbox_regression_per_image: Tensor, cnf: Optional[Dict[str, float]] = None, ) -> Tensor: - torch._assert(type in ["l1", "smooth_l1", "giou"], f"Unsupported loss: {type}") + torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}") if type == "l1": target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) @@ -527,7 +527,12 @@ def _box_loss( target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0 return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta) - else: # giou + else: bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image) eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7 - return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) + if type == "ciou": + return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) + elif type == "diou": + return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) + else: # giou + return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) From 9142674e5f1e0966267b354c3f88e3d3aebf05e7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 10 May 2022 15:59:00 +0100 Subject: [PATCH 2/3] Fix linter --- torchvision/models/detection/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 865a31ca797..13f2490e833 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -534,5 +534,5 @@ def _box_loss( return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) elif type == "diou": return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) - else: # giou + else: # giou return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) From ae41ab2b4fb3535cb02a8477785b230a2439bac5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 11 May 2022 09:31:50 +0100 Subject: [PATCH 3/3] Addressing comments for nits --- torchvision/models/detection/_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 13f2490e833..d808ecffed3 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -532,7 +532,7 @@ def _box_loss( eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7 if type == "ciou": return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) - elif type == "diou": + if type == "diou": return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) - else: # giou - return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) + # otherwise giou + return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)