Skip to content

Commit a4ec036

Browse files
committed
Adding quantizable MobileNetV3 architecture.
1 parent fc728c1 commit a4ec036

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all
2+
from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all
23

3-
__all__ = mv2_all
4+
__all__ = mv2_all + mv3_all
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from torch import nn, Tensor
2+
from torchvision.models.utils import load_state_dict_from_url
3+
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
4+
SqueezeExcitation, model_urls, _mobilenet_v3_conf
5+
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
6+
from typing import Any, List
7+
from .utils import _replace_relu, quantize_model
8+
9+
10+
__all__ = ['QuantizableMobileNetV3', 'mobilenet_v3_large', 'mobilenet_v3_small']
11+
12+
# TODO: Add URLs
13+
quant_model_urls = {
14+
'mobilenet_v3_large_qnnpack': None,
15+
'mobilenet_v3_small_qnnpack': None,
16+
}
17+
18+
19+
class QuantizableSqueezeExcitation(SqueezeExcitation):
20+
def __init__(self, *args, **kwargs):
21+
super().__init__(*args, **kwargs)
22+
self.skip_mul = nn.quantized.FloatFunctional()
23+
24+
def forward(self, input: Tensor) -> Tensor:
25+
return self.skip_mul.mul(self._scale(input, False), input)
26+
27+
def fuse_model(self):
28+
fuse_modules(self, ['fc1', 'relu'], inplace=True)
29+
30+
31+
class QuantizableInvertedResidual(InvertedResidual):
32+
def __init__(self, *args, **kwargs):
33+
super().__init__(*args, se_layer=QuantizableSqueezeExcitation, **kwargs)
34+
self.skip_add = nn.quantized.FloatFunctional()
35+
36+
def forward(self, x):
37+
if self.use_res_connect:
38+
return self.skip_add.add(x, self.block(x))
39+
else:
40+
return self.block(x)
41+
42+
43+
class QuantizableMobileNetV3(MobileNetV3):
44+
def __init__(self, *args, **kwargs):
45+
"""
46+
MobileNet V3 main class
47+
48+
Args:
49+
Inherits args from floating point MobileNetV3
50+
"""
51+
super().__init__(*args, **kwargs)
52+
self.quant = QuantStub()
53+
self.dequant = DeQuantStub()
54+
55+
def forward(self, x):
56+
x = self.quant(x)
57+
x = self._forward_impl(x)
58+
x = self.dequant(x)
59+
return x
60+
61+
def fuse_model(self):
62+
for m in self.modules():
63+
if type(m) == ConvBNActivation:
64+
modules_to_fuse = ['0', '1']
65+
if type(m[2]) == nn.ReLU:
66+
modules_to_fuse.append('2')
67+
fuse_modules(m, modules_to_fuse, inplace=True)
68+
elif type(m) == QuantizableSqueezeExcitation:
69+
m.fuse_model()
70+
71+
72+
def _mobilenet_v3_model(
73+
arch: str,
74+
inverted_residual_setting: List[InvertedResidualConfig],
75+
last_channel: int,
76+
pretrained: bool,
77+
progress: bool,
78+
quantize: bool,
79+
**kwargs: Any
80+
):
81+
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
82+
_replace_relu(model)
83+
84+
if quantize:
85+
backend = 'qnnpack'
86+
quantize_model(model, backend)
87+
model_url = quant_model_urls.get(arch + '_' + backend, None)
88+
else:
89+
assert pretrained in [True, False]
90+
model_url = model_urls.get(arch, None)
91+
92+
if pretrained:
93+
if model_url is None:
94+
raise ValueError("No checkpoint is available for {}".format(arch))
95+
state_dict = load_state_dict_from_url(model_url, progress=progress)
96+
model.load_state_dict(state_dict)
97+
98+
return model
99+
100+
101+
def mobilenet_v3_large(pretrained=False, progress=True, quantize=False, **kwargs):
102+
"""
103+
Constructs a MobileNetV3 Large architecture from
104+
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
105+
106+
Note that quantize = True returns a quantized model with 8 bit
107+
weights. Quantized models only support inference and run on CPUs.
108+
GPU inference is not yet supported
109+
110+
Args:
111+
pretrained (bool): If True, returns a model pre-trained on ImageNet.
112+
progress (bool): If True, displays a progress bar of the download to stderr
113+
quantize (bool): If True, returns a quantized model, else returns a float model
114+
"""
115+
arch = "mobilenet_v3_large"
116+
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
117+
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs)
118+
119+
120+
def mobilenet_v3_small(pretrained=False, progress=True, quantize=False, **kwargs):
121+
"""
122+
Constructs a MobileNetV3 Small architecture from
123+
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
124+
125+
Note that quantize = True returns a quantized model with 8 bit
126+
weights. Quantized models only support inference and run on CPUs.
127+
GPU inference is not yet supported
128+
129+
Args:
130+
pretrained (bool): If True, returns a model pre-trained on ImageNet.
131+
progress (bool): If True, displays a progress bar of the download to stderr
132+
quantize (bool): If True, returns a quantized model, else returns a float model
133+
"""
134+
arch = "mobilenet_v3_small"
135+
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
136+
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs)

0 commit comments

Comments
 (0)