|
4 | 4 | from .. import resnet
|
5 | 5 | from .deeplabv3 import DeepLabHead, DeepLabV3
|
6 | 6 | from .fcn import FCN, FCNHead
|
| 7 | +from .lraspp import LRASPP |
7 | 8 |
|
8 | 9 |
|
9 | 10 | __all__ = ['fcn_resnet50', 'fcn_resnet101', 'fcn_mobilenet_v3_large', 'deeplabv3_resnet50', 'deeplabv3_resnet101',
|
10 |
| - 'deeplabv3_mobilenet_v3_large'] |
| 11 | + 'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large'] |
11 | 12 |
|
12 | 13 |
|
13 | 14 | model_urls = {
|
|
17 | 18 | 'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth',
|
18 | 19 | 'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
|
19 | 20 | 'deeplabv3_mobilenet_v3_large_coco': None,
|
| 21 | + 'lraspp_mobilenet_v3_large_coco': None, |
20 | 22 | }
|
21 | 23 |
|
22 | 24 |
|
@@ -69,13 +71,34 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss
|
69 | 71 | aux_loss = True
|
70 | 72 | model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
|
71 | 73 | if pretrained:
|
72 |
| - arch = arch_type + '_' + backbone + '_coco' |
73 |
| - model_url = model_urls[arch] |
74 |
| - if model_url is None: |
75 |
| - raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) |
76 |
| - else: |
77 |
| - state_dict = load_state_dict_from_url(model_url, progress=progress) |
78 |
| - model.load_state_dict(state_dict) |
| 74 | + _load_weights(model, arch_type, backbone, progress) |
| 75 | + return model |
| 76 | + |
| 77 | + |
| 78 | +def _load_weights(model, arch_type, backbone, progress): |
| 79 | + arch = arch_type + '_' + backbone + '_coco' |
| 80 | + model_url = model_urls[arch] |
| 81 | + if model_url is None: |
| 82 | + raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) |
| 83 | + else: |
| 84 | + state_dict = load_state_dict_from_url(model_url, progress=progress) |
| 85 | + model.load_state_dict(state_dict) |
| 86 | + |
| 87 | + |
| 88 | +def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True): |
| 89 | + backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features |
| 90 | + |
| 91 | + # Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. |
| 92 | + # The first and last blocks are always included because they are the C0 (conv1) and Cn. |
| 93 | + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] |
| 94 | + low_pos = stage_indices[-4] # use C2 here which has output_stride = 8 |
| 95 | + high_pos = stage_indices[-1] # use C5 which has output_stride = 16 |
| 96 | + low_channels = backbone[low_pos].out_channels |
| 97 | + high_channels = backbone[high_pos].out_channels |
| 98 | + |
| 99 | + backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): 'low', str(high_pos): 'high'}) |
| 100 | + |
| 101 | + model = LRASPP(backbone, low_channels, high_channels, num_classes) |
79 | 102 | return model
|
80 | 103 |
|
81 | 104 |
|
@@ -161,3 +184,24 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
|
161 | 184 | aux_loss (bool): If True, it uses an auxiliary loss
|
162 | 185 | """
|
163 | 186 | return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)
|
| 187 | + |
| 188 | + |
| 189 | +def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs): |
| 190 | + """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. |
| 191 | +
|
| 192 | + Args: |
| 193 | + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which |
| 194 | + contains the same classes as Pascal VOC |
| 195 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 196 | + num_classes (int): number of output classes of the model (including the background) |
| 197 | + """ |
| 198 | + if kwargs.pop("aux_loss", False): |
| 199 | + raise NotImplementedError('This model does not use auxiliary loss') |
| 200 | + |
| 201 | + backbone_name = 'mobilenet_v3_large' |
| 202 | + model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs) |
| 203 | + |
| 204 | + if pretrained: |
| 205 | + _load_weights(model, 'lraspp', backbone_name, progress) |
| 206 | + |
| 207 | + return model |
0 commit comments