Skip to content

Commit 274c6a1

Browse files
committed
Workaround for loading correct weights of quant model.
1 parent 3d69b2a commit 274c6a1

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

torchvision/models/quantization/mobilenetv3.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1+
import torch
12
from torch import nn, Tensor
23
from torchvision.models.utils import load_state_dict_from_url
34
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
45
SqueezeExcitation, model_urls, _mobilenet_v3_conf
56
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
6-
from typing import Any, List
7-
from .utils import _replace_relu, quantize_model
7+
from typing import Any, List, Optional
8+
from .utils import _replace_relu
89

910

1011
__all__ = ['QuantizableMobileNetV3', 'mobilenet_v3_large', 'mobilenet_v3_small']
1112

12-
# TODO: Add URLs
1313
quant_model_urls = {
14-
'mobilenet_v3_large_qnnpack': None,
14+
'mobilenet_v3_large_qnnpack':
15+
"https://github.com/datumbox/torchvision-models/raw/main/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
1516
'mobilenet_v3_small_qnnpack': None,
1617
}
1718

@@ -69,6 +70,18 @@ def fuse_model(self):
6970
m.fuse_model()
7071

7172

73+
def _load_weights(
74+
arch: str,
75+
model: QuantizableMobileNetV3,
76+
model_url: Optional[str],
77+
progress: bool,
78+
):
79+
if model_url is None:
80+
raise ValueError("No checkpoint is available for {}".format(arch))
81+
state_dict = load_state_dict_from_url(model_url, progress=progress)
82+
model.load_state_dict(state_dict)
83+
84+
7285
def _mobilenet_v3_model(
7386
arch: str,
7487
inverted_residual_setting: List[InvertedResidualConfig],
@@ -83,17 +96,18 @@ def _mobilenet_v3_model(
8396

8497
if quantize:
8598
backend = 'qnnpack'
86-
quantize_model(model, backend)
87-
model_url = quant_model_urls.get(arch + '_' + backend, None)
99+
100+
model.fuse_model()
101+
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
102+
torch.quantization.prepare_qat(model, inplace=True)
103+
104+
if pretrained:
105+
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress)
106+
107+
torch.quantization.convert(model, inplace=True)
88108
else:
89-
assert pretrained in [True, False]
90-
model_url = model_urls.get(arch, None)
91-
92-
if pretrained:
93-
if model_url is None:
94-
raise ValueError("No checkpoint is available for {}".format(arch))
95-
state_dict = load_state_dict_from_url(model_url, progress=progress)
96-
model.load_state_dict(state_dict)
109+
if pretrained:
110+
_load_weights(arch, model, model_urls.get(arch, None), progress)
97111

98112
return model
99113

0 commit comments

Comments
 (0)