Skip to content

Commit 3dd9e4b

Browse files
author
Hans Gaiser
committed
Fix type annotations for TorchScript.
1 parent 5d082db commit 3dd9e4b

File tree

1 file changed

+28
-23
lines changed

1 file changed

+28
-23
lines changed

torchvision/models/detection/retinanet.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import warnings
44

55
import torch
6-
from torch import nn
7-
import torch.nn.functional as F
6+
import torch.nn as nn
7+
from torch import Tensor
88
from torch.jit.annotations import Dict, List, Tuple
99

1010
from ..utils import load_state_dict_from_url
@@ -39,7 +39,7 @@ def __init__(self, in_channels, num_anchors, num_classes):
3939
self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors)
4040

4141
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
42-
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
42+
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
4343
return {
4444
'classification': self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
4545
'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
@@ -84,9 +84,14 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01
8484
self.num_classes = num_classes
8585
self.num_anchors = num_anchors
8686

87+
# This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
88+
# TorchScript doesn't support class attributes.
89+
# https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
90+
self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS
91+
8792
def compute_loss(self, targets, head_outputs, matched_idxs):
8893
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
89-
loss = []
94+
loss = torch.tensor(0.0)
9095

9196
cls_logits = head_outputs['cls_logits']
9297

@@ -95,7 +100,7 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
95100
if matched_idxs_per_image.numel() == 0:
96101
gt_classes_target = torch.zeros_like(cls_logits_per_image)
97102
valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0])
98-
num_foreground = 0
103+
num_foreground = torch.tensor(0.0)
99104
else:
100105
# determine only the foreground
101106
foreground_idxs_per_image = matched_idxs_per_image >= 0
@@ -109,16 +114,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
109114
] = torch.tensor(1.0)
110115

111116
# find indices for which anchors should be ignored
112-
valid_idxs_per_image = matched_idxs_per_image != det_utils.Matcher.BETWEEN_THRESHOLDS
117+
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
113118

114119
# compute the classification loss
115-
loss.append(sigmoid_focal_loss(
120+
loss += sigmoid_focal_loss(
116121
cls_logits_per_image[valid_idxs_per_image],
117122
gt_classes_target[valid_idxs_per_image],
118123
reduction='sum',
119-
) / max(1, num_foreground))
124+
) / max(1, num_foreground)
120125

121-
return sum(loss) / len(loss)
126+
return loss / len(targets)
122127

123128
def forward(self, x):
124129
# type: (List[Tensor]) -> Tensor
@@ -170,7 +175,7 @@ def __init__(self, in_channels, num_anchors):
170175

171176
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
172177
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
173-
loss = []
178+
loss = torch.tensor(0.0)
174179

175180
bbox_regression = head_outputs['bbox_regression']
176181

@@ -200,15 +205,13 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
200205
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
201206

202207
# compute the loss
203-
loss.append(
204-
det_utils.smooth_l1_loss(
205-
bbox_regression_per_image,
206-
target_regression,
207-
size_average=False
208-
) / max(1, num_foreground)
209-
)
208+
loss += det_utils.smooth_l1_loss(
209+
bbox_regression_per_image,
210+
target_regression,
211+
size_average=False
212+
) / max(1, num_foreground)
210213

211-
return sum(loss) / max(1, len(loss))
214+
return loss / max(1, len(targets))
212215

213216
def forward(self, x):
214217
# type: (List[Tensor]) -> Tensor
@@ -379,7 +382,7 @@ def eager_outputs(self, losses, detections):
379382
return detections
380383

381384
def compute_loss(self, targets, head_outputs, anchors):
382-
# type: (List[Dict[str, Tensor]], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
385+
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
383386
matched_idxs = []
384387
for anchors_per_image, targets_per_image in zip(anchors, targets):
385388
if targets_per_image['boxes'].numel() == 0:
@@ -392,7 +395,7 @@ def compute_loss(self, targets, head_outputs, anchors):
392395
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
393396

394397
def postprocess_detections(self, head_outputs, anchors, image_shapes):
395-
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> Dict[str, Tensor]
398+
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
396399
# TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
397400

398401
class_logits = head_outputs.pop('cls_logits')
@@ -408,7 +411,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
408411
labels = torch.arange(num_classes, device=device)
409412
labels = labels.view(1, -1).expand_as(scores)
410413

411-
detections = []
414+
detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
412415

413416
for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \
414417
enumerate(zip(box_regression, scores, labels, anchors, image_shapes)):
@@ -421,7 +424,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
421424
image_boxes = []
422425
image_scores = []
423426
image_labels = []
424-
image_other_outputs = {k: [] for k in other_outputs.keys()}
427+
image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]], {})
425428

426429
for class_index in range(num_classes):
427430
# remove low scoring boxes
@@ -450,6 +453,8 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
450453
image_labels.append(labels_per_class)
451454

452455
for k, v in other_outputs_per_class:
456+
if k not in image_other_outputs:
457+
image_other_outputs[k] = []
453458
image_other_outputs[k].append(v)
454459

455460
detections.append({
@@ -510,7 +515,7 @@ def forward(self, images, targets=None):
510515
anchors = self.anchor_generator(images, features)
511516

512517
losses = {}
513-
detections = torch.jit.annotate(List[Dict[str, torch.Tensor]], [])
518+
detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
514519
if self.training:
515520
assert targets is not None
516521

0 commit comments

Comments
 (0)