|
1 |
| -from typing import Any, Optional, Union |
| 1 | +from typing import Any, Callable, List, Optional, Tuple, Union |
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 | import torch.nn.functional as F
|
|
24 | 24 | __all__ = [
|
25 | 25 | "FasterRCNN",
|
26 | 26 | "FasterRCNN_ResNet50_FPN_Weights",
|
| 27 | + "FasterRCNN_ResNet50_FPN_V2_Weights", |
27 | 28 | "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
|
28 | 29 | "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
|
29 | 30 | "fasterrcnn_resnet50_fpn",
|
| 31 | + "fasterrcnn_resnet50_fpn_v2", |
30 | 32 | "fasterrcnn_mobilenet_v3_large_fpn",
|
31 | 33 | "fasterrcnn_mobilenet_v3_large_320_fpn",
|
32 | 34 | ]
|
33 | 35 |
|
34 | 36 |
|
| 37 | +def _default_anchorgen(): |
| 38 | + anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) |
| 39 | + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) |
| 40 | + return AnchorGenerator(anchor_sizes, aspect_ratios) |
| 41 | + |
| 42 | + |
35 | 43 | class FasterRCNN(GeneralizedRCNN):
|
36 | 44 | """
|
37 | 45 | Implements Faster R-CNN.
|
@@ -216,9 +224,7 @@ def __init__(
|
216 | 224 | out_channels = backbone.out_channels
|
217 | 225 |
|
218 | 226 | if rpn_anchor_generator is None:
|
219 |
| - anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) |
220 |
| - aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) |
221 |
| - rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) |
| 227 | + rpn_anchor_generator = _default_anchorgen() |
222 | 228 | if rpn_head is None:
|
223 | 229 | rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
|
224 | 230 |
|
@@ -298,6 +304,43 @@ def forward(self, x):
|
298 | 304 | return x
|
299 | 305 |
|
300 | 306 |
|
| 307 | +class FastRCNNConvFCHead(nn.Sequential): |
| 308 | + def __init__( |
| 309 | + self, |
| 310 | + input_size: Tuple[int, int, int], |
| 311 | + conv_layers: List[int], |
| 312 | + fc_layers: List[int], |
| 313 | + norm_layer: Optional[Callable[..., nn.Module]] = None, |
| 314 | + ): |
| 315 | + """ |
| 316 | + Args: |
| 317 | + input_size (Tuple[int, int, int]): the input size in CHW format. |
| 318 | + conv_layers (list): feature dimensions of each Convolution layer |
| 319 | + fc_layers (list): feature dimensions of each FCN layer |
| 320 | + norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None |
| 321 | + """ |
| 322 | + in_channels, in_height, in_width = input_size |
| 323 | + |
| 324 | + blocks = [] |
| 325 | + previous_channels = in_channels |
| 326 | + for current_channels in conv_layers: |
| 327 | + blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer)) |
| 328 | + previous_channels = current_channels |
| 329 | + blocks.append(nn.Flatten()) |
| 330 | + previous_channels = previous_channels * in_height * in_width |
| 331 | + for current_channels in fc_layers: |
| 332 | + blocks.append(nn.Linear(previous_channels, current_channels)) |
| 333 | + blocks.append(nn.ReLU(inplace=True)) |
| 334 | + previous_channels = current_channels |
| 335 | + |
| 336 | + super().__init__(*blocks) |
| 337 | + for layer in self.modules(): |
| 338 | + if isinstance(layer, nn.Conv2d): |
| 339 | + nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu") |
| 340 | + if layer.bias is not None: |
| 341 | + nn.init.zeros_(layer.bias) |
| 342 | + |
| 343 | + |
301 | 344 | class FastRCNNPredictor(nn.Module):
|
302 | 345 | """
|
303 | 346 | Standard classification + bounding box regression layers
|
@@ -349,6 +392,10 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
|
349 | 392 | DEFAULT = COCO_V1
|
350 | 393 |
|
351 | 394 |
|
| 395 | +class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): |
| 396 | + pass |
| 397 | + |
| 398 | + |
352 | 399 | class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
|
353 | 400 | COCO_V1 = Weights(
|
354 | 401 | url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
|
@@ -481,6 +528,66 @@ def fasterrcnn_resnet50_fpn(
|
481 | 528 | return model
|
482 | 529 |
|
483 | 530 |
|
| 531 | +def fasterrcnn_resnet50_fpn_v2( |
| 532 | + *, |
| 533 | + weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None, |
| 534 | + progress: bool = True, |
| 535 | + num_classes: Optional[int] = None, |
| 536 | + weights_backbone: Optional[ResNet50_Weights] = None, |
| 537 | + trainable_backbone_layers: Optional[int] = None, |
| 538 | + **kwargs: Any, |
| 539 | +) -> FasterRCNN: |
| 540 | + """ |
| 541 | + Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone. |
| 542 | +
|
| 543 | + Reference: `"Benchmarking Detection Transfer Learning with Vision Transformers" |
| 544 | + <https://arxiv.org/abs/2111.11429>`_. |
| 545 | +
|
| 546 | + :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more details. |
| 547 | +
|
| 548 | + Args: |
| 549 | + weights (FasterRCNN_ResNet50_FPN_V2_Weights, optional): The pretrained weights for the model |
| 550 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 551 | + num_classes (int, optional): number of output classes of the model (including the background) |
| 552 | + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone |
| 553 | + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. |
| 554 | + Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is |
| 555 | + passed (the default) this value is set to 3. |
| 556 | + """ |
| 557 | + weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights) |
| 558 | + weights_backbone = ResNet50_Weights.verify(weights_backbone) |
| 559 | + |
| 560 | + if weights is not None: |
| 561 | + weights_backbone = None |
| 562 | + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) |
| 563 | + elif num_classes is None: |
| 564 | + num_classes = 91 |
| 565 | + |
| 566 | + is_trained = weights is not None or weights_backbone is not None |
| 567 | + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) |
| 568 | + |
| 569 | + backbone = resnet50(weights=weights_backbone, progress=progress) |
| 570 | + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d) |
| 571 | + rpn_anchor_generator = _default_anchorgen() |
| 572 | + rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2) |
| 573 | + box_head = FastRCNNConvFCHead( |
| 574 | + (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d |
| 575 | + ) |
| 576 | + model = FasterRCNN( |
| 577 | + backbone, |
| 578 | + num_classes=num_classes, |
| 579 | + rpn_anchor_generator=rpn_anchor_generator, |
| 580 | + rpn_head=rpn_head, |
| 581 | + box_head=box_head, |
| 582 | + **kwargs, |
| 583 | + ) |
| 584 | + |
| 585 | + if weights is not None: |
| 586 | + model.load_state_dict(weights.get_state_dict(progress=progress)) |
| 587 | + |
| 588 | + return model |
| 589 | + |
| 590 | + |
484 | 591 | def _fasterrcnn_mobilenet_v3_large_fpn(
|
485 | 592 | *,
|
486 | 593 | weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
|
|
0 commit comments