Skip to content

Commit 23a877a

Browse files
vincentqbfacebook-github-bot
authored andcommitted
Add MobileNetV3 architecture for Segmentation (#3276)
Summary: * Making _segm_resnet() generic and reusable. * Adding fcn and deeplabv3 directly on mobilenetv3 backbone. * Adding tests for segmentation models. * Rename is_strided with _is_cn. * Add dilation support on MobileNetV3 for Segmentation. * Add Lite R-ASPP with MobileNetV3 backbone. * Add pretrained model weights. * Removing model fcn_mobilenet_v3_large. * Adding docs and imports. * Fixing typo and readme. Reviewed By: datumbox Differential Revision: D26156380 fbshipit-source-id: e62528b52728804a40da79c1311562a7f1c2afbd
1 parent dfab2df commit 23a877a

13 files changed

+250
-72
lines changed

docs/source/models.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ The models subpackage contains definitions for the following model
271271
architectures for semantic segmentation:
272272

273273
- `FCN ResNet50, ResNet101 <https://arxiv.org/abs/1411.4038>`_
274-
- `DeepLabV3 ResNet50, ResNet101 <https://arxiv.org/abs/1706.05587>`_
274+
- `DeepLabV3 ResNet50, ResNet101, MobileNetV3-Large <https://arxiv.org/abs/1706.05587>`_
275+
- `LR-ASPP MobileNetV3-Large <https://arxiv.org/abs/1905.02244>`_
275276

276277
As with image classification models, all pre-trained models expect input images normalized in the same way.
277278
The images have to be loaded in to a range of ``[0, 1]`` and then normalized using
@@ -298,6 +299,8 @@ FCN ResNet50 60.5 91.4
298299
FCN ResNet101 63.7 91.9
299300
DeepLabV3 ResNet50 66.4 92.4
300301
DeepLabV3 ResNet101 67.4 92.4
302+
DeepLabV3 MobileNetV3-Large 60.3 91.2
303+
LR-ASPP MobileNetV3-Large 57.9 91.2
301304
================================ ============= ====================
302305

303306

@@ -313,6 +316,13 @@ DeepLabV3
313316

314317
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet50
315318
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet101
319+
.. autofunction:: torchvision.models.segmentation.deeplabv3_mobilenet_v3_large
320+
321+
322+
LR-ASPP
323+
-------
324+
325+
.. autofunction:: torchvision.models.segmentation.lraspp_mobilenet_v3_large
316326

317327

318328
Object Detection, Instance Segmentation and Person Keypoint Detection

hubconf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818

1919
# segmentation
2020
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \
21-
deeplabv3_resnet50, deeplabv3_resnet101
21+
deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large

references/segmentation/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.
3131
```
3232
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss
3333
```
34+
35+
## deeplabv3_mobilenet_v3_large
36+
```
37+
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001
38+
```
39+
40+
## lraspp_mobilenet_v3_large
41+
```
42+
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001
43+
```
Binary file not shown.
Binary file not shown.

test/test_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ def get_available_video_models():
6161
"wide_resnet101_2",
6262
"deeplabv3_resnet50",
6363
"deeplabv3_resnet101",
64+
"deeplabv3_mobilenet_v3_large",
6465
"fcn_resnet50",
6566
"fcn_resnet101",
67+
"lraspp_mobilenet_v3_large",
6668
)
6769

6870

torchvision/models/detection/backbone_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ def mobilenet_backbone(
136136
):
137137
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
138138

139-
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
139+
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
140140
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
141-
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1]
141+
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
142142
num_stages = len(stage_indices)
143143

144144
# find the index of the layer from which we wont freeze

torchvision/models/mobilenetv2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@ def __init__(
3838
groups: int = 1,
3939
norm_layer: Optional[Callable[..., nn.Module]] = None,
4040
activation_layer: Optional[Callable[..., nn.Module]] = None,
41+
dilation: int = 1,
4142
) -> None:
42-
padding = (kernel_size - 1) // 2
43+
padding = (kernel_size - 1) // 2 * dilation
4344
if norm_layer is None:
4445
norm_layer = nn.BatchNorm2d
4546
if activation_layer is None:
4647
activation_layer = nn.ReLU6
4748
super(ConvBNReLU, self).__init__(
48-
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
49+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
50+
bias=False),
4951
norm_layer(out_planes),
5052
activation_layer(inplace=True)
5153
)
@@ -88,7 +90,7 @@ def __init__(
8890
])
8991
self.conv = nn.Sequential(*layers)
9092
self.out_channels = oup
91-
self.is_strided = stride > 1
93+
self._is_cn = stride > 1
9294

9395
def forward(self, x: Tensor) -> Tensor:
9496
if self.use_res_connect:

torchvision/models/mobilenetv3.py

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,15 @@ def forward(self, input: Tensor) -> Tensor:
3838
class InvertedResidualConfig:
3939

4040
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
41-
activation: str, stride: int, width_mult: float):
41+
activation: str, stride: int, dilation: int, width_mult: float):
4242
self.input_channels = self.adjust_channels(input_channels, width_mult)
4343
self.kernel = kernel
4444
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
4545
self.out_channels = self.adjust_channels(out_channels, width_mult)
4646
self.use_se = use_se
4747
self.use_hs = activation == "HS"
4848
self.stride = stride
49+
self.dilation = dilation
4950

5051
@staticmethod
5152
def adjust_channels(channels: int, width_mult: float):
@@ -70,9 +71,10 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
7071
norm_layer=norm_layer, activation_layer=activation_layer))
7172

7273
# depthwise
74+
stride = 1 if cnf.dilation > 1 else cnf.stride
7375
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
74-
stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer,
75-
activation_layer=activation_layer))
76+
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
77+
norm_layer=norm_layer, activation_layer=activation_layer))
7678
if cnf.use_se:
7779
layers.append(SqueezeExcitation(cnf.expanded_channels))
7880

@@ -82,7 +84,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
8284

8385
self.block = nn.Sequential(*layers)
8486
self.out_channels = cnf.out_channels
85-
self.is_strided = cnf.stride > 1
87+
self._is_cn = cnf.stride > 1
8688

8789
def forward(self, input: Tensor) -> Tensor:
8890
result = self.block(input)
@@ -194,78 +196,74 @@ def _mobilenet_v3(
194196
return model
195197

196198

197-
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
198-
**kwargs: Any) -> MobileNetV3:
199+
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
199200
"""
200201
Constructs a large MobileNetV3 architecture from
201202
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
202203
203204
Args:
204205
pretrained (bool): If True, returns a model pre-trained on ImageNet
205206
progress (bool): If True, displays a progress bar of the download to stderr
206-
reduced_tail (bool): If True, reduces the channel counts of all feature layers
207-
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
208-
backbone for Detection and Segmentation.
209207
"""
208+
# non-public config parameters
209+
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
210+
dilation = 2 if kwargs.pop('_dilated', False) else 1
210211
width_mult = 1.0
212+
211213
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
212214
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
213215

214-
reduce_divider = 2 if reduced_tail else 1
215-
216216
inverted_residual_setting = [
217-
bneck_conf(16, 3, 16, 16, False, "RE", 1),
218-
bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1
219-
bneck_conf(24, 3, 72, 24, False, "RE", 1),
220-
bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2
221-
bneck_conf(40, 5, 120, 40, True, "RE", 1),
222-
bneck_conf(40, 5, 120, 40, True, "RE", 1),
223-
bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3
224-
bneck_conf(80, 3, 200, 80, False, "HS", 1),
225-
bneck_conf(80, 3, 184, 80, False, "HS", 1),
226-
bneck_conf(80, 3, 184, 80, False, "HS", 1),
227-
bneck_conf(80, 3, 480, 112, True, "HS", 1),
228-
bneck_conf(112, 3, 672, 112, True, "HS", 1),
229-
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4
230-
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
231-
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
217+
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
218+
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
219+
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
220+
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
221+
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
222+
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
223+
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
224+
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
225+
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
226+
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
227+
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
228+
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
229+
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
230+
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
231+
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
232232
]
233233
last_channel = adjust_channels(1280 // reduce_divider) # C5
234234

235235
return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
236236

237237

238-
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
239-
**kwargs: Any) -> MobileNetV3:
238+
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
240239
"""
241240
Constructs a small MobileNetV3 architecture from
242241
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
243242
244243
Args:
245244
pretrained (bool): If True, returns a model pre-trained on ImageNet
246245
progress (bool): If True, displays a progress bar of the download to stderr
247-
reduced_tail (bool): If True, reduces the channel counts of all feature layers
248-
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
249-
backbone for Detection and Segmentation.
250246
"""
247+
# non-public config parameters
248+
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
249+
dilation = 2 if kwargs.pop('_dilated', False) else 1
251250
width_mult = 1.0
251+
252252
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
253253
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
254254

255-
reduce_divider = 2 if reduced_tail else 1
256-
257255
inverted_residual_setting = [
258-
bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1
259-
bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2
260-
bneck_conf(24, 3, 88, 24, False, "RE", 1),
261-
bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3
262-
bneck_conf(40, 5, 240, 40, True, "HS", 1),
263-
bneck_conf(40, 5, 240, 40, True, "HS", 1),
264-
bneck_conf(40, 5, 120, 48, True, "HS", 1),
265-
bneck_conf(48, 5, 144, 48, True, "HS", 1),
266-
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4
267-
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
268-
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
256+
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
257+
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
258+
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
259+
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
260+
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
261+
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
262+
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
263+
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
264+
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
265+
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
266+
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
269267
]
270268
last_channel = adjust_channels(1024 // reduce_divider) # C5
271269

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .segmentation import *
22
from .fcn import *
33
from .deeplabv3 import *
4+
from .lraspp import *

torchvision/models/segmentation/_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections import OrderedDict
22

3-
import torch
43
from torch import nn
54
from torch.nn import functional as F
65

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from collections import OrderedDict
2+
3+
from torch import nn, Tensor
4+
from torch.nn import functional as F
5+
from typing import Dict
6+
7+
8+
__all__ = ["LRASPP"]
9+
10+
11+
class LRASPP(nn.Module):
12+
"""
13+
Implements a Lite R-ASPP Network for semantic segmentation from
14+
`"Searching for MobileNetV3"
15+
<https://arxiv.org/abs/1905.02244>`_.
16+
17+
Args:
18+
backbone (nn.Module): the network used to compute the features for the model.
19+
The backbone should return an OrderedDict[Tensor], with the key being
20+
"high" for the high level feature map and "low" for the low level feature map.
21+
low_channels (int): the number of channels of the low level features.
22+
high_channels (int): the number of channels of the high level features.
23+
num_classes (int): number of output classes of the model (including the background).
24+
inter_channels (int, optional): the number of channels for intermediate computations.
25+
"""
26+
27+
def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128):
28+
super().__init__()
29+
self.backbone = backbone
30+
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)
31+
32+
def forward(self, input):
33+
features = self.backbone(input)
34+
out = self.classifier(features)
35+
out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False)
36+
37+
result = OrderedDict()
38+
result["out"] = out
39+
40+
return result
41+
42+
43+
class LRASPPHead(nn.Module):
44+
45+
def __init__(self, low_channels, high_channels, num_classes, inter_channels):
46+
super().__init__()
47+
self.cbr = nn.Sequential(
48+
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
49+
nn.BatchNorm2d(inter_channels),
50+
nn.ReLU(inplace=True)
51+
)
52+
self.scale = nn.Sequential(
53+
nn.AdaptiveAvgPool2d(1),
54+
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
55+
nn.Sigmoid(),
56+
)
57+
self.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
58+
self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1)
59+
60+
def forward(self, input: Dict[str, Tensor]) -> Tensor:
61+
low = input["low"]
62+
high = input["high"]
63+
64+
x = self.cbr(high)
65+
s = self.scale(high)
66+
x = x * s
67+
x = F.interpolate(x, size=low.shape[-2:], mode='bilinear', align_corners=False)
68+
69+
return self.low_classifier(low) + self.high_classifier(x)

0 commit comments

Comments
 (0)