1
+ import torch
1
2
from torch import nn , Tensor
2
3
from torchvision .models .utils import load_state_dict_from_url
3
4
from torchvision .models .mobilenetv3 import InvertedResidual , InvertedResidualConfig , ConvBNActivation , MobileNetV3 ,\
4
5
SqueezeExcitation , model_urls , _mobilenet_v3_conf
5
6
from torch .quantization import QuantStub , DeQuantStub , fuse_modules
6
- from typing import Any , List
7
- from .utils import _replace_relu , quantize_model
7
+ from typing import Any , List , Optional
8
+ from .utils import _replace_relu
8
9
9
10
10
11
__all__ = ['QuantizableMobileNetV3' , 'mobilenet_v3_large' , 'mobilenet_v3_small' ]
11
12
12
- # TODO: Add URLs
13
13
quant_model_urls = {
14
- 'mobilenet_v3_large_qnnpack' : None ,
14
+ 'mobilenet_v3_large_qnnpack' :
15
+ "https://github.com/datumbox/torchvision-models/raw/main/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth" ,
15
16
'mobilenet_v3_small_qnnpack' : None ,
16
17
}
17
18
@@ -69,6 +70,18 @@ def fuse_model(self):
69
70
m .fuse_model ()
70
71
71
72
73
+ def _load_weights (
74
+ arch : str ,
75
+ model : QuantizableMobileNetV3 ,
76
+ model_url : Optional [str ],
77
+ progress : bool ,
78
+ ):
79
+ if model_url is None :
80
+ raise ValueError ("No checkpoint is available for {}" .format (arch ))
81
+ state_dict = load_state_dict_from_url (model_url , progress = progress )
82
+ model .load_state_dict (state_dict )
83
+
84
+
72
85
def _mobilenet_v3_model (
73
86
arch : str ,
74
87
inverted_residual_setting : List [InvertedResidualConfig ],
@@ -83,17 +96,18 @@ def _mobilenet_v3_model(
83
96
84
97
if quantize :
85
98
backend = 'qnnpack'
86
- quantize_model (model , backend )
87
- model_url = quant_model_urls .get (arch + '_' + backend , None )
99
+
100
+ model .fuse_model ()
101
+ model .qconfig = torch .quantization .get_default_qat_qconfig (backend )
102
+ torch .quantization .prepare_qat (model , inplace = True )
103
+
104
+ if pretrained :
105
+ _load_weights (arch , model , quant_model_urls .get (arch + '_' + backend , None ), progress )
106
+
107
+ torch .quantization .convert (model , inplace = True )
88
108
else :
89
- assert pretrained in [True , False ]
90
- model_url = model_urls .get (arch , None )
91
-
92
- if pretrained :
93
- if model_url is None :
94
- raise ValueError ("No checkpoint is available for {}" .format (arch ))
95
- state_dict = load_state_dict_from_url (model_url , progress = progress )
96
- model .load_state_dict (state_dict )
109
+ if pretrained :
110
+ _load_weights (arch , model , model_urls .get (arch , None ), progress )
97
111
98
112
return model
99
113
0 commit comments