Skip to content

Commit 8303ff9

Browse files
committed
Add rough implementation of RetinaNet.
1 parent 07cbb46 commit 8303ff9

File tree

1 file changed

+371
-0
lines changed

1 file changed

+371
-0
lines changed
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
from collections import OrderedDict
2+
3+
import torch
4+
from torch import nn
5+
import torch.nn.functional as F
6+
7+
from ..utils import load_state_dict_from_url
8+
9+
from .rpn import AnchorGenerator
10+
from .transform import GeneralizedRCNNTransform
11+
from .backbone_utils import resnet_fpn_backbone
12+
13+
14+
__all__ = [
15+
"RetinaNet", "retinanet_resnet50_fpn",
16+
]
17+
18+
19+
class RetinaNetHead(nn.Module):
20+
"""
21+
A regression and classification head for use in RetinaNet.
22+
23+
Arguments:
24+
in_channels (int): number of channels of the input feature
25+
num_anchors (int): number of anchors to be predicted
26+
num_classes (int): number of classes to be predicted
27+
"""
28+
29+
def __init__(self, in_channels, num_anchors, num_classes):
30+
super(RPNHead, self).__init__()
31+
self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes)
32+
self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors)
33+
34+
def compute_loss(self, outputs, labels, matched_gt_boxes):
35+
return {
36+
'classification': self.classification_head.compute_loss(outputs, targets, anchor_state),
37+
'regression': self.regression_head.compute_loss(outputs, targets, anchor_state),
38+
}
39+
40+
def forward(self, x):
41+
logits = [self.classification_head(feature, targets) for feature in x]
42+
bbox_reg = [self.regression_head(feature, targets) for feature in x]
43+
return dict(logits=logits, bbox_reg=bbox_reg)
44+
45+
46+
class RetinaNetClassificationHead(nn.Module):
47+
"""
48+
A classification head for use in RetinaNet.
49+
50+
Arguments:
51+
in_channels (int): number of channels of the input feature
52+
num_anchors (int): number of anchors to be predicted
53+
num_classes (int): number of classes to be predicted
54+
"""
55+
56+
def __init__(self, in_channels, num_anchors, num_classes):
57+
super(RPNHead, self).__init__()
58+
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
59+
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
60+
self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
61+
self.conv4 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
62+
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1)
63+
64+
for l in self.children():
65+
torch.nn.init.normal_(l.weight, std=0.01)
66+
torch.nn.init.constant_(l.bias, 0)
67+
68+
def compute_loss(self, outputs, labels, matched_gt_boxes):
69+
# TODO Implement focal loss, is there an existing function for this?
70+
return 0
71+
72+
def forward(self, x):
73+
logits = []
74+
for feature in x:
75+
t = F.relu(self.conv1(feature))
76+
t = F.relu(self.conv2(t))
77+
t = F.relu(self.conv3(t))
78+
t = F.relu(self.conv4(t))
79+
logits.append(self.cls_logits(t))
80+
return logits
81+
82+
83+
class RetinaNetRegressionHead(nn.Module):
84+
"""
85+
A regression head for use in RetinaNet.
86+
87+
Arguments:
88+
in_channels (int): number of channels of the input feature
89+
num_anchors (int): number of anchors to be predicted
90+
"""
91+
92+
def __init__(self, in_channels, num_anchors):
93+
super(RPNHead, self).__init__()
94+
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
95+
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
96+
self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
97+
self.conv4 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
98+
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1)
99+
100+
for l in self.children():
101+
torch.nn.init.normal_(l.weight, std=0.01)
102+
torch.nn.init.constant_(l.bias, 0)
103+
104+
def compute_loss(self, outputs, labels, matched_gt_boxes):
105+
# TODO Use SmoothL1 loss for regression, or just L1 like in rpn.py ?
106+
return 0
107+
108+
def forward(self, x):
109+
bbox_reg = []
110+
for feature in x:
111+
t = F.relu(self.conv1(feature))
112+
t = F.relu(self.conv2(t))
113+
t = F.relu(self.conv3(t))
114+
t = F.relu(self.conv4(t))
115+
bbox_reg.append(self.bbox_reg(t))
116+
return bbox_reg
117+
118+
119+
class RetinaNet(nn.Module):
120+
"""
121+
Implements RetinaNet.
122+
123+
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
124+
image, and should be in 0-1 range. Different images can have different sizes.
125+
126+
The behavior of the model changes depending if it is in training or evaluation mode.
127+
128+
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
129+
containing:
130+
- boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values
131+
between 0 and H and 0 and W
132+
- labels (Int64Tensor[N]): the class label for each ground-truth box
133+
134+
The model returns a Dict[Tensor] during training, containing the classification and regression
135+
losses.
136+
137+
During inference, the model requires only the input tensors, and returns the post-processed
138+
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
139+
follows:
140+
- boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between
141+
0 and H and 0 and W
142+
- labels (Int64Tensor[N]): the predicted labels for each image
143+
- scores (Tensor[N]): the scores for each prediction
144+
145+
Arguments:
146+
backbone (nn.Module): the network used to compute the features for the model.
147+
It should contain an out_channels attribute, which indicates the number of output
148+
channels that each feature map has (and it should be the same for all feature maps).
149+
The backbone should return a single Tensor or an OrderedDict[Tensor].
150+
num_classes (int): number of output classes of the model (excluding the background).
151+
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
152+
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
153+
image_mean (Tuple[float, float, float]): mean values used for input normalization.
154+
They are generally the mean values of the dataset on which the backbone has been trained
155+
on
156+
image_std (Tuple[float, float, float]): std values used for input normalization.
157+
They are generally the std values of the dataset on which the backbone has been trained on
158+
anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
159+
maps.
160+
head (nn.Module): Module run on top of the feature pyramid.
161+
Defaults to a module containing a classification and regression module.
162+
pre_nms_top_n (int): number of proposals to keep before applying NMS during testing.
163+
post_nms_top_n (int): number of proposals to keep after applying NMS during testing.
164+
nms_thresh (float): NMS threshold used for postprocessing the detections.
165+
fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
166+
considered as positive during training.
167+
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
168+
considered as negative during training.
169+
170+
Example::
171+
172+
>>> import torch
173+
>>> import torchvision
174+
>>> from torchvision.models.detection import RetinaNet
175+
>>> from torchvision.models.detection.rpn import AnchorGenerator
176+
>>> # load a pre-trained model for classification and return
177+
>>> # only the features
178+
>>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
179+
>>> # RetinaNet needs to know the number of
180+
>>> # output channels in a backbone. For mobilenet_v2, it's 1280
181+
>>> # so we need to add it here
182+
>>> backbone.out_channels = 1280
183+
>>>
184+
>>> # let's make the network generate 5 x 3 anchors per spatial
185+
>>> # location, with 5 different sizes and 3 different aspect
186+
>>> # ratios. We have a Tuple[Tuple[int]] because each feature
187+
>>> # map could potentially have different sizes and
188+
>>> # aspect ratios
189+
>>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
190+
>>> aspect_ratios=((0.5, 1.0, 2.0),))
191+
>>>
192+
>>> # put the pieces together inside a RetinaNet model
193+
>>> model = RetinaNet(backbone,
194+
>>> num_classes=2,
195+
>>> anchor_generator=anchor_generator)
196+
>>> model.eval()
197+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
198+
>>> predictions = model(x)
199+
"""
200+
201+
def __init__(self, backbone, num_classes,
202+
# transform parameters
203+
min_size=800, max_size=1333,
204+
image_mean=None, image_std=None,
205+
# Anchor parameters
206+
anchor_generator=None, head=None,
207+
pre_nms_top_n=1000, post_nms_top_n=1000,
208+
nms_thresh=0.5,
209+
fg_iou_thresh=0.5, bg_iou_thresh=0.4):
210+
211+
if not hasattr(backbone, "out_channels"):
212+
raise ValueError(
213+
"backbone should contain an attribute out_channels "
214+
"specifying the number of output channels (assumed to be the "
215+
"same for all the levels)")
216+
217+
assert isinstance(anchor_generator, (AnchorGenerator, type(None)))
218+
219+
if anchor_generator is None:
220+
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
221+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
222+
self.anchor_generator = AnchorGenerator(
223+
anchor_sizes, aspect_ratios
224+
)
225+
226+
if head is None:
227+
head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location())
228+
self.head = head
229+
230+
if image_mean is None:
231+
image_mean = [0.485, 0.456, 0.406]
232+
if image_std is None:
233+
image_std = [0.229, 0.224, 0.225]
234+
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
235+
236+
@torch.jit.unused
237+
def eager_outputs(self, losses, detections):
238+
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
239+
if self.training:
240+
return losses
241+
242+
return detections
243+
244+
def forward(self, images, targets=None):
245+
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
246+
"""
247+
Arguments:
248+
images (list[Tensor]): images to be processed
249+
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
250+
251+
Returns:
252+
result (list[BoxList] or dict[Tensor]): the output from the model.
253+
During training, it returns a dict[Tensor] which contains the losses.
254+
During testing, it returns list[BoxList] contains additional fields
255+
like `scores`, `labels` and `mask` (for Mask R-CNN models).
256+
257+
"""
258+
if self.training and targets is None:
259+
raise ValueError("In training mode, targets should be passed")
260+
261+
# get the original image sizes
262+
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
263+
for img in images:
264+
val = img.shape[-2:]
265+
assert len(val) == 2
266+
original_image_sizes.append((val[0], val[1]))
267+
268+
# transform the input
269+
images, targets = self.transform(images, targets)
270+
271+
# get the features from the backbone
272+
features = self.backbone(images.tensors)
273+
if isinstance(features, torch.Tensor):
274+
features = OrderedDict([('0', features)])
275+
276+
# compute the retinanet heads outputs using the features
277+
head_outputs = self.head(images, features, targets)
278+
279+
# create the set of anchors
280+
anchors = self.anchor_generator(images, features)
281+
282+
losses = {}
283+
detections = torch.jit.annotate(List[Dict[str, torch.Tensor]], [])
284+
if self.training:
285+
assert targets is not None
286+
287+
# compute the losses
288+
# TODO: Move necessary functions out of rpn.RegionProposalNetwork to a class or function
289+
# so that we can use it here and in rpn.RegionProposalNetwork
290+
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
291+
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
292+
losses = self.head.compute_loss(head_outputs, labels, matched_gt_boxes)
293+
else:
294+
# compute the detections
295+
# TODO: Implement postprocess_detections
296+
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, anchors)
297+
num_images = len(images)
298+
for i in range(num_images):
299+
detections.append(
300+
{
301+
"boxes": boxes[i],
302+
"labels": labels[i],
303+
"scores": scores[i],
304+
}
305+
)
306+
307+
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
308+
309+
if torch.jit.is_scripting():
310+
if not self._has_warned:
311+
warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
312+
self._has_warned = True
313+
return (losses, detections)
314+
else:
315+
return self.eager_outputs(losses, detections)
316+
317+
318+
model_urls = {
319+
'retinanet_resnet50_fpn_coco':
320+
'#TODO',
321+
}
322+
323+
324+
def retinanet_resnet50_fpn(pretrained=False, progress=True,
325+
num_classes=91, pretrained_backbone=True, **kwargs):
326+
"""
327+
Constructs a RetinaNet model with a ResNet-50-FPN backbone.
328+
329+
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
330+
image, and should be in ``0-1`` range. Different images can have different sizes.
331+
332+
The behavior of the model changes depending if it is in training or evaluation mode.
333+
334+
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
335+
containing:
336+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values
337+
between ``0`` and ``H`` and ``0`` and ``W``
338+
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
339+
340+
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
341+
losses.
342+
343+
During inference, the model requires only the input tensors, and returns the post-processed
344+
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
345+
follows:
346+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between
347+
``0`` and ``H`` and ``0`` and ``W``
348+
- labels (``Int64Tensor[N]``): the predicted labels for each image
349+
- scores (``Tensor[N]``): the scores or each prediction
350+
351+
Example::
352+
353+
>>> model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True)
354+
>>> model.eval()
355+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
356+
>>> predictions = model(x)
357+
358+
Arguments:
359+
pretrained (bool): If True, returns a model pre-trained on COCO train2017
360+
progress (bool): If True, displays a progress bar of the download to stderr
361+
"""
362+
if pretrained:
363+
# no need to download the backbone if pretrained is set
364+
pretrained_backbone = False
365+
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
366+
model = RetinaNet(backbone, num_classes, **kwargs)
367+
if pretrained:
368+
state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'],
369+
progress=progress)
370+
model.load_state_dict(state_dict)
371+
return model

0 commit comments

Comments
 (0)