1
1
from .._utils import IntermediateLayerGetter
2
2
from ..utils import load_state_dict_from_url
3
+ from .. import mobilenet
3
4
from .. import resnet
4
5
from .deeplabv3 import DeepLabHead , DeepLabV3
5
6
from .fcn import FCN , FCNHead
6
7
7
8
8
- __all__ = ['fcn_resnet50' , 'fcn_resnet101' , 'deeplabv3_resnet50' , 'deeplabv3_resnet101' ]
9
+ __all__ = ['fcn_resnet50' , 'fcn_resnet101' , 'fcn_mobilenet_v3_large' , 'deeplabv3_resnet50' , 'deeplabv3_resnet101' ,
10
+ 'deeplabv3_mobilenet_v3_large' ]
9
11
10
12
11
13
model_urls = {
12
14
'fcn_resnet50_coco' : 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth' ,
13
15
'fcn_resnet101_coco' : 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth' ,
16
+ 'fcn_mobilenet_v3_large_coco' : None ,
14
17
'deeplabv3_resnet50_coco' : 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth' ,
15
18
'deeplabv3_resnet101_coco' : 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth' ,
19
+ 'deeplabv3_mobilenet_v3_large_coco' : None ,
16
20
}
17
21
18
22
@@ -22,7 +26,22 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
22
26
pretrained = pretrained_backbone ,
23
27
replace_stride_with_dilation = [False , True , True ])
24
28
out_layer = 'layer4'
29
+ out_inplanes = 2048
25
30
aux_layer = 'layer3'
31
+ aux_inplanes = 1024
32
+ elif 'mobilenet' in backbone_name :
33
+ backbone = mobilenet .__dict__ [backbone_name ](pretrained = pretrained_backbone ).features
34
+
35
+ # Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
36
+ # The first and last blocks are always included because they are the C0 (conv1) and Cn.
37
+ stage_indices = [0 ] + [i for i , b in enumerate (backbone ) if getattr (b , "is_strided" , False )] + [
38
+ len (backbone ) - 1 ]
39
+ out_pos = stage_indices [- 1 ]
40
+ out_layer = str (out_pos )
41
+ out_inplanes = backbone [out_pos ].out_channels
42
+ aux_pos = stage_indices [- 2 ]
43
+ aux_layer = str (aux_pos )
44
+ aux_inplanes = backbone [aux_pos ].out_channels
26
45
else :
27
46
raise NotImplementedError ('backbone {} is not supported as of now' .format (backbone_name ))
28
47
@@ -33,15 +52,13 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
33
52
34
53
aux_classifier = None
35
54
if aux :
36
- inplanes = 1024
37
- aux_classifier = FCNHead (inplanes , num_classes )
55
+ aux_classifier = FCNHead (aux_inplanes , num_classes )
38
56
39
57
model_map = {
40
58
'deeplabv3' : (DeepLabHead , DeepLabV3 ),
41
59
'fcn' : (FCNHead , FCN ),
42
60
}
43
- inplanes = 2048
44
- classifier = model_map [name ][0 ](inplanes , num_classes )
61
+ classifier = model_map [name ][0 ](out_inplanes , num_classes )
45
62
base_model = model_map [name ][1 ]
46
63
47
64
model = base_model (backbone , classifier , aux_classifier )
@@ -71,6 +88,8 @@ def fcn_resnet50(pretrained=False, progress=True,
71
88
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
72
89
contains the same classes as Pascal VOC
73
90
progress (bool): If True, displays a progress bar of the download to stderr
91
+ num_classes (int): number of output classes of the model (including the background)
92
+ aux_loss (bool): If True, it uses an auxiliary loss
74
93
"""
75
94
return _load_model ('fcn' , 'resnet50' , pretrained , progress , num_classes , aux_loss , ** kwargs )
76
95
@@ -83,10 +102,26 @@ def fcn_resnet101(pretrained=False, progress=True,
83
102
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
84
103
contains the same classes as Pascal VOC
85
104
progress (bool): If True, displays a progress bar of the download to stderr
105
+ num_classes (int): number of output classes of the model (including the background)
106
+ aux_loss (bool): If True, it uses an auxiliary loss
86
107
"""
87
108
return _load_model ('fcn' , 'resnet101' , pretrained , progress , num_classes , aux_loss , ** kwargs )
88
109
89
110
111
+ def fcn_mobilenet_v3_large (pretrained = False , progress = True ,
112
+ num_classes = 21 , aux_loss = None , ** kwargs ):
113
+ """Constructs a Fully-Convolutional Network model with a MobileNetV3-Large backbone.
114
+
115
+ Args:
116
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
117
+ contains the same classes as Pascal VOC
118
+ progress (bool): If True, displays a progress bar of the download to stderr
119
+ num_classes (int): number of output classes of the model (including the background)
120
+ aux_loss (bool): If True, it uses an auxiliary loss
121
+ """
122
+ return _load_model ('fcn' , 'mobilenet_v3_large' , pretrained , progress , num_classes , aux_loss , ** kwargs )
123
+
124
+
90
125
def deeplabv3_resnet50 (pretrained = False , progress = True ,
91
126
num_classes = 21 , aux_loss = None , ** kwargs ):
92
127
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
@@ -95,6 +130,8 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
95
130
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
96
131
contains the same classes as Pascal VOC
97
132
progress (bool): If True, displays a progress bar of the download to stderr
133
+ num_classes (int): number of output classes of the model (including the background)
134
+ aux_loss (bool): If True, it uses an auxiliary loss
98
135
"""
99
136
return _load_model ('deeplabv3' , 'resnet50' , pretrained , progress , num_classes , aux_loss , ** kwargs )
100
137
@@ -107,5 +144,21 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
107
144
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
108
145
contains the same classes as Pascal VOC
109
146
progress (bool): If True, displays a progress bar of the download to stderr
147
+ num_classes (int): number of output classes of the model (including the background)
148
+ aux_loss (bool): If True, it uses an auxiliary loss
110
149
"""
111
150
return _load_model ('deeplabv3' , 'resnet101' , pretrained , progress , num_classes , aux_loss , ** kwargs )
151
+
152
+
153
+ def deeplabv3_mobilenet_v3_large (pretrained = False , progress = True ,
154
+ num_classes = 21 , aux_loss = None , ** kwargs ):
155
+ """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
156
+
157
+ Args:
158
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
159
+ contains the same classes as Pascal VOC
160
+ progress (bool): If True, displays a progress bar of the download to stderr
161
+ num_classes (int): number of output classes of the model (including the background)
162
+ aux_loss (bool): If True, it uses an auxiliary loss
163
+ """
164
+ return _load_model ('deeplabv3' , 'mobilenet_v3_large' , pretrained , progress , num_classes , aux_loss , ** kwargs )
0 commit comments