Skip to content

Commit 0de780d

Browse files
committed
Output reshaped outputs from retinanet heads.
1 parent 003a9f8 commit 0de780d

File tree

1 file changed

+49
-42
lines changed

1 file changed

+49
-42
lines changed

torchvision/models/detection/retinanet.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
4242
}
4343

4444
def forward(self, x):
45-
cls_logits = [self.classification_head(feature) for feature in x]
46-
bbox_reg = [self.regression_head(feature) for feature in x]
47-
return dict(cls_logits=cls_logits, bbox_reg=bbox_reg)
45+
return {
46+
'cls_logits': self.classification_head(x),
47+
'bbox_regression': self.regression_head(x)
48+
}
4849

4950

5051
def sigmoid_focal_loss(
@@ -127,45 +128,48 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01
127128
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
128129
loss = []
129130

130-
def permute_classification(tensor):
131-
""" Permute classification output from (N, A * K, H, W) to (N, HWA, K). """
132-
N, _, H, W = tensor.shape
133-
tensor = tensor.view(N, -1, self.num_classes, H, W)
134-
tensor = tensor.permute(0, 3, 4, 1, 2)
135-
tensor = tensor.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
136-
return tensor
137-
138-
predicted_classification = head_outputs['cls_logits']
139-
predicted_classification = [permute_classification(cls) for cls in predicted_classification]
140-
predicted_classification = torch.cat(predicted_classification, dim=1)
131+
cls_logits = head_outputs['cls_logits']
141132

142-
for targets_per_image, predicted_classification_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, predicted_classification, anchors, matched_idxs):
133+
for targets_per_image, cls_logits_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, cls_logits, anchors, matched_idxs):
143134
# determine only the foreground
144135
foreground_idxs_per_image = matched_idxs_per_image >= 0
145136
num_foreground = foreground_idxs_per_image.sum()
146137

147138
# create the target classification
148-
gt_classes_target = torch.zeros_like(predicted_classification_per_image)
139+
gt_classes_target = torch.zeros_like(cls_logits_per_image)
149140
gt_classes_target[foreground_idxs_per_image, targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]] = 1
150141

151142
# find indices for which anchors should be ignored
152143
valid_idxs_per_image = matched_idxs_per_image != det_utils.Matcher.BETWEEN_THRESHOLDS
153144

154145
# compute the classification loss
155-
loss.append(sigmoid_focal_loss_jit(
156-
predicted_classification_per_image[valid_idxs_per_image],
146+
loss.append(sigmoid_focal_loss(
147+
cls_logits_per_image[valid_idxs_per_image],
157148
gt_classes_target[valid_idxs_per_image],
158149
reduction='sum',
159150
) / max(1, num_foreground))
160151

161152
return sum(loss) / len(loss)
162153

163154
def forward(self, x):
164-
x = F.relu(self.conv1(x))
165-
x = F.relu(self.conv2(x))
166-
x = F.relu(self.conv3(x))
167-
x = F.relu(self.conv4(x))
168-
return self.cls_logits(x)
155+
all_cls_logits = []
156+
157+
for features in x:
158+
cls_logits = F.relu(self.conv1(features))
159+
cls_logits = F.relu(self.conv2(cls_logits))
160+
cls_logits = F.relu(self.conv3(cls_logits))
161+
cls_logits = F.relu(self.conv4(cls_logits))
162+
cls_logits = self.cls_logits(cls_logits)
163+
164+
# Permute classification output from (N, A * K, H, W) to (N, HWA, K).
165+
N, _, H, W = cls_logits.shape
166+
cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
167+
cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
168+
cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
169+
170+
all_cls_logits.append(cls_logits)
171+
172+
return torch.cat(all_cls_logits, dim=1)
169173

170174

171175
class RetinaNetRegressionHead(nn.Module):
@@ -194,19 +198,9 @@ def __init__(self, in_channels, num_anchors):
194198
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
195199
loss = []
196200

197-
def permute_bbox_reg(tensor):
198-
""" Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). """
199-
N, _, H, W = tensor.shape
200-
tensor = tensor.view(N, -1, 4, H, W)
201-
tensor = tensor.permute(0, 3, 4, 1, 2)
202-
tensor = tensor.reshape(N, -1, 4) # Size=(N, HWA, 4)
203-
return tensor
204-
205-
predicted_regression = head_outputs['bbox_reg']
206-
predicted_regression = [permute_bbox_reg(reg) for reg in predicted_regression]
207-
predicted_regression = torch.cat(predicted_regression, dim=1)
201+
bbox_regression = head_outputs['bbox_regression']
208202

209-
for targets_per_image, predicted_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, predicted_regression, anchors, matched_idxs):
203+
for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(targets, bbox_regression, anchors, matched_idxs):
210204
# get the targets corresponding GT for each proposal
211205
# NB: need to clamp the indices because we can have a single
212206
# GT in the image, and matched_idxs can be -2, which goes
@@ -219,23 +213,36 @@ def permute_bbox_reg(tensor):
219213

220214
# select only the foreground boxes
221215
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :]
222-
predicted_regression_per_image = predicted_regression_per_image[foreground_idxs_per_image, :]
216+
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
223217
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
224218

225219
# compute the regression targets
226220
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
227221

228222
# compute the loss
229-
loss.append(torch.nn.SmoothL1Loss(reduction='sum')(predicted_regression_per_image, target_regression) / max(1, num_foreground))
223+
loss.append(torch.nn.SmoothL1Loss(reduction='sum')(bbox_regression_per_image, target_regression) / max(1, num_foreground))
230224

231225
return sum(loss) / max(1, len(loss))
232226

233227
def forward(self, x):
234-
x = F.relu(self.conv1(x))
235-
x = F.relu(self.conv2(x))
236-
x = F.relu(self.conv3(x))
237-
x = F.relu(self.conv4(x))
238-
return self.bbox_reg(x)
228+
all_bbox_regression = []
229+
230+
for features in x:
231+
bbox_regression = F.relu(self.conv1(features))
232+
bbox_regression = F.relu(self.conv2(bbox_regression))
233+
bbox_regression = F.relu(self.conv3(bbox_regression))
234+
bbox_regression = F.relu(self.conv4(bbox_regression))
235+
bbox_regression = self.bbox_reg(bbox_regression)
236+
237+
# Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
238+
N, _, H, W = bbox_regression.shape
239+
bbox_regression = bbox_regression.view(N, -1, 4, H, W)
240+
bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
241+
bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
242+
243+
all_bbox_regression.append(bbox_regression)
244+
245+
return torch.cat(all_bbox_regression, dim=1)
239246

240247

241248
class RetinaNet(nn.Module):

0 commit comments

Comments
 (0)