|
| 1 | +import torch |
| 2 | + |
| 3 | +from collections import OrderedDict |
| 4 | +from functools import partial |
| 5 | +from torch import nn, Tensor |
| 6 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
| 7 | + |
| 8 | +from . import _utils as det_utils |
| 9 | +from .ssd import SSD, SSDScoringHead |
| 10 | +from .anchor_utils import DefaultBoxGenerator |
| 11 | +from .backbone_utils import _validate_trainable_layers |
| 12 | +from .. import mobilenet |
| 13 | +from ..mobilenetv3 import ConvBNActivation |
| 14 | +from ..utils import load_state_dict_from_url |
| 15 | + |
| 16 | + |
| 17 | +__all__ = ['ssdlite320_mobilenet_v3_large'] |
| 18 | + |
| 19 | +model_urls = { |
| 20 | + 'ssdlite320_mobilenet_v3_large_coco': |
| 21 | + 'https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth' |
| 22 | +} |
| 23 | + |
| 24 | + |
| 25 | +def _prediction_block(in_channels: int, out_channels: int, kernel_size: int, |
| 26 | + norm_layer: Callable[..., nn.Module]) -> nn.Sequential: |
| 27 | + return nn.Sequential( |
| 28 | + # 3x3 depthwise with stride 1 and padding 1 |
| 29 | + ConvBNActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, |
| 30 | + norm_layer=norm_layer, activation_layer=nn.ReLU6), |
| 31 | + |
| 32 | + # 1x1 projetion to output channels |
| 33 | + nn.Conv2d(in_channels, out_channels, 1) |
| 34 | + ) |
| 35 | + |
| 36 | + |
| 37 | +def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential: |
| 38 | + activation = nn.ReLU6 |
| 39 | + intermediate_channels = out_channels // 2 |
| 40 | + return nn.Sequential( |
| 41 | + # 1x1 projection to half output channels |
| 42 | + ConvBNActivation(in_channels, intermediate_channels, kernel_size=1, |
| 43 | + norm_layer=norm_layer, activation_layer=activation), |
| 44 | + |
| 45 | + # 3x3 depthwise with stride 2 and padding 1 |
| 46 | + ConvBNActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2, |
| 47 | + groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation), |
| 48 | + |
| 49 | + # 1x1 projetion to output channels |
| 50 | + ConvBNActivation(intermediate_channels, out_channels, kernel_size=1, |
| 51 | + norm_layer=norm_layer, activation_layer=activation), |
| 52 | + ) |
| 53 | + |
| 54 | + |
| 55 | +def _normal_init(conv: nn.Module): |
| 56 | + for layer in conv.modules(): |
| 57 | + if isinstance(layer, nn.Conv2d): |
| 58 | + torch.nn.init.normal_(layer.weight, mean=0.0, std=0.03) |
| 59 | + if layer.bias is not None: |
| 60 | + torch.nn.init.constant_(layer.bias, 0.0) |
| 61 | + |
| 62 | + |
| 63 | +class SSDLiteHead(nn.Module): |
| 64 | + def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, |
| 65 | + norm_layer: Callable[..., nn.Module]): |
| 66 | + super().__init__() |
| 67 | + self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer) |
| 68 | + self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer) |
| 69 | + |
| 70 | + def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: |
| 71 | + return { |
| 72 | + 'bbox_regression': self.regression_head(x), |
| 73 | + 'cls_logits': self.classification_head(x), |
| 74 | + } |
| 75 | + |
| 76 | + |
| 77 | +class SSDLiteClassificationHead(SSDScoringHead): |
| 78 | + def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, |
| 79 | + norm_layer: Callable[..., nn.Module]): |
| 80 | + cls_logits = nn.ModuleList() |
| 81 | + for channels, anchors in zip(in_channels, num_anchors): |
| 82 | + cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer)) |
| 83 | + _normal_init(cls_logits) |
| 84 | + super().__init__(cls_logits, num_classes) |
| 85 | + |
| 86 | + |
| 87 | +class SSDLiteRegressionHead(SSDScoringHead): |
| 88 | + def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: Callable[..., nn.Module]): |
| 89 | + bbox_reg = nn.ModuleList() |
| 90 | + for channels, anchors in zip(in_channels, num_anchors): |
| 91 | + bbox_reg.append(_prediction_block(channels, 4 * anchors, 3, norm_layer)) |
| 92 | + _normal_init(bbox_reg) |
| 93 | + super().__init__(bbox_reg, 4) |
| 94 | + |
| 95 | + |
| 96 | +class SSDLiteFeatureExtractorMobileNet(nn.Module): |
| 97 | + def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], rescaling: bool, |
| 98 | + **kwargs: Any): |
| 99 | + super().__init__() |
| 100 | + # non-public config parameters |
| 101 | + min_depth = kwargs.pop('_min_depth', 16) |
| 102 | + width_mult = kwargs.pop('_width_mult', 1.0) |
| 103 | + |
| 104 | + assert not backbone[c4_pos].use_res_connect |
| 105 | + self.features = nn.Sequential( |
| 106 | + nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer |
| 107 | + nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1:]), # from C4 depthwise until end |
| 108 | + ) |
| 109 | + |
| 110 | + get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731 |
| 111 | + extra = nn.ModuleList([ |
| 112 | + _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer), |
| 113 | + _extra_block(get_depth(512), get_depth(256), norm_layer), |
| 114 | + _extra_block(get_depth(256), get_depth(256), norm_layer), |
| 115 | + _extra_block(get_depth(256), get_depth(128), norm_layer), |
| 116 | + ]) |
| 117 | + _normal_init(extra) |
| 118 | + |
| 119 | + self.extra = extra |
| 120 | + self.rescaling = rescaling |
| 121 | + |
| 122 | + def forward(self, x: Tensor) -> Dict[str, Tensor]: |
| 123 | + # Rescale from [0, 1] to [-1, -1] |
| 124 | + if self.rescaling: |
| 125 | + x = 2.0 * x - 1.0 |
| 126 | + |
| 127 | + # Get feature maps from backbone and extra. Can't be refactored due to JIT limitations. |
| 128 | + output = [] |
| 129 | + for block in self.features: |
| 130 | + x = block(x) |
| 131 | + output.append(x) |
| 132 | + |
| 133 | + for block in self.extra: |
| 134 | + x = block(x) |
| 135 | + output.append(x) |
| 136 | + |
| 137 | + return OrderedDict([(str(i), v) for i, v in enumerate(output)]) |
| 138 | + |
| 139 | + |
| 140 | +def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int, |
| 141 | + norm_layer: Callable[..., nn.Module], rescaling: bool, **kwargs: Any): |
| 142 | + backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress, |
| 143 | + norm_layer=norm_layer, **kwargs).features |
| 144 | + if not pretrained: |
| 145 | + # Change the default initialization scheme if not pretrained |
| 146 | + _normal_init(backbone) |
| 147 | + |
| 148 | + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. |
| 149 | + # The first and last blocks are always included because they are the C0 (conv1) and Cn. |
| 150 | + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] |
| 151 | + num_stages = len(stage_indices) |
| 152 | + |
| 153 | + # find the index of the layer from which we wont freeze |
| 154 | + assert 0 <= trainable_layers <= num_stages |
| 155 | + freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] |
| 156 | + |
| 157 | + for b in backbone[:freeze_before]: |
| 158 | + for parameter in b.parameters(): |
| 159 | + parameter.requires_grad_(False) |
| 160 | + |
| 161 | + return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, rescaling, **kwargs) |
| 162 | + |
| 163 | + |
| 164 | +def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91, |
| 165 | + pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None, |
| 166 | + norm_layer: Optional[Callable[..., nn.Module]] = None, |
| 167 | + **kwargs: Any): |
| 168 | + """ |
| 169 | + Constructs an SSDlite model with a MobileNetV3 Large backbone. See `SSD` for more details. |
| 170 | +
|
| 171 | + Example: |
| 172 | +
|
| 173 | + >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True) |
| 174 | + >>> model.eval() |
| 175 | + >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)] |
| 176 | + >>> predictions = model(x) |
| 177 | +
|
| 178 | + Args: |
| 179 | + norm_layer: |
| 180 | + **kwargs: |
| 181 | + pretrained (bool): If True, returns a model pre-trained on COCO train2017 |
| 182 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 183 | + num_classes (int): number of output classes of the model (including the background) |
| 184 | + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet |
| 185 | + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. |
| 186 | + Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. |
| 187 | + norm_layer (callable, optional): Module specifying the normalization layer to use. |
| 188 | + """ |
| 189 | + trainable_backbone_layers = _validate_trainable_layers( |
| 190 | + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6) |
| 191 | + |
| 192 | + if pretrained: |
| 193 | + pretrained_backbone = False |
| 194 | + |
| 195 | + # Enable [-1, 1] rescaling and reduced tail if no pretrained backbone is selected |
| 196 | + rescaling = reduce_tail = not pretrained_backbone |
| 197 | + |
| 198 | + if norm_layer is None: |
| 199 | + norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) |
| 200 | + |
| 201 | + backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers, |
| 202 | + norm_layer, rescaling, _reduced_tail=reduce_tail, _width_mult=1.0) |
| 203 | + |
| 204 | + size = (320, 320) |
| 205 | + anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) |
| 206 | + out_channels = det_utils.retrieve_out_channels(backbone, size) |
| 207 | + num_anchors = anchor_generator.num_anchors_per_location() |
| 208 | + assert len(out_channels) == len(anchor_generator.aspect_ratios) |
| 209 | + |
| 210 | + defaults = { |
| 211 | + "score_thresh": 0.001, |
| 212 | + "nms_thresh": 0.55, |
| 213 | + "detections_per_img": 300, |
| 214 | + "topk_candidates": 300, |
| 215 | + "image_mean": [0., 0., 0.], |
| 216 | + "image_std": [1., 1., 1.], |
| 217 | + } |
| 218 | + kwargs = {**defaults, **kwargs} |
| 219 | + model = SSD(backbone, anchor_generator, size, num_classes, |
| 220 | + head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), **kwargs) |
| 221 | + |
| 222 | + if pretrained: |
| 223 | + weights_name = 'ssdlite320_mobilenet_v3_large_coco' |
| 224 | + if model_urls.get(weights_name, None) is None: |
| 225 | + raise ValueError("No checkpoint is available for model {}".format(weights_name)) |
| 226 | + state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) |
| 227 | + model.load_state_dict(state_dict) |
| 228 | + return model |
0 commit comments