Skip to content

Commit 1ab143e

Browse files
committed
Adding quantizable MobileNetV3 architecture.
1 parent fc728c1 commit 1ab143e

File tree

2 files changed

+143
-1
lines changed

2 files changed

+143
-1
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all
2+
from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all
23

3-
__all__ = mv2_all
4+
__all__ = mv2_all + mv3_all
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from torch import nn, Tensor
2+
from torchvision.models.utils import load_state_dict_from_url
3+
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
4+
SqueezeExcitation, model_urls, _mobilenet_v3_conf
5+
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
6+
from typing import Any, List
7+
from .utils import _replace_relu, quantize_model
8+
9+
10+
__all__ = ['QuantizableMobileNetV3', 'mobilenet_v3_large', 'mobilenet_v3_small']
11+
12+
# TODO: Add URLs
13+
quant_model_urls = {
14+
'mobilenet_v3_large_qnnpack': None,
15+
'mobilenet_v3_small_qnnpack': None,
16+
}
17+
18+
19+
class QuantizableSqueezeExcitation(SqueezeExcitation):
20+
def __init__(self, *args, **kwargs):
21+
super().__init__(*args, **kwargs)
22+
self.skip_mul = nn.quantized.FloatFunctional()
23+
24+
def forward(self, input: Tensor) -> Tensor:
25+
return self.skip_mul.mul(self._scale(input, False), input)
26+
27+
def fuse_model(self):
28+
fuse_modules(self, ['fc1', 'relu'], inplace=True)
29+
30+
31+
class QuantizableInvertedResidual(InvertedResidual):
32+
def __init__(self, *args, **kwargs):
33+
super().__init__(*args, se_layer=QuantizableSqueezeExcitation, **kwargs)
34+
self.skip_add = nn.quantized.FloatFunctional()
35+
36+
def forward(self, x):
37+
if self.use_res_connect:
38+
return self.skip_add.add(x, self.block(x))
39+
else:
40+
return self.block(x)
41+
42+
def fuse_model(self):
43+
for idx in range(len(self.block)):
44+
if type(self.block[idx]) == SqueezeExcitation:
45+
fuse_modules(self.block[idx], ['fc1', 'relu'], inplace=True)
46+
47+
48+
class QuantizableMobileNetV3(MobileNetV3):
49+
def __init__(self, *args, **kwargs):
50+
"""
51+
MobileNet V3 main class
52+
53+
Args:
54+
Inherits args from floating point MobileNetV3
55+
"""
56+
super().__init__(*args, **kwargs)
57+
self.quant = QuantStub()
58+
self.dequant = DeQuantStub()
59+
60+
def forward(self, x):
61+
x = self.quant(x)
62+
x = self._forward_impl(x)
63+
x = self.dequant(x)
64+
return x
65+
66+
def fuse_model(self):
67+
for m in self.modules():
68+
if type(m) == ConvBNActivation:
69+
modules_to_fuse = ['0', '1']
70+
if type(m[2]) == nn.ReLU:
71+
modules_to_fuse.append('2')
72+
fuse_modules(m, modules_to_fuse, inplace=True)
73+
elif type(m) in {QuantizableInvertedResidual, QuantizableSqueezeExcitation}:
74+
m.fuse_model()
75+
76+
77+
def _mobilenet_v3_model(
78+
arch: str,
79+
inverted_residual_setting: List[InvertedResidualConfig],
80+
last_channel: int,
81+
pretrained: bool,
82+
progress: bool,
83+
quantize: bool,
84+
**kwargs: Any
85+
):
86+
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
87+
_replace_relu(model)
88+
89+
if quantize:
90+
backend = 'qnnpack'
91+
quantize_model(model, backend)
92+
model_url = quant_model_urls.get(arch + '_' + backend, None)
93+
else:
94+
assert pretrained in [True, False]
95+
model_url = model_urls.get(arch, None)
96+
97+
if pretrained:
98+
if model_url is None:
99+
raise ValueError("No checkpoint is available for {}".format(arch))
100+
state_dict = load_state_dict_from_url(model_url, progress=progress)
101+
model.load_state_dict(state_dict)
102+
103+
return model
104+
105+
106+
def mobilenet_v3_large(pretrained=False, progress=True, quantize=False, **kwargs):
107+
"""
108+
Constructs a MobileNetV3 Large architecture from
109+
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
110+
111+
Note that quantize = True returns a quantized model with 8 bit
112+
weights. Quantized models only support inference and run on CPUs.
113+
GPU inference is not yet supported
114+
115+
Args:
116+
pretrained (bool): If True, returns a model pre-trained on ImageNet.
117+
progress (bool): If True, displays a progress bar of the download to stderr
118+
quantize (bool): If True, returns a quantized model, else returns a float model
119+
"""
120+
arch = "mobilenet_v3_large"
121+
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
122+
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs)
123+
124+
125+
def mobilenet_v3_small(pretrained=False, progress=True, quantize=False, **kwargs):
126+
"""
127+
Constructs a MobileNetV3 Small architecture from
128+
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
129+
130+
Note that quantize = True returns a quantized model with 8 bit
131+
weights. Quantized models only support inference and run on CPUs.
132+
GPU inference is not yet supported
133+
134+
Args:
135+
pretrained (bool): If True, returns a model pre-trained on ImageNet.
136+
progress (bool): If True, displays a progress bar of the download to stderr
137+
quantize (bool): If True, returns a quantized model, else returns a float model
138+
"""
139+
arch = "mobilenet_v3_small"
140+
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
141+
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs)

0 commit comments

Comments
 (0)