Skip to content

Commit 359d941

Browse files
committed
Add Lite R-ASPP with MobileNetV3 backbone.
1 parent 10a51cf commit 359d941

File tree

5 files changed

+120
-9
lines changed

5 files changed

+120
-9
lines changed
Binary file not shown.

test/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def get_available_video_models():
6565
"fcn_resnet50",
6666
"fcn_resnet101",
6767
"fcn_mobilenet_v3_large",
68+
"lraspp_mobilenet_v3_large",
6869
)
6970

7071

torchvision/models/segmentation/_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections import OrderedDict
22

3-
import torch
43
from torch import nn
54
from torch.nn import functional as F
65

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from collections import OrderedDict
2+
3+
from torch import nn, Tensor
4+
from torch.nn import functional as F
5+
from typing import Dict
6+
7+
8+
__all__ = ["LRASPP"]
9+
10+
11+
class LRASPP(nn.Module):
12+
"""
13+
Implements a Lite R-ASPP Network for semantic segmentation.
14+
15+
Args:
16+
backbone (nn.Module): the network used to compute the features for the model.
17+
The backbone should return an OrderedDict[Tensor], with the key being
18+
"high" for the high level feature map and "low" for the low level feature map.
19+
low_channels (int): the number of channels of the low level features.
20+
high_channels (int): the number of channels of the high level features.
21+
num_classes (int): number of output classes of the model (including the background).
22+
inter_channels (int, optional): the number of channels for intermediate computations.
23+
"""
24+
25+
def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128):
26+
super().__init__()
27+
self.backbone = backbone
28+
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)
29+
30+
def forward(self, input):
31+
features = self.backbone(input)
32+
out = self.classifier(features)
33+
out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False)
34+
35+
result = OrderedDict()
36+
result["out"] = out
37+
38+
return result
39+
40+
41+
class LRASPPHead(nn.Module):
42+
43+
def __init__(self, low_channels, high_channels, num_classes, inter_channels):
44+
super().__init__()
45+
self.cbr = nn.Sequential(
46+
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
47+
nn.BatchNorm2d(inter_channels),
48+
nn.ReLU(inplace=True)
49+
)
50+
self.scale = nn.Sequential(
51+
nn.AdaptiveAvgPool2d(1),
52+
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
53+
nn.Sigmoid(),
54+
)
55+
self.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
56+
self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1)
57+
58+
def forward(self, input: Dict[str, Tensor]) -> Tensor:
59+
low = input["low"]
60+
high = input["high"]
61+
62+
x = self.cbr(high)
63+
s = self.scale(high)
64+
x = x * s
65+
x = F.interpolate(x, size=low.shape[-2:], mode='bilinear', align_corners=False)
66+
67+
return self.low_classifier(low) + self.high_classifier(x)

torchvision/models/segmentation/segmentation.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from .. import resnet
55
from .deeplabv3 import DeepLabHead, DeepLabV3
66
from .fcn import FCN, FCNHead
7+
from .lraspp import LRASPP
78

89

910
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'fcn_mobilenet_v3_large', 'deeplabv3_resnet50', 'deeplabv3_resnet101',
10-
'deeplabv3_mobilenet_v3_large']
11+
'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large']
1112

1213

1314
model_urls = {
@@ -17,6 +18,7 @@
1718
'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth',
1819
'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
1920
'deeplabv3_mobilenet_v3_large_coco': None,
21+
'lraspp_mobilenet_v3_large_coco': None,
2022
}
2123

2224

@@ -69,13 +71,34 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss
6971
aux_loss = True
7072
model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
7173
if pretrained:
72-
arch = arch_type + '_' + backbone + '_coco'
73-
model_url = model_urls[arch]
74-
if model_url is None:
75-
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
76-
else:
77-
state_dict = load_state_dict_from_url(model_url, progress=progress)
78-
model.load_state_dict(state_dict)
74+
_load_weights(model, arch_type, backbone, progress)
75+
return model
76+
77+
78+
def _load_weights(model, arch_type, backbone, progress):
79+
arch = arch_type + '_' + backbone + '_coco'
80+
model_url = model_urls[arch]
81+
if model_url is None:
82+
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
83+
else:
84+
state_dict = load_state_dict_from_url(model_url, progress=progress)
85+
model.load_state_dict(state_dict)
86+
87+
88+
def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True):
89+
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features
90+
91+
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
92+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
93+
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
94+
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
95+
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
96+
low_channels = backbone[low_pos].out_channels
97+
high_channels = backbone[high_pos].out_channels
98+
99+
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): 'low', str(high_pos): 'high'})
100+
101+
model = LRASPP(backbone, low_channels, high_channels, num_classes)
79102
return model
80103

81104

@@ -161,3 +184,24 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
161184
aux_loss (bool): If True, it uses an auxiliary loss
162185
"""
163186
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)
187+
188+
189+
def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs):
190+
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
191+
192+
Args:
193+
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
194+
contains the same classes as Pascal VOC
195+
progress (bool): If True, displays a progress bar of the download to stderr
196+
num_classes (int): number of output classes of the model (including the background)
197+
"""
198+
if kwargs.pop("aux_loss", False):
199+
raise NotImplementedError('This model does not use auxiliary loss')
200+
201+
backbone_name = 'mobilenet_v3_large'
202+
model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs)
203+
204+
if pretrained:
205+
_load_weights(model, 'lraspp', backbone_name, progress)
206+
207+
return model

0 commit comments

Comments
 (0)