Skip to content

Commit afb1d58

Browse files
committed
General fixes for retinanet model.
1 parent 7ee068a commit afb1d58

File tree

1 file changed

+77
-34
lines changed

1 file changed

+77
-34
lines changed

torchvision/models/detection/retinanet.py

Lines changed: 77 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,19 @@ class RetinaNetHead(nn.Module):
2929
"""
3030

3131
def __init__(self, in_channels, num_anchors, num_classes):
32-
super(RPNHead, self).__init__()
32+
super(RetinaNetHead, self).__init__()
3333
self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes)
3434
self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors)
3535

36-
def compute_loss(self, outputs, labels, matched_gt_boxes):
36+
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
3737
return {
38-
'classification': self.classification_head.compute_loss(outputs, targets, anchor_state),
39-
'regression': self.regression_head.compute_loss(outputs, targets, anchor_state),
38+
'classification': self.classification_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
39+
'bbox_reg': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
4040
}
4141

4242
def forward(self, x):
43-
logits = [self.classification_head(feature, targets) for feature in x]
44-
bbox_reg = [self.regression_head(feature, targets) for feature in x]
43+
logits = [self.classification_head(feature) for feature in x]
44+
bbox_reg = [self.regression_head(feature) for feature in x]
4545
return dict(logits=logits, bbox_reg=bbox_reg)
4646

4747

@@ -56,7 +56,7 @@ class RetinaNetClassificationHead(nn.Module):
5656
"""
5757

5858
def __init__(self, in_channels, num_anchors, num_classes):
59-
super(RPNHead, self).__init__()
59+
super(RetinaNetClassificationHead, self).__init__()
6060
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
6161
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
6262
self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
@@ -67,19 +67,16 @@ def __init__(self, in_channels, num_anchors, num_classes):
6767
torch.nn.init.normal_(l.weight, std=0.01)
6868
torch.nn.init.constant_(l.bias, 0)
6969

70-
def compute_loss(self, outputs, labels, matched_gt_boxes):
70+
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
7171
# TODO Implement focal loss, is there an existing function for this?
7272
return 0
7373

7474
def forward(self, x):
75-
logits = []
76-
for feature in x:
77-
t = F.relu(self.conv1(feature))
78-
t = F.relu(self.conv2(t))
79-
t = F.relu(self.conv3(t))
80-
t = F.relu(self.conv4(t))
81-
logits.append(self.cls_logits(t))
82-
return logits
75+
x = F.relu(self.conv1(x))
76+
x = F.relu(self.conv2(x))
77+
x = F.relu(self.conv3(x))
78+
x = F.relu(self.conv4(x))
79+
return self.cls_logits(x)
8380

8481

8582
class RetinaNetRegressionHead(nn.Module):
@@ -92,7 +89,7 @@ class RetinaNetRegressionHead(nn.Module):
9289
"""
9390

9491
def __init__(self, in_channels, num_anchors):
95-
super(RPNHead, self).__init__()
92+
super(RetinaNetRegressionHead, self).__init__()
9693
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
9794
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
9895
self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
@@ -103,19 +100,42 @@ def __init__(self, in_channels, num_anchors):
103100
torch.nn.init.normal_(l.weight, std=0.01)
104101
torch.nn.init.constant_(l.bias, 0)
105102

106-
def compute_loss(self, outputs, labels, matched_gt_boxes):
107-
# TODO Use SmoothL1 loss for regression, or just L1 like in rpn.py ?
108-
return 0
103+
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
104+
105+
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
106+
loss = []
107+
108+
predicted_regression = head_outputs['bbox_reg'][0]
109+
for targets_per_image, predicted_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, predicted_regression, anchors, matched_idxs):
110+
# get the targets corresponding GT for each proposal
111+
# NB: need to clamp the indices because we can have a single
112+
# GT in the image, and matched_idxs can be -2, which goes
113+
# out of bounds
114+
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)]
115+
116+
# determine only the foreground indices, ignore the rest
117+
foreground_idxs_per_image = matched_idxs_per_image >= 0
118+
119+
# select only the foreground boxes
120+
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :]
121+
print(predicted_regression_per_image.shape)
122+
predicted_regression_per_image = predicted_regression_per_image['bbox_reg'][foreground_idxs_per_image, :]
123+
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
124+
125+
# compute the regression targets
126+
target_regression = self.box_coder.encode(matched_gt_boxes_per_image, anchors_per_image)
127+
128+
# compute the loss
129+
loss.append(torch.nn.SmoothL1Loss()(predicted_regression_per_image, target_regression))
130+
131+
return sum(loss) / len(loss)
109132

110133
def forward(self, x):
111-
bbox_reg = []
112-
for feature in x:
113-
t = F.relu(self.conv1(feature))
114-
t = F.relu(self.conv2(t))
115-
t = F.relu(self.conv3(t))
116-
t = F.relu(self.conv4(t))
117-
bbox_reg.append(self.bbox_reg(t))
118-
return bbox_reg
134+
x = F.relu(self.conv1(x))
135+
x = F.relu(self.conv2(x))
136+
x = F.relu(self.conv3(x))
137+
x = F.relu(self.conv4(x))
138+
return self.bbox_reg(x)
119139

120140

121141
class RetinaNet(nn.Module):
@@ -206,15 +226,18 @@ def __init__(self, backbone, num_classes,
206226
image_mean=None, image_std=None,
207227
# Anchor parameters
208228
anchor_generator=None, head=None,
229+
proposal_matcher=None,
209230
pre_nms_top_n=1000, post_nms_top_n=1000,
210231
nms_thresh=0.5,
211232
fg_iou_thresh=0.5, bg_iou_thresh=0.4):
233+
super(RetinaNet, self).__init__()
212234

213235
if not hasattr(backbone, "out_channels"):
214236
raise ValueError(
215237
"backbone should contain an attribute out_channels "
216238
"specifying the number of output channels (assumed to be the "
217239
"same for all the levels)")
240+
self.backbone = backbone
218241

219242
assert isinstance(anchor_generator, (AnchorGenerator, type(None)))
220243

@@ -231,6 +254,14 @@ def __init__(self, backbone, num_classes,
231254
head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location()[0])
232255
self.head = head
233256

257+
if proposal_matcher is None:
258+
proposal_matcher = det_utils.Matcher(
259+
fg_iou_thresh,
260+
bg_iou_thresh,
261+
allow_low_quality_matches=True,
262+
)
263+
self.proposal_matcher = proposal_matcher
264+
234265
if image_mean is None:
235266
image_mean = [0.485, 0.456, 0.406]
236267
if image_std is None:
@@ -245,6 +276,13 @@ def eager_outputs(self, losses, detections):
245276

246277
return detections
247278

279+
def compute_loss(self, targets, head_outputs, anchors):
280+
matched_idxs = []
281+
for anchors_per_image, targets_per_image in zip(anchors, targets):
282+
matched_idxs.append(self.proposal_matcher(targets_per_image["boxes"], anchors_per_image))
283+
284+
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
285+
248286
def forward(self, images, targets=None):
249287
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
250288
"""
@@ -277,8 +315,17 @@ def forward(self, images, targets=None):
277315
if isinstance(features, torch.Tensor):
278316
features = OrderedDict([('0', features)])
279317

318+
# TODO: Do we want a list or a dict?
319+
features = list(features.values())
320+
321+
assert len(features) == 1 or len(features) == 6
322+
323+
if len(features) == 6:
324+
# skip P2 because it generates too many anchors
325+
features = features[1:]
326+
280327
# compute the retinanet heads outputs using the features
281-
head_outputs = self.head(images, features, targets)
328+
head_outputs = self.head(features)
282329

283330
# create the set of anchors
284331
anchors = self.anchor_generator(images, features)
@@ -289,11 +336,7 @@ def forward(self, images, targets=None):
289336
assert targets is not None
290337

291338
# compute the losses
292-
# TODO: Move necessary functions out of rpn.RegionProposalNetwork to a class or function
293-
# so that we can use it here and in rpn.RegionProposalNetwork
294-
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
295-
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
296-
losses = self.head.compute_loss(head_outputs, labels, matched_gt_boxes)
339+
losses = self.compute_loss(targets, head_outputs, anchors)
297340
else:
298341
# compute the detections
299342
# TODO: Implement postprocess_detections

0 commit comments

Comments
 (0)