@@ -409,10 +409,9 @@ def compute_loss(self, targets, head_outputs, anchors):
409
409
410
410
def postprocess_detections (self , head_outputs , anchors , image_shapes ):
411
411
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
412
- # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
413
- class_logits = head_outputs .pop ('cls_logits' )
414
- box_regression = head_outputs .pop ('bbox_regression' )
415
- other_outputs = head_outputs
412
+ # TODO: confirm that RetinaNet can't have other outputs like masks
413
+ class_logits = head_outputs ['cls_logits' ]
414
+ box_regression = head_outputs ['bbox_regression' ]
416
415
417
416
num_classes = class_logits .shape [- 1 ]
418
417
@@ -443,15 +442,11 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
443
442
# non-maximum suppression
444
443
keep = box_ops .batched_nms (boxes_per_image , scores_per_image , labels_per_image , self .nms_thresh )
445
444
446
- det = {
445
+ detections . append ( {
447
446
'boxes' : boxes_per_image [keep ],
448
447
'scores' : scores_per_image [keep ],
449
448
'labels' : labels_per_image [keep ],
450
- }
451
- for k , v in other_outputs .items ():
452
- det [k ] = v [index ][keep ]
453
-
454
- detections .append (det )
449
+ })
455
450
456
451
return detections
457
452
0 commit comments