Skip to content

Commit ee24dba

Browse files
authored
Fix onnxrt backend recover function (#1788)
Signed-off-by: Mengni Wang <[email protected]>
1 parent b6237cf commit ee24dba

File tree

3 files changed

+125
-90
lines changed

3 files changed

+125
-90
lines changed

neural_compressor/adaptor/onnxrt.py

+111-86
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,9 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
274274
if ort_version < ONNXRT152_VERSION: # pragma: no cover
275275
logger.warning("Quantize input needs onnxruntime 1.5.2 or newer.")
276276
return model
277+
if ort_version < ONNXRT170_VERSION and self.format == "qdq":
278+
logger.error("QDQ mode needs onnxruntime1.7.0 or newer.")
279+
exit(0)
277280
if model.model.opset_import[0].version < 11: # pragma: no cover
278281
logger.warning("Quantize input needs model opset 11 or newer.")
279282
if self.backend == "DnnlExecutionProvider" and any(
@@ -289,17 +292,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
289292
"please upgrade it manually to run with bf16 data type"
290293
)
291294
exit(0)
292-
293-
from neural_compressor.adaptor.ox_utils.util import QuantizationMode
294-
295-
if self.format == "qlinearops":
296-
format = QuantizationMode.QLinearOps
297-
elif self.format == "qdq":
298-
assert ort_version >= ONNXRT170_VERSION, "QDQ mode needs onnxruntime1.7.0 or newer"
299-
format = "qdq"
300-
else:
301-
format = QuantizationMode.IntegerOps
302-
303295
self.quantizable_ops = self._query_quantizable_ops(model.model)
304296
quantize_config = self._cfg_to_quantize_config(tune_cfg)
305297

@@ -405,43 +397,11 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
405397
)
406398
else:
407399
quantize_params = None
400+
q_config = self._generate_qconfig(model.model, tune_cfg, quantize_params)
408401
self.quantize_params = quantize_params
409-
410-
from neural_compressor import options
411-
from neural_compressor.adaptor.ox_utils.quantizer import Quantizer
412-
413-
quantizer = Quantizer(
414-
tmp_model,
415-
quantize_config,
416-
format,
417-
self.static,
418-
quantize_params,
419-
self.quantizable_op_types,
420-
self.query_handler.get_fallback_list(),
421-
self.reduce_range,
422-
(
423-
options.onnxrt.qdq_setting.AddQDQPairToWeight
424-
if "add_qdq_pair_to_weight" not in self.recipes
425-
else self.recipes.get("add_qdq_pair_to_weight", False)
426-
),
427-
(
428-
options.onnxrt.qdq_setting.OpTypesToExcludeOutputQuantizatioin
429-
if "optypes_to_exclude_output_quant" not in self.recipes
430-
else self.recipes.get("optypes_to_exclude_output_quant", [])
431-
),
432-
(
433-
options.onnxrt.qdq_setting.DedicatedQDQPair
434-
if "dedicated_qdq_pair" not in self.recipes
435-
else self.recipes.get("dedicated_qdq_pair", False)
436-
),
437-
self.backend,
438-
)
439-
quantizer.quantize_model()
440-
tmp_model.q_config = self._generate_qconfig(model.model, tune_cfg, quantize_params)
441-
tmp_model.model = quantizer.model.model
442-
self.quantize_config = quantize_config # update so other methods can know current configs
402+
tmp_model = self._quantize_model(tmp_model, quantize_config, quantize_params)
403+
tmp_model.q_config = q_config
443404
self._dump_model_op_stats(tmp_model)
444-
tmp_model.topological_sort()
445405

446406
# if the model is large and acc tuning is required, save it to workspace
447407
if not self.performance_only and tmp_model.is_large_model: # pragma: no cover
@@ -496,13 +456,21 @@ def _get_split_model_quantize_params(
496456
)
497457
return split_quantize_params, dataloder_for_next_split_model
498458

499-
def _quantize_split_model(self, split_model, quantize_config, quantize_params, quantized_model_merged):
500-
"""Quantize split model, and merge the quantized models to generate final model."""
459+
def _quantize_model(self, model, quantize_config, quantize_params):
460+
"""Quantize model."""
501461
from neural_compressor import options
502462
from neural_compressor.adaptor.ox_utils.quantizer import Quantizer
463+
from neural_compressor.adaptor.ox_utils.util import QuantizationMode
464+
465+
if self.format == "qlinearops":
466+
format = QuantizationMode.QLinearOps
467+
elif self.format == "qdq":
468+
format = "qdq"
469+
else:
470+
format = QuantizationMode.IntegerOps
503471

504472
quantizer = Quantizer(
505-
split_model,
473+
model,
506474
quantize_config,
507475
format,
508476
self.static,
@@ -528,14 +496,19 @@ def _quantize_split_model(self, split_model, quantize_config, quantize_params, q
528496
self.backend,
529497
)
530498
quantizer.quantize_model()
531-
split_model.model = quantizer.model.model
532-
split_model.topological_sort()
499+
model.model = quantizer.model.model
500+
self.quantize_config = quantize_config # update so other methods can know current configs
501+
model.topological_sort()
502+
return model
533503

504+
def _quantize_split_model(self, split_model, quantize_config, quantize_params, quantized_model_merged):
505+
"""Quantize split model, and merge the quantized models to generate final model."""
506+
split_model = self._quantize_model(split_model, quantize_config, quantize_params)
534507
if quantized_model_merged is None:
535-
quantized_model_merged = quantizer.model
508+
quantized_model_merged = split_model
536509
quantized_model_merged.write_external_data_to_new_location(overwrite=True)
537510
else:
538-
quantized_model_merged.merge_split_models(quantizer.model)
511+
quantized_model_merged.merge_split_models(split_model)
539512

540513
return quantized_model_merged
541514

@@ -640,57 +613,109 @@ def recover(self, model, q_config):
640613
"""
641614
self._pre_optimize(model)
642615
model = self.pre_optimized_model
616+
643617
ort_version = Version(ort.__version__)
644618
if ort_version < ONNXRT152_VERSION: # pragma: no cover
645619
logger.warning("Quantize input needs onnxruntime 1.5.2 or newer.")
646620
return model
647621
if model.model.opset_import[0].version < 11: # pragma: no cover
648622
logger.warning("Quantize input needs model opset 11 or newer.")
623+
if ort_version < ONNXRT170_VERSION and self.format == "qdq":
624+
logger.error("QDQ mode needs onnxruntime1.7.0 or newer.")
625+
exit(0)
626+
if self.backend == "DnnlExecutionProvider" and any(
627+
[i.domain in ["", "ai.onnx"] and i.version < 15 for i in model.model.opset_import]
628+
): # pragma: no cover
629+
from onnx import version_converter
630+
631+
try:
632+
model = self._rename_node(ONNXModel(version_converter.convert_version(model.model, 15)))
633+
except:
634+
logging.warning(
635+
"Fail to upgrade model opset_import to >= 15, "
636+
"please upgrade it manually to run with bf16 data type"
637+
)
638+
exit(0)
649639

650640
from neural_compressor.adaptor.ox_utils.util import QuantizationMode
651641

652-
if self.format in ["qlinearops"]:
642+
if self.format == "qlinearops":
653643
format = QuantizationMode.QLinearOps
654644
elif self.format == "qdq":
655-
assert ort_version >= ONNXRT170_VERSION, "QDQ mode needs onnxruntime1.7.0 or newer"
656-
format = self.format
645+
format = "qdq"
657646
else:
658647
format = QuantizationMode.IntegerOps
659-
from neural_compressor import options
660-
from neural_compressor.adaptor.ox_utils.quantizer import Quantizer
661648

662649
self.quantizable_ops = self._query_quantizable_ops(model.model)
663650
quantize_params, tune_cfg = self._parse_qconfig(q_config)
664651
quantize_config = self._cfg_to_quantize_config(tune_cfg)
665-
quantizer = Quantizer(
666-
model.model,
667-
quantize_config,
668-
format,
669-
self.static,
670-
quantize_params,
671-
self.quantizable_op_types,
672-
self.query_handler.get_fallback_list(),
673-
self.reduce_range,
674-
(
675-
options.onnxrt.qdq_setting.AddQDQPairToWeight
676-
if not options.onnxrt.qdq_setting.AddQDQPairToWeight
677-
else self.recipes.get("add_qdq_pair_to_weight", False)
678-
),
679-
(
680-
options.onnxrt.qdq_setting.OpTypesToExcludeOutputQuantizatioin
681-
if options.onnxrt.qdq_setting.OpTypesToExcludeOutputQuantizatioin is not None
682-
else self.recipes.get("optypes_to_exclude_output_quant", [])
683-
),
684-
(
685-
options.onnxrt.qdq_setting.DedicatedQDQPair
686-
if not options.onnxrt.qdq_setting.DedicatedQDQPair
687-
else self.recipes.get("dedicated_qdq_pair", False)
688-
),
689-
)
690652

691-
quantizer.quantize_model()
692-
model.model = quantizer.model.model
693-
model.topological_sort()
653+
if self._need_smooth_quant(tune_cfg):
654+
logger.error("Don't support to recover quantized model with smooth quant from original fp32 model.")
655+
exit(0)
656+
657+
if self.recipes.get("layer_wise_quant", False) and not self.dynamic:
658+
# layer-wise quantization
659+
# details refer to docs/source/quantization_weight_only.md#layer-wise-quantization
660+
_model_to_split = copy.deepcopy(model)
661+
662+
split_nodes = _model_to_split.find_split_nodes()
663+
logger.info(
664+
"Will split model into {} parts to do layer-wise quantization".format(
665+
len([node.name for node in split_nodes]) + 1
666+
)
667+
)
668+
logger.debug(
669+
"Will split model with these nodes for layer-wise quantization: {}".format(
670+
[node.name for node in split_nodes]
671+
)
672+
)
673+
674+
split_idx = 1
675+
model_to_split = [_model_to_split]
676+
quantized_model_merged = None
677+
678+
while len(model_to_split) != 0:
679+
split_model = model_to_split.pop(0)
680+
split_node = split_nodes.pop(0)
681+
save_both_split_models = True if len(split_nodes) == 0 else False
682+
shape_infer = True if split_idx == 1 else False
683+
684+
# split model with given split_node
685+
split_model_part_1, split_model_part_2 = split_model.split_model_with_node(
686+
split_node.name, model.model_path, shape_infer, save_both_split_models
687+
)
688+
if not save_both_split_models:
689+
# append split_model_part_2 to do next split
690+
model_to_split.append(split_model_part_2)
691+
692+
logger.info("Quantize split model {}".format(split_idx))
693+
694+
# quantize split model
695+
quantized_model_merged = self._quantize_split_model(
696+
split_model_part_1, quantize_config, quantize_params, quantized_model_merged
697+
)
698+
699+
split_idx += 1
700+
701+
# if this is the last split, then quantize the last split model
702+
if save_both_split_models:
703+
logger.info("Quantize split model {}".format(split_idx))
704+
705+
# quantize split model
706+
quantized_model_merged = self._quantize_split_model(
707+
split_model_part_2, quantize_config, quantize_params, quantized_model_merged
708+
)
709+
quantized_model_merged.re_org_output(model.output()) # re-org output as the origin output
710+
711+
model.model = quantized_model_merged.model
712+
self._dump_model_op_stats(model)
713+
model.check_is_large_model()
714+
715+
else:
716+
model = self._quantize_model(model, quantize_config, quantize_params)
717+
718+
self._dump_model_op_stats(model)
694719
return model
695720

696721
def _parse_qconfig(self, q_config):

test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
from packaging.version import Version
1515
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
1616

17-
from neural_compressor import PostTrainingQuantConfig, quantization
17+
from neural_compressor import PostTrainingQuantConfig, quantization, set_workspace
1818
from neural_compressor.adaptor import FRAMEWORKS
1919
from neural_compressor.adaptor.pytorch import get_torch_version
2020
from neural_compressor.conf.config import conf
2121
from neural_compressor.data import DATALOADERS, DataLoader, Datasets
2222
from neural_compressor.experimental import Benchmark, Quantization, common
2323
from neural_compressor.model import Model
24+
from neural_compressor.utils.utility import recover
2425

2526

2627
def build_static_yaml():
@@ -898,6 +899,7 @@ def setUpClass(self):
898899
self.albert_model = onnx.load(self.albert_export_path)
899900
self.gather_matmul_model = build_matmul_gather_model()
900901
build_benchmark()
902+
set_workspace("nc_workspace")
901903

902904
@classmethod
903905
def tearDownClass(self):
@@ -1390,8 +1392,6 @@ def test_adaptor(self):
13901392
self.assertNotEqual(q_model, None)
13911393

13921394
# check recover model function
1393-
from neural_compressor.utils.utility import recover
1394-
13951395
model = recover(self.mb_v2_model, "./nc_workspace/recover/history.snapshot", 0)
13961396
self.assertTrue(model.model == q_model.model)
13971397

@@ -1489,6 +1489,10 @@ def eval(model):
14891489
q_model = quantization.fit(self.matmul_model, config, calib_dataloader=self.matmul_dataloader, eval_func=eval)
14901490
self.assertTrue("QLinearMatMul" in [i.op_type for i in q_model.nodes()])
14911491

1492+
q_model = quantization.fit(self.matmul_model, config, calib_dataloader=self.matmul_dataloader)
1493+
recover_model = recover(self.matmul_model, "nc_workspace/history.snapshot", 0)
1494+
self.assertTrue(q_model.model == recover_model.model)
1495+
14921496
config = PostTrainingQuantConfig(approach="dynamic")
14931497
q_model = quantization.fit(self.matmul_model, config, calib_dataloader=self.matmul_dataloader, eval_func=eval)
14941498
self.assertTrue("MatMulInteger" in [i.op_type for i in q_model.nodes()])
@@ -1535,6 +1539,8 @@ def test_smooth_quant(self):
15351539
)
15361540
q_model = quantization.fit(self.conv_model, config, calib_dataloader=self.cv_dataloader)
15371541
self.assertEqual(len([i for i in q_model.nodes() if i.op_type == "Mul"]), 2)
1542+
with self.assertRaises(SystemExit):
1543+
recover_model = recover(self.conv_model, "nc_workspace/history.snapshot", 0)
15381544

15391545
def test_smooth_quant_args(self):
15401546
from neural_compressor.model.onnx_model import ONNXModel

test/adaptor/onnxrt_adaptor/test_layer_wise.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import onnxruntime as ort
88
from transformers import AutoTokenizer
99

10-
from neural_compressor import PostTrainingQuantConfig, quantization
10+
from neural_compressor import PostTrainingQuantConfig, quantization, set_workspace
1111
from neural_compressor.utils.constant import FP32
12+
from neural_compressor.utils.utility import recover
1213

1314

1415
def Inference(model_path, data):
@@ -44,6 +45,7 @@ def setUpClass(self):
4445

4546
self.model = onnx.load("tiny-llama/decoder_model.onnx")
4647
self.dataloader = DummyNLPDataloader("yujiepan/llama-2-tiny-3layers-random")
48+
set_workspace("nc_workspace")
4749

4850
@classmethod
4951
def tearDownClass(self):
@@ -57,6 +59,8 @@ def test_layer_wise_W8A8_quant(self):
5759
calibration_sampling_size=[1], recipes={"layer_wise_quant": True}, op_type_dict={"^((?!(MatMul)).)*$": FP32}
5860
)
5961
q_model = quantization.fit("tiny-llama/decoder_model.onnx", config, calib_dataloader=self.dataloader)
62+
recover_model = recover("tiny-llama/decoder_model.onnx", "nc_workspace/history.snapshot", 0)
63+
self.assertTrue(recover_model.model == q_model.model)
6064
q_model.save(layerwise_quantized_model_path)
6165

6266
# not layer-wise quantization

0 commit comments

Comments
 (0)