Skip to content

Commit 10a51cf

Browse files
committed
Add dilation support on MobileNetV3 for Segmentation.
1 parent f64bfed commit 10a51cf

5 files changed

+50
-50
lines changed
Binary file not shown.
Binary file not shown.

torchvision/models/mobilenetv2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@ def __init__(
3838
groups: int = 1,
3939
norm_layer: Optional[Callable[..., nn.Module]] = None,
4040
activation_layer: Optional[Callable[..., nn.Module]] = None,
41+
dilation: int = 1,
4142
) -> None:
42-
padding = (kernel_size - 1) // 2
43+
padding = (kernel_size - 1) // 2 * dilation
4344
if norm_layer is None:
4445
norm_layer = nn.BatchNorm2d
4546
if activation_layer is None:
4647
activation_layer = nn.ReLU6
4748
super(ConvBNReLU, self).__init__(
48-
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
49+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
50+
bias=False),
4951
norm_layer(out_planes),
5052
activation_layer(inplace=True)
5153
)

torchvision/models/mobilenetv3.py

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,15 @@ def forward(self, input: Tensor) -> Tensor:
3838
class InvertedResidualConfig:
3939

4040
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
41-
activation: str, stride: int, width_mult: float):
41+
activation: str, stride: int, dilation: int, width_mult: float):
4242
self.input_channels = self.adjust_channels(input_channels, width_mult)
4343
self.kernel = kernel
4444
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
4545
self.out_channels = self.adjust_channels(out_channels, width_mult)
4646
self.use_se = use_se
4747
self.use_hs = activation == "HS"
4848
self.stride = stride
49+
self.dilation = dilation
4950

5051
@staticmethod
5152
def adjust_channels(channels: int, width_mult: float):
@@ -70,9 +71,10 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
7071
norm_layer=norm_layer, activation_layer=activation_layer))
7172

7273
# depthwise
74+
stride = 1 if cnf.dilation > 1 else cnf.stride
7375
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
74-
stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer,
75-
activation_layer=activation_layer))
76+
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
77+
norm_layer=norm_layer, activation_layer=activation_layer))
7678
if cnf.use_se:
7779
layers.append(SqueezeExcitation(cnf.expanded_channels))
7880

@@ -194,78 +196,74 @@ def _mobilenet_v3(
194196
return model
195197

196198

197-
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
198-
**kwargs: Any) -> MobileNetV3:
199+
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
199200
"""
200201
Constructs a large MobileNetV3 architecture from
201202
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
202203
203204
Args:
204205
pretrained (bool): If True, returns a model pre-trained on ImageNet
205206
progress (bool): If True, displays a progress bar of the download to stderr
206-
reduced_tail (bool): If True, reduces the channel counts of all feature layers
207-
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
208-
backbone for Detection and Segmentation.
209207
"""
208+
# non-public config parameters
209+
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
210+
dilation = 2 if kwargs.pop('_dilated', False) else 1
210211
width_mult = 1.0
212+
211213
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
212214
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
213215

214-
reduce_divider = 2 if reduced_tail else 1
215-
216216
inverted_residual_setting = [
217-
bneck_conf(16, 3, 16, 16, False, "RE", 1),
218-
bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1
219-
bneck_conf(24, 3, 72, 24, False, "RE", 1),
220-
bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2
221-
bneck_conf(40, 5, 120, 40, True, "RE", 1),
222-
bneck_conf(40, 5, 120, 40, True, "RE", 1),
223-
bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3
224-
bneck_conf(80, 3, 200, 80, False, "HS", 1),
225-
bneck_conf(80, 3, 184, 80, False, "HS", 1),
226-
bneck_conf(80, 3, 184, 80, False, "HS", 1),
227-
bneck_conf(80, 3, 480, 112, True, "HS", 1),
228-
bneck_conf(112, 3, 672, 112, True, "HS", 1),
229-
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4
230-
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
231-
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
217+
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
218+
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
219+
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
220+
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
221+
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
222+
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
223+
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
224+
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
225+
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
226+
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
227+
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
228+
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
229+
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
230+
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
231+
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
232232
]
233233
last_channel = adjust_channels(1280 // reduce_divider) # C5
234234

235235
return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
236236

237237

238-
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
239-
**kwargs: Any) -> MobileNetV3:
238+
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
240239
"""
241240
Constructs a small MobileNetV3 architecture from
242241
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
243242
244243
Args:
245244
pretrained (bool): If True, returns a model pre-trained on ImageNet
246245
progress (bool): If True, displays a progress bar of the download to stderr
247-
reduced_tail (bool): If True, reduces the channel counts of all feature layers
248-
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
249-
backbone for Detection and Segmentation.
250246
"""
247+
# non-public config parameters
248+
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
249+
dilation = 2 if kwargs.pop('_dilated', False) else 1
251250
width_mult = 1.0
251+
252252
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
253253
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
254254

255-
reduce_divider = 2 if reduced_tail else 1
256-
257255
inverted_residual_setting = [
258-
bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1
259-
bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2
260-
bneck_conf(24, 3, 88, 24, False, "RE", 1),
261-
bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3
262-
bneck_conf(40, 5, 240, 40, True, "HS", 1),
263-
bneck_conf(40, 5, 240, 40, True, "HS", 1),
264-
bneck_conf(40, 5, 120, 48, True, "HS", 1),
265-
bneck_conf(48, 5, 144, 48, True, "HS", 1),
266-
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4
267-
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
268-
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
256+
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
257+
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
258+
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
259+
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
260+
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
261+
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
262+
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
263+
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
264+
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
265+
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
266+
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
269267
]
270268
last_channel = adjust_channels(1024 // reduce_divider) # C5
271269

torchvision/models/segmentation/segmentation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .._utils import IntermediateLayerGetter
22
from ..utils import load_state_dict_from_url
3-
from .. import mobilenet
3+
from .. import mobilenetv3
44
from .. import resnet
55
from .deeplabv3 import DeepLabHead, DeepLabV3
66
from .fcn import FCN, FCNHead
@@ -29,16 +29,16 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
2929
out_inplanes = 2048
3030
aux_layer = 'layer3'
3131
aux_inplanes = 1024
32-
elif 'mobilenet' in backbone_name:
33-
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained_backbone).features
32+
elif 'mobilenet_v3' in backbone_name:
33+
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features
3434

3535
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
3636
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
3737
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
38-
out_pos = stage_indices[-1]
38+
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
3939
out_layer = str(out_pos)
4040
out_inplanes = backbone[out_pos].out_channels
41-
aux_pos = stage_indices[-2]
41+
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
4242
aux_layer = str(aux_pos)
4343
aux_inplanes = backbone[aux_pos].out_channels
4444
else:

0 commit comments

Comments
 (0)