1313from .transform import GeneralizedRCNNTransform
1414from .backbone_utils import resnet_fpn_backbone
1515from ...ops .feature_pyramid_network import LastLevelP6P7
16+ from ...ops import boxes as box_ops
1617
1718
1819__all__ = [
@@ -288,9 +289,9 @@ class RetinaNet(nn.Module):
288289 maps.
289290 head (nn.Module): Module run on top of the feature pyramid.
290291 Defaults to a module containing a classification and regression module.
291- pre_nms_top_n (int): number of proposals to keep before applying NMS during testing.
292- post_nms_top_n (int): number of proposals to keep after applying NMS during testing.
292+ score_thresh (float): Score threshold used for postprocessing the detections.
293293 nms_thresh (float): NMS threshold used for postprocessing the detections.
294+ detections_per_img (int): Number of best detections to keep after NMS.
294295 fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
295296 considered as positive during training.
296297 bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
@@ -334,8 +335,9 @@ def __init__(self, backbone, num_classes,
334335 # Anchor parameters
335336 anchor_generator = None , head = None ,
336337 proposal_matcher = None ,
337- pre_nms_top_n = 1000 , post_nms_top_n = 1000 ,
338+ score_thresh = 0.5 ,
338339 nms_thresh = 0.5 ,
340+ detections_per_img = 300 ,
339341 fg_iou_thresh = 0.5 , bg_iou_thresh = 0.4 ):
340342 super (RetinaNet , self ).__init__ ()
341343
@@ -349,7 +351,6 @@ def __init__(self, backbone, num_classes,
349351 assert isinstance (anchor_generator , (AnchorGenerator , type (None )))
350352
351353 if anchor_generator is None :
352- # TODO: Set correct default values
353354 anchor_sizes = [[x , x * 2 ** (1.0 / 3 ), x * 2 ** (2.0 / 3 )] for x in [32 , 64 , 128 , 256 , 512 ]]
354355 aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
355356 anchor_generator = AnchorGenerator (
@@ -369,12 +370,18 @@ def __init__(self, backbone, num_classes,
369370 )
370371 self .proposal_matcher = proposal_matcher
371372
373+ self .box_coder = det_utils .BoxCoder (weights = (1.0 , 1.0 , 1.0 , 1.0 ))
374+
372375 if image_mean is None :
373376 image_mean = [0.485 , 0.456 , 0.406 ]
374377 if image_std is None :
375378 image_std = [0.229 , 0.224 , 0.225 ]
376379 self .transform = GeneralizedRCNNTransform (min_size , max_size , image_mean , image_std )
377380
381+ self .score_thresh = score_thresh
382+ self .nms_thresh = nms_thresh
383+ self .detections_per_img = detections_per_img
384+
378385 @torch .jit .unused
379386 def eager_outputs (self , losses , detections ):
380387 # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
@@ -390,6 +397,57 @@ def compute_loss(self, targets, head_outputs, anchors):
390397
391398 return self .head .compute_loss (targets , head_outputs , anchors , matched_idxs )
392399
400+ def postprocess_detections (self , class_logits , box_regression , anchors , image_shapes ):
401+ # type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]])
402+ # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
403+ device = class_logits .device
404+ num_classes = class_logits .shape [- 1 ]
405+
406+ scores = torch .sigmoid (class_logits )
407+
408+ # create labels for each score
409+ # the +1 is to make the labels identical to other object detection algorithms that treat background as label 0
410+ labels = torch .arange (num_classes , device = device ) + 1
411+ labels = labels .view (1 , - 1 ).expand_as (scores )
412+
413+ detections = []
414+
415+ for box_regression_per_image , scores_per_image , labels_per_image , anchors_per_image , image_shape in zip (box_regression , scores , labels , anchors , image_shapes ):
416+ boxes_per_image = self .box_coder .decode_single (box_regression_per_image , anchors_per_image )
417+ boxes_per_image = box_ops .clip_boxes_to_image (boxes_per_image , image_shape )
418+
419+ image_boxes = []
420+ image_scores = []
421+ image_labels = []
422+
423+ for class_index in range (num_classes ):
424+ # remove low scoring boxes
425+ inds = torch .nonzero (scores_per_image [:, class_index ] > self .score_thresh ).squeeze (1 )
426+ boxes_per_class , scores_per_class , labels_per_class = boxes_per_image [inds ], scores_per_image [inds , class_index ], labels_per_image [inds , class_index ]
427+
428+ # remove empty boxes
429+ keep = box_ops .remove_small_boxes (boxes_per_class , min_size = 1e-2 )
430+ boxes_per_class , scores_per_class , labels_per_class = boxes_per_class [keep ], scores_per_class [keep ], labels_per_class [keep ]
431+
432+ # non-maximum suppression, independently done per class
433+ keep = box_ops .nms (boxes_per_class , scores_per_class , self .nms_thresh )
434+
435+ # keep only topk scoring predictions
436+ keep = keep [:self .detections_per_img ]
437+ boxes_per_class , scores_per_class , labels_per_class = boxes_per_class [keep ], scores_per_class [keep ], labels_per_class [keep ]
438+
439+ image_boxes .append (boxes_per_class )
440+ image_scores .append (scores_per_class )
441+ image_labels .append (labels_per_class )
442+
443+ detections .append ({
444+ 'boxes' : torch .cat (image_boxes , dim = 0 ),
445+ 'scores' : torch .cat (image_scores , dim = 0 ),
446+ 'labels' : torch .cat (image_labels , dim = 0 ),
447+ })
448+
449+ return detections
450+
393451 def forward (self , images , targets = None ):
394452 # type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
395453 """
@@ -446,19 +504,8 @@ def forward(self, images, targets=None):
446504 losses = self .compute_loss (targets , head_outputs , anchors )
447505 else :
448506 # compute the detections
449- # TODO: Implement postprocess_detections
450- boxes , scores , labels = self .postprocess_detections (class_logits , box_regression , anchors )
451- num_images = len (images )
452- for i in range (num_images ):
453- detections .append (
454- {
455- "boxes" : boxes [i ],
456- "labels" : labels [i ],
457- "scores" : scores [i ],
458- }
459- )
460-
461- detections = self .transform .postprocess (detections , images .image_sizes , original_image_sizes )
507+ detections = self .postprocess_detections (head_outputs ['cls_logits' ], head_outputs ['bbox_regression' ], anchors , original_image_sizes )
508+ detections = self .transform .postprocess (detections , images .image_sizes , original_image_sizes )
462509
463510 if torch .jit .is_scripting ():
464511 if not self ._has_warned :
0 commit comments