@@ -26,10 +26,12 @@ def forward(self, x):
26
26
27
27
28
28
model = Model ()
29
+ example_inputs = torch .rand ([1 , 3 ])
29
30
30
31
31
32
def run_fn (model ):
32
- model (torch .randn ([1 , 3 ]))
33
+ for i in range (10 ):
34
+ model (example_inputs )
33
35
34
36
35
37
class TestSmoothQuant :
@@ -40,7 +42,6 @@ def teardown_class(self):
40
42
def test_smooth_quant_default (self ):
41
43
fp32_model = copy .deepcopy (model )
42
44
quant_config = get_default_sq_config ()
43
- example_inputs = torch .randn ([1 , 3 ])
44
45
prepared_model = prepare (fp32_model , quant_config = quant_config , example_inputs = example_inputs )
45
46
run_fn (prepared_model )
46
47
q_model = convert (prepared_model )
@@ -57,7 +58,6 @@ def test_smooth_quant_default(self):
57
58
def test_smooth_quant_fallback (self ):
58
59
fp32_model = copy .deepcopy (model )
59
60
quant_config = get_default_sq_config ()
60
- example_inputs = torch .randn ([1 , 3 ])
61
61
# fallback by op_type
62
62
quant_config .set_local (torch .nn .Linear , SmoothQuantConfig (w_dtype = "fp32" , act_dtype = "fp32" ))
63
63
prepared_model = prepare (fp32_model , quant_config = quant_config , example_inputs = example_inputs )
@@ -87,10 +87,6 @@ def test_sq_linear_params(self, act_sym, act_algo, alpha, folding, scale_sharing
87
87
quant_config = SmoothQuantConfig (
88
88
act_sym = act_sym , act_algo = act_algo , alpha = alpha , folding = folding , scale_sharing = scale_sharing
89
89
)
90
- example_inputs = torch .zeros ([1 , 3 ])
91
-
92
- def run_fn (model ):
93
- model (example_inputs )
94
90
95
91
prepared_model = prepare (fp32_model , quant_config = quant_config , example_inputs = example_inputs )
96
92
run_fn (prepared_model )
@@ -102,7 +98,6 @@ def run_fn(model):
102
98
103
99
@pytest .mark .skipif (not is_ipex_available (), reason = "Requires IPEX" )
104
100
def test_sq_ipex_accuracy (self ):
105
- example_inputs = torch .zeros ([1 , 3 ])
106
101
qconfig = ipex .quantization .get_smooth_quant_qconfig_mapping (alpha = 0.5 )
107
102
user_model = copy .deepcopy (model )
108
103
user_model = ipex .quantization .prepare (user_model .eval (), qconfig , example_inputs = example_inputs , inplace = True )
@@ -144,7 +139,6 @@ def run_fn(model):
144
139
def test_sq_save_load (self ):
145
140
fp32_model = copy .deepcopy (model )
146
141
quant_config = get_default_sq_config ()
147
- example_inputs = torch .zeros ([1 , 3 ])
148
142
prepared_model = prepare (fp32_model , quant_config = quant_config , example_inputs = example_inputs )
149
143
run_fn (prepared_model )
150
144
q_model = convert (prepared_model )
@@ -171,7 +165,6 @@ def test_sq_save_load(self):
171
165
def test_smooth_quant_with_quantize_API (self ):
172
166
fp32_model = copy .deepcopy (model )
173
167
quant_config = get_default_sq_config ()
174
- example_inputs = torch .randn ([1 , 3 ])
175
168
q_model = quantize (fp32_model , quant_config = quant_config , run_fn = run_fn , example_inputs = example_inputs )
176
169
assert q_model is not None , "Quantization failed!"
177
170
@@ -184,7 +177,6 @@ def test_smooth_quant_with_quantize_API(self):
184
177
def test_smooth_quant_mixed_precision (self ):
185
178
fp32_model = copy .deepcopy (model )
186
179
quant_config = get_default_sq_config () # do mixed_precison by default.
187
- example_inputs = torch .randn ([1 , 3 ])
188
180
189
181
# prepare/convert API
190
182
prepared_model = prepare (fp32_model , quant_config = quant_config , example_inputs = example_inputs )
@@ -203,3 +195,43 @@ def test_smooth_quant_mixed_precision(self):
203
195
quant_config .folding = True
204
196
q_model = quantize (fp32_model , quant_config = quant_config , run_fn = run_fn , example_inputs = example_inputs )
205
197
assert q_model is not None , "Quantization failed!"
198
+
199
+ @pytest .mark .skipif (not is_ipex_available (), reason = "Requires IPEX" )
200
+ def test_smooth_quant_auto (self ):
201
+ fp32_model = copy .deepcopy (model )
202
+ example_inputs = torch .rand ([1 , 3 ])
203
+
204
+ def run_fn (model ):
205
+ for i in range (100 ):
206
+ model (example_inputs )
207
+
208
+ # block-wise
209
+ quant_config = SmoothQuantConfig (
210
+ alpha = "auto" ,
211
+ alpha_min = 0.45 ,
212
+ alpha_max = 0.55 ,
213
+ alpha_step = 0.01 ,
214
+ shared_criterion = "mean" ,
215
+ do_blockwise = True ,
216
+ folding = False ,
217
+ )
218
+ q_model = quantize (fp32_model , quant_config = quant_config , run_fn = run_fn , example_inputs = example_inputs )
219
+ assert q_model is not None , "Quantization failed!"
220
+ output1 = fp32_model (example_inputs )
221
+ output2 = q_model (example_inputs )
222
+ assert torch .allclose (output1 , output2 , atol = 2e-2 ), "Accuracy gap atol > 0.02 is unexpected. Please check."
223
+
224
+ # layer-wise
225
+ quant_config = SmoothQuantConfig (
226
+ alpha = "auto" ,
227
+ alpha_min = 0.45 ,
228
+ alpha_max = 0.55 ,
229
+ alpha_step = 0.01 ,
230
+ shared_criterion = "max" ,
231
+ do_blockwise = False ,
232
+ folding = False ,
233
+ )
234
+ q_model = quantize (fp32_model , quant_config = quant_config , run_fn = run_fn , example_inputs = example_inputs )
235
+ assert q_model is not None , "Quantization failed!"
236
+ output2 = q_model (example_inputs )
237
+ assert torch .allclose (output1 , output2 , atol = 2e-2 ), "Accuracy gap atol > 0.02 is unexpected. Please check."
0 commit comments