@@ -410,74 +410,48 @@ def compute_loss(self, targets, head_outputs, anchors):
410
410
def postprocess_detections (self , head_outputs , anchors , image_shapes ):
411
411
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
412
412
# TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
413
-
414
413
class_logits = head_outputs .pop ('cls_logits' )
415
414
box_regression = head_outputs .pop ('bbox_regression' )
416
415
other_outputs = head_outputs
417
416
418
- device = class_logits .device
419
417
num_classes = class_logits .shape [- 1 ]
420
418
421
419
scores = torch .sigmoid (class_logits )
422
420
423
- # create labels for each score
424
- labels = torch .arange (num_classes , device = device )
425
- labels = labels .view (1 , - 1 ).expand_as (scores )
426
-
427
421
detections = torch .jit .annotate (List [Dict [str , Tensor ]], [])
428
422
429
- for index , (box_regression_per_image , scores_per_image , labels_per_image , anchors_per_image , image_shape ) in \
430
- enumerate (zip (box_regression , scores , labels , anchors , image_shapes )):
423
+ for index , (box_regression_per_image , scores_per_image , anchors_per_image , image_shape ) in \
424
+ enumerate (zip (box_regression , scores , anchors , image_shapes )):
425
+ # remove low scoring boxes
426
+ scores_per_image = scores_per_image .flatten ()
427
+ keep_idxs = scores_per_image > self .score_thresh
428
+ scores_per_image = scores_per_image [keep_idxs ]
429
+ topk_idxs = torch .where (keep_idxs )[0 ]
430
+
431
+ # keep only topk scoring predictions
432
+ num_topk = min (self .detections_per_img , topk_idxs .size (0 ))
433
+ scores_per_image , idxs = scores_per_image .topk (num_topk )
434
+ topk_idxs = topk_idxs [idxs ]
431
435
432
- boxes_per_image = self .box_coder .decode_single (box_regression_per_image , anchors_per_image )
436
+ anchor_idxs = topk_idxs // num_classes
437
+ labels_per_image = topk_idxs % num_classes
438
+
439
+ boxes_per_image = self .box_coder .decode_single (box_regression_per_image [anchor_idxs ],
440
+ anchors_per_image [anchor_idxs ])
433
441
boxes_per_image = box_ops .clip_boxes_to_image (boxes_per_image , image_shape )
434
442
435
- other_outputs_per_image = [(k , v [index ]) for k , v in other_outputs .items ()]
436
-
437
- image_boxes = []
438
- image_scores = []
439
- image_labels = []
440
- image_other_outputs = torch .jit .annotate (Dict [str , List [Tensor ]], {})
441
-
442
- for class_index in range (num_classes ):
443
- # remove low scoring boxes
444
- inds = torch .gt (scores_per_image [:, class_index ], self .score_thresh )
445
- boxes_per_class , scores_per_class , labels_per_class = \
446
- boxes_per_image [inds ], scores_per_image [inds , class_index ], labels_per_image [inds , class_index ]
447
- other_outputs_per_class = [(k , v [inds ]) for k , v in other_outputs_per_image ]
448
-
449
- # remove empty boxes
450
- keep = box_ops .remove_small_boxes (boxes_per_class , min_size = 1e-2 )
451
- boxes_per_class , scores_per_class , labels_per_class = \
452
- boxes_per_class [keep ], scores_per_class [keep ], labels_per_class [keep ]
453
- other_outputs_per_class = [(k , v [keep ]) for k , v in other_outputs_per_class ]
454
-
455
- # non-maximum suppression, independently done per class
456
- keep = box_ops .nms (boxes_per_class , scores_per_class , self .nms_thresh )
457
-
458
- # keep only topk scoring predictions
459
- keep = keep [:self .detections_per_img ]
460
- boxes_per_class , scores_per_class , labels_per_class = \
461
- boxes_per_class [keep ], scores_per_class [keep ], labels_per_class [keep ]
462
- other_outputs_per_class = [(k , v [keep ]) for k , v in other_outputs_per_class ]
463
-
464
- image_boxes .append (boxes_per_class )
465
- image_scores .append (scores_per_class )
466
- image_labels .append (labels_per_class )
467
-
468
- for k , v in other_outputs_per_class :
469
- if k not in image_other_outputs :
470
- image_other_outputs [k ] = []
471
- image_other_outputs [k ].append (v )
472
-
473
- detections .append ({
474
- 'boxes' : torch .cat (image_boxes , dim = 0 ),
475
- 'scores' : torch .cat (image_scores , dim = 0 ),
476
- 'labels' : torch .cat (image_labels , dim = 0 ),
477
- })
478
-
479
- for k , v in image_other_outputs .items ():
480
- detections [- 1 ].update ({k : torch .cat (v , dim = 0 )})
443
+ # non-maximum suppression
444
+ keep = box_ops .batched_nms (boxes_per_image , scores_per_image , labels_per_image , self .nms_thresh )
445
+
446
+ det = {
447
+ 'boxes' : boxes_per_image [keep ],
448
+ 'scores' : scores_per_image [keep ],
449
+ 'labels' : labels_per_image [keep ],
450
+ }
451
+ for k , v in other_outputs .items ():
452
+ det [k ] = v [index ][keep ]
453
+
454
+ detections .append (det )
481
455
482
456
return detections
483
457
0 commit comments