1
1
from .._utils import IntermediateLayerGetter
2
2
from ..utils import load_state_dict_from_url
3
+ from .. import mobilenet
3
4
from .. import resnet
4
5
from .deeplabv3 import DeepLabHead , DeepLabV3
5
6
from .fcn import FCN , FCNHead
6
7
7
8
8
- __all__ = ['fcn_resnet50' , 'fcn_resnet101' , 'deeplabv3_resnet50' , 'deeplabv3_resnet101' ]
9
+ __all__ = ['fcn_resnet50' , 'fcn_resnet101' , 'fcn_mobilenet_v3_large' , 'deeplabv3_resnet50' , 'deeplabv3_resnet101' ,
10
+ 'deeplabv3_mobilenet_v3_large' ]
9
11
10
12
11
13
model_urls = {
12
14
'fcn_resnet50_coco' : 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth' ,
13
15
'fcn_resnet101_coco' : 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth' ,
16
+ 'fcn_mobilenet_v3_large_coco' : None ,
14
17
'deeplabv3_resnet50_coco' : 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth' ,
15
18
'deeplabv3_resnet101_coco' : 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth' ,
19
+ 'deeplabv3_mobilenet_v3_large_coco' : None ,
16
20
}
17
21
18
22
@@ -23,6 +27,15 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
23
27
replace_stride_with_dilation = [False , True , True ])
24
28
out_layer = 'layer4'
25
29
aux_layer = 'layer3'
30
+ elif 'mobilenet' in backbone_name :
31
+ backbone = mobilenet .__dict__ [backbone_name ](pretrained = pretrained_backbone ).features
32
+
33
+ # Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
34
+ # The first and last blocks are always included because they are the C0 (conv1) and Cn.
35
+ stage_indices = [0 ] + [i for i , b in enumerate (backbone ) if getattr (b , "is_strided" , False )] + [
36
+ len (backbone ) - 1 ]
37
+ out_layer = str (stage_indices [- 1 ])
38
+ aux_layer = str (stage_indices [- 2 ])
26
39
else :
27
40
raise NotImplementedError ('backbone {} is not supported as of now' .format (backbone_name ))
28
41
@@ -71,6 +84,8 @@ def fcn_resnet50(pretrained=False, progress=True,
71
84
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
72
85
contains the same classes as Pascal VOC
73
86
progress (bool): If True, displays a progress bar of the download to stderr
87
+ num_classes (int): number of output classes of the model (including the background)
88
+ aux_loss (bool): If True, it uses an auxiliary loss
74
89
"""
75
90
return _load_model ('fcn' , 'resnet50' , pretrained , progress , num_classes , aux_loss , ** kwargs )
76
91
@@ -83,10 +98,26 @@ def fcn_resnet101(pretrained=False, progress=True,
83
98
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
84
99
contains the same classes as Pascal VOC
85
100
progress (bool): If True, displays a progress bar of the download to stderr
101
+ num_classes (int): number of output classes of the model (including the background)
102
+ aux_loss (bool): If True, it uses an auxiliary loss
86
103
"""
87
104
return _load_model ('fcn' , 'resnet101' , pretrained , progress , num_classes , aux_loss , ** kwargs )
88
105
89
106
107
+ def fcn_mobilenet_v3_large (pretrained = False , progress = True ,
108
+ num_classes = 21 , aux_loss = None , ** kwargs ):
109
+ """Constructs a Fully-Convolutional Network model with a MobileNetV3-Large backbone.
110
+
111
+ Args:
112
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
113
+ contains the same classes as Pascal VOC
114
+ progress (bool): If True, displays a progress bar of the download to stderr
115
+ num_classes (int): number of output classes of the model (including the background)
116
+ aux_loss (bool): If True, it uses an auxiliary loss
117
+ """
118
+ return _load_model ('fcn' , 'mobilenet_v3_large' , pretrained , progress , num_classes , aux_loss , ** kwargs )
119
+
120
+
90
121
def deeplabv3_resnet50 (pretrained = False , progress = True ,
91
122
num_classes = 21 , aux_loss = None , ** kwargs ):
92
123
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
@@ -95,6 +126,8 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
95
126
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
96
127
contains the same classes as Pascal VOC
97
128
progress (bool): If True, displays a progress bar of the download to stderr
129
+ num_classes (int): number of output classes of the model (including the background)
130
+ aux_loss (bool): If True, it uses an auxiliary loss
98
131
"""
99
132
return _load_model ('deeplabv3' , 'resnet50' , pretrained , progress , num_classes , aux_loss , ** kwargs )
100
133
@@ -107,5 +140,21 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
107
140
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
108
141
contains the same classes as Pascal VOC
109
142
progress (bool): If True, displays a progress bar of the download to stderr
143
+ num_classes (int): number of output classes of the model (including the background)
144
+ aux_loss (bool): If True, it uses an auxiliary loss
110
145
"""
111
146
return _load_model ('deeplabv3' , 'resnet101' , pretrained , progress , num_classes , aux_loss , ** kwargs )
147
+
148
+
149
+ def deeplabv3_mobilenet_v3_large (pretrained = False , progress = True ,
150
+ num_classes = 21 , aux_loss = None , ** kwargs ):
151
+ """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
152
+
153
+ Args:
154
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
155
+ contains the same classes as Pascal VOC
156
+ progress (bool): If True, displays a progress bar of the download to stderr
157
+ num_classes (int): number of output classes of the model (including the background)
158
+ aux_loss (bool): If True, it uses an auxiliary loss
159
+ """
160
+ return _load_model ('deeplabv3' , 'mobilenet_v3_large' , pretrained , progress , num_classes , aux_loss , ** kwargs )
0 commit comments