13
13
from .transform import GeneralizedRCNNTransform
14
14
from .backbone_utils import resnet_fpn_backbone
15
15
from ...ops .feature_pyramid_network import LastLevelP6P7
16
+ from ...ops import boxes as box_ops
16
17
17
18
18
19
__all__ = [
@@ -288,9 +289,9 @@ class RetinaNet(nn.Module):
288
289
maps.
289
290
head (nn.Module): Module run on top of the feature pyramid.
290
291
Defaults to a module containing a classification and regression module.
291
- pre_nms_top_n (int): number of proposals to keep before applying NMS during testing.
292
- post_nms_top_n (int): number of proposals to keep after applying NMS during testing.
292
+ score_thresh (float): Score threshold used for postprocessing the detections.
293
293
nms_thresh (float): NMS threshold used for postprocessing the detections.
294
+ detections_per_img (int): Number of best detections to keep after NMS.
294
295
fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
295
296
considered as positive during training.
296
297
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
@@ -334,8 +335,9 @@ def __init__(self, backbone, num_classes,
334
335
# Anchor parameters
335
336
anchor_generator = None , head = None ,
336
337
proposal_matcher = None ,
337
- pre_nms_top_n = 1000 , post_nms_top_n = 1000 ,
338
+ score_thresh = 0.5 ,
338
339
nms_thresh = 0.5 ,
340
+ detections_per_img = 300 ,
339
341
fg_iou_thresh = 0.5 , bg_iou_thresh = 0.4 ):
340
342
super (RetinaNet , self ).__init__ ()
341
343
@@ -349,7 +351,6 @@ def __init__(self, backbone, num_classes,
349
351
assert isinstance (anchor_generator , (AnchorGenerator , type (None )))
350
352
351
353
if anchor_generator is None :
352
- # TODO: Set correct default values
353
354
anchor_sizes = [[x , x * 2 ** (1.0 / 3 ), x * 2 ** (2.0 / 3 )] for x in [32 , 64 , 128 , 256 , 512 ]]
354
355
aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
355
356
anchor_generator = AnchorGenerator (
@@ -369,12 +370,18 @@ def __init__(self, backbone, num_classes,
369
370
)
370
371
self .proposal_matcher = proposal_matcher
371
372
373
+ self .box_coder = det_utils .BoxCoder (weights = (1.0 , 1.0 , 1.0 , 1.0 ))
374
+
372
375
if image_mean is None :
373
376
image_mean = [0.485 , 0.456 , 0.406 ]
374
377
if image_std is None :
375
378
image_std = [0.229 , 0.224 , 0.225 ]
376
379
self .transform = GeneralizedRCNNTransform (min_size , max_size , image_mean , image_std )
377
380
381
+ self .score_thresh = score_thresh
382
+ self .nms_thresh = nms_thresh
383
+ self .detections_per_img = detections_per_img
384
+
378
385
@torch .jit .unused
379
386
def eager_outputs (self , losses , detections ):
380
387
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
@@ -390,6 +397,46 @@ def compute_loss(self, targets, head_outputs, anchors):
390
397
391
398
return self .head .compute_loss (targets , head_outputs , anchors , matched_idxs )
392
399
400
+ def postprocess_detections (self , class_logits , box_regression , anchors , image_shapes ):
401
+ # type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]])
402
+ # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
403
+ device = class_logits .device
404
+ num_classes = class_logits .shape [- 1 ]
405
+
406
+ scores = torch .sigmoid (class_logits )
407
+
408
+ # create labels for each score
409
+ labels = torch .arange (num_classes , device = device )
410
+ labels = labels .view (1 , - 1 ).expand_as (scores )
411
+
412
+ all_boxes = []
413
+ all_scores = []
414
+ all_labels = []
415
+ for box_regression_per_image , scores_per_image , labels_per_image , anchors_per_image , image_shape in zip (box_regression , scores , labels , anchors , image_shapes ):
416
+ boxes_per_image = self .box_coder .decode_single (box_regression_per_image , anchors_per_image )
417
+ boxes_per_image = box_ops .clip_boxes_to_image (boxes_per_image , image_shape )
418
+
419
+ for class_index in range (num_classes ):
420
+ # remove low scoring boxes
421
+ inds = torch .nonzero (scores_per_image [:, class_index ] > self .score_thresh ).squeeze (1 )
422
+ boxes_per_class , scores_per_class , labels_per_class = boxes_per_image [inds ], scores_per_image [inds , class_index ], labels_per_image [inds , class_index ]
423
+
424
+ # remove empty boxes
425
+ keep = box_ops .remove_small_boxes (boxes_per_class , min_size = 1e-2 )
426
+ boxes_per_class , scores_per_class , labels_per_class = boxes_per_class [keep ], scores_per_class [keep ], labels_per_class [keep ]
427
+
428
+ # non-maximum suppression, independently done per class
429
+ keep = box_ops .nms (boxes_per_class , scores_per_class , self .nms_thresh )
430
+ # keep only topk scoring predictions
431
+ keep = keep [:self .detections_per_img ]
432
+ boxes_per_class , scores_per_class , labels_per_class = boxes_per_class [keep ], scores_per_class [keep ], labels_per_class [keep ]
433
+
434
+ all_boxes .append (boxes_per_class )
435
+ all_scores .append (scores_per_class )
436
+ all_labels .append (labels_per_class )
437
+
438
+ return all_boxes , all_scores , all_labels
439
+
393
440
def forward (self , images , targets = None ):
394
441
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
395
442
"""
@@ -446,9 +493,8 @@ def forward(self, images, targets=None):
446
493
losses = self .compute_loss (targets , head_outputs , anchors )
447
494
else :
448
495
# compute the detections
449
- # TODO: Implement postprocess_detections
450
- boxes , scores , labels = self .postprocess_detections (class_logits , box_regression , anchors )
451
- num_images = len (images )
496
+ boxes , scores , labels = self .postprocess_detections (head_outputs ['cls_logits' ], head_outputs ['bbox_regression' ], anchors , original_image_sizes )
497
+ num_images = len (original_image_sizes )
452
498
for i in range (num_images ):
453
499
detections .append (
454
500
{
0 commit comments