1+ import math
12from collections import OrderedDict
23
34import torch
45from torch import nn
56import torch .nn .functional as F
7+ from torch .jit .annotations import Dict , List , Tuple
68
79from ..utils import load_state_dict_from_url
810
@@ -36,13 +38,62 @@ def __init__(self, in_channels, num_anchors, num_classes):
3638 def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
3739 return {
3840 'classification' : self .classification_head .compute_loss (targets , head_outputs , anchors , matched_idxs ),
39- 'bbox_reg ' : self .regression_head .compute_loss (targets , head_outputs , anchors , matched_idxs ),
41+ 'bbox_regression ' : self .regression_head .compute_loss (targets , head_outputs , anchors , matched_idxs ),
4042 }
4143
4244 def forward (self , x ):
43- logits = [self .classification_head (feature ) for feature in x ]
45+ cls_logits = [self .classification_head (feature ) for feature in x ]
4446 bbox_reg = [self .regression_head (feature ) for feature in x ]
45- return dict (logits = logits , bbox_reg = bbox_reg )
47+ return dict (cls_logits = cls_logits , bbox_reg = bbox_reg )
48+
49+
50+ def sigmoid_focal_loss (
51+ inputs ,
52+ targets ,
53+ alpha : float = 0.25 ,
54+ gamma : float = 2 ,
55+ reduction : str = "none" ,
56+ ):
57+ """
58+ Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
59+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
60+ Args:
61+ inputs: A float tensor of arbitrary shape.
62+ The predictions for each example.
63+ targets: A float tensor with the same shape as inputs. Stores the binary
64+ classification label for each element in inputs
65+ (0 for the negative class and 1 for the positive class).
66+ alpha: (optional) Weighting factor in range (0,1) to balance
67+ positive vs negative examples or -1 for ignore. Default = 0.25
68+ gamma: Exponent of the modulating factor (1 - p_t) to
69+ balance easy vs hard examples.
70+ reduction: 'none' | 'mean' | 'sum'
71+ 'none': No reduction will be applied to the output.
72+ 'mean': The output will be averaged.
73+ 'sum': The output will be summed.
74+ Returns:
75+ Loss tensor with the reduction option applied.
76+ """
77+ p = torch .sigmoid (inputs )
78+ ce_loss = F .binary_cross_entropy_with_logits (
79+ inputs , targets , reduction = "none"
80+ )
81+ p_t = p * targets + (1 - p ) * (1 - targets )
82+ loss = ce_loss * ((1 - p_t ) ** gamma )
83+
84+ if alpha >= 0 :
85+ alpha_t = alpha * targets + (1 - alpha ) * (1 - targets )
86+ loss = alpha_t * loss
87+
88+ if reduction == "mean" :
89+ loss = loss .mean ()
90+ elif reduction == "sum" :
91+ loss = loss .sum ()
92+
93+ return loss
94+
95+
96+ sigmoid_focal_loss_jit = torch .jit .script (sigmoid_focal_loss )
4697
4798
4899class RetinaNetClassificationHead (nn .Module ):
@@ -55,21 +106,59 @@ class RetinaNetClassificationHead(nn.Module):
55106 num_classes (int): number of classes to be predicted
56107 """
57108
58- def __init__ (self , in_channels , num_anchors , num_classes ):
109+ def __init__ (self , in_channels , num_anchors , num_classes , prior_probability = 0.01 ):
59110 super (RetinaNetClassificationHead , self ).__init__ ()
60111 self .conv1 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
61112 self .conv2 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
62113 self .conv3 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
63114 self .conv4 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
64- self .cls_logits = nn .Conv2d (in_channels , num_anchors * num_classes , kernel_size = 3 , stride = 1 )
65115
66116 for l in self .children ():
67117 torch .nn .init .normal_ (l .weight , std = 0.01 )
68118 torch .nn .init .constant_ (l .bias , 0 )
69119
120+ self .cls_logits = nn .Conv2d (in_channels , num_anchors * num_classes , kernel_size = 3 , stride = 1 , padding = 1 )
121+ torch .nn .init .normal_ (self .cls_logits .weight , std = 0.01 )
122+ torch .nn .init .constant_ (self .cls_logits .bias , - math .log ((1 - prior_probability ) / prior_probability ))
123+
124+ self .num_classes = num_classes
125+ self .num_anchors = num_anchors
126+
70127 def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
71- # TODO Implement focal loss, is there an existing function for this?
72- return 0
128+ loss = []
129+
130+ def permute_classification (tensor ):
131+ """ Permute classification output from (N, A * K, H, W) to (N, HWA, K). """
132+ N , _ , H , W = tensor .shape
133+ tensor = tensor .view (N , - 1 , self .num_classes , H , W )
134+ tensor = tensor .permute (0 , 3 , 4 , 1 , 2 )
135+ tensor = tensor .reshape (N , - 1 , self .num_classes ) # Size=(N, HWA, 4)
136+ return tensor
137+
138+ predicted_classification = head_outputs ['cls_logits' ]
139+ predicted_classification = [permute_classification (cls ) for cls in predicted_classification ]
140+ predicted_classification = torch .cat (predicted_classification , dim = 1 )
141+
142+ for targets_per_image , predicted_classification_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , predicted_classification , anchors , matched_idxs ):
143+ # determine only the foreground
144+ foreground_idxs_per_image = matched_idxs_per_image >= 0
145+ num_foreground = foreground_idxs_per_image .sum ()
146+
147+ # create the target classification
148+ gt_classes_target = torch .zeros_like (predicted_classification_per_image )
149+ gt_classes_target [foreground_idxs_per_image , targets_per_image ['labels' ][matched_idxs_per_image [foreground_idxs_per_image ]]] = 1
150+
151+ # find indices for which anchors should be ignored
152+ valid_idxs_per_image = matched_idxs_per_image != det_utils .Matcher .BETWEEN_THRESHOLDS
153+
154+ # compute the classification loss
155+ loss .append (sigmoid_focal_loss_jit (
156+ predicted_classification_per_image [valid_idxs_per_image ],
157+ gt_classes_target [valid_idxs_per_image ],
158+ reduction = 'sum' ,
159+ ) / max (1 , num_foreground ))
160+
161+ return sum (loss ) / len (loss )
73162
74163 def forward (self , x ):
75164 x = F .relu (self .conv1 (x ))
@@ -94,18 +183,29 @@ def __init__(self, in_channels, num_anchors):
94183 self .conv2 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
95184 self .conv3 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
96185 self .conv4 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
97- self .bbox_reg = nn .Conv2d (in_channels , num_anchors * 4 , kernel_size = 3 , stride = 1 )
186+ self .bbox_reg = nn .Conv2d (in_channels , num_anchors * 4 , kernel_size = 3 , stride = 1 , padding = 1 )
98187
99188 for l in self .children ():
100189 torch .nn .init .normal_ (l .weight , std = 0.01 )
101- torch .nn .init .constant_ (l .bias , 0 )
190+ torch .nn .init .zeros_ (l .bias )
102191
103192 self .box_coder = det_utils .BoxCoder (weights = (1.0 , 1.0 , 1.0 , 1.0 ))
104193
105194 def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
106195 loss = []
107196
108- predicted_regression = head_outputs ['bbox_reg' ][0 ]
197+ def permute_bbox_reg (tensor ):
198+ """ Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). """
199+ N , _ , H , W = tensor .shape
200+ tensor = tensor .view (N , - 1 , 4 , H , W )
201+ tensor = tensor .permute (0 , 3 , 4 , 1 , 2 )
202+ tensor = tensor .reshape (N , - 1 , 4 ) # Size=(N, HWA, 4)
203+ return tensor
204+
205+ predicted_regression = head_outputs ['bbox_reg' ]
206+ predicted_regression = [permute_bbox_reg (reg ) for reg in predicted_regression ]
207+ predicted_regression = torch .cat (predicted_regression , dim = 1 )
208+
109209 for targets_per_image , predicted_regression_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , predicted_regression , anchors , matched_idxs ):
110210 # get the targets corresponding GT for each proposal
111211 # NB: need to clamp the indices because we can have a single
@@ -115,20 +215,20 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
115215
116216 # determine only the foreground indices, ignore the rest
117217 foreground_idxs_per_image = matched_idxs_per_image >= 0
218+ num_foreground = foreground_idxs_per_image .sum ()
118219
119220 # select only the foreground boxes
120221 matched_gt_boxes_per_image = matched_gt_boxes_per_image [foreground_idxs_per_image , :]
121- print (predicted_regression_per_image .shape )
122- predicted_regression_per_image = predicted_regression_per_image ['bbox_reg' ][foreground_idxs_per_image , :]
222+ predicted_regression_per_image = predicted_regression_per_image [foreground_idxs_per_image , :]
123223 anchors_per_image = anchors_per_image [foreground_idxs_per_image , :]
124224
125225 # compute the regression targets
126- target_regression = self .box_coder .encode (matched_gt_boxes_per_image , anchors_per_image )
226+ target_regression = self .box_coder .encode_single (matched_gt_boxes_per_image , anchors_per_image )
127227
128228 # compute the loss
129- loss .append (torch .nn .SmoothL1Loss ()(predicted_regression_per_image , target_regression ))
229+ loss .append (torch .nn .SmoothL1Loss (reduction = 'sum' )(predicted_regression_per_image , target_regression ) / max ( 1 , num_foreground ))
130230
131- return sum (loss ) / len (loss )
231+ return sum (loss ) / max ( 1 , len (loss ) )
132232
133233 def forward (self , x ):
134234 x = F .relu (self .conv1 (x ))
@@ -251,7 +351,7 @@ def __init__(self, backbone, num_classes,
251351 self .anchor_generator = anchor_generator
252352
253353 if head is None :
254- head = RetinaNetHead (backbone .out_channels , num_classes , anchor_generator .num_anchors_per_location ()[0 ])
354+ head = RetinaNetHead (backbone .out_channels , anchor_generator .num_anchors_per_location ()[0 ], num_classes )
255355 self .head = head
256356
257357 if proposal_matcher is None :
0 commit comments