@@ -29,19 +29,19 @@ class RetinaNetHead(nn.Module):
29
29
"""
30
30
31
31
def __init__ (self , in_channels , num_anchors , num_classes ):
32
- super (RPNHead , self ).__init__ ()
32
+ super (RetinaNetHead , self ).__init__ ()
33
33
self .classification_head = RetinaNetClassificationHead (in_channels , num_anchors , num_classes )
34
34
self .regression_head = RetinaNetRegressionHead (in_channels , num_anchors )
35
35
36
- def compute_loss (self , outputs , labels , matched_gt_boxes ):
36
+ def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
37
37
return {
38
- 'classification' : self .classification_head .compute_loss (outputs , targets , anchor_state ),
39
- 'regression ' : self .regression_head .compute_loss (outputs , targets , anchor_state ),
38
+ 'classification' : self .classification_head .compute_loss (targets , head_outputs , anchors , matched_idxs ),
39
+ 'bbox_reg ' : self .regression_head .compute_loss (targets , head_outputs , anchors , matched_idxs ),
40
40
}
41
41
42
42
def forward (self , x ):
43
- logits = [self .classification_head (feature , targets ) for feature in x ]
44
- bbox_reg = [self .regression_head (feature , targets ) for feature in x ]
43
+ logits = [self .classification_head (feature ) for feature in x ]
44
+ bbox_reg = [self .regression_head (feature ) for feature in x ]
45
45
return dict (logits = logits , bbox_reg = bbox_reg )
46
46
47
47
@@ -56,7 +56,7 @@ class RetinaNetClassificationHead(nn.Module):
56
56
"""
57
57
58
58
def __init__ (self , in_channels , num_anchors , num_classes ):
59
- super (RPNHead , self ).__init__ ()
59
+ super (RetinaNetClassificationHead , self ).__init__ ()
60
60
self .conv1 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
61
61
self .conv2 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
62
62
self .conv3 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
@@ -67,19 +67,16 @@ def __init__(self, in_channels, num_anchors, num_classes):
67
67
torch .nn .init .normal_ (l .weight , std = 0.01 )
68
68
torch .nn .init .constant_ (l .bias , 0 )
69
69
70
- def compute_loss (self , outputs , labels , matched_gt_boxes ):
70
+ def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
71
71
# TODO Implement focal loss, is there an existing function for this?
72
72
return 0
73
73
74
74
def forward (self , x ):
75
- logits = []
76
- for feature in x :
77
- t = F .relu (self .conv1 (feature ))
78
- t = F .relu (self .conv2 (t ))
79
- t = F .relu (self .conv3 (t ))
80
- t = F .relu (self .conv4 (t ))
81
- logits .append (self .cls_logits (t ))
82
- return logits
75
+ x = F .relu (self .conv1 (x ))
76
+ x = F .relu (self .conv2 (x ))
77
+ x = F .relu (self .conv3 (x ))
78
+ x = F .relu (self .conv4 (x ))
79
+ return self .cls_logits (x )
83
80
84
81
85
82
class RetinaNetRegressionHead (nn .Module ):
@@ -92,7 +89,7 @@ class RetinaNetRegressionHead(nn.Module):
92
89
"""
93
90
94
91
def __init__ (self , in_channels , num_anchors ):
95
- super (RPNHead , self ).__init__ ()
92
+ super (RetinaNetRegressionHead , self ).__init__ ()
96
93
self .conv1 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
97
94
self .conv2 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
98
95
self .conv3 = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
@@ -103,19 +100,42 @@ def __init__(self, in_channels, num_anchors):
103
100
torch .nn .init .normal_ (l .weight , std = 0.01 )
104
101
torch .nn .init .constant_ (l .bias , 0 )
105
102
106
- def compute_loss (self , outputs , labels , matched_gt_boxes ):
107
- # TODO Use SmoothL1 loss for regression, or just L1 like in rpn.py ?
108
- return 0
103
+ self .box_coder = det_utils .BoxCoder (weights = (1.0 , 1.0 , 1.0 , 1.0 ))
104
+
105
+ def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
106
+ loss = []
107
+
108
+ predicted_regression = head_outputs ['bbox_reg' ][0 ]
109
+ for targets_per_image , predicted_regression_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , predicted_regression , anchors , matched_idxs ):
110
+ # get the targets corresponding GT for each proposal
111
+ # NB: need to clamp the indices because we can have a single
112
+ # GT in the image, and matched_idxs can be -2, which goes
113
+ # out of bounds
114
+ matched_gt_boxes_per_image = targets_per_image ['boxes' ][matched_idxs_per_image .clamp (min = 0 )]
115
+
116
+ # determine only the foreground indices, ignore the rest
117
+ foreground_idxs_per_image = matched_idxs_per_image >= 0
118
+
119
+ # select only the foreground boxes
120
+ matched_gt_boxes_per_image = matched_gt_boxes_per_image [foreground_idxs_per_image , :]
121
+ print (predicted_regression_per_image .shape )
122
+ predicted_regression_per_image = predicted_regression_per_image ['bbox_reg' ][foreground_idxs_per_image , :]
123
+ anchors_per_image = anchors_per_image [foreground_idxs_per_image , :]
124
+
125
+ # compute the regression targets
126
+ target_regression = self .box_coder .encode (matched_gt_boxes_per_image , anchors_per_image )
127
+
128
+ # compute the loss
129
+ loss .append (torch .nn .SmoothL1Loss ()(predicted_regression_per_image , target_regression ))
130
+
131
+ return sum (loss ) / len (loss )
109
132
110
133
def forward (self , x ):
111
- bbox_reg = []
112
- for feature in x :
113
- t = F .relu (self .conv1 (feature ))
114
- t = F .relu (self .conv2 (t ))
115
- t = F .relu (self .conv3 (t ))
116
- t = F .relu (self .conv4 (t ))
117
- bbox_reg .append (self .bbox_reg (t ))
118
- return bbox_reg
134
+ x = F .relu (self .conv1 (x ))
135
+ x = F .relu (self .conv2 (x ))
136
+ x = F .relu (self .conv3 (x ))
137
+ x = F .relu (self .conv4 (x ))
138
+ return self .bbox_reg (x )
119
139
120
140
121
141
class RetinaNet (nn .Module ):
@@ -206,15 +226,18 @@ def __init__(self, backbone, num_classes,
206
226
image_mean = None , image_std = None ,
207
227
# Anchor parameters
208
228
anchor_generator = None , head = None ,
229
+ proposal_matcher = None ,
209
230
pre_nms_top_n = 1000 , post_nms_top_n = 1000 ,
210
231
nms_thresh = 0.5 ,
211
232
fg_iou_thresh = 0.5 , bg_iou_thresh = 0.4 ):
233
+ super (RetinaNet , self ).__init__ ()
212
234
213
235
if not hasattr (backbone , "out_channels" ):
214
236
raise ValueError (
215
237
"backbone should contain an attribute out_channels "
216
238
"specifying the number of output channels (assumed to be the "
217
239
"same for all the levels)" )
240
+ self .backbone = backbone
218
241
219
242
assert isinstance (anchor_generator , (AnchorGenerator , type (None )))
220
243
@@ -231,6 +254,14 @@ def __init__(self, backbone, num_classes,
231
254
head = RetinaNetHead (backbone .out_channels , num_classes , anchor_generator .num_anchors_per_location ()[0 ])
232
255
self .head = head
233
256
257
+ if proposal_matcher is None :
258
+ proposal_matcher = det_utils .Matcher (
259
+ fg_iou_thresh ,
260
+ bg_iou_thresh ,
261
+ allow_low_quality_matches = True ,
262
+ )
263
+ self .proposal_matcher = proposal_matcher
264
+
234
265
if image_mean is None :
235
266
image_mean = [0.485 , 0.456 , 0.406 ]
236
267
if image_std is None :
@@ -245,6 +276,13 @@ def eager_outputs(self, losses, detections):
245
276
246
277
return detections
247
278
279
+ def compute_loss (self , targets , head_outputs , anchors ):
280
+ matched_idxs = []
281
+ for anchors_per_image , targets_per_image in zip (anchors , targets ):
282
+ matched_idxs .append (self .proposal_matcher (targets_per_image ["boxes" ], anchors_per_image ))
283
+
284
+ return self .head .compute_loss (targets , head_outputs , anchors , matched_idxs )
285
+
248
286
def forward (self , images , targets = None ):
249
287
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
250
288
"""
@@ -277,8 +315,17 @@ def forward(self, images, targets=None):
277
315
if isinstance (features , torch .Tensor ):
278
316
features = OrderedDict ([('0' , features )])
279
317
318
+ # TODO: Do we want a list or a dict?
319
+ features = list (features .values ())
320
+
321
+ assert len (features ) == 1 or len (features ) == 6
322
+
323
+ if len (features ) == 6 :
324
+ # skip P2 because it generates too many anchors
325
+ features = features [1 :]
326
+
280
327
# compute the retinanet heads outputs using the features
281
- head_outputs = self .head (images , features , targets )
328
+ head_outputs = self .head (features )
282
329
283
330
# create the set of anchors
284
331
anchors = self .anchor_generator (images , features )
@@ -289,11 +336,7 @@ def forward(self, images, targets=None):
289
336
assert targets is not None
290
337
291
338
# compute the losses
292
- # TODO: Move necessary functions out of rpn.RegionProposalNetwork to a class or function
293
- # so that we can use it here and in rpn.RegionProposalNetwork
294
- labels , matched_gt_boxes = self .assign_targets_to_anchors (anchors , targets )
295
- regression_targets = self .box_coder .encode (matched_gt_boxes , anchors )
296
- losses = self .head .compute_loss (head_outputs , labels , matched_gt_boxes )
339
+ losses = self .compute_loss (targets , head_outputs , anchors )
297
340
else :
298
341
# compute the detections
299
342
# TODO: Implement postprocess_detections
0 commit comments