33
33
34
34
from neural_compressor .torch .algorithms import Quantizer
35
35
from neural_compressor .torch .utils import logger
36
+ from neural_compressor .torch .utils .auto_accelerator import auto_detect_accelerator
36
37
37
38
from .utility import (
38
39
CpuInfo ,
39
40
cfg_to_qconfig ,
40
41
dump_model_op_stats ,
42
+ generate_xpu_qconfig ,
41
43
get_ipex_version ,
42
44
get_quantizable_ops_recursively ,
43
45
ipex_config_path ,
@@ -56,6 +58,7 @@ def __init__(self, quant_config: OrderedDict = {}):
56
58
"""
57
59
super ().__init__ (quant_config )
58
60
self .user_cfg = OrderedDict ()
61
+ self .device = auto_detect_accelerator ().current_device ()
59
62
60
63
def prepare (self , model , example_inputs , inplace = True , * args , ** kwargs ):
61
64
"""Prepares a given model for quantization.
@@ -70,43 +73,61 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
70
73
"""
71
74
assert example_inputs is not None , "Please provide example_inputs for static quantization."
72
75
73
- _ , cfgs , op_infos_from_cfgs , output_tensor_id_op_name , _ = get_quantizable_ops_recursively (
74
- model , example_inputs
75
- )
76
- # update json file in ipex_config_path; map ipex op_name to pt op_name
77
- self .user_cfg = cfg_to_qconfig (self .quant_config , cfgs , op_infos_from_cfgs , output_tensor_id_op_name )
78
- model .eval ()
76
+ if self .device == "cpu" :
77
+ _ , cfgs , op_infos_from_cfgs , output_tensor_id_op_name , _ = get_quantizable_ops_recursively (
78
+ model , example_inputs
79
+ )
80
+ # update json file in ipex_config_path; map ipex op_name to pt op_name
81
+ self .user_cfg = cfg_to_qconfig (self .quant_config , cfgs , op_infos_from_cfgs , output_tensor_id_op_name )
82
+ else : # pragma: no cover
83
+ model = model .to ("xpu" )
79
84
80
- use_bf16 = self . quant_config . get ( "use_bf16" , None )
85
+ model . eval ( )
81
86
82
87
# Check save_qconf_summary part is a workaround for IPEX bug.
83
- # Sometimes the prepared model from get_op_capablitiy loss this attribute
84
- if not hasattr (model , "save_qconf_summary" ) or not hasattr (model , "load_qconf_summary" ):
85
- from torch .ao .quantization import MinMaxObserver , PerChannelMinMaxObserver , QConfig
86
-
87
- if ipex_ver .release >= Version ("2.1" ).release :
88
- # HistogramObserver will cause a performance issue.
89
- # static_qconfig = ipex.quantization.default_static_qconfig_mapping
90
- qconfig = QConfig (
91
- activation = MinMaxObserver .with_args (qscheme = torch .per_tensor_affine , dtype = torch .quint8 ),
92
- weight = PerChannelMinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_channel_symmetric ),
93
- )
94
- from torch .ao .quantization import QConfigMapping
95
-
96
- static_qconfig = QConfigMapping ().set_global (qconfig )
97
- else :
98
- static_qconfig = QConfig (
99
- activation = MinMaxObserver .with_args (qscheme = torch .per_tensor_affine , dtype = torch .quint8 ),
100
- weight = PerChannelMinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_channel_symmetric ),
101
- )
102
- if isinstance (example_inputs , dict ):
103
- model = ipex .quantization .prepare (
104
- model , static_qconfig , example_kwarg_inputs = example_inputs , inplace = inplace
105
- )
88
+ # Sometimes the prepared model from get_op_capablitiy loss this attributes
89
+ if not hasattr (model , "save_qconf_summary" ) or not hasattr (model , "load_qconf_summary" ): # pragma: no cover
90
+ from torch .ao .quantization import HistogramObserver , MinMaxObserver , PerChannelMinMaxObserver , QConfig
91
+
92
+ if self .device != "cpu" : # pragma: no cover
93
+ from torch .quantization .quantize_jit import prepare_jit
94
+
95
+ with torch .no_grad ():
96
+ modelJit = torch .jit .trace (model , example_inputs )
97
+ qconfig = generate_xpu_qconfig (self .quant_config )
98
+ model = prepare_jit (modelJit , qconfig , inplace )
106
99
else :
107
- model = ipex .quantization .prepare (model , static_qconfig , example_inputs = example_inputs , inplace = inplace )
100
+ if ipex_ver .release >= Version ("2.1" ).release :
101
+ # HistogramObserver will cause a performance issue.
102
+ # static_qconfig = ipex.quantization.default_static_qconfig_mapping
103
+ qconfig = QConfig (
104
+ activation = MinMaxObserver .with_args (qscheme = torch .per_tensor_affine , dtype = torch .quint8 ),
105
+ weight = PerChannelMinMaxObserver .with_args (
106
+ dtype = torch .qint8 , qscheme = torch .per_channel_symmetric
107
+ ),
108
+ )
109
+ from torch .ao .quantization import QConfigMapping
110
+
111
+ static_qconfig = QConfigMapping ().set_global (qconfig )
112
+ else : # pragma: no cover
113
+ static_qconfig = QConfig (
114
+ activation = MinMaxObserver .with_args (qscheme = torch .per_tensor_affine , dtype = torch .quint8 ),
115
+ weight = PerChannelMinMaxObserver .with_args (
116
+ dtype = torch .qint8 , qscheme = torch .per_channel_symmetric
117
+ ),
118
+ )
119
+ if isinstance (example_inputs , dict ):
120
+ model = ipex .quantization .prepare (
121
+ model , static_qconfig , example_kwarg_inputs = example_inputs , inplace = inplace
122
+ )
123
+ else :
124
+ model = ipex .quantization .prepare (
125
+ model , static_qconfig , example_inputs = example_inputs , inplace = inplace
126
+ )
127
+
128
+ if self .device == "cpu" :
129
+ model .load_qconf_summary (qconf_summary = ipex_config_path )
108
130
109
- model .load_qconf_summary (qconf_summary = ipex_config_path )
110
131
return model
111
132
112
133
def convert (self , model , example_inputs , inplace = True , * args , ** kwargs ):
@@ -124,18 +145,27 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
124
145
125
146
from neural_compressor .torch .algorithms .static_quant import save
126
147
127
- model . save_qconf_summary ( qconf_summary = ipex_config_path )
128
- model = _ipex_post_quant_process ( model , example_inputs , use_bf16 , inplace = inplace )
148
+ if self . device != "cpu" : # pragma: no cover
149
+ from torch . quantization . quantize_jit import convert_jit
129
150
130
- with open (ipex_config_path , "r" ) as f :
131
- model .tune_cfg = json .load (f )
132
- model .ipex_config_path = ipex_config_path
151
+ model = convert_jit (model , inplace )
152
+ simple_inference (model , example_inputs , iterations = 2 )
153
+ model .qconfig = self .quant_config ["op" ]
154
+ dump_model_op_stats (model .qconfig )
155
+ else :
156
+ model .save_qconf_summary (qconf_summary = ipex_config_path )
157
+ model = _ipex_post_quant_process (model , example_inputs , use_bf16 , inplace = inplace )
133
158
134
- dump_model_op_stats (self .user_cfg )
159
+ with open (ipex_config_path , "r" ) as f :
160
+ model .tune_cfg = json .load (f )
161
+ model .ipex_config_path = ipex_config_path
162
+
163
+ dump_model_op_stats (self .user_cfg )
135
164
136
- logger .info ("Static quantization done." )
137
165
model .ori_save = model .save
138
166
model .save = MethodType (save , model )
167
+
168
+ logger .info ("Static quantization done." )
139
169
return model
140
170
141
171
0 commit comments