@@ -107,21 +107,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
107
107
# determine only the foreground
108
108
foreground_idxs_per_image = matched_idxs_per_image >= 0
109
109
num_foreground = foreground_idxs_per_image .sum ()
110
- # no matched_idxs means there were no annotations in this image
111
- # TODO: enable support for images without annotations that works on distributed
112
- if False : # matched_idxs_per_image.numel() == 0:
113
- gt_classes_target = torch .zeros_like (cls_logits_per_image )
114
- valid_idxs_per_image = torch .arange (cls_logits_per_image .shape [0 ])
115
- else :
116
- # create the target classification
117
- gt_classes_target = torch .zeros_like (cls_logits_per_image )
118
- gt_classes_target [
119
- foreground_idxs_per_image ,
120
- targets_per_image ['labels' ][matched_idxs_per_image [foreground_idxs_per_image ]]
121
- ] = 1.0
122
-
123
- # find indices for which anchors should be ignored
124
- valid_idxs_per_image = matched_idxs_per_image != self .BETWEEN_THRESHOLDS
110
+
111
+ # create the target classification
112
+ gt_classes_target = torch .zeros_like (cls_logits_per_image )
113
+ gt_classes_target [
114
+ foreground_idxs_per_image ,
115
+ targets_per_image ['labels' ][matched_idxs_per_image [foreground_idxs_per_image ]]
116
+ ] = 1.0
117
+
118
+ # find indices for which anchors should be ignored
119
+ valid_idxs_per_image = matched_idxs_per_image != self .BETWEEN_THRESHOLDS
125
120
126
121
# compute the classification loss
127
122
losses .append (sigmoid_focal_loss (
@@ -191,23 +186,12 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
191
186
192
187
for targets_per_image , bbox_regression_per_image , anchors_per_image , matched_idxs_per_image in \
193
188
zip (targets , bbox_regression , anchors , matched_idxs ):
194
- # no matched_idxs means there were no annotations in this image
195
- # TODO enable support for images without annotations with distributed support
196
- # if matched_idxs_per_image.numel() == 0:
197
- # continue
198
-
199
- # get the targets corresponding GT for each proposal
200
- # NB: need to clamp the indices because we can have a single
201
- # GT in the image, and matched_idxs can be -2, which goes
202
- # out of bounds
203
- matched_gt_boxes_per_image = targets_per_image ['boxes' ][matched_idxs_per_image .clamp (min = 0 )]
204
-
205
189
# determine only the foreground indices, ignore the rest
206
- foreground_idxs_per_image = matched_idxs_per_image >= 0
207
- num_foreground = foreground_idxs_per_image .sum ()
190
+ foreground_idxs_per_image = torch . where ( matched_idxs_per_image >= 0 )[ 0 ]
191
+ num_foreground = foreground_idxs_per_image .numel ()
208
192
209
193
# select only the foreground boxes
210
- matched_gt_boxes_per_image = matched_gt_boxes_per_image [ foreground_idxs_per_image , : ]
194
+ matched_gt_boxes_per_image = targets_per_image [ 'boxes' ][ matched_idxs_per_image [ foreground_idxs_per_image ] ]
211
195
bbox_regression_per_image = bbox_regression_per_image [foreground_idxs_per_image , :]
212
196
anchors_per_image = anchors_per_image [foreground_idxs_per_image , :]
213
197
@@ -403,7 +387,7 @@ def compute_loss(self, targets, head_outputs, anchors):
403
387
matched_idxs = []
404
388
for anchors_per_image , targets_per_image in zip (anchors , targets ):
405
389
if targets_per_image ['boxes' ].numel () == 0 :
406
- matched_idxs .append (torch .empty (( 0 ,), dtype = torch .int32 ))
390
+ matched_idxs .append (torch .full (( anchors_per_image . size ( 0 ) ,), - 1 , dtype = torch .int64 ))
407
391
continue
408
392
409
393
match_quality_matrix = box_ops .box_iou (targets_per_image ['boxes' ], anchors_per_image )
0 commit comments