Skip to content

Commit 462d59a

Browse files
committed
Adding fcn and deeplabv3 directly on mobilenetv3 backbone.
1 parent 77da44c commit 462d59a

File tree

1 file changed

+58
-5
lines changed

1 file changed

+58
-5
lines changed

torchvision/models/segmentation/segmentation.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
from .._utils import IntermediateLayerGetter
22
from ..utils import load_state_dict_from_url
3+
from .. import mobilenet
34
from .. import resnet
45
from .deeplabv3 import DeepLabHead, DeepLabV3
56
from .fcn import FCN, FCNHead
67

78

8-
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101']
9+
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'fcn_mobilenet_v3_large', 'deeplabv3_resnet50', 'deeplabv3_resnet101',
10+
'deeplabv3_mobilenet_v3_large']
911

1012

1113
model_urls = {
1214
'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth',
1315
'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth',
16+
'fcn_mobilenet_v3_large_coco': None,
1417
'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth',
1518
'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
19+
'deeplabv3_mobilenet_v3_large_coco': None,
1620
}
1721

1822

@@ -22,7 +26,22 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
2226
pretrained=pretrained_backbone,
2327
replace_stride_with_dilation=[False, True, True])
2428
out_layer = 'layer4'
29+
out_inplanes = 2048
2530
aux_layer = 'layer3'
31+
aux_inplanes = 1024
32+
elif 'mobilenet' in backbone_name:
33+
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained_backbone).features
34+
35+
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
36+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
37+
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [
38+
len(backbone) - 1]
39+
out_pos = stage_indices[-1]
40+
out_layer = str(out_pos)
41+
out_inplanes = backbone[out_pos].out_channels
42+
aux_pos = stage_indices[-2]
43+
aux_layer = str(aux_pos)
44+
aux_inplanes = backbone[aux_pos].out_channels
2645
else:
2746
raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name))
2847

@@ -33,15 +52,13 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
3352

3453
aux_classifier = None
3554
if aux:
36-
inplanes = 1024
37-
aux_classifier = FCNHead(inplanes, num_classes)
55+
aux_classifier = FCNHead(aux_inplanes, num_classes)
3856

3957
model_map = {
4058
'deeplabv3': (DeepLabHead, DeepLabV3),
4159
'fcn': (FCNHead, FCN),
4260
}
43-
inplanes = 2048
44-
classifier = model_map[name][0](inplanes, num_classes)
61+
classifier = model_map[name][0](out_inplanes, num_classes)
4562
base_model = model_map[name][1]
4663

4764
model = base_model(backbone, classifier, aux_classifier)
@@ -71,6 +88,8 @@ def fcn_resnet50(pretrained=False, progress=True,
7188
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
7289
contains the same classes as Pascal VOC
7390
progress (bool): If True, displays a progress bar of the download to stderr
91+
num_classes (int): number of output classes of the model (including the background)
92+
aux_loss (bool): If True, it uses an auxiliary loss
7493
"""
7594
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
7695

@@ -83,10 +102,26 @@ def fcn_resnet101(pretrained=False, progress=True,
83102
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
84103
contains the same classes as Pascal VOC
85104
progress (bool): If True, displays a progress bar of the download to stderr
105+
num_classes (int): number of output classes of the model (including the background)
106+
aux_loss (bool): If True, it uses an auxiliary loss
86107
"""
87108
return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
88109

89110

111+
def fcn_mobilenet_v3_large(pretrained=False, progress=True,
112+
num_classes=21, aux_loss=None, **kwargs):
113+
"""Constructs a Fully-Convolutional Network model with a MobileNetV3-Large backbone.
114+
115+
Args:
116+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
117+
contains the same classes as Pascal VOC
118+
progress (bool): If True, displays a progress bar of the download to stderr
119+
num_classes (int): number of output classes of the model (including the background)
120+
aux_loss (bool): If True, it uses an auxiliary loss
121+
"""
122+
return _load_model('fcn', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)
123+
124+
90125
def deeplabv3_resnet50(pretrained=False, progress=True,
91126
num_classes=21, aux_loss=None, **kwargs):
92127
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
@@ -95,6 +130,8 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
95130
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
96131
contains the same classes as Pascal VOC
97132
progress (bool): If True, displays a progress bar of the download to stderr
133+
num_classes (int): number of output classes of the model (including the background)
134+
aux_loss (bool): If True, it uses an auxiliary loss
98135
"""
99136
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
100137

@@ -107,5 +144,21 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
107144
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
108145
contains the same classes as Pascal VOC
109146
progress (bool): If True, displays a progress bar of the download to stderr
147+
num_classes (int): number of output classes of the model (including the background)
148+
aux_loss (bool): If True, it uses an auxiliary loss
110149
"""
111150
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
151+
152+
153+
def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
154+
num_classes=21, aux_loss=None, **kwargs):
155+
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
156+
157+
Args:
158+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
159+
contains the same classes as Pascal VOC
160+
progress (bool): If True, displays a progress bar of the download to stderr
161+
num_classes (int): number of output classes of the model (including the background)
162+
aux_loss (bool): If True, it uses an auxiliary loss
163+
"""
164+
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)

0 commit comments

Comments
 (0)