Skip to content

Commit f787a9d

Browse files
committed
Add postprocessing of detections.
1 parent 0de780d commit f787a9d

File tree

1 file changed

+64
-17
lines changed

1 file changed

+64
-17
lines changed

torchvision/models/detection/retinanet.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .transform import GeneralizedRCNNTransform
1414
from .backbone_utils import resnet_fpn_backbone
1515
from ...ops.feature_pyramid_network import LastLevelP6P7
16+
from ...ops import boxes as box_ops
1617

1718

1819
__all__ = [
@@ -288,9 +289,9 @@ class RetinaNet(nn.Module):
288289
maps.
289290
head (nn.Module): Module run on top of the feature pyramid.
290291
Defaults to a module containing a classification and regression module.
291-
pre_nms_top_n (int): number of proposals to keep before applying NMS during testing.
292-
post_nms_top_n (int): number of proposals to keep after applying NMS during testing.
292+
score_thresh (float): Score threshold used for postprocessing the detections.
293293
nms_thresh (float): NMS threshold used for postprocessing the detections.
294+
detections_per_img (int): Number of best detections to keep after NMS.
294295
fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
295296
considered as positive during training.
296297
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
@@ -334,8 +335,9 @@ def __init__(self, backbone, num_classes,
334335
# Anchor parameters
335336
anchor_generator=None, head=None,
336337
proposal_matcher=None,
337-
pre_nms_top_n=1000, post_nms_top_n=1000,
338+
score_thresh=0.5,
338339
nms_thresh=0.5,
340+
detections_per_img=300,
339341
fg_iou_thresh=0.5, bg_iou_thresh=0.4):
340342
super(RetinaNet, self).__init__()
341343

@@ -349,7 +351,6 @@ def __init__(self, backbone, num_classes,
349351
assert isinstance(anchor_generator, (AnchorGenerator, type(None)))
350352

351353
if anchor_generator is None:
352-
# TODO: Set correct default values
353354
anchor_sizes = [[x, x * 2 ** (1.0 / 3), x * 2 ** (2.0 / 3)] for x in [32, 64, 128, 256, 512]]
354355
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
355356
anchor_generator = AnchorGenerator(
@@ -369,12 +370,18 @@ def __init__(self, backbone, num_classes,
369370
)
370371
self.proposal_matcher = proposal_matcher
371372

373+
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
374+
372375
if image_mean is None:
373376
image_mean = [0.485, 0.456, 0.406]
374377
if image_std is None:
375378
image_std = [0.229, 0.224, 0.225]
376379
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
377380

381+
self.score_thresh = score_thresh
382+
self.nms_thresh = nms_thresh
383+
self.detections_per_img = detections_per_img
384+
378385
@torch.jit.unused
379386
def eager_outputs(self, losses, detections):
380387
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
@@ -390,6 +397,57 @@ def compute_loss(self, targets, head_outputs, anchors):
390397

391398
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
392399

400+
def postprocess_detections(self, class_logits, box_regression, anchors, image_shapes):
401+
# type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]])
402+
# TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
403+
device = class_logits.device
404+
num_classes = class_logits.shape[-1]
405+
406+
scores = torch.sigmoid(class_logits)
407+
408+
# create labels for each score
409+
# the +1 is to make the labels identical to other object detection algorithms that treat background as label 0
410+
labels = torch.arange(num_classes, device=device) + 1
411+
labels = labels.view(1, -1).expand_as(scores)
412+
413+
detections = []
414+
415+
for box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape in zip(box_regression, scores, labels, anchors, image_shapes):
416+
boxes_per_image = self.box_coder.decode_single(box_regression_per_image, anchors_per_image)
417+
boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape)
418+
419+
image_boxes = []
420+
image_scores = []
421+
image_labels = []
422+
423+
for class_index in range(num_classes):
424+
# remove low scoring boxes
425+
inds = torch.nonzero(scores_per_image[:, class_index] > self.score_thresh).squeeze(1)
426+
boxes_per_class, scores_per_class, labels_per_class = boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index]
427+
428+
# remove empty boxes
429+
keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2)
430+
boxes_per_class, scores_per_class, labels_per_class = boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
431+
432+
# non-maximum suppression, independently done per class
433+
keep = box_ops.nms(boxes_per_class, scores_per_class, self.nms_thresh)
434+
435+
# keep only topk scoring predictions
436+
keep = keep[:self.detections_per_img]
437+
boxes_per_class, scores_per_class, labels_per_class = boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
438+
439+
image_boxes.append(boxes_per_class)
440+
image_scores.append(scores_per_class)
441+
image_labels.append(labels_per_class)
442+
443+
detections.append({
444+
'boxes': torch.cat(image_boxes, dim=0),
445+
'scores': torch.cat(image_scores, dim=0),
446+
'labels': torch.cat(image_labels, dim=0),
447+
})
448+
449+
return detections
450+
393451
def forward(self, images, targets=None):
394452
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
395453
"""
@@ -446,19 +504,8 @@ def forward(self, images, targets=None):
446504
losses = self.compute_loss(targets, head_outputs, anchors)
447505
else:
448506
# compute the detections
449-
# TODO: Implement postprocess_detections
450-
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, anchors)
451-
num_images = len(images)
452-
for i in range(num_images):
453-
detections.append(
454-
{
455-
"boxes": boxes[i],
456-
"labels": labels[i],
457-
"scores": scores[i],
458-
}
459-
)
460-
461-
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
507+
detections = self.postprocess_detections(head_outputs['cls_logits'], head_outputs['bbox_regression'], anchors, original_image_sizes)
508+
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
462509

463510
if torch.jit.is_scripting():
464511
if not self._has_warned:

0 commit comments

Comments
 (0)