@@ -408,44 +408,56 @@ def compute_loss(self, targets, head_outputs, anchors):
408
408
return self .head .compute_loss (targets , head_outputs , anchors , matched_idxs )
409
409
410
410
def postprocess_detections (self , head_outputs , anchors , image_shapes ):
411
- # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
412
- # TODO: confirm that RetinaNet can't have other outputs like masks
411
+ # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
413
412
class_logits = head_outputs ['cls_logits' ]
414
413
box_regression = head_outputs ['bbox_regression' ]
415
414
416
- num_classes = class_logits .shape [- 1 ]
417
-
418
- scores = torch .sigmoid (class_logits )
415
+ num_images = len (image_shapes )
419
416
420
417
detections = torch .jit .annotate (List [Dict [str , Tensor ]], [])
421
418
422
- for index , (box_regression_per_image , scores_per_image , anchors_per_image , image_shape ) in \
423
- enumerate (zip (box_regression , scores , anchors , image_shapes )):
424
- # remove low scoring boxes
425
- scores_per_image = scores_per_image .flatten ()
426
- keep_idxs = scores_per_image > self .score_thresh
427
- scores_per_image = scores_per_image [keep_idxs ]
428
- topk_idxs = torch .where (keep_idxs )[0 ]
419
+ for index in range (num_images ):
420
+ box_regression_per_image = [br [index ] for br in box_regression ]
421
+ logits_per_image = [cl [index ] for cl in class_logits ]
422
+ anchors_per_image , image_shape = anchors [index ], image_shapes [index ]
423
+
424
+ image_boxes = []
425
+ image_scores = []
426
+ image_labels = []
427
+
428
+ for box_regression_per_level , logits_per_level , anchors_per_level in \
429
+ zip (box_regression_per_image , logits_per_image , anchors_per_image ):
430
+ num_classes = logits_per_level .shape [- 1 ]
429
431
430
- # keep only topk scoring predictions
431
- num_topk = min (self .detections_per_img , topk_idxs .size (0 ))
432
- scores_per_image , idxs = scores_per_image .topk (num_topk )
433
- topk_idxs = topk_idxs [idxs ]
432
+ # remove low scoring boxes
433
+ scores_per_level = torch .sigmoid (logits_per_level ).flatten ()
434
+ keep_idxs = scores_per_level > self .score_thresh
435
+ scores_per_level = scores_per_level [keep_idxs ]
436
+ topk_idxs = torch .where (keep_idxs )[0 ]
434
437
435
- anchor_idxs = topk_idxs // num_classes
436
- labels_per_image = topk_idxs % num_classes
438
+ # keep only topk scoring predictions
439
+ num_topk = min (self .detections_per_img , topk_idxs .size (0 ))
440
+ scores_per_level , idxs = scores_per_level .topk (num_topk )
441
+ topk_idxs = topk_idxs [idxs ]
437
442
438
- boxes_per_image = self .box_coder .decode_single (box_regression_per_image [anchor_idxs ],
439
- anchors_per_image [anchor_idxs ])
440
- boxes_per_image = box_ops .clip_boxes_to_image (boxes_per_image , image_shape )
443
+ anchor_idxs = topk_idxs // num_classes
444
+ labels_per_level = topk_idxs % num_classes
441
445
442
- # non-maximum suppression
443
- keep = box_ops .batched_nms (boxes_per_image , scores_per_image , labels_per_image , self .nms_thresh )
446
+ boxes_per_level = self .box_coder .decode_single (box_regression_per_level [anchor_idxs ],
447
+ anchors_per_level [anchor_idxs ])
448
+ boxes_per_level = box_ops .clip_boxes_to_image (boxes_per_level , image_shape )
449
+
450
+ # non-maximum suppression
451
+ keep = box_ops .batched_nms (boxes_per_level , scores_per_level , labels_per_level , self .nms_thresh )
452
+
453
+ image_boxes .append (boxes_per_level [keep ])
454
+ image_scores .append (scores_per_level [keep ])
455
+ image_labels .append (labels_per_level [keep ])
444
456
445
457
detections .append ({
446
- 'boxes' : boxes_per_image [ keep ] ,
447
- 'scores' : scores_per_image [ keep ] ,
448
- 'labels' : labels_per_image [ keep ] ,
458
+ 'boxes' : torch . cat ( image_boxes , dim = 0 ) ,
459
+ 'scores' : torch . cat ( image_scores , dim = 0 ) ,
460
+ 'labels' : torch . cat ( image_labels , dim = 0 ) ,
449
461
})
450
462
451
463
return detections
@@ -526,8 +538,24 @@ def forward(self, images, targets=None):
526
538
# compute the losses
527
539
losses = self .compute_loss (targets , head_outputs , anchors )
528
540
else :
541
+ # recover level sizes
542
+ feature_sizes_per_level = [x .size (2 ) * x .size (3 ) for x in features ]
543
+ HW = 0
544
+ for v in feature_sizes_per_level :
545
+ HW += v
546
+ HWA = head_outputs ['cls_logits' ].size (1 )
547
+ A = HWA // HW
548
+ feature_sizes_per_level = [hw * A for hw in feature_sizes_per_level ]
549
+
550
+ # split outputs per level
551
+ split_head_outputs : Dict [str , List [Tensor ]] = {}
552
+ for k in head_outputs :
553
+ split_head_outputs [k ] = [x .permute (1 , 0 , 2 ) for x in
554
+ head_outputs [k ].permute (1 , 0 , 2 ).split_with_sizes (feature_sizes_per_level )]
555
+ split_anchors = [list (a .split_with_sizes (feature_sizes_per_level )) for a in anchors ]
556
+
529
557
# compute the detections
530
- detections = self .postprocess_detections (head_outputs , anchors , images .image_sizes )
558
+ detections = self .postprocess_detections (split_head_outputs , split_anchors , images .image_sizes )
531
559
detections = self .transform .postprocess (detections , images .image_sizes , original_image_sizes )
532
560
533
561
if torch .jit .is_scripting ():
0 commit comments