Skip to content

Commit 231a525

Browse files
committed
Add Lite R-ASPP with MobileNetV3 backbone.
1 parent 10a51cf commit 231a525

File tree

4 files changed

+101
-8
lines changed

4 files changed

+101
-8
lines changed
Binary file not shown.

test/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def get_available_video_models():
6565
"fcn_resnet50",
6666
"fcn_resnet101",
6767
"fcn_mobilenet_v3_large",
68+
"lraspp_mobilenet_v3_large",
6869
)
6970

7071

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from collections import OrderedDict
2+
3+
from torch import nn
4+
from torch.nn import functional as F
5+
6+
7+
__all__ = ["LRASPP"]
8+
9+
10+
class LRASPP(nn.Module):
11+
12+
def __init__(self, backbone, s8_channels, s16_channels, num_classes, inter_channels=128):
13+
super().__init__()
14+
self.backbone = backbone
15+
16+
self.cbr = nn.Sequential(
17+
nn.Conv2d(s16_channels, inter_channels, 1, bias=False),
18+
nn.BatchNorm2d(inter_channels),
19+
nn.ReLU(inplace=True)
20+
)
21+
self.scale = nn.Sequential(
22+
nn.AdaptiveAvgPool2d(1),
23+
nn.Conv2d(s16_channels, inter_channels, 1, bias=False),
24+
nn.Sigmoid(),
25+
)
26+
27+
self.s8_classifier = nn.Conv2d(s8_channels, num_classes, 1)
28+
self.s16_classifier = nn.Conv2d(inter_channels, num_classes, 1)
29+
30+
def forward(self, input):
31+
input_shape = input.shape[-2:]
32+
features = self.backbone(input)
33+
34+
s8 = features["s8"]
35+
s16 = features["s16"]
36+
37+
x = self.cbr(s16)
38+
s = self.scale(s16)
39+
x = x * s
40+
x = F.interpolate(x, size=s8.shape[-2:], mode='bilinear', align_corners=False)
41+
42+
out = self.s8_classifier(s8) + self.s16_classifier(x)
43+
out = F.interpolate(out, size=input_shape, mode='bilinear', align_corners=False)
44+
45+
result = OrderedDict()
46+
result["out"] = out
47+
48+
return result

torchvision/models/segmentation/segmentation.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from .. import resnet
55
from .deeplabv3 import DeepLabHead, DeepLabV3
66
from .fcn import FCN, FCNHead
7+
from .lraspp import LRASPP
78

89

910
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'fcn_mobilenet_v3_large', 'deeplabv3_resnet50', 'deeplabv3_resnet101',
10-
'deeplabv3_mobilenet_v3_large']
11+
'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large']
1112

1213

1314
model_urls = {
@@ -17,6 +18,7 @@
1718
'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth',
1819
'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
1920
'deeplabv3_mobilenet_v3_large_coco': None,
21+
'lraspp_mobilenet_v3_large_coco': None,
2022
}
2123

2224

@@ -64,18 +66,39 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
6466
return model
6567

6668

69+
def _segm_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True):
70+
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features
71+
72+
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
73+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
74+
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
75+
s8_pos = stage_indices[-4] # use C2 here which has output_stride = 8
76+
s16_pos = stage_indices[-1] # use C5 which has output_stride = 16
77+
s8_channels = backbone[s8_pos].out_channels
78+
s16_channels = backbone[s16_pos].out_channels
79+
80+
backbone = IntermediateLayerGetter(backbone, return_layers={str(s8_pos): 's8', str(s16_pos): 's16'})
81+
82+
model = LRASPP(backbone, s8_channels, s16_channels, num_classes)
83+
return model
84+
85+
86+
def _load_weights(model, arch_type, backbone, progress):
87+
arch = arch_type + '_' + backbone + '_coco'
88+
model_url = model_urls[arch]
89+
if model_url is None:
90+
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
91+
else:
92+
state_dict = load_state_dict_from_url(model_url, progress=progress)
93+
model.load_state_dict(state_dict)
94+
95+
6796
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
6897
if pretrained:
6998
aux_loss = True
7099
model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
71100
if pretrained:
72-
arch = arch_type + '_' + backbone + '_coco'
73-
model_url = model_urls[arch]
74-
if model_url is None:
75-
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
76-
else:
77-
state_dict = load_state_dict_from_url(model_url, progress=progress)
78-
model.load_state_dict(state_dict)
101+
_load_weights(model, arch_type, backbone, progress)
79102
return model
80103

81104

@@ -161,3 +184,24 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
161184
aux_loss (bool): If True, it uses an auxiliary loss
162185
"""
163186
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)
187+
188+
189+
def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs):
190+
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
191+
192+
Args:
193+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
194+
contains the same classes as Pascal VOC
195+
progress (bool): If True, displays a progress bar of the download to stderr
196+
num_classes (int): number of output classes of the model (including the background)
197+
"""
198+
if kwargs.pop("aux_loss", False):
199+
raise NotImplementedError('This model does not use auxiliary loss')
200+
201+
backbone_name = 'mobilenet_v3_large'
202+
model = _segm_mobilenetv3(backbone_name, num_classes, **kwargs)
203+
204+
if pretrained:
205+
_load_weights(model, 'lraspp', backbone_name, progress)
206+
207+
return model

0 commit comments

Comments
 (0)