|
| 1 | +from collections import OrderedDict |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn |
| 5 | +import torch.nn.functional as F |
| 6 | + |
| 7 | +from ..utils import load_state_dict_from_url |
| 8 | + |
| 9 | +from .rpn import AnchorGenerator |
| 10 | +from .transform import GeneralizedRCNNTransform |
| 11 | +from .backbone_utils import resnet_fpn_backbone |
| 12 | + |
| 13 | + |
| 14 | +__all__ = [ |
| 15 | + "RetinaNet", "retinanet_resnet50_fpn", |
| 16 | +] |
| 17 | + |
| 18 | + |
| 19 | +class RetinaNetHead(nn.Module): |
| 20 | + """ |
| 21 | + A regression and classification head for use in RetinaNet. |
| 22 | +
|
| 23 | + Arguments: |
| 24 | + in_channels (int): number of channels of the input feature |
| 25 | + num_anchors (int): number of anchors to be predicted |
| 26 | + num_classes (int): number of classes to be predicted |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__(self, in_channels, num_anchors, num_classes): |
| 30 | + super(RPNHead, self).__init__() |
| 31 | + self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes) |
| 32 | + self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors) |
| 33 | + |
| 34 | + def compute_loss(self, outputs, labels, matched_gt_boxes): |
| 35 | + return { |
| 36 | + 'classification': self.classification_head.compute_loss(outputs, targets, anchor_state), |
| 37 | + 'regression': self.regression_head.compute_loss(outputs, targets, anchor_state), |
| 38 | + } |
| 39 | + |
| 40 | + def forward(self, x): |
| 41 | + logits = [self.classification_head(feature, targets) for feature in x] |
| 42 | + bbox_reg = [self.regression_head(feature, targets) for feature in x] |
| 43 | + return dict(logits=logits, bbox_reg=bbox_reg) |
| 44 | + |
| 45 | + |
| 46 | +class RetinaNetClassificationHead(nn.Module): |
| 47 | + """ |
| 48 | + A classification head for use in RetinaNet. |
| 49 | +
|
| 50 | + Arguments: |
| 51 | + in_channels (int): number of channels of the input feature |
| 52 | + num_anchors (int): number of anchors to be predicted |
| 53 | + num_classes (int): number of classes to be predicted |
| 54 | + """ |
| 55 | + |
| 56 | + def __init__(self, in_channels, num_anchors, num_classes): |
| 57 | + super(RPNHead, self).__init__() |
| 58 | + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
| 59 | + self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
| 60 | + self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
| 61 | + self.conv4 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
| 62 | + self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1) |
| 63 | + |
| 64 | + for l in self.children(): |
| 65 | + torch.nn.init.normal_(l.weight, std=0.01) |
| 66 | + torch.nn.init.constant_(l.bias, 0) |
| 67 | + |
| 68 | + def compute_loss(self, outputs, labels, matched_gt_boxes): |
| 69 | + # TODO Implement focal loss, is there an existing function for this? |
| 70 | + return 0 |
| 71 | + |
| 72 | + def forward(self, x): |
| 73 | + logits = [] |
| 74 | + for feature in x: |
| 75 | + t = F.relu(self.conv1(feature)) |
| 76 | + t = F.relu(self.conv2(t)) |
| 77 | + t = F.relu(self.conv3(t)) |
| 78 | + t = F.relu(self.conv4(t)) |
| 79 | + logits.append(self.cls_logits(t)) |
| 80 | + return logits |
| 81 | + |
| 82 | + |
| 83 | +class RetinaNetRegressionHead(nn.Module): |
| 84 | + """ |
| 85 | + A regression head for use in RetinaNet. |
| 86 | +
|
| 87 | + Arguments: |
| 88 | + in_channels (int): number of channels of the input feature |
| 89 | + num_anchors (int): number of anchors to be predicted |
| 90 | + """ |
| 91 | + |
| 92 | + def __init__(self, in_channels, num_anchors): |
| 93 | + super(RPNHead, self).__init__() |
| 94 | + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
| 95 | + self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
| 96 | + self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
| 97 | + self.conv4 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
| 98 | + self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1) |
| 99 | + |
| 100 | + for l in self.children(): |
| 101 | + torch.nn.init.normal_(l.weight, std=0.01) |
| 102 | + torch.nn.init.constant_(l.bias, 0) |
| 103 | + |
| 104 | + def compute_loss(self, outputs, labels, matched_gt_boxes): |
| 105 | + # TODO Use SmoothL1 loss for regression, or just L1 like in rpn.py ? |
| 106 | + return 0 |
| 107 | + |
| 108 | + def forward(self, x): |
| 109 | + bbox_reg = [] |
| 110 | + for feature in x: |
| 111 | + t = F.relu(self.conv1(feature)) |
| 112 | + t = F.relu(self.conv2(t)) |
| 113 | + t = F.relu(self.conv3(t)) |
| 114 | + t = F.relu(self.conv4(t)) |
| 115 | + bbox_reg.append(self.bbox_reg(t)) |
| 116 | + return bbox_reg |
| 117 | + |
| 118 | + |
| 119 | +class RetinaNet(nn.Module): |
| 120 | + """ |
| 121 | + Implements RetinaNet. |
| 122 | +
|
| 123 | + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each |
| 124 | + image, and should be in 0-1 range. Different images can have different sizes. |
| 125 | +
|
| 126 | + The behavior of the model changes depending if it is in training or evaluation mode. |
| 127 | +
|
| 128 | + During training, the model expects both the input tensors, as well as a targets (list of dictionary), |
| 129 | + containing: |
| 130 | + - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values |
| 131 | + between 0 and H and 0 and W |
| 132 | + - labels (Int64Tensor[N]): the class label for each ground-truth box |
| 133 | +
|
| 134 | + The model returns a Dict[Tensor] during training, containing the classification and regression |
| 135 | + losses. |
| 136 | +
|
| 137 | + During inference, the model requires only the input tensors, and returns the post-processed |
| 138 | + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as |
| 139 | + follows: |
| 140 | + - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between |
| 141 | + 0 and H and 0 and W |
| 142 | + - labels (Int64Tensor[N]): the predicted labels for each image |
| 143 | + - scores (Tensor[N]): the scores for each prediction |
| 144 | +
|
| 145 | + Arguments: |
| 146 | + backbone (nn.Module): the network used to compute the features for the model. |
| 147 | + It should contain an out_channels attribute, which indicates the number of output |
| 148 | + channels that each feature map has (and it should be the same for all feature maps). |
| 149 | + The backbone should return a single Tensor or an OrderedDict[Tensor]. |
| 150 | + num_classes (int): number of output classes of the model (excluding the background). |
| 151 | + min_size (int): minimum size of the image to be rescaled before feeding it to the backbone |
| 152 | + max_size (int): maximum size of the image to be rescaled before feeding it to the backbone |
| 153 | + image_mean (Tuple[float, float, float]): mean values used for input normalization. |
| 154 | + They are generally the mean values of the dataset on which the backbone has been trained |
| 155 | + on |
| 156 | + image_std (Tuple[float, float, float]): std values used for input normalization. |
| 157 | + They are generally the std values of the dataset on which the backbone has been trained on |
| 158 | + anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature |
| 159 | + maps. |
| 160 | + head (nn.Module): Module run on top of the feature pyramid. |
| 161 | + Defaults to a module containing a classification and regression module. |
| 162 | + pre_nms_top_n (int): number of proposals to keep before applying NMS during testing. |
| 163 | + post_nms_top_n (int): number of proposals to keep after applying NMS during testing. |
| 164 | + nms_thresh (float): NMS threshold used for postprocessing the detections. |
| 165 | + fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be |
| 166 | + considered as positive during training. |
| 167 | + bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be |
| 168 | + considered as negative during training. |
| 169 | +
|
| 170 | + Example:: |
| 171 | +
|
| 172 | + >>> import torch |
| 173 | + >>> import torchvision |
| 174 | + >>> from torchvision.models.detection import RetinaNet |
| 175 | + >>> from torchvision.models.detection.rpn import AnchorGenerator |
| 176 | + >>> # load a pre-trained model for classification and return |
| 177 | + >>> # only the features |
| 178 | + >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features |
| 179 | + >>> # RetinaNet needs to know the number of |
| 180 | + >>> # output channels in a backbone. For mobilenet_v2, it's 1280 |
| 181 | + >>> # so we need to add it here |
| 182 | + >>> backbone.out_channels = 1280 |
| 183 | + >>> |
| 184 | + >>> # let's make the network generate 5 x 3 anchors per spatial |
| 185 | + >>> # location, with 5 different sizes and 3 different aspect |
| 186 | + >>> # ratios. We have a Tuple[Tuple[int]] because each feature |
| 187 | + >>> # map could potentially have different sizes and |
| 188 | + >>> # aspect ratios |
| 189 | + >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), |
| 190 | + >>> aspect_ratios=((0.5, 1.0, 2.0),)) |
| 191 | + >>> |
| 192 | + >>> # put the pieces together inside a RetinaNet model |
| 193 | + >>> model = RetinaNet(backbone, |
| 194 | + >>> num_classes=2, |
| 195 | + >>> anchor_generator=anchor_generator) |
| 196 | + >>> model.eval() |
| 197 | + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] |
| 198 | + >>> predictions = model(x) |
| 199 | + """ |
| 200 | + |
| 201 | + def __init__(self, backbone, num_classes, |
| 202 | + # transform parameters |
| 203 | + min_size=800, max_size=1333, |
| 204 | + image_mean=None, image_std=None, |
| 205 | + # Anchor parameters |
| 206 | + anchor_generator=None, head=None, |
| 207 | + pre_nms_top_n=1000, post_nms_top_n=1000, |
| 208 | + nms_thresh=0.5, |
| 209 | + fg_iou_thresh=0.5, bg_iou_thresh=0.4): |
| 210 | + |
| 211 | + if not hasattr(backbone, "out_channels"): |
| 212 | + raise ValueError( |
| 213 | + "backbone should contain an attribute out_channels " |
| 214 | + "specifying the number of output channels (assumed to be the " |
| 215 | + "same for all the levels)") |
| 216 | + |
| 217 | + assert isinstance(anchor_generator, (AnchorGenerator, type(None))) |
| 218 | + |
| 219 | + if anchor_generator is None: |
| 220 | + anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) |
| 221 | + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) |
| 222 | + self.anchor_generator = AnchorGenerator( |
| 223 | + anchor_sizes, aspect_ratios |
| 224 | + ) |
| 225 | + |
| 226 | + if head is None: |
| 227 | + head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location()) |
| 228 | + self.head = head |
| 229 | + |
| 230 | + if image_mean is None: |
| 231 | + image_mean = [0.485, 0.456, 0.406] |
| 232 | + if image_std is None: |
| 233 | + image_std = [0.229, 0.224, 0.225] |
| 234 | + self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) |
| 235 | + |
| 236 | + @torch.jit.unused |
| 237 | + def eager_outputs(self, losses, detections): |
| 238 | + # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] |
| 239 | + if self.training: |
| 240 | + return losses |
| 241 | + |
| 242 | + return detections |
| 243 | + |
| 244 | + def forward(self, images, targets=None): |
| 245 | + # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) |
| 246 | + """ |
| 247 | + Arguments: |
| 248 | + images (list[Tensor]): images to be processed |
| 249 | + targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) |
| 250 | +
|
| 251 | + Returns: |
| 252 | + result (list[BoxList] or dict[Tensor]): the output from the model. |
| 253 | + During training, it returns a dict[Tensor] which contains the losses. |
| 254 | + During testing, it returns list[BoxList] contains additional fields |
| 255 | + like `scores`, `labels` and `mask` (for Mask R-CNN models). |
| 256 | +
|
| 257 | + """ |
| 258 | + if self.training and targets is None: |
| 259 | + raise ValueError("In training mode, targets should be passed") |
| 260 | + |
| 261 | + # get the original image sizes |
| 262 | + original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) |
| 263 | + for img in images: |
| 264 | + val = img.shape[-2:] |
| 265 | + assert len(val) == 2 |
| 266 | + original_image_sizes.append((val[0], val[1])) |
| 267 | + |
| 268 | + # transform the input |
| 269 | + images, targets = self.transform(images, targets) |
| 270 | + |
| 271 | + # get the features from the backbone |
| 272 | + features = self.backbone(images.tensors) |
| 273 | + if isinstance(features, torch.Tensor): |
| 274 | + features = OrderedDict([('0', features)]) |
| 275 | + |
| 276 | + # compute the retinanet heads outputs using the features |
| 277 | + head_outputs = self.head(images, features, targets) |
| 278 | + |
| 279 | + # create the set of anchors |
| 280 | + anchors = self.anchor_generator(images, features) |
| 281 | + |
| 282 | + losses = {} |
| 283 | + detections = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) |
| 284 | + if self.training: |
| 285 | + assert targets is not None |
| 286 | + |
| 287 | + # compute the losses |
| 288 | + # TODO: Move necessary functions out of rpn.RegionProposalNetwork to a class or function |
| 289 | + # so that we can use it here and in rpn.RegionProposalNetwork |
| 290 | + labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) |
| 291 | + regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) |
| 292 | + losses = self.head.compute_loss(head_outputs, labels, matched_gt_boxes) |
| 293 | + else: |
| 294 | + # compute the detections |
| 295 | + # TODO: Implement postprocess_detections |
| 296 | + boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, anchors) |
| 297 | + num_images = len(images) |
| 298 | + for i in range(num_images): |
| 299 | + detections.append( |
| 300 | + { |
| 301 | + "boxes": boxes[i], |
| 302 | + "labels": labels[i], |
| 303 | + "scores": scores[i], |
| 304 | + } |
| 305 | + ) |
| 306 | + |
| 307 | + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) |
| 308 | + |
| 309 | + if torch.jit.is_scripting(): |
| 310 | + if not self._has_warned: |
| 311 | + warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting") |
| 312 | + self._has_warned = True |
| 313 | + return (losses, detections) |
| 314 | + else: |
| 315 | + return self.eager_outputs(losses, detections) |
| 316 | + |
| 317 | + |
| 318 | +model_urls = { |
| 319 | + 'retinanet_resnet50_fpn_coco': |
| 320 | + '#TODO', |
| 321 | +} |
| 322 | + |
| 323 | + |
| 324 | +def retinanet_resnet50_fpn(pretrained=False, progress=True, |
| 325 | + num_classes=91, pretrained_backbone=True, **kwargs): |
| 326 | + """ |
| 327 | + Constructs a RetinaNet model with a ResNet-50-FPN backbone. |
| 328 | +
|
| 329 | + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each |
| 330 | + image, and should be in ``0-1`` range. Different images can have different sizes. |
| 331 | +
|
| 332 | + The behavior of the model changes depending if it is in training or evaluation mode. |
| 333 | +
|
| 334 | + During training, the model expects both the input tensors, as well as a targets (list of dictionary), |
| 335 | + containing: |
| 336 | + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values |
| 337 | + between ``0`` and ``H`` and ``0`` and ``W`` |
| 338 | + - labels (``Int64Tensor[N]``): the class label for each ground-truth box |
| 339 | +
|
| 340 | + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression |
| 341 | + losses. |
| 342 | +
|
| 343 | + During inference, the model requires only the input tensors, and returns the post-processed |
| 344 | + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as |
| 345 | + follows: |
| 346 | + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between |
| 347 | + ``0`` and ``H`` and ``0`` and ``W`` |
| 348 | + - labels (``Int64Tensor[N]``): the predicted labels for each image |
| 349 | + - scores (``Tensor[N]``): the scores or each prediction |
| 350 | +
|
| 351 | + Example:: |
| 352 | +
|
| 353 | + >>> model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True) |
| 354 | + >>> model.eval() |
| 355 | + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] |
| 356 | + >>> predictions = model(x) |
| 357 | +
|
| 358 | + Arguments: |
| 359 | + pretrained (bool): If True, returns a model pre-trained on COCO train2017 |
| 360 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 361 | + """ |
| 362 | + if pretrained: |
| 363 | + # no need to download the backbone if pretrained is set |
| 364 | + pretrained_backbone = False |
| 365 | + backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) |
| 366 | + model = RetinaNet(backbone, num_classes, **kwargs) |
| 367 | + if pretrained: |
| 368 | + state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'], |
| 369 | + progress=progress) |
| 370 | + model.load_state_dict(state_dict) |
| 371 | + return model |
0 commit comments