|
10 | 10 | from .anchor_utils import DefaultBoxGenerator
|
11 | 11 | from .backbone_utils import _validate_trainable_layers
|
12 | 12 | from .transform import GeneralizedRCNNTransform
|
13 |
| -from .. import vgg |
| 13 | +from .. import vgg, resnet |
14 | 14 | from ..utils import load_state_dict_from_url
|
15 | 15 | from ...ops import boxes as box_ops
|
16 | 16 |
|
17 |
| -__all__ = ['SSD', 'ssd300_vgg16'] |
| 17 | +__all__ = ['SSD', 'ssd300_vgg16', 'ssd512_resnet50'] |
18 | 18 |
|
19 | 19 | model_urls = {
|
20 | 20 | 'ssd300_vgg16_coco': 'https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth',
|
| 21 | + 'ssd512_resnet50_coco': None, # TODO: add weights |
21 | 22 | }
|
22 | 23 |
|
23 | 24 | backbone_urls = {
|
@@ -562,3 +563,114 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
|
562 | 563 | state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
|
563 | 564 | model.load_state_dict(state_dict)
|
564 | 565 | return model
|
| 566 | + |
| 567 | + |
| 568 | +class SSDFeatureExtractorResNet(nn.Module): |
| 569 | + def __init__(self, backbone: resnet.ResNet): |
| 570 | + super().__init__() |
| 571 | + |
| 572 | + self.features = nn.Sequential( |
| 573 | + backbone.conv1, |
| 574 | + backbone.bn1, |
| 575 | + backbone.relu, |
| 576 | + backbone.maxpool, |
| 577 | + backbone.layer1, |
| 578 | + backbone.layer2, |
| 579 | + backbone.layer3, |
| 580 | + backbone.layer4, |
| 581 | + ) |
| 582 | + |
| 583 | + # Patch last block's strides to get valid output sizes |
| 584 | + for m in self.features[-1][0].modules(): |
| 585 | + if hasattr(m, 'stride'): |
| 586 | + m.stride = 1 |
| 587 | + |
| 588 | + backbone_out_channels = self.features[-1][-1].bn3.num_features |
| 589 | + extra = nn.ModuleList([ |
| 590 | + nn.Sequential( |
| 591 | + nn.Conv2d(backbone_out_channels, 256, kernel_size=1, bias=False), |
| 592 | + nn.BatchNorm2d(256), |
| 593 | + nn.ReLU(inplace=True), |
| 594 | + nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False), |
| 595 | + nn.BatchNorm2d(512), |
| 596 | + nn.ReLU(inplace=True), |
| 597 | + ), |
| 598 | + nn.Sequential( |
| 599 | + nn.Conv2d(512, 256, kernel_size=1, bias=False), |
| 600 | + nn.BatchNorm2d(256), |
| 601 | + nn.ReLU(inplace=True), |
| 602 | + nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False), |
| 603 | + nn.BatchNorm2d(512), |
| 604 | + nn.ReLU(inplace=True), |
| 605 | + ), |
| 606 | + nn.Sequential( |
| 607 | + nn.Conv2d(512, 128, kernel_size=1, bias=False), |
| 608 | + nn.BatchNorm2d(128), |
| 609 | + nn.ReLU(inplace=True), |
| 610 | + nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2, bias=False), |
| 611 | + nn.BatchNorm2d(256), |
| 612 | + nn.ReLU(inplace=True), |
| 613 | + ), |
| 614 | + nn.Sequential( |
| 615 | + nn.Conv2d(256, 128, kernel_size=1, bias=False), |
| 616 | + nn.BatchNorm2d(128), |
| 617 | + nn.ReLU(inplace=True), |
| 618 | + nn.Conv2d(128, 256, kernel_size=3, bias=False), |
| 619 | + nn.BatchNorm2d(256), |
| 620 | + nn.ReLU(inplace=True), |
| 621 | + ), |
| 622 | + nn.Sequential( |
| 623 | + nn.Conv2d(256, 128, kernel_size=1, bias=False), |
| 624 | + nn.BatchNorm2d(128), |
| 625 | + nn.ReLU(inplace=True), |
| 626 | + nn.Conv2d(128, 256, kernel_size=2, bias=False), |
| 627 | + nn.BatchNorm2d(256), |
| 628 | + nn.ReLU(inplace=True), |
| 629 | + ) |
| 630 | + ]) |
| 631 | + _xavier_init(extra) |
| 632 | + self.extra = extra |
| 633 | + |
| 634 | + def forward(self, x: Tensor) -> Dict[str, Tensor]: |
| 635 | + x = self.features(x) |
| 636 | + output = [x] |
| 637 | + |
| 638 | + for block in self.extra: |
| 639 | + x = block(x) |
| 640 | + output.append(x) |
| 641 | + |
| 642 | + return OrderedDict([(str(i), v) for i, v in enumerate(output)]) |
| 643 | + |
| 644 | + |
| 645 | +def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: int): |
| 646 | + backbone = resnet.__dict__[backbone_name](pretrained=pretrained) |
| 647 | + |
| 648 | + assert 0 <= trainable_layers <= 5 |
| 649 | + layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] |
| 650 | + if trainable_layers == 5: |
| 651 | + layers_to_train.append('bn1') |
| 652 | + for name, parameter in backbone.named_parameters(): |
| 653 | + if all([not name.startswith(layer) for layer in layers_to_train]): |
| 654 | + parameter.requires_grad_(False) |
| 655 | + |
| 656 | + return SSDFeatureExtractorResNet(backbone) |
| 657 | + |
| 658 | + |
| 659 | +def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes: int = 91, |
| 660 | + pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any): |
| 661 | + trainable_backbone_layers = _validate_trainable_layers( |
| 662 | + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5) |
| 663 | + |
| 664 | + if pretrained: |
| 665 | + pretrained_backbone = False |
| 666 | + |
| 667 | + backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers) |
| 668 | + anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]]) |
| 669 | + model = SSD(backbone, anchor_generator, (512, 512), num_classes, **kwargs) |
| 670 | + if pretrained: |
| 671 | + weights_name = 'ssd512_resnet50_coco' |
| 672 | + if model_urls.get(weights_name, None) is None: |
| 673 | + raise ValueError("No checkpoint is available for model {}".format(weights_name)) |
| 674 | + state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) |
| 675 | + model.load_state_dict(state_dict) |
| 676 | + return model |
0 commit comments