Skip to content

Commit 403396c

Browse files
committed
Adding support for reduced tail on MobileNetV3.
1 parent 5d0a664 commit 403396c

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

torchvision/models/mobilenetv3.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -201,19 +201,25 @@ def _mobilenet_v3(
201201
return model
202202

203203

204-
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
204+
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
205+
**kwargs: Any) -> MobileNetV3:
205206
"""
206207
Constructs a large MobileNetV3 architecture from
207208
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
208209
209210
Args:
210211
pretrained (bool): If True, returns a model pre-trained on ImageNet
211212
progress (bool): If True, displays a progress bar of the download to stderr
213+
reduced_tail (bool): If True, reduces the channel counts of all feature layers
214+
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
215+
backbone for Detection and Segmentation.
212216
"""
213217
width_mult = 1.0
214218
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
215219
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
216220

221+
reduce_divider = 2 if reduced_tail else 1
222+
217223
inverted_residual_setting = [
218224
bneck_conf(16, 3, 16, 16, False, "RE", 1),
219225
bneck_conf(16, 3, 64, 24, False, "RE", 2),
@@ -227,28 +233,34 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs
227233
bneck_conf(80, 3, 184, 80, False, "HS", 1),
228234
bneck_conf(80, 3, 480, 112, True, "HS", 1),
229235
bneck_conf(112, 3, 672, 112, True, "HS", 1),
230-
bneck_conf(112, 5, 672, 160, True, "HS", 2),
231-
bneck_conf(160, 5, 960, 160, True, "HS", 1),
232-
bneck_conf(160, 5, 960, 160, True, "HS", 1),
236+
bneck_conf(112, 5, 672, 160, True, "HS", 2), # C4
237+
bneck_conf(160 // reduce_divider, 5, 960, 160, True, "HS", 1),
238+
bneck_conf(160 // reduce_divider, 5, 960, 160, True, "HS", 1),
233239
]
234-
last_channel = adjust_channels(1280)
240+
last_channel = adjust_channels(1280 // reduce_divider) # C5
235241

236242
return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
237243

238244

239-
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
245+
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
246+
**kwargs: Any) -> MobileNetV3:
240247
"""
241248
Constructs a small MobileNetV3 architecture from
242249
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
243250
244251
Args:
245252
pretrained (bool): If True, returns a model pre-trained on ImageNet
246253
progress (bool): If True, displays a progress bar of the download to stderr
254+
reduced_tail (bool): If True, reduces the channel counts of all feature layers
255+
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
256+
backbone for Detection and Segmentation.
247257
"""
248258
width_mult = 1.0
249259
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
250260
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
251261

262+
reduce_divider = 2 if reduced_tail else 1
263+
252264
inverted_residual_setting = [
253265
bneck_conf(16, 3, 16, 16, True, "RE", 2),
254266
bneck_conf(16, 3, 72, 24, False, "RE", 2),
@@ -258,10 +270,10 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
258270
bneck_conf(40, 5, 240, 40, True, "HS", 1),
259271
bneck_conf(40, 5, 120, 48, True, "HS", 1),
260272
bneck_conf(48, 5, 144, 48, True, "HS", 1),
261-
bneck_conf(48, 5, 288, 96, True, "HS", 2),
262-
bneck_conf(96, 5, 576, 96, True, "HS", 1),
263-
bneck_conf(96, 5, 576, 96, True, "HS", 1),
273+
bneck_conf(48, 5, 288, 96, True, "HS", 2), # C4
274+
bneck_conf(96 // reduce_divider, 5, 576, 96, True, "HS", 1),
275+
bneck_conf(96 // reduce_divider, 5, 576, 96, True, "HS", 1),
264276
]
265-
last_channel = adjust_channels(1024)
277+
last_channel = adjust_channels(1024 // reduce_divider) # C5
266278

267279
return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)

0 commit comments

Comments
 (0)