@@ -42,9 +42,10 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
42
42
}
43
43
44
44
def forward (self , x ):
45
- cls_logits = [self .classification_head (feature ) for feature in x ]
46
- bbox_reg = [self .regression_head (feature ) for feature in x ]
47
- return dict (cls_logits = cls_logits , bbox_reg = bbox_reg )
45
+ return {
46
+ 'cls_logits' : self .classification_head (x ),
47
+ 'bbox_regression' : self .regression_head (x )
48
+ }
48
49
49
50
50
51
def sigmoid_focal_loss (
@@ -127,45 +128,48 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01
127
128
def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
128
129
loss = []
129
130
130
- def permute_classification (tensor ):
131
- """ Permute classification output from (N, A * K, H, W) to (N, HWA, K). """
132
- N , _ , H , W = tensor .shape
133
- tensor = tensor .view (N , - 1 , self .num_classes , H , W )
134
- tensor = tensor .permute (0 , 3 , 4 , 1 , 2 )
135
- tensor = tensor .reshape (N , - 1 , self .num_classes ) # Size=(N, HWA, 4)
136
- return tensor
137
-
138
- predicted_classification = head_outputs ['cls_logits' ]
139
- predicted_classification = [permute_classification (cls ) for cls in predicted_classification ]
140
- predicted_classification = torch .cat (predicted_classification , dim = 1 )
131
+ cls_logits = head_outputs ['cls_logits' ]
141
132
142
- for targets_per_image , predicted_classification_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , predicted_classification , anchors , matched_idxs ):
133
+ for targets_per_image , cls_logits_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , cls_logits , anchors , matched_idxs ):
143
134
# determine only the foreground
144
135
foreground_idxs_per_image = matched_idxs_per_image >= 0
145
136
num_foreground = foreground_idxs_per_image .sum ()
146
137
147
138
# create the target classification
148
- gt_classes_target = torch .zeros_like (predicted_classification_per_image )
139
+ gt_classes_target = torch .zeros_like (cls_logits_per_image )
149
140
gt_classes_target [foreground_idxs_per_image , targets_per_image ['labels' ][matched_idxs_per_image [foreground_idxs_per_image ]]] = 1
150
141
151
142
# find indices for which anchors should be ignored
152
143
valid_idxs_per_image = matched_idxs_per_image != det_utils .Matcher .BETWEEN_THRESHOLDS
153
144
154
145
# compute the classification loss
155
- loss .append (sigmoid_focal_loss_jit (
156
- predicted_classification_per_image [valid_idxs_per_image ],
146
+ loss .append (sigmoid_focal_loss (
147
+ cls_logits_per_image [valid_idxs_per_image ],
157
148
gt_classes_target [valid_idxs_per_image ],
158
149
reduction = 'sum' ,
159
150
) / max (1 , num_foreground ))
160
151
161
152
return sum (loss ) / len (loss )
162
153
163
154
def forward (self , x ):
164
- x = F .relu (self .conv1 (x ))
165
- x = F .relu (self .conv2 (x ))
166
- x = F .relu (self .conv3 (x ))
167
- x = F .relu (self .conv4 (x ))
168
- return self .cls_logits (x )
155
+ all_cls_logits = []
156
+
157
+ for features in x :
158
+ cls_logits = F .relu (self .conv1 (features ))
159
+ cls_logits = F .relu (self .conv2 (cls_logits ))
160
+ cls_logits = F .relu (self .conv3 (cls_logits ))
161
+ cls_logits = F .relu (self .conv4 (cls_logits ))
162
+ cls_logits = self .cls_logits (cls_logits )
163
+
164
+ # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
165
+ N , _ , H , W = cls_logits .shape
166
+ cls_logits = cls_logits .view (N , - 1 , self .num_classes , H , W )
167
+ cls_logits = cls_logits .permute (0 , 3 , 4 , 1 , 2 )
168
+ cls_logits = cls_logits .reshape (N , - 1 , self .num_classes ) # Size=(N, HWA, 4)
169
+
170
+ all_cls_logits .append (cls_logits )
171
+
172
+ return torch .cat (all_cls_logits , dim = 1 )
169
173
170
174
171
175
class RetinaNetRegressionHead (nn .Module ):
@@ -194,19 +198,9 @@ def __init__(self, in_channels, num_anchors):
194
198
def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
195
199
loss = []
196
200
197
- def permute_bbox_reg (tensor ):
198
- """ Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). """
199
- N , _ , H , W = tensor .shape
200
- tensor = tensor .view (N , - 1 , 4 , H , W )
201
- tensor = tensor .permute (0 , 3 , 4 , 1 , 2 )
202
- tensor = tensor .reshape (N , - 1 , 4 ) # Size=(N, HWA, 4)
203
- return tensor
204
-
205
- predicted_regression = head_outputs ['bbox_reg' ]
206
- predicted_regression = [permute_bbox_reg (reg ) for reg in predicted_regression ]
207
- predicted_regression = torch .cat (predicted_regression , dim = 1 )
201
+ bbox_regression = head_outputs ['bbox_regression' ]
208
202
209
- for targets_per_image , predicted_regression_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , predicted_regression , anchors , matched_idxs ):
203
+ for targets_per_image , bbox_regression_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , bbox_regression , anchors , matched_idxs ):
210
204
# get the targets corresponding GT for each proposal
211
205
# NB: need to clamp the indices because we can have a single
212
206
# GT in the image, and matched_idxs can be -2, which goes
@@ -219,23 +213,36 @@ def permute_bbox_reg(tensor):
219
213
220
214
# select only the foreground boxes
221
215
matched_gt_boxes_per_image = matched_gt_boxes_per_image [foreground_idxs_per_image , :]
222
- predicted_regression_per_image = predicted_regression_per_image [foreground_idxs_per_image , :]
216
+ bbox_regression_per_image = bbox_regression_per_image [foreground_idxs_per_image , :]
223
217
anchors_per_image = anchors_per_image [foreground_idxs_per_image , :]
224
218
225
219
# compute the regression targets
226
220
target_regression = self .box_coder .encode_single (matched_gt_boxes_per_image , anchors_per_image )
227
221
228
222
# compute the loss
229
- loss .append (torch .nn .SmoothL1Loss (reduction = 'sum' )(predicted_regression_per_image , target_regression ) / max (1 , num_foreground ))
223
+ loss .append (torch .nn .SmoothL1Loss (reduction = 'sum' )(bbox_regression_per_image , target_regression ) / max (1 , num_foreground ))
230
224
231
225
return sum (loss ) / max (1 , len (loss ))
232
226
233
227
def forward (self , x ):
234
- x = F .relu (self .conv1 (x ))
235
- x = F .relu (self .conv2 (x ))
236
- x = F .relu (self .conv3 (x ))
237
- x = F .relu (self .conv4 (x ))
238
- return self .bbox_reg (x )
228
+ all_bbox_regression = []
229
+
230
+ for features in x :
231
+ bbox_regression = F .relu (self .conv1 (features ))
232
+ bbox_regression = F .relu (self .conv2 (bbox_regression ))
233
+ bbox_regression = F .relu (self .conv3 (bbox_regression ))
234
+ bbox_regression = F .relu (self .conv4 (bbox_regression ))
235
+ bbox_regression = self .bbox_reg (bbox_regression )
236
+
237
+ # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
238
+ N , _ , H , W = bbox_regression .shape
239
+ bbox_regression = bbox_regression .view (N , - 1 , 4 , H , W )
240
+ bbox_regression = bbox_regression .permute (0 , 3 , 4 , 1 , 2 )
241
+ bbox_regression = bbox_regression .reshape (N , - 1 , 4 ) # Size=(N, HWA, 4)
242
+
243
+ all_bbox_regression .append (bbox_regression )
244
+
245
+ return torch .cat (all_bbox_regression , dim = 1 )
239
246
240
247
241
248
class RetinaNet (nn .Module ):
0 commit comments