3
3
import warnings
4
4
5
5
import torch
6
- from torch import nn
7
- import torch . nn . functional as F
6
+ import torch . nn as nn
7
+ from torch import Tensor
8
8
from torch .jit .annotations import Dict , List , Tuple
9
9
10
10
from ..utils import load_state_dict_from_url
@@ -39,7 +39,7 @@ def __init__(self, in_channels, num_anchors, num_classes):
39
39
self .regression_head = RetinaNetRegressionHead (in_channels , num_anchors )
40
40
41
41
def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
42
- # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
42
+ # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
43
43
return {
44
44
'classification' : self .classification_head .compute_loss (targets , head_outputs , matched_idxs ),
45
45
'bbox_regression' : self .regression_head .compute_loss (targets , head_outputs , anchors , matched_idxs ),
@@ -84,9 +84,14 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01
84
84
self .num_classes = num_classes
85
85
self .num_anchors = num_anchors
86
86
87
+ # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
88
+ # TorchScript doesn't support class attributes.
89
+ # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
90
+ self .BETWEEN_THRESHOLDS = det_utils .Matcher .BETWEEN_THRESHOLDS
91
+
87
92
def compute_loss (self , targets , head_outputs , matched_idxs ):
88
93
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
89
- loss = []
94
+ loss = torch . tensor ( 0.0 )
90
95
91
96
cls_logits = head_outputs ['cls_logits' ]
92
97
@@ -95,7 +100,7 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
95
100
if matched_idxs_per_image .numel () == 0 :
96
101
gt_classes_target = torch .zeros_like (cls_logits_per_image )
97
102
valid_idxs_per_image = torch .arange (cls_logits_per_image .shape [0 ])
98
- num_foreground = 0
103
+ num_foreground = torch . tensor ( 0.0 )
99
104
else :
100
105
# determine only the foreground
101
106
foreground_idxs_per_image = matched_idxs_per_image >= 0
@@ -109,16 +114,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
109
114
] = torch .tensor (1.0 )
110
115
111
116
# find indices for which anchors should be ignored
112
- valid_idxs_per_image = matched_idxs_per_image != det_utils . Matcher .BETWEEN_THRESHOLDS
117
+ valid_idxs_per_image = matched_idxs_per_image != self .BETWEEN_THRESHOLDS
113
118
114
119
# compute the classification loss
115
- loss . append ( sigmoid_focal_loss (
120
+ loss += sigmoid_focal_loss (
116
121
cls_logits_per_image [valid_idxs_per_image ],
117
122
gt_classes_target [valid_idxs_per_image ],
118
123
reduction = 'sum' ,
119
- ) / max (1 , num_foreground ))
124
+ ) / max (1 , num_foreground )
120
125
121
- return sum ( loss ) / len (loss )
126
+ return loss / len (targets )
122
127
123
128
def forward (self , x ):
124
129
# type: (List[Tensor]) -> Tensor
@@ -170,7 +175,7 @@ def __init__(self, in_channels, num_anchors):
170
175
171
176
def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
172
177
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
173
- loss = []
178
+ loss = torch . tensor ( 0.0 )
174
179
175
180
bbox_regression = head_outputs ['bbox_regression' ]
176
181
@@ -200,15 +205,13 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
200
205
target_regression = self .box_coder .encode_single (matched_gt_boxes_per_image , anchors_per_image )
201
206
202
207
# compute the loss
203
- loss .append (
204
- det_utils .smooth_l1_loss (
205
- bbox_regression_per_image ,
206
- target_regression ,
207
- size_average = False
208
- ) / max (1 , num_foreground )
209
- )
208
+ loss += det_utils .smooth_l1_loss (
209
+ bbox_regression_per_image ,
210
+ target_regression ,
211
+ size_average = False
212
+ ) / max (1 , num_foreground )
210
213
211
- return sum ( loss ) / max (1 , len (loss ))
214
+ return loss / max (1 , len (targets ))
212
215
213
216
def forward (self , x ):
214
217
# type: (List[Tensor]) -> Tensor
@@ -379,7 +382,7 @@ def eager_outputs(self, losses, detections):
379
382
return detections
380
383
381
384
def compute_loss (self , targets , head_outputs , anchors ):
382
- # type: (List[Dict[str, Tensor]], List[ Tensor], List[Tensor]) -> Dict[str, Tensor]
385
+ # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
383
386
matched_idxs = []
384
387
for anchors_per_image , targets_per_image in zip (anchors , targets ):
385
388
if targets_per_image ['boxes' ].numel () == 0 :
@@ -392,7 +395,7 @@ def compute_loss(self, targets, head_outputs, anchors):
392
395
return self .head .compute_loss (targets , head_outputs , anchors , matched_idxs )
393
396
394
397
def postprocess_detections (self , head_outputs , anchors , image_shapes ):
395
- # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> Dict[str, Tensor]
398
+ # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[ Dict[str, Tensor] ]
396
399
# TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
397
400
398
401
class_logits = head_outputs .pop ('cls_logits' )
@@ -408,7 +411,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
408
411
labels = torch .arange (num_classes , device = device )
409
412
labels = labels .view (1 , - 1 ).expand_as (scores )
410
413
411
- detections = []
414
+ detections = torch . jit . annotate ( List [ Dict [ str , Tensor ]], [])
412
415
413
416
for index , (box_regression_per_image , scores_per_image , labels_per_image , anchors_per_image , image_shape ) in \
414
417
enumerate (zip (box_regression , scores , labels , anchors , image_shapes )):
@@ -421,7 +424,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
421
424
image_boxes = []
422
425
image_scores = []
423
426
image_labels = []
424
- image_other_outputs = { k : [] for k in other_outputs . keys ()}
427
+ image_other_outputs = torch . jit . annotate ( Dict [ str , List [ Tensor ]], {})
425
428
426
429
for class_index in range (num_classes ):
427
430
# remove low scoring boxes
@@ -450,6 +453,8 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
450
453
image_labels .append (labels_per_class )
451
454
452
455
for k , v in other_outputs_per_class :
456
+ if k not in image_other_outputs :
457
+ image_other_outputs [k ] = []
453
458
image_other_outputs [k ].append (v )
454
459
455
460
detections .append ({
@@ -510,7 +515,7 @@ def forward(self, images, targets=None):
510
515
anchors = self .anchor_generator (images , features )
511
516
512
517
losses = {}
513
- detections = torch .jit .annotate (List [Dict [str , torch . Tensor ]], [])
518
+ detections = torch .jit .annotate (List [Dict [str , Tensor ]], [])
514
519
if self .training :
515
520
assert targets is not None
516
521
0 commit comments