13
13
from .transform import GeneralizedRCNNTransform
14
14
from .backbone_utils import resnet_fpn_backbone
15
15
from ...ops .feature_pyramid_network import LastLevelP6P7
16
+ from ...ops import boxes as box_ops
16
17
17
18
18
19
__all__ = [
@@ -288,9 +289,9 @@ class RetinaNet(nn.Module):
288
289
maps.
289
290
head (nn.Module): Module run on top of the feature pyramid.
290
291
Defaults to a module containing a classification and regression module.
291
- pre_nms_top_n (int): number of proposals to keep before applying NMS during testing.
292
- post_nms_top_n (int): number of proposals to keep after applying NMS during testing.
292
+ score_thresh (float): Score threshold used for postprocessing the detections.
293
293
nms_thresh (float): NMS threshold used for postprocessing the detections.
294
+ detections_per_img (int): Number of best detections to keep after NMS.
294
295
fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
295
296
considered as positive during training.
296
297
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
@@ -334,8 +335,9 @@ def __init__(self, backbone, num_classes,
334
335
# Anchor parameters
335
336
anchor_generator = None , head = None ,
336
337
proposal_matcher = None ,
337
- pre_nms_top_n = 1000 , post_nms_top_n = 1000 ,
338
+ score_thresh = 0.5 ,
338
339
nms_thresh = 0.5 ,
340
+ detections_per_img = 300 ,
339
341
fg_iou_thresh = 0.5 , bg_iou_thresh = 0.4 ):
340
342
super (RetinaNet , self ).__init__ ()
341
343
@@ -349,7 +351,6 @@ def __init__(self, backbone, num_classes,
349
351
assert isinstance (anchor_generator , (AnchorGenerator , type (None )))
350
352
351
353
if anchor_generator is None :
352
- # TODO: Set correct default values
353
354
anchor_sizes = [[x , x * 2 ** (1.0 / 3 ), x * 2 ** (2.0 / 3 )] for x in [32 , 64 , 128 , 256 , 512 ]]
354
355
aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
355
356
anchor_generator = AnchorGenerator (
@@ -369,12 +370,18 @@ def __init__(self, backbone, num_classes,
369
370
)
370
371
self .proposal_matcher = proposal_matcher
371
372
373
+ self .box_coder = det_utils .BoxCoder (weights = (1.0 , 1.0 , 1.0 , 1.0 ))
374
+
372
375
if image_mean is None :
373
376
image_mean = [0.485 , 0.456 , 0.406 ]
374
377
if image_std is None :
375
378
image_std = [0.229 , 0.224 , 0.225 ]
376
379
self .transform = GeneralizedRCNNTransform (min_size , max_size , image_mean , image_std )
377
380
381
+ self .score_thresh = score_thresh
382
+ self .nms_thresh = nms_thresh
383
+ self .detections_per_img = detections_per_img
384
+
378
385
@torch .jit .unused
379
386
def eager_outputs (self , losses , detections ):
380
387
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
@@ -390,6 +397,57 @@ def compute_loss(self, targets, head_outputs, anchors):
390
397
391
398
return self .head .compute_loss (targets , head_outputs , anchors , matched_idxs )
392
399
400
+ def postprocess_detections (self , class_logits , box_regression , anchors , image_shapes ):
401
+ # type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]])
402
+ # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
403
+ device = class_logits .device
404
+ num_classes = class_logits .shape [- 1 ]
405
+
406
+ scores = torch .sigmoid (class_logits )
407
+
408
+ # create labels for each score
409
+ # the +1 is to make the labels identical to other object detection algorithms that treat background as label 0
410
+ labels = torch .arange (num_classes , device = device ) + 1
411
+ labels = labels .view (1 , - 1 ).expand_as (scores )
412
+
413
+ detections = []
414
+
415
+ for box_regression_per_image , scores_per_image , labels_per_image , anchors_per_image , image_shape in zip (box_regression , scores , labels , anchors , image_shapes ):
416
+ boxes_per_image = self .box_coder .decode_single (box_regression_per_image , anchors_per_image )
417
+ boxes_per_image = box_ops .clip_boxes_to_image (boxes_per_image , image_shape )
418
+
419
+ image_boxes = []
420
+ image_scores = []
421
+ image_labels = []
422
+
423
+ for class_index in range (num_classes ):
424
+ # remove low scoring boxes
425
+ inds = torch .nonzero (scores_per_image [:, class_index ] > self .score_thresh ).squeeze (1 )
426
+ boxes_per_class , scores_per_class , labels_per_class = boxes_per_image [inds ], scores_per_image [inds , class_index ], labels_per_image [inds , class_index ]
427
+
428
+ # remove empty boxes
429
+ keep = box_ops .remove_small_boxes (boxes_per_class , min_size = 1e-2 )
430
+ boxes_per_class , scores_per_class , labels_per_class = boxes_per_class [keep ], scores_per_class [keep ], labels_per_class [keep ]
431
+
432
+ # non-maximum suppression, independently done per class
433
+ keep = box_ops .nms (boxes_per_class , scores_per_class , self .nms_thresh )
434
+
435
+ # keep only topk scoring predictions
436
+ keep = keep [:self .detections_per_img ]
437
+ boxes_per_class , scores_per_class , labels_per_class = boxes_per_class [keep ], scores_per_class [keep ], labels_per_class [keep ]
438
+
439
+ image_boxes .append (boxes_per_class )
440
+ image_scores .append (scores_per_class )
441
+ image_labels .append (labels_per_class )
442
+
443
+ detections .append ({
444
+ 'boxes' : torch .cat (image_boxes , dim = 0 ),
445
+ 'scores' : torch .cat (image_scores , dim = 0 ),
446
+ 'labels' : torch .cat (image_labels , dim = 0 ),
447
+ })
448
+
449
+ return detections
450
+
393
451
def forward (self , images , targets = None ):
394
452
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
395
453
"""
@@ -446,19 +504,8 @@ def forward(self, images, targets=None):
446
504
losses = self .compute_loss (targets , head_outputs , anchors )
447
505
else :
448
506
# compute the detections
449
- # TODO: Implement postprocess_detections
450
- boxes , scores , labels = self .postprocess_detections (class_logits , box_regression , anchors )
451
- num_images = len (images )
452
- for i in range (num_images ):
453
- detections .append (
454
- {
455
- "boxes" : boxes [i ],
456
- "labels" : labels [i ],
457
- "scores" : scores [i ],
458
- }
459
- )
460
-
461
- detections = self .transform .postprocess (detections , images .image_sizes , original_image_sizes )
507
+ detections = self .postprocess_detections (head_outputs ['cls_logits' ], head_outputs ['bbox_regression' ], anchors , original_image_sizes )
508
+ detections = self .transform .postprocess (detections , images .image_sizes , original_image_sizes )
462
509
463
510
if torch .jit .is_scripting ():
464
511
if not self ._has_warned :
0 commit comments