1
+ import math
1
2
from collections import OrderedDict
2
3
3
4
import torch
4
5
from torch import nn
5
6
import torch .nn .functional as F
7
+ from torch .jit .annotations import Dict , List , Tuple
6
8
7
9
from ..utils import load_state_dict_from_url
8
10
@@ -36,13 +38,62 @@ def __init__(self, in_channels, num_anchors, num_classes):
36
38
def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
37
39
return {
38
40
'classification' : self .classification_head .compute_loss (targets , head_outputs , anchors , matched_idxs ),
39
- 'bbox_reg ' : self .regression_head .compute_loss (targets , head_outputs , anchors , matched_idxs ),
41
+ 'bbox_regression ' : self .regression_head .compute_loss (targets , head_outputs , anchors , matched_idxs ),
40
42
}
41
43
42
44
def forward (self , x ):
43
- logits = [self .classification_head (feature ) for feature in x ]
45
+ cls_logits = [self .classification_head (feature ) for feature in x ]
44
46
bbox_reg = [self .regression_head (feature ) for feature in x ]
45
- return dict (logits = logits , bbox_reg = bbox_reg )
47
+ return dict (cls_logits = cls_logits , bbox_reg = bbox_reg )
48
+
49
+
50
+ def sigmoid_focal_loss (
51
+ inputs ,
52
+ targets ,
53
+ alpha : float = 0.25 ,
54
+ gamma : float = 2 ,
55
+ reduction : str = "none" ,
56
+ ):
57
+ """
58
+ Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
59
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
60
+ Args:
61
+ inputs: A float tensor of arbitrary shape.
62
+ The predictions for each example.
63
+ targets: A float tensor with the same shape as inputs. Stores the binary
64
+ classification label for each element in inputs
65
+ (0 for the negative class and 1 for the positive class).
66
+ alpha: (optional) Weighting factor in range (0,1) to balance
67
+ positive vs negative examples or -1 for ignore. Default = 0.25
68
+ gamma: Exponent of the modulating factor (1 - p_t) to
69
+ balance easy vs hard examples.
70
+ reduction: 'none' | 'mean' | 'sum'
71
+ 'none': No reduction will be applied to the output.
72
+ 'mean': The output will be averaged.
73
+ 'sum': The output will be summed.
74
+ Returns:
75
+ Loss tensor with the reduction option applied.
76
+ """
77
+ p = torch .sigmoid (inputs )
78
+ ce_loss = F .binary_cross_entropy_with_logits (
79
+ inputs , targets , reduction = "none"
80
+ )
81
+ p_t = p * targets + (1 - p ) * (1 - targets )
82
+ loss = ce_loss * ((1 - p_t ) ** gamma )
83
+
84
+ if alpha >= 0 :
85
+ alpha_t = alpha * targets + (1 - alpha ) * (1 - targets )
86
+ loss = alpha_t * loss
87
+
88
+ if reduction == "mean" :
89
+ loss = loss .mean ()
90
+ elif reduction == "sum" :
91
+ loss = loss .sum ()
92
+
93
+ return loss
94
+
95
+
96
+ sigmoid_focal_loss_jit = torch .jit .script (sigmoid_focal_loss )
46
97
47
98
48
99
class RetinaNetClassificationHead (nn .Module ):
@@ -55,21 +106,59 @@ class RetinaNetClassificationHead(nn.Module):
55
106
num_classes (int): number of classes to be predicted
56
107
"""
57
108
58
- def __init__ (self , in_channels , num_anchors , num_classes ):
109
+ def __init__ (self , in_channels , num_anchors , num_classes , prior_probability = 0.01 ):
59
110
super (RetinaNetClassificationHead , self ).__init__ ()
60
111
self .conv1 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
61
112
self .conv2 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
62
113
self .conv3 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
63
114
self .conv4 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
64
- self .cls_logits = nn .Conv2d (in_channels , num_anchors * num_classes , kernel_size = 3 , stride = 1 )
65
115
66
116
for l in self .children ():
67
117
torch .nn .init .normal_ (l .weight , std = 0.01 )
68
118
torch .nn .init .constant_ (l .bias , 0 )
69
119
120
+ self .cls_logits = nn .Conv2d (in_channels , num_anchors * num_classes , kernel_size = 3 , stride = 1 , padding = 1 )
121
+ torch .nn .init .normal_ (self .cls_logits .weight , std = 0.01 )
122
+ torch .nn .init .constant_ (self .cls_logits .bias , - math .log ((1 - prior_probability ) / prior_probability ))
123
+
124
+ self .num_classes = num_classes
125
+ self .num_anchors = num_anchors
126
+
70
127
def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
71
- # TODO Implement focal loss, is there an existing function for this?
72
- return 0
128
+ loss = []
129
+
130
+ def permute_classification (tensor ):
131
+ """ Permute classification output from (N, A * K, H, W) to (N, HWA, K). """
132
+ N , _ , H , W = tensor .shape
133
+ tensor = tensor .view (N , - 1 , self .num_classes , H , W )
134
+ tensor = tensor .permute (0 , 3 , 4 , 1 , 2 )
135
+ tensor = tensor .reshape (N , - 1 , self .num_classes ) # Size=(N, HWA, 4)
136
+ return tensor
137
+
138
+ predicted_classification = head_outputs ['cls_logits' ]
139
+ predicted_classification = [permute_classification (cls ) for cls in predicted_classification ]
140
+ predicted_classification = torch .cat (predicted_classification , dim = 1 )
141
+
142
+ for targets_per_image , predicted_classification_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , predicted_classification , anchors , matched_idxs ):
143
+ # determine only the foreground
144
+ foreground_idxs_per_image = matched_idxs_per_image >= 0
145
+ num_foreground = foreground_idxs_per_image .sum ()
146
+
147
+ # create the target classification
148
+ gt_classes_target = torch .zeros_like (predicted_classification_per_image )
149
+ gt_classes_target [foreground_idxs_per_image , targets_per_image ['labels' ][matched_idxs_per_image [foreground_idxs_per_image ]]] = 1
150
+
151
+ # find indices for which anchors should be ignored
152
+ valid_idxs_per_image = matched_idxs_per_image != det_utils .Matcher .BETWEEN_THRESHOLDS
153
+
154
+ # compute the classification loss
155
+ loss .append (sigmoid_focal_loss_jit (
156
+ predicted_classification_per_image [valid_idxs_per_image ],
157
+ gt_classes_target [valid_idxs_per_image ],
158
+ reduction = 'sum' ,
159
+ ) / max (1 , num_foreground ))
160
+
161
+ return sum (loss ) / len (loss )
73
162
74
163
def forward (self , x ):
75
164
x = F .relu (self .conv1 (x ))
@@ -94,18 +183,29 @@ def __init__(self, in_channels, num_anchors):
94
183
self .conv2 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
95
184
self .conv3 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
96
185
self .conv4 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
97
- self .bbox_reg = nn .Conv2d (in_channels , num_anchors * 4 , kernel_size = 3 , stride = 1 )
186
+ self .bbox_reg = nn .Conv2d (in_channels , num_anchors * 4 , kernel_size = 3 , stride = 1 , padding = 1 )
98
187
99
188
for l in self .children ():
100
189
torch .nn .init .normal_ (l .weight , std = 0.01 )
101
- torch .nn .init .constant_ (l .bias , 0 )
190
+ torch .nn .init .zeros_ (l .bias )
102
191
103
192
self .box_coder = det_utils .BoxCoder (weights = (1.0 , 1.0 , 1.0 , 1.0 ))
104
193
105
194
def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
106
195
loss = []
107
196
108
- predicted_regression = head_outputs ['bbox_reg' ][0 ]
197
+ def permute_bbox_reg (tensor ):
198
+ """ Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). """
199
+ N , _ , H , W = tensor .shape
200
+ tensor = tensor .view (N , - 1 , 4 , H , W )
201
+ tensor = tensor .permute (0 , 3 , 4 , 1 , 2 )
202
+ tensor = tensor .reshape (N , - 1 , 4 ) # Size=(N, HWA, 4)
203
+ return tensor
204
+
205
+ predicted_regression = head_outputs ['bbox_reg' ]
206
+ predicted_regression = [permute_bbox_reg (reg ) for reg in predicted_regression ]
207
+ predicted_regression = torch .cat (predicted_regression , dim = 1 )
208
+
109
209
for targets_per_image , predicted_regression_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , predicted_regression , anchors , matched_idxs ):
110
210
# get the targets corresponding GT for each proposal
111
211
# NB: need to clamp the indices because we can have a single
@@ -115,20 +215,20 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
115
215
116
216
# determine only the foreground indices, ignore the rest
117
217
foreground_idxs_per_image = matched_idxs_per_image >= 0
218
+ num_foreground = foreground_idxs_per_image .sum ()
118
219
119
220
# select only the foreground boxes
120
221
matched_gt_boxes_per_image = matched_gt_boxes_per_image [foreground_idxs_per_image , :]
121
- print (predicted_regression_per_image .shape )
122
- predicted_regression_per_image = predicted_regression_per_image ['bbox_reg' ][foreground_idxs_per_image , :]
222
+ predicted_regression_per_image = predicted_regression_per_image [foreground_idxs_per_image , :]
123
223
anchors_per_image = anchors_per_image [foreground_idxs_per_image , :]
124
224
125
225
# compute the regression targets
126
- target_regression = self .box_coder .encode (matched_gt_boxes_per_image , anchors_per_image )
226
+ target_regression = self .box_coder .encode_single (matched_gt_boxes_per_image , anchors_per_image )
127
227
128
228
# compute the loss
129
- loss .append (torch .nn .SmoothL1Loss ()(predicted_regression_per_image , target_regression ))
229
+ loss .append (torch .nn .SmoothL1Loss (reduction = 'sum' )(predicted_regression_per_image , target_regression ) / max ( 1 , num_foreground ))
130
230
131
- return sum (loss ) / len (loss )
231
+ return sum (loss ) / max ( 1 , len (loss ) )
132
232
133
233
def forward (self , x ):
134
234
x = F .relu (self .conv1 (x ))
@@ -251,7 +351,7 @@ def __init__(self, backbone, num_classes,
251
351
self .anchor_generator = anchor_generator
252
352
253
353
if head is None :
254
- head = RetinaNetHead (backbone .out_channels , num_classes , anchor_generator .num_anchors_per_location ()[0 ])
354
+ head = RetinaNetHead (backbone .out_channels , anchor_generator .num_anchors_per_location ()[0 ], num_classes )
255
355
self .head = head
256
356
257
357
if proposal_matcher is None :
0 commit comments