Skip to content

Commit 7120dd4

Browse files
authored
bug fix for 3.x sq and static quant (#1823)
Signed-off-by: Cheng, Zixuan <[email protected]>
1 parent 2764494 commit 7120dd4

File tree

5 files changed

+28
-14
lines changed

5 files changed

+28
-14
lines changed

neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,9 @@ def _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=False):
389389
else:
390390
model = torch.jit.trace(model, example_inputs, strict=False)
391391
model = torch.jit.freeze(model.eval())
392-
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
393-
# At the 2nd run, the llga pass will be triggered and the model is turned into
394-
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
395-
simple_inference(model, example_inputs, iterations=2)
396-
return model
392+
393+
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
394+
# At the 2nd run, the llga pass will be triggered and the model is turned into
395+
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
396+
simple_inference(model, example_inputs, iterations=2)
397+
return model

neural_compressor/torch/algorithms/static_quant/static_quant.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,9 @@ def _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=False):
176176
else:
177177
model = torch.jit.trace(model, example_inputs, strict=False)
178178
model = torch.jit.freeze(model.eval())
179-
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
180-
# At the 2nd run, the llga pass will be triggered and the model is turned into
181-
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
182-
simple_inference(model, example_inputs, iterations=2)
183-
return model
179+
180+
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
181+
# At the 2nd run, the llga pass will be triggered and the model is turned into
182+
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
183+
simple_inference(model, example_inputs, iterations=2)
184+
return model

neural_compressor/torch/quantization/quantize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def quantize(
7171
from neural_compressor.torch.algorithms.smooth_quant import TorchSmoothQuant
7272

7373
sq = TorchSmoothQuant(
74-
model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True
74+
q_model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True
7575
)
76-
model.sq_info = sq
77-
model = sq.transform(
76+
q_model.sq_info = sq
77+
q_model = sq.transform(
7878
alpha=quant_config.alpha,
7979
folding=quant_config.folding,
8080
auto_alpha_args=quant_config.auto_alpha_args,

test/3x/torch/quantization/test_smooth_quant.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,13 @@ def test_smooth_quant_mixed_precision(self):
193193
assert q_model is not None, "Quantization failed!"
194194

195195
# quantize API
196+
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
197+
assert q_model is not None, "Quantization failed!"
198+
196199
quant_config.excluded_precisions = ["bf16"]
197200
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
198201
assert q_model is not None, "Quantization failed!"
202+
203+
quant_config.folding = True
204+
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
205+
assert q_model is not None, "Quantization failed!"

test/3x/torch/quantization/test_static_quant.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,14 @@ def test_static_quant_with_quantize_API(self):
195195
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
196196
def test_static_quant_mixed_precision(self):
197197
fp32_model = copy.deepcopy(self.fp32_model)
198+
example_inputs = self.input
198199
quant_config = get_default_static_config()
200+
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
201+
run_fn(prepared_model)
202+
q_model = convert(prepared_model)
203+
assert q_model is not None, "Quantization failed!"
204+
199205
quant_config.excluded_precisions = ["bf16"]
200-
example_inputs = self.input
201206
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
202207
run_fn(prepared_model)
203208
q_model = convert(prepared_model)

0 commit comments

Comments
 (0)