diff --git a/.gitignore b/.gitignore index 4ed0749da06..c332dda4f4a 100644 --- a/.gitignore +++ b/.gitignore @@ -22,5 +22,6 @@ htmlcov gen.yml .mypy_cache .vscode/ +.idea/ *.orig *-checkpoint.ipynb \ No newline at end of file diff --git a/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl index 5b5079f20f8..548b0a22e1c 100644 Binary files a/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl and b/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl differ diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 770c5dcb1b0..d128ecb5699 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -291,6 +291,7 @@ class RetinaNet(nn.Module): considered as positive during training. bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be considered as negative during training. + topk_candidates (int): Number of best detections to keep before NMS. Example: @@ -339,7 +340,8 @@ def __init__(self, backbone, num_classes, score_thresh=0.05, nms_thresh=0.5, detections_per_img=300, - fg_iou_thresh=0.5, bg_iou_thresh=0.4): + fg_iou_thresh=0.5, bg_iou_thresh=0.4, + topk_candidates=1000): super().__init__() if not hasattr(backbone, "out_channels"): @@ -382,6 +384,7 @@ def __init__(self, backbone, num_classes, self.score_thresh = score_thresh self.nms_thresh = nms_thresh self.detections_per_img = detections_per_img + self.topk_candidates = topk_candidates # used only on torchscript mode self._has_warned = False @@ -408,77 +411,63 @@ def compute_loss(self, targets, head_outputs, anchors): return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) def postprocess_detections(self, head_outputs, anchors, image_shapes): - # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] - # TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ? + # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] + class_logits = head_outputs['cls_logits'] + box_regression = head_outputs['bbox_regression'] - class_logits = head_outputs.pop('cls_logits') - box_regression = head_outputs.pop('bbox_regression') - other_outputs = head_outputs - - device = class_logits.device - num_classes = class_logits.shape[-1] - - scores = torch.sigmoid(class_logits) - - # create labels for each score - labels = torch.arange(num_classes, device=device) - labels = labels.view(1, -1).expand_as(scores) + num_images = len(image_shapes) detections = torch.jit.annotate(List[Dict[str, Tensor]], []) - for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \ - enumerate(zip(box_regression, scores, labels, anchors, image_shapes)): - - boxes_per_image = self.box_coder.decode_single(box_regression_per_image, anchors_per_image) - boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape) - - other_outputs_per_image = [(k, v[index]) for k, v in other_outputs.items()] + for index in range(num_images): + box_regression_per_image = [br[index] for br in box_regression] + logits_per_image = [cl[index] for cl in class_logits] + anchors_per_image, image_shape = anchors[index], image_shapes[index] image_boxes = [] image_scores = [] image_labels = [] - image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]], {}) - for class_index in range(num_classes): + for box_regression_per_level, logits_per_level, anchors_per_level in \ + zip(box_regression_per_image, logits_per_image, anchors_per_image): + num_classes = logits_per_level.shape[-1] + # remove low scoring boxes - inds = torch.gt(scores_per_image[:, class_index], self.score_thresh) - boxes_per_class, scores_per_class, labels_per_class = \ - boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index] - other_outputs_per_class = [(k, v[inds]) for k, v in other_outputs_per_image] + scores_per_level = torch.sigmoid(logits_per_level).flatten() + keep_idxs = scores_per_level > self.score_thresh + scores_per_level = scores_per_level[keep_idxs] + topk_idxs = torch.where(keep_idxs)[0] - # remove empty boxes - keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2) - boxes_per_class, scores_per_class, labels_per_class = \ - boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] - other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class] + # keep only topk scoring predictions + num_topk = min(self.topk_candidates, topk_idxs.size(0)) + scores_per_level, idxs = scores_per_level.topk(num_topk) + topk_idxs = topk_idxs[idxs] - # non-maximum suppression, independently done per class - keep = box_ops.nms(boxes_per_class, scores_per_class, self.nms_thresh) + anchor_idxs = topk_idxs // num_classes + labels_per_level = topk_idxs % num_classes - # keep only topk scoring predictions - keep = keep[:self.detections_per_img] - boxes_per_class, scores_per_class, labels_per_class = \ - boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep] - other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class] + boxes_per_level = self.box_coder.decode_single(box_regression_per_level[anchor_idxs], + anchors_per_level[anchor_idxs]) + boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape) + + image_boxes.append(boxes_per_level) + image_scores.append(scores_per_level) + image_labels.append(labels_per_level) - image_boxes.append(boxes_per_class) - image_scores.append(scores_per_class) - image_labels.append(labels_per_class) + image_boxes = torch.cat(image_boxes, dim=0) + image_scores = torch.cat(image_scores, dim=0) + image_labels = torch.cat(image_labels, dim=0) - for k, v in other_outputs_per_class: - if k not in image_other_outputs: - image_other_outputs[k] = [] - image_other_outputs[k].append(v) + # non-maximum suppression + keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) + keep = keep[:self.detections_per_img] detections.append({ - 'boxes': torch.cat(image_boxes, dim=0), - 'scores': torch.cat(image_scores, dim=0), - 'labels': torch.cat(image_labels, dim=0), + 'boxes': image_boxes[keep], + 'scores': image_scores[keep], + 'labels': image_labels[keep], }) - for k, v in image_other_outputs.items(): - detections[-1].update({k: torch.cat(v, dim=0)}) - return detections def forward(self, images, targets=None): @@ -557,8 +546,23 @@ def forward(self, images, targets=None): # compute the losses losses = self.compute_loss(targets, head_outputs, anchors) else: + # recover level sizes + num_anchors_per_level = [x.size(2) * x.size(3) for x in features] + HW = 0 + for v in num_anchors_per_level: + HW += v + HWA = head_outputs['cls_logits'].size(1) + A = HWA // HW + num_anchors_per_level = [hw * A for hw in num_anchors_per_level] + + # split outputs per level + split_head_outputs: Dict[str, List[Tensor]] = {} + for k in head_outputs: + split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1)) + split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors] + # compute the detections - detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes) + detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes) detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) if torch.jit.is_scripting():