Skip to content

Commit 8061535

Browse files
committed
Adding fcn and deeplabv3 directly on mobilenetv3 backbone.
1 parent 77da44c commit 8061535

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

torchvision/models/segmentation/segmentation.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
from .._utils import IntermediateLayerGetter
22
from ..utils import load_state_dict_from_url
3+
from .. import mobilenet
34
from .. import resnet
45
from .deeplabv3 import DeepLabHead, DeepLabV3
56
from .fcn import FCN, FCNHead
67

78

8-
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101']
9+
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'fcn_mobilenet_v3_large', 'deeplabv3_resnet50', 'deeplabv3_resnet101',
10+
'deeplabv3_mobilenet_v3_large']
911

1012

1113
model_urls = {
1214
'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth',
1315
'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth',
16+
'fcn_mobilenet_v3_large_coco': None,
1417
'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth',
1518
'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
19+
'deeplabv3_mobilenet_v3_large_coco': None,
1620
}
1721

1822

@@ -23,6 +27,15 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
2327
replace_stride_with_dilation=[False, True, True])
2428
out_layer = 'layer4'
2529
aux_layer = 'layer3'
30+
elif 'mobilenet' in backbone_name:
31+
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained_backbone).features
32+
33+
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
34+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
35+
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [
36+
len(backbone) - 1]
37+
out_layer = str(stage_indices[-1])
38+
aux_layer = str(stage_indices[-2])
2639
else:
2740
raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name))
2841

@@ -71,6 +84,8 @@ def fcn_resnet50(pretrained=False, progress=True,
7184
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
7285
contains the same classes as Pascal VOC
7386
progress (bool): If True, displays a progress bar of the download to stderr
87+
num_classes (int): number of output classes of the model (including the background)
88+
aux_loss (bool): If True, it uses an auxiliary loss
7489
"""
7590
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
7691

@@ -83,10 +98,26 @@ def fcn_resnet101(pretrained=False, progress=True,
8398
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
8499
contains the same classes as Pascal VOC
85100
progress (bool): If True, displays a progress bar of the download to stderr
101+
num_classes (int): number of output classes of the model (including the background)
102+
aux_loss (bool): If True, it uses an auxiliary loss
86103
"""
87104
return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
88105

89106

107+
def fcn_mobilenet_v3_large(pretrained=False, progress=True,
108+
num_classes=21, aux_loss=None, **kwargs):
109+
"""Constructs a Fully-Convolutional Network model with a MobileNetV3-Large backbone.
110+
111+
Args:
112+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
113+
contains the same classes as Pascal VOC
114+
progress (bool): If True, displays a progress bar of the download to stderr
115+
num_classes (int): number of output classes of the model (including the background)
116+
aux_loss (bool): If True, it uses an auxiliary loss
117+
"""
118+
return _load_model('fcn', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)
119+
120+
90121
def deeplabv3_resnet50(pretrained=False, progress=True,
91122
num_classes=21, aux_loss=None, **kwargs):
92123
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
@@ -95,6 +126,8 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
95126
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
96127
contains the same classes as Pascal VOC
97128
progress (bool): If True, displays a progress bar of the download to stderr
129+
num_classes (int): number of output classes of the model (including the background)
130+
aux_loss (bool): If True, it uses an auxiliary loss
98131
"""
99132
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
100133

@@ -107,5 +140,21 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
107140
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
108141
contains the same classes as Pascal VOC
109142
progress (bool): If True, displays a progress bar of the download to stderr
143+
num_classes (int): number of output classes of the model (including the background)
144+
aux_loss (bool): If True, it uses an auxiliary loss
110145
"""
111146
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
147+
148+
149+
def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
150+
num_classes=21, aux_loss=None, **kwargs):
151+
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
152+
153+
Args:
154+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
155+
contains the same classes as Pascal VOC
156+
progress (bool): If True, displays a progress bar of the download to stderr
157+
num_classes (int): number of output classes of the model (including the background)
158+
aux_loss (bool): If True, it uses an auxiliary loss
159+
"""
160+
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)

0 commit comments

Comments
 (0)