@@ -291,6 +291,7 @@ class RetinaNet(nn.Module):
291
291
considered as positive during training.
292
292
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
293
293
considered as negative during training.
294
+ topk_candidates (int): Number of best detections to keep before NMS.
294
295
295
296
Example:
296
297
@@ -339,7 +340,8 @@ def __init__(self, backbone, num_classes,
339
340
score_thresh = 0.05 ,
340
341
nms_thresh = 0.5 ,
341
342
detections_per_img = 300 ,
342
- fg_iou_thresh = 0.5 , bg_iou_thresh = 0.4 ):
343
+ fg_iou_thresh = 0.5 , bg_iou_thresh = 0.4 ,
344
+ topk_candidates = 1000 ):
343
345
super ().__init__ ()
344
346
345
347
if not hasattr (backbone , "out_channels" ):
@@ -382,6 +384,7 @@ def __init__(self, backbone, num_classes,
382
384
self .score_thresh = score_thresh
383
385
self .nms_thresh = nms_thresh
384
386
self .detections_per_img = detections_per_img
387
+ self .topk_candidates = topk_candidates
385
388
386
389
# used only on torchscript mode
387
390
self ._has_warned = False
@@ -408,77 +411,63 @@ def compute_loss(self, targets, head_outputs, anchors):
408
411
return self .head .compute_loss (targets , head_outputs , anchors , matched_idxs )
409
412
410
413
def postprocess_detections (self , head_outputs , anchors , image_shapes ):
411
- # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
412
- # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
414
+ # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
415
+ class_logits = head_outputs ['cls_logits' ]
416
+ box_regression = head_outputs ['bbox_regression' ]
413
417
414
- class_logits = head_outputs .pop ('cls_logits' )
415
- box_regression = head_outputs .pop ('bbox_regression' )
416
- other_outputs = head_outputs
417
-
418
- device = class_logits .device
419
- num_classes = class_logits .shape [- 1 ]
420
-
421
- scores = torch .sigmoid (class_logits )
422
-
423
- # create labels for each score
424
- labels = torch .arange (num_classes , device = device )
425
- labels = labels .view (1 , - 1 ).expand_as (scores )
418
+ num_images = len (image_shapes )
426
419
427
420
detections = torch .jit .annotate (List [Dict [str , Tensor ]], [])
428
421
429
- for index , (box_regression_per_image , scores_per_image , labels_per_image , anchors_per_image , image_shape ) in \
430
- enumerate (zip (box_regression , scores , labels , anchors , image_shapes )):
431
-
432
- boxes_per_image = self .box_coder .decode_single (box_regression_per_image , anchors_per_image )
433
- boxes_per_image = box_ops .clip_boxes_to_image (boxes_per_image , image_shape )
434
-
435
- other_outputs_per_image = [(k , v [index ]) for k , v in other_outputs .items ()]
422
+ for index in range (num_images ):
423
+ box_regression_per_image = [br [index ] for br in box_regression ]
424
+ logits_per_image = [cl [index ] for cl in class_logits ]
425
+ anchors_per_image , image_shape = anchors [index ], image_shapes [index ]
436
426
437
427
image_boxes = []
438
428
image_scores = []
439
429
image_labels = []
440
- image_other_outputs = torch .jit .annotate (Dict [str , List [Tensor ]], {})
441
430
442
- for class_index in range (num_classes ):
431
+ for box_regression_per_level , logits_per_level , anchors_per_level in \
432
+ zip (box_regression_per_image , logits_per_image , anchors_per_image ):
433
+ num_classes = logits_per_level .shape [- 1 ]
434
+
443
435
# remove low scoring boxes
444
- inds = torch .gt ( scores_per_image [:, class_index ], self . score_thresh )
445
- boxes_per_class , scores_per_class , labels_per_class = \
446
- boxes_per_image [ inds ], scores_per_image [ inds , class_index ], labels_per_image [ inds , class_index ]
447
- other_outputs_per_class = [( k , v [ inds ]) for k , v in other_outputs_per_image ]
436
+ scores_per_level = torch .sigmoid ( logits_per_level ). flatten ( )
437
+ keep_idxs = scores_per_level > self . score_thresh
438
+ scores_per_level = scores_per_level [ keep_idxs ]
439
+ topk_idxs = torch . where ( keep_idxs )[ 0 ]
448
440
449
- # remove empty boxes
450
- keep = box_ops .remove_small_boxes (boxes_per_class , min_size = 1e-2 )
451
- boxes_per_class , scores_per_class , labels_per_class = \
452
- boxes_per_class [keep ], scores_per_class [keep ], labels_per_class [keep ]
453
- other_outputs_per_class = [(k , v [keep ]) for k , v in other_outputs_per_class ]
441
+ # keep only topk scoring predictions
442
+ num_topk = min (self .topk_candidates , topk_idxs .size (0 ))
443
+ scores_per_level , idxs = scores_per_level .topk (num_topk )
444
+ topk_idxs = topk_idxs [idxs ]
454
445
455
- # non-maximum suppression, independently done per class
456
- keep = box_ops . nms ( boxes_per_class , scores_per_class , self . nms_thresh )
446
+ anchor_idxs = topk_idxs // num_classes
447
+ labels_per_level = topk_idxs % num_classes
457
448
458
- # keep only topk scoring predictions
459
- keep = keep [:self .detections_per_img ]
460
- boxes_per_class , scores_per_class , labels_per_class = \
461
- boxes_per_class [keep ], scores_per_class [keep ], labels_per_class [keep ]
462
- other_outputs_per_class = [(k , v [keep ]) for k , v in other_outputs_per_class ]
449
+ boxes_per_level = self .box_coder .decode_single (box_regression_per_level [anchor_idxs ],
450
+ anchors_per_level [anchor_idxs ])
451
+ boxes_per_level = box_ops .clip_boxes_to_image (boxes_per_level , image_shape )
452
+
453
+ image_boxes .append (boxes_per_level )
454
+ image_scores .append (scores_per_level )
455
+ image_labels .append (labels_per_level )
463
456
464
- image_boxes . append ( boxes_per_class )
465
- image_scores . append ( scores_per_class )
466
- image_labels . append ( labels_per_class )
457
+ image_boxes = torch . cat ( image_boxes , dim = 0 )
458
+ image_scores = torch . cat ( image_scores , dim = 0 )
459
+ image_labels = torch . cat ( image_labels , dim = 0 )
467
460
468
- for k , v in other_outputs_per_class :
469
- if k not in image_other_outputs :
470
- image_other_outputs [k ] = []
471
- image_other_outputs [k ].append (v )
461
+ # non-maximum suppression
462
+ keep = box_ops .batched_nms (image_boxes , image_scores , image_labels , self .nms_thresh )
463
+ keep = keep [:self .detections_per_img ]
472
464
473
465
detections .append ({
474
- 'boxes' : torch . cat ( image_boxes , dim = 0 ) ,
475
- 'scores' : torch . cat ( image_scores , dim = 0 ) ,
476
- 'labels' : torch . cat ( image_labels , dim = 0 ) ,
466
+ 'boxes' : image_boxes [ keep ] ,
467
+ 'scores' : image_scores [ keep ] ,
468
+ 'labels' : image_labels [ keep ] ,
477
469
})
478
470
479
- for k , v in image_other_outputs .items ():
480
- detections [- 1 ].update ({k : torch .cat (v , dim = 0 )})
481
-
482
471
return detections
483
472
484
473
def forward (self , images , targets = None ):
@@ -557,8 +546,23 @@ def forward(self, images, targets=None):
557
546
# compute the losses
558
547
losses = self .compute_loss (targets , head_outputs , anchors )
559
548
else :
549
+ # recover level sizes
550
+ num_anchors_per_level = [x .size (2 ) * x .size (3 ) for x in features ]
551
+ HW = 0
552
+ for v in num_anchors_per_level :
553
+ HW += v
554
+ HWA = head_outputs ['cls_logits' ].size (1 )
555
+ A = HWA // HW
556
+ num_anchors_per_level = [hw * A for hw in num_anchors_per_level ]
557
+
558
+ # split outputs per level
559
+ split_head_outputs : Dict [str , List [Tensor ]] = {}
560
+ for k in head_outputs :
561
+ split_head_outputs [k ] = list (head_outputs [k ].split (num_anchors_per_level , dim = 1 ))
562
+ split_anchors = [list (a .split (num_anchors_per_level )) for a in anchors ]
563
+
560
564
# compute the detections
561
- detections = self .postprocess_detections (head_outputs , anchors , images .image_sizes )
565
+ detections = self .postprocess_detections (split_head_outputs , split_anchors , images .image_sizes )
562
566
detections = self .transform .postprocess (detections , images .image_sizes , original_image_sizes )
563
567
564
568
if torch .jit .is_scripting ():
0 commit comments