Skip to content

Commit 003a9f8

Browse files
committed
Implement loss for retinanet heads.
1 parent afb1d58 commit 003a9f8

File tree

4 files changed

+129
-26
lines changed

4 files changed

+129
-26
lines changed

torchvision/models/detection/_utils.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import math
44

55
import torch
6-
from torch.jit.annotations import List, Tuple
6+
from torch.jit.annotations import List, Tuple, Optional
77
from torch import Tensor
88
import torchvision
99

@@ -257,8 +257,7 @@ class Matcher(object):
257257
def __init__(self,
258258
high_threshold,
259259
low_threshold,
260-
allow_low_quality_matches=False,
261-
box_similarity=None):
260+
allow_low_quality_matches=False):
262261
# type: (float, float, bool)
263262
"""
264263
Args:
@@ -280,22 +279,24 @@ def __init__(self,
280279
self.low_threshold = low_threshold
281280
self.allow_low_quality_matches = allow_low_quality_matches
282281

283-
if box_similarity is None:
284-
box_similarity = box_ops.box_iou
285-
self.box_similarity = box_similarity
282+
# if box_similarity is None:
283+
# box_similarity = box_ops.box_iou
284+
# self.box_similarity = box_similarity
286285

287286
def __call__(self, gt_boxes, anchors_per_image):
288287
"""
289288
Args:
290-
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
291-
pairwise quality between M ground-truth elements and N predicted elements.
289+
gt_boxes (Tensor[float]): an Mx4 tensor, containing M detections.
290+
291+
anchors_per_image (Tensor[float]): an Mx4 tensor, containing
292+
the anchors for a specific image.
292293
293294
Returns:
294295
matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
295296
[0, M - 1] or a negative value indicating that prediction i could not
296297
be matched.
297298
"""
298-
match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
299+
match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image) # self.box_similarity(gt_boxes, anchors_per_image)
299300

300301
if match_quality_matrix.numel() == 0:
301302
# empty targets or proposals not supported during training

torchvision/models/detection/anchor_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
22
import torch
3+
import torchvision
34
from torch import nn
45

6+
from torch.jit.annotations import List, Optional, Dict
7+
58

69
class AnchorGenerator(nn.Module):
710
__annotations__ = {

torchvision/models/detection/backbone_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ class BackboneWithFPN(nn.Module):
2525
Attributes:
2626
out_channels (int): the number of channels in the FPN
2727
"""
28-
def __init__(self, backbone, return_layers, in_channels_list, out_channels):
2928
def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=LastLevelMaxPool()):
3029
super(BackboneWithFPN, self).__init__()
3130
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)

torchvision/models/detection/retinanet.py

Lines changed: 116 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import math
12
from collections import OrderedDict
23

34
import torch
45
from torch import nn
56
import torch.nn.functional as F
7+
from torch.jit.annotations import Dict, List, Tuple
68

79
from ..utils import load_state_dict_from_url
810

@@ -36,13 +38,62 @@ def __init__(self, in_channels, num_anchors, num_classes):
3638
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
3739
return {
3840
'classification': self.classification_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
39-
'bbox_reg': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
41+
'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
4042
}
4143

4244
def forward(self, x):
43-
logits = [self.classification_head(feature) for feature in x]
45+
cls_logits = [self.classification_head(feature) for feature in x]
4446
bbox_reg = [self.regression_head(feature) for feature in x]
45-
return dict(logits=logits, bbox_reg=bbox_reg)
47+
return dict(cls_logits=cls_logits, bbox_reg=bbox_reg)
48+
49+
50+
def sigmoid_focal_loss(
51+
inputs,
52+
targets,
53+
alpha: float = 0.25,
54+
gamma: float = 2,
55+
reduction: str = "none",
56+
):
57+
"""
58+
Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
59+
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
60+
Args:
61+
inputs: A float tensor of arbitrary shape.
62+
The predictions for each example.
63+
targets: A float tensor with the same shape as inputs. Stores the binary
64+
classification label for each element in inputs
65+
(0 for the negative class and 1 for the positive class).
66+
alpha: (optional) Weighting factor in range (0,1) to balance
67+
positive vs negative examples or -1 for ignore. Default = 0.25
68+
gamma: Exponent of the modulating factor (1 - p_t) to
69+
balance easy vs hard examples.
70+
reduction: 'none' | 'mean' | 'sum'
71+
'none': No reduction will be applied to the output.
72+
'mean': The output will be averaged.
73+
'sum': The output will be summed.
74+
Returns:
75+
Loss tensor with the reduction option applied.
76+
"""
77+
p = torch.sigmoid(inputs)
78+
ce_loss = F.binary_cross_entropy_with_logits(
79+
inputs, targets, reduction="none"
80+
)
81+
p_t = p * targets + (1 - p) * (1 - targets)
82+
loss = ce_loss * ((1 - p_t) ** gamma)
83+
84+
if alpha >= 0:
85+
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
86+
loss = alpha_t * loss
87+
88+
if reduction == "mean":
89+
loss = loss.mean()
90+
elif reduction == "sum":
91+
loss = loss.sum()
92+
93+
return loss
94+
95+
96+
sigmoid_focal_loss_jit = torch.jit.script(sigmoid_focal_loss)
4697

4798

4899
class RetinaNetClassificationHead(nn.Module):
@@ -55,21 +106,59 @@ class RetinaNetClassificationHead(nn.Module):
55106
num_classes (int): number of classes to be predicted
56107
"""
57108

58-
def __init__(self, in_channels, num_anchors, num_classes):
109+
def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01):
59110
super(RetinaNetClassificationHead, self).__init__()
60111
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
61112
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
62113
self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
63114
self.conv4 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
64-
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1)
65115

66116
for l in self.children():
67117
torch.nn.init.normal_(l.weight, std=0.01)
68118
torch.nn.init.constant_(l.bias, 0)
69119

120+
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
121+
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
122+
torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
123+
124+
self.num_classes = num_classes
125+
self.num_anchors = num_anchors
126+
70127
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
71-
# TODO Implement focal loss, is there an existing function for this?
72-
return 0
128+
loss = []
129+
130+
def permute_classification(tensor):
131+
""" Permute classification output from (N, A * K, H, W) to (N, HWA, K). """
132+
N, _, H, W = tensor.shape
133+
tensor = tensor.view(N, -1, self.num_classes, H, W)
134+
tensor = tensor.permute(0, 3, 4, 1, 2)
135+
tensor = tensor.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
136+
return tensor
137+
138+
predicted_classification = head_outputs['cls_logits']
139+
predicted_classification = [permute_classification(cls) for cls in predicted_classification]
140+
predicted_classification = torch.cat(predicted_classification, dim=1)
141+
142+
for targets_per_image, predicted_classification_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, predicted_classification, anchors, matched_idxs):
143+
# determine only the foreground
144+
foreground_idxs_per_image = matched_idxs_per_image >= 0
145+
num_foreground = foreground_idxs_per_image.sum()
146+
147+
# create the target classification
148+
gt_classes_target = torch.zeros_like(predicted_classification_per_image)
149+
gt_classes_target[foreground_idxs_per_image, targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]] = 1
150+
151+
# find indices for which anchors should be ignored
152+
valid_idxs_per_image = matched_idxs_per_image != det_utils.Matcher.BETWEEN_THRESHOLDS
153+
154+
# compute the classification loss
155+
loss.append(sigmoid_focal_loss_jit(
156+
predicted_classification_per_image[valid_idxs_per_image],
157+
gt_classes_target[valid_idxs_per_image],
158+
reduction='sum',
159+
) / max(1, num_foreground))
160+
161+
return sum(loss) / len(loss)
73162

74163
def forward(self, x):
75164
x = F.relu(self.conv1(x))
@@ -94,18 +183,29 @@ def __init__(self, in_channels, num_anchors):
94183
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
95184
self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
96185
self.conv4 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
97-
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1)
186+
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
98187

99188
for l in self.children():
100189
torch.nn.init.normal_(l.weight, std=0.01)
101-
torch.nn.init.constant_(l.bias, 0)
190+
torch.nn.init.zeros_(l.bias)
102191

103192
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
104193

105194
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
106195
loss = []
107196

108-
predicted_regression = head_outputs['bbox_reg'][0]
197+
def permute_bbox_reg(tensor):
198+
""" Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). """
199+
N, _, H, W = tensor.shape
200+
tensor = tensor.view(N, -1, 4, H, W)
201+
tensor = tensor.permute(0, 3, 4, 1, 2)
202+
tensor = tensor.reshape(N, -1, 4) # Size=(N, HWA, 4)
203+
return tensor
204+
205+
predicted_regression = head_outputs['bbox_reg']
206+
predicted_regression = [permute_bbox_reg(reg) for reg in predicted_regression]
207+
predicted_regression = torch.cat(predicted_regression, dim=1)
208+
109209
for targets_per_image, predicted_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, predicted_regression, anchors, matched_idxs):
110210
# get the targets corresponding GT for each proposal
111211
# NB: need to clamp the indices because we can have a single
@@ -115,20 +215,20 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
115215

116216
# determine only the foreground indices, ignore the rest
117217
foreground_idxs_per_image = matched_idxs_per_image >= 0
218+
num_foreground = foreground_idxs_per_image.sum()
118219

119220
# select only the foreground boxes
120221
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :]
121-
print(predicted_regression_per_image.shape)
122-
predicted_regression_per_image = predicted_regression_per_image['bbox_reg'][foreground_idxs_per_image, :]
222+
predicted_regression_per_image = predicted_regression_per_image[foreground_idxs_per_image, :]
123223
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
124224

125225
# compute the regression targets
126-
target_regression = self.box_coder.encode(matched_gt_boxes_per_image, anchors_per_image)
226+
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
127227

128228
# compute the loss
129-
loss.append(torch.nn.SmoothL1Loss()(predicted_regression_per_image, target_regression))
229+
loss.append(torch.nn.SmoothL1Loss(reduction='sum')(predicted_regression_per_image, target_regression) / max(1, num_foreground))
130230

131-
return sum(loss) / len(loss)
231+
return sum(loss) / max(1, len(loss))
132232

133233
def forward(self, x):
134234
x = F.relu(self.conv1(x))
@@ -251,7 +351,7 @@ def __init__(self, backbone, num_classes,
251351
self.anchor_generator = anchor_generator
252352

253353
if head is None:
254-
head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location()[0])
354+
head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
255355
self.head = head
256356

257357
if proposal_matcher is None:

0 commit comments

Comments
 (0)