Skip to content

Commit 6142e48

Browse files
authored
Support ONNXRT layer-wise W8A8 quantization (#1389)
Signed-off-by: yuwenzho <[email protected]>
1 parent 8447d70 commit 6142e48

File tree

5 files changed

+648
-42
lines changed

5 files changed

+648
-42
lines changed

neural_compressor/adaptor/onnxrt.py

Lines changed: 168 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from neural_compressor.adaptor.ox_utils.util import ONNXRT_BACKENDS, PROVIDERS, to_numpy
3636
from neural_compressor.adaptor.query import QueryBackendCapability
3737
from neural_compressor.data.dataloaders.base_dataloader import BaseDataLoader
38+
from neural_compressor.model.onnx_model import ONNXModel
3839
from neural_compressor.utils.utility import GLOBAL_STATE, MODE, CpuInfo, LazyImport, Statistics, dump_elapsed_time
3940

4041
onnx = LazyImport("onnx")
@@ -267,8 +268,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
267268
): # pragma: no cover
268269
from onnx import version_converter
269270

270-
from neural_compressor.model.onnx_model import ONNXModel
271-
272271
try:
273272
model = self._rename_node(ONNXModel(version_converter.convert_version(model.model, 15)))
274273
except:
@@ -308,18 +307,146 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
308307

309308
iterations = tune_cfg.get("calib_iteration", 1)
310309
calib_sampling_size = tune_cfg.get("calib_sampling_size", 1)
311-
if not self.dynamic:
312-
calib_iterations = self._reset_calib_iter(data_loader, calib_sampling_size, iterations)
313-
quantize_params = self._get_quantize_params(tmp_model, data_loader, quantize_config, calib_iterations)
310+
311+
if self.recipes.get("layer_wise_quant", False) and not self.dynamic:
312+
# layer-wise quantization
313+
# details refer to docs/source/quantization_weight_only.md#layer-wise-quantization
314+
_model_to_split = copy.deepcopy(tmp_model)
315+
316+
split_nodes = _model_to_split.find_split_nodes()
317+
logger.info(
318+
"Will split model into {} parts to do layer-wise quantization".format(
319+
len([node.name for node in split_nodes]) + 1
320+
)
321+
)
322+
logger.debug(
323+
"Will split model with these nodes for layer-wise quantization: {}".format(
324+
[node.name for node in split_nodes]
325+
)
326+
)
327+
328+
split_idx = 1
329+
model_to_split = [_model_to_split]
330+
dataloader_for_split_model = [data_loader]
331+
quantize_params = {}
332+
quantized_model_merged = None
333+
334+
while len(model_to_split) != 0:
335+
split_model = model_to_split.pop(0)
336+
split_node = split_nodes.pop(0)
337+
save_both_split_models = True if len(split_nodes) == 0 else False
338+
shape_infer = True if split_idx == 1 else False
339+
340+
# split model with given split_node
341+
split_model_part_1, split_model_part_2 = split_model.split_model_with_node(
342+
split_node.name, tmp_model.model_path, shape_infer, save_both_split_models
343+
)
344+
if not save_both_split_models:
345+
# append split_model_part_2 to do next split
346+
model_to_split.append(split_model_part_2)
347+
348+
logger.info("Quantize split model {}".format(split_idx))
349+
# get quantize params of split model
350+
split_quantize_params, dataloder_for_next_split_model = self._get_split_model_quantize_params(
351+
split_model_part_1, dataloader_for_split_model, quantize_config, calib_sampling_size, iterations
352+
)
353+
dataloader_for_split_model.append(dataloder_for_next_split_model)
354+
quantize_params.update(split_quantize_params)
355+
356+
# quantize split model
357+
quantized_model_merged = self._quantize_split_model(
358+
split_model_part_1, quantize_config, split_quantize_params, quantized_model_merged
359+
)
360+
361+
split_idx += 1
362+
363+
# if this is the last split, then quantize the last split model
364+
if save_both_split_models:
365+
logger.info("Quantize split model {}".format(split_idx))
366+
# get quantize params of split model
367+
split_quantize_params, dataloder_for_next_split_model = self._get_split_model_quantize_params(
368+
split_model_part_2, dataloader_for_split_model, quantize_config, calib_sampling_size, iterations
369+
)
370+
quantize_params.update(split_quantize_params)
371+
372+
# quantize split model
373+
quantized_model_merged = self._quantize_split_model(
374+
split_model_part_2, quantize_config, split_quantize_params, quantized_model_merged
375+
)
376+
quantized_model_merged.re_org_output(tmp_model.output()) # re-org output as the origin output
377+
378+
self.quantize_params = quantize_params
379+
tmp_model.q_config = self._generate_qconfig(model.model, tune_cfg, quantize_params)
380+
tmp_model.model = quantized_model_merged.model
381+
self.quantize_config = quantize_config # update so other methods can know current configs
382+
self._dump_model_op_stats(tmp_model)
383+
tmp_model.topological_sort()
384+
tmp_model.check_is_large_model()
385+
return tmp_model
386+
314387
else:
315-
quantize_params = None
316-
self.quantize_params = quantize_params
388+
if not self.dynamic:
389+
calib_iterations = self._reset_calib_iter(data_loader, calib_sampling_size, iterations)
390+
quantize_params, _ = self._get_quantize_params(
391+
tmp_model, data_loader, quantize_config, calib_iterations
392+
)
393+
else:
394+
quantize_params = None
395+
self.quantize_params = quantize_params
396+
397+
from neural_compressor import options
398+
from neural_compressor.adaptor.ox_utils.quantizer import Quantizer
317399

400+
quantizer = Quantizer(
401+
tmp_model,
402+
quantize_config,
403+
format,
404+
self.static,
405+
quantize_params,
406+
self.quantizable_op_types,
407+
self.query_handler.get_fallback_list(),
408+
self.reduce_range,
409+
options.onnxrt.qdq_setting.AddQDQPairToWeight
410+
if "add_qdq_pair_to_weight" not in self.recipes
411+
else self.recipes.get("add_qdq_pair_to_weight", False),
412+
options.onnxrt.qdq_setting.OpTypesToExcludeOutputQuantizatioin
413+
if "optypes_to_exclude_output_quant" not in self.recipes
414+
else self.recipes.get("optypes_to_exclude_output_quant", []),
415+
options.onnxrt.qdq_setting.DedicatedQDQPair
416+
if "dedicated_qdq_pair" not in self.recipes
417+
else self.recipes.get("dedicated_qdq_pair", False),
418+
self.backend,
419+
)
420+
quantizer.quantize_model()
421+
tmp_model.q_config = self._generate_qconfig(model.model, tune_cfg, quantize_params)
422+
tmp_model.model = quantizer.model.model
423+
self.quantize_config = quantize_config # update so other methods can know current configs
424+
self._dump_model_op_stats(tmp_model)
425+
tmp_model.topological_sort()
426+
return tmp_model
427+
428+
def _get_split_model_quantize_params(
429+
self, split_model, split_dataloader, quantize_config, calib_sampling_size, iterations
430+
):
431+
"""Get quantize params for current split model and get dataloader for next split model."""
432+
dataloader = split_dataloader.pop(0)
433+
calib_iterations = self._reset_calib_iter(dataloader, calib_sampling_size, iterations)
434+
split_quantize_params, dataloder_for_next_split_model = self._get_quantize_params(
435+
split_model,
436+
dataloader,
437+
quantize_config,
438+
calib_iterations,
439+
split_model_input_names=split_model.input(),
440+
)
441+
return split_quantize_params, dataloder_for_next_split_model
442+
443+
def _quantize_split_model(self, split_model, quantize_config, quantize_params, quantized_model_merged):
444+
"""Quantize split model, and merge the quantized models to generate final model."""
318445
from neural_compressor import options
319446
from neural_compressor.adaptor.ox_utils.quantizer import Quantizer
320447

321448
quantizer = Quantizer(
322-
tmp_model,
449+
split_model,
323450
quantize_config,
324451
format,
325452
self.static,
@@ -339,12 +466,16 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
339466
self.backend,
340467
)
341468
quantizer.quantize_model()
342-
tmp_model.q_config = self._generate_qconfig(model.model, tune_cfg, quantize_params)
343-
tmp_model.model = quantizer.model.model
344-
self.quantize_config = quantize_config # update so other methods can know current configs
345-
self._dump_model_op_stats(tmp_model)
346-
tmp_model.topological_sort()
347-
return tmp_model
469+
split_model.model = quantizer.model.model
470+
split_model.topological_sort()
471+
472+
if quantized_model_merged is None:
473+
quantized_model_merged = quantizer.model
474+
quantized_model_merged.write_external_data_to_new_location(overwrite=True)
475+
else:
476+
quantized_model_merged.merge_split_models(quantizer.model)
477+
478+
return quantized_model_merged
348479

349480
def _check_backend_available(self, backend):
350481
"""Check backend is available or not."""
@@ -570,7 +701,7 @@ def _dump_model_op_stats(self, model):
570701
Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat()
571702
self.optype_statistics = field_names, output_data
572703

573-
def _get_quantize_params(self, model, data_loader, quantize_config, iterations):
704+
def _get_quantize_params(self, model, data_loader, quantize_config, iterations, **kwargs):
574705
from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment
575706
from neural_compressor.model.onnx_model import ONNXModel
576707

@@ -588,10 +719,12 @@ def _get_quantize_params(self, model, data_loader, quantize_config, iterations):
588719
iterations=list(range(0, iterations)),
589720
backend=self.backend,
590721
reduce_range=self.reduce_range,
722+
**kwargs,
591723
)
592724
self.min_max = augment.dump_minmax(quantize_config)
593725
quantize_params = augment.dump_calibration(quantize_config, min_max=self.min_max)
594-
return quantize_params
726+
dataloder_for_next_split_model = augment.dataloder_for_next_split_model
727+
return quantize_params, dataloder_for_next_split_model
595728

596729
def inspect_tensor(
597730
self,
@@ -606,7 +739,6 @@ def inspect_tensor(
606739
):
607740
"""The function is used by tune strategy class for dumping tensor info."""
608741
from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment
609-
from neural_compressor.model.onnx_model import ONNXModel
610742
from neural_compressor.utils.utility import dump_data_to_local
611743

612744
if not isinstance(model, ONNXModel):
@@ -763,6 +895,9 @@ def _pre_optimize(self, model, level=1):
763895
}
764896
if not isinstance(self.query_handler.get_graph_optimization(), list):
765897
level = self.query_handler.get_graph_optimization()
898+
elif self.recipes.get("layer_wise_quant"):
899+
level = "ENABLE_BASIC"
900+
logger.info("Force set graph optimization level to 'ENABLE_BASIC' for layer-wise quantization")
766901
elif options.onnxrt.graph_optimization.level is not None:
767902
level = options.onnxrt.graph_optimization.level
768903
elif self.recipes.get("graph_optimization_level", None) is not None:
@@ -778,10 +913,23 @@ def _pre_optimize(self, model, level=1):
778913
)
779914
sess_options.graph_optimization_level = optimization_levels[level]
780915
sess_options.optimized_model_filepath = os.path.join(self.work_space, "Optimized_model.onnx")
916+
if model.is_large_model and self.recipes.get("layer_wise_quant", False):
917+
# save the model and external data for layer-wise quantization
918+
external_data_filename = os.path.basename(sess_options.optimized_model_filepath) + "_data"
919+
external_data_file_threshold = 1024
920+
sess_options.add_session_config_entry(
921+
"session.optimized_model_external_initializers_file_name", external_data_filename
922+
)
923+
sess_options.add_session_config_entry(
924+
"session.optimized_model_external_initializers_min_size_in_bytes", str(external_data_file_threshold)
925+
)
926+
logger.info("Saving optimized model for layer-wise quantization. This may take a while...")
927+
781928
if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover
782929
from onnxruntime_extensions import get_library_path
783930

784931
sess_options.register_custom_ops_library(get_library_path())
932+
785933
if not model.is_large_model:
786934
sess = ort.InferenceSession(
787935
model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
@@ -792,13 +940,14 @@ def _pre_optimize(self, model, level=1):
792940
else: # pragma: no cover
793941
logger.warning("Please use model path instead of onnx model object to quantize")
794942
del sess
795-
796943
tmp_model = onnx.load(sess_options.optimized_model_filepath, load_external_data=False)
797944

798-
if model.is_large_model: # pragma: no cover
945+
# load external data if model is large and not layer wise quantization
946+
if model.is_large_model and not self.recipes.get("layer_wise_quant", False): # pragma: no cover
799947
from onnx.external_data_helper import load_external_data_for_model
800948

801949
load_external_data_for_model(tmp_model, os.path.split(model.model_path)[0])
950+
802951
model.model_path = sess_options.optimized_model_filepath
803952
model.model = (
804953
self._replace_gemm_with_matmul(tmp_model).model
@@ -903,8 +1052,6 @@ def _replace_gemm_with_matmul(model):
9031052
new_nodes = []
9041053
from onnx import numpy_helper
9051054

906-
from neural_compressor.model.onnx_model import ONNXModel
907-
9081055
if not isinstance(model, ONNXModel):
9091056
model = ONNXModel(model)
9101057

neural_compressor/adaptor/ox_utils/calibration.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
iterations=[],
6464
backend="CPUExecutionProvider",
6565
reduce_range=False,
66+
**kwargs,
6667
):
6768
"""Initialization.
6869
@@ -94,6 +95,16 @@ def __init__(
9495
self.ort_version = Version(onnxruntime.__version__)
9596
self.reduce_range = reduce_range
9697

98+
self.layer_wise = True if len(kwargs.get("split_model_input_names", [])) != 0 else False
99+
if self.layer_wise:
100+
self.split_model_input_names = kwargs.get("split_model_input_names", [])
101+
self._dataloder_for_next_split_model = None
102+
103+
@property
104+
def dataloder_for_next_split_model(self):
105+
"""Return dataloader for next split model for layer-wise quantization."""
106+
return self._dataloder_for_next_split_model
107+
97108
def augment_graph(self, activation_only=False, weight_only=False):
98109
"""Augment_graph.
99110
@@ -245,12 +256,13 @@ def get_intermediate_outputs(self, q_config=None):
245256

246257
len_inputs = len(session.get_inputs())
247258
inputs_names = [session.get_inputs()[i].name for i in range(len_inputs)]
259+
len_outputs = len(session.get_outputs())
260+
outputs_names = [session.get_outputs()[i].name for i in range(len_outputs)]
248261

249262
node_output_names = [
250263
output.name if output.name not in self.dequantized_output else self.dequantized_output[output.name]
251264
for output in session.get_outputs()
252265
]
253-
254266
augment_model_wrapper = (
255267
ONNXModel(self.augmented_model)
256268
if not self.model_wrapper.is_large_model
@@ -271,6 +283,7 @@ def get_intermediate_outputs(self, q_config=None):
271283
output_dicts = {}
272284
intermediate_tensor = {}
273285
name_to_calibrator = {}
286+
ort_inputs_for_next_split_model = []
274287
for idx, (inputs, labels) in enumerate(self.dataloader):
275288
ort_inputs = {}
276289

@@ -281,15 +294,25 @@ def get_intermediate_outputs(self, q_config=None):
281294
else:
282295
ort_inputs.update({inputs_names[0]: to_numpy(inputs)})
283296
else:
284-
assert len_inputs == len(inputs), "number of input tensors must align with graph inputs"
297+
if not self.layer_wise:
298+
# for layer-wise calibration
299+
assert len_inputs == len(inputs), "number of input tensors must align with graph inputs"
285300

286301
if isinstance(inputs, dict):
287302
for name, input in inputs.items():
288303
ort_inputs.update({name: to_numpy(input)})
289304
else:
290305
ort_inputs = dict(zip(inputs_names, [to_numpy(i) for i in inputs]))
291306

292-
def _collect_data():
307+
def _collect_data(ort_inputs):
308+
if self.layer_wise:
309+
# for layer-wise calibration
310+
ort_inputs = {
311+
input_name: input_tensor
312+
for input_name, input_tensor in ort_inputs.items()
313+
if input_name in self.split_model_input_names
314+
}
315+
293316
for output_idx, output in enumerate(session.run(None, ort_inputs)):
294317
if q_config is not None and output.size != 0:
295318
node_name = name_to_node[node_output_names[output_idx]]
@@ -321,13 +344,18 @@ def _collect_data():
321344
elif q_config is None:
322345
output_dicts.setdefault(node_output_names[output_idx], []).append(output)
323346

347+
if self.layer_wise:
348+
# for layer-wise calibration
349+
ort_inputs.update({outputs_names[output_idx]: output})
350+
ort_inputs_for_next_split_model.append((ort_inputs, labels))
351+
324352
if self.iterations != []:
325353
if idx > max(self.iterations):
326354
break
327355
if idx in self.iterations:
328-
_collect_data()
356+
_collect_data(ort_inputs)
329357
else:
330-
_collect_data()
358+
_collect_data(ort_inputs)
331359

332360
# for kl and percentile method, collect calibration range after all tensors are collected.
333361
merged_dict = intermediate_tensor
@@ -344,6 +372,9 @@ def _collect_data():
344372
output_dicts.setdefault(output_name, []).append(list(calibrator.calib_range))
345373
calibrator.clear()
346374
del calibrator
375+
376+
self._dataloder_for_next_split_model = ort_inputs_for_next_split_model
377+
347378
return list(output_dicts.keys()), output_dicts
348379

349380
def _dequantize(self, tensor, scale_tensor, zo_tensor):

0 commit comments

Comments
 (0)