5
5
import torch
6
6
from torch import Tensor , nn
7
7
from torch .nn import functional as F
8
- from torchvision .ops import FrozenBatchNorm2d , generalized_box_iou_loss
8
+ from torchvision .ops import FrozenBatchNorm2d , complete_box_iou_loss , distance_box_iou_loss , generalized_box_iou_loss
9
9
10
10
11
11
class BalancedPositiveNegativeSampler :
@@ -518,7 +518,7 @@ def _box_loss(
518
518
bbox_regression_per_image : Tensor ,
519
519
cnf : Optional [Dict [str , float ]] = None ,
520
520
) -> Tensor :
521
- torch ._assert (type in ["l1" , "smooth_l1" , "giou" ], f"Unsupported loss: { type } " )
521
+ torch ._assert (type in ["l1" , "smooth_l1" , "ciou" , "diou" , " giou" ], f"Unsupported loss: { type } " )
522
522
523
523
if type == "l1" :
524
524
target_regression = box_coder .encode_single (matched_gt_boxes_per_image , anchors_per_image )
@@ -527,7 +527,12 @@ def _box_loss(
527
527
target_regression = box_coder .encode_single (matched_gt_boxes_per_image , anchors_per_image )
528
528
beta = cnf ["beta" ] if cnf is not None and "beta" in cnf else 1.0
529
529
return F .smooth_l1_loss (bbox_regression_per_image , target_regression , reduction = "sum" , beta = beta )
530
- else : # giou
530
+ else :
531
531
bbox_per_image = box_coder .decode_single (bbox_regression_per_image , anchors_per_image )
532
532
eps = cnf ["eps" ] if cnf is not None and "eps" in cnf else 1e-7
533
+ if type == "ciou" :
534
+ return complete_box_iou_loss (bbox_per_image , matched_gt_boxes_per_image , reduction = "sum" , eps = eps )
535
+ if type == "diou" :
536
+ return distance_box_iou_loss (bbox_per_image , matched_gt_boxes_per_image , reduction = "sum" , eps = eps )
537
+ # otherwise giou
533
538
return generalized_box_iou_loss (bbox_per_image , matched_gt_boxes_per_image , reduction = "sum" , eps = eps )
0 commit comments