@@ -107,20 +107,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
107
107
# determine only the foreground
108
108
foreground_idxs_per_image = matched_idxs_per_image >= 0
109
109
num_foreground = foreground_idxs_per_image .sum ()
110
- # no matched_idxs means there were no annotations in this image
111
- if matched_idxs_per_image .numel () == 0 :
112
- gt_classes_target = torch .zeros_like (cls_logits_per_image )
113
- valid_idxs_per_image = torch .arange (cls_logits_per_image .shape [0 ], device = cls_logits_per_image .device )
114
- else :
115
- # create the target classification
116
- gt_classes_target = torch .zeros_like (cls_logits_per_image )
117
- gt_classes_target [
118
- foreground_idxs_per_image ,
119
- targets_per_image ['labels' ][matched_idxs_per_image [foreground_idxs_per_image ]]
120
- ] = 1.0
121
-
122
- # find indices for which anchors should be ignored
123
- valid_idxs_per_image = matched_idxs_per_image != self .BETWEEN_THRESHOLDS
110
+
111
+ # create the target classification
112
+ gt_classes_target = torch .zeros_like (cls_logits_per_image )
113
+ gt_classes_target [
114
+ foreground_idxs_per_image ,
115
+ targets_per_image ['labels' ][matched_idxs_per_image [foreground_idxs_per_image ]]
116
+ ] = 1.0
117
+
118
+ # find indices for which anchors should be ignored
119
+ valid_idxs_per_image = matched_idxs_per_image != self .BETWEEN_THRESHOLDS
124
120
125
121
# compute the classification loss
126
122
losses .append (sigmoid_focal_loss (
@@ -190,22 +186,12 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
190
186
191
187
for targets_per_image , bbox_regression_per_image , anchors_per_image , matched_idxs_per_image in \
192
188
zip (targets , bbox_regression , anchors , matched_idxs ):
193
- # no matched_idxs means there were no annotations in this image
194
- if matched_idxs_per_image .numel () == 0 :
195
- matched_gt_boxes_per_image = torch .zeros_like (bbox_regression_per_image )
196
- else :
197
- # get the targets corresponding GT for each proposal
198
- # NB: need to clamp the indices because we can have a single
199
- # GT in the image, and matched_idxs can be -2, which goes
200
- # out of bounds
201
- matched_gt_boxes_per_image = targets_per_image ['boxes' ][matched_idxs_per_image .clamp (min = 0 )]
202
-
203
189
# determine only the foreground indices, ignore the rest
204
190
foreground_idxs_per_image = torch .where (matched_idxs_per_image >= 0 )[0 ]
205
191
num_foreground = foreground_idxs_per_image .numel ()
206
192
207
193
# select only the foreground boxes
208
- matched_gt_boxes_per_image = matched_gt_boxes_per_image [ foreground_idxs_per_image , : ]
194
+ matched_gt_boxes_per_image = targets_per_image [ 'boxes' ][ matched_idxs_per_image [ foreground_idxs_per_image ] ]
209
195
bbox_regression_per_image = bbox_regression_per_image [foreground_idxs_per_image , :]
210
196
anchors_per_image = anchors_per_image [foreground_idxs_per_image , :]
211
197
@@ -401,7 +387,7 @@ def compute_loss(self, targets, head_outputs, anchors):
401
387
matched_idxs = []
402
388
for anchors_per_image , targets_per_image in zip (anchors , targets ):
403
389
if targets_per_image ['boxes' ].numel () == 0 :
404
- matched_idxs .append (torch .empty (( 0 ,) , dtype = torch .int64 ))
390
+ matched_idxs .append (torch .full (( anchors_per_image . size ( 0 ),), - 1 , dtype = torch .int64 ))
405
391
continue
406
392
407
393
match_quality_matrix = box_ops .box_iou (targets_per_image ['boxes' ], anchors_per_image )
0 commit comments