Skip to content

Commit c58aeaa

Browse files
Fix ONNXRT calibration for Dml EP (#1526)
Signed-off-by: yuwenzho <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d22df53 commit c58aeaa

File tree

3 files changed

+145
-47
lines changed

3 files changed

+145
-47
lines changed

neural_compressor/adaptor/onnxrt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def _get_quantize_params(self, model, data_loader, quantize_config, iterations,
765765
black_nodes=black_nodes,
766766
white_nodes=white_nodes,
767767
iterations=list(range(0, iterations)),
768-
backend=self.backend if self.backend != "DmlExecutionProvider" else "CPUExecutionProvider",
768+
backend=self.backend,
769769
reduce_range=self.reduce_range,
770770
**kwargs,
771771
)

neural_compressor/adaptor/ox_utils/calibration.py

+128-36
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def dataloder_for_next_split_model(self):
105105
"""Return dataloader for next split model for layer-wise quantization."""
106106
return self._dataloder_for_next_split_model
107107

108-
def augment_graph(self, activation_only=False, weight_only=False):
108+
def augment_graph(self):
109109
"""Augment_graph.
110110
111111
Adds nodes to all quantization_candidates op type nodes in model and
@@ -118,7 +118,7 @@ def augment_graph(self, activation_only=False, weight_only=False):
118118
self.dequantized_output.clear()
119119
onnx_version = Version(onnx.__version__)
120120
if onnx_version < ONNX18_VERSION:
121-
logger.warning("Static quantization for NLP model is supported " "at onnx 1.8.0 and newer.")
121+
logger.warning("Static quantization for NLP model is supported at onnx 1.8.0 and newer.")
122122
if self.already_quantized and any(
123123
[i.dims in [1, 2] for i in self.model_wrapper.initializer() if i.name.endswith("_scale")]
124124
):
@@ -138,53 +138,43 @@ def augment_graph(self, activation_only=False, weight_only=False):
138138
for augment_node_type in self.augment_nodes:
139139
if augment_node_type not in ["DequantizeLinear"]: # pragma: no cover
140140
raise ValueError(
141-
"Unexpected augment_node {} only DequantizeLinear is " "supported".format(augment_node_type)
141+
"Unexpected augment_node {} only DequantizeLinear is supported".format(augment_node_type)
142142
)
143143

144144
if self.already_quantized:
145145
# mapping between fp32 node and int8 node
146146
new_white_nodes = []
147147
for white_node in self.white_nodes:
148148
new_white_node = white_node + "_quant"
149-
assert new_white_node in model_nodes_names, "no quantized {} in the " "graph".format(white_node)
149+
assert new_white_node in model_nodes_names, "no quantized {} in the graph".format(white_node)
150150
new_white_nodes.append(new_white_node)
151151
self.white_nodes = new_white_nodes
152152

153-
initializers = {i.name: i.data_type for i in model.graph.initializer}
154153
node_outputs = []
155154
for node in model.graph.node: # pylint: disable=no-member
156155
node_outputs.extend(node.output)
157156
should_be_dump = ((node.op_type in self.dump_op_types) and (node.name not in self.black_nodes)) or (
158157
node.name in self.white_nodes
159158
)
160159
if should_be_dump:
161-
if not weight_only and not activation_only:
162-
tensors_to_dump.update([input for input in node.input if len(input) != 0])
163-
tensors_to_dump.update([output for output in node.output if len(output) != 0])
164-
tensors_to_dump.update(node.output)
165-
elif weight_only:
166-
for input in node.input:
167-
if (
168-
self.already_quantized
169-
and input.replace("_dequantized", "_quantized") in initializers
170-
and len(input) != 0
171-
):
172-
tensors_to_dump.add(input)
173-
elif not self.already_quantized and input in initializers and len(input) != 0:
160+
# add input tensors which should be dump
161+
for input in node.input:
162+
if len(input) != 0: # to prevent input is ""
163+
initializer_tensor = self.model_wrapper.get_initializer(input)
164+
if initializer_tensor is None:
174165
tensors_to_dump.add(input)
175-
elif activation_only:
176-
if len(node.input[0]) != 0:
177-
tensors_to_dump.update([node.input[0]])
166+
# add output tensors which should be dump
167+
tensors_to_dump.update([output for output in node.output if len(output) != 0])
178168

179169
model_inputs = [i.name for i in model.graph.input]
180170
for tensor in tensors_to_dump:
181-
if tensor not in node_outputs and tensor not in initializers and tensor not in model_inputs:
171+
if tensor not in node_outputs and tensor not in model_inputs:
182172
continue
183173
if self.augment_nodes:
184174
for augment_node_type in self.augment_nodes:
185175
if augment_node_type in ["DequantizeLinear"]:
186176
# insert DequantizeLinear node as output
187-
if tensor.endswith("_scale") or tensor.endswith("_zero_point"):
177+
if tensor.endswith("_scale") or tensor.endswith("_zero_point"): # pragma: no cover
188178
continue
189179

190180
if not self.dynamically_quantized:
@@ -238,10 +228,18 @@ def augment_graph(self, activation_only=False, weight_only=False):
238228
convert_attribute=False,
239229
)
240230

241-
def get_intermediate_outputs(self, q_config=None):
242-
"""Gather intermediate model outputs after running inference."""
231+
def get_activation_tensors_calib_range(self, q_config=None):
232+
"""Get calib ranges of activation tensors.
233+
234+
Args:
235+
q_config (dict, optional): quantization config. Defaults to None.
236+
237+
Returns:
238+
dict: calib ranges
239+
"""
243240
# conduct inference session and get intermediate outputs
244241
so = onnxruntime.SessionOptions()
242+
so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
245243
if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover
246244
from onnxruntime_extensions import get_library_path
247245

@@ -280,7 +278,7 @@ def get_intermediate_outputs(self, q_config=None):
280278
assert node, "{} is neither an input nor an output of nodes in augmented model.".format(data_name)
281279
name_to_node[data_name] = node.name
282280

283-
output_dicts = {}
281+
activation_tensors_calib_range = {}
284282
intermediate_tensor = {}
285283
name_to_calibrator = {}
286284
ort_inputs_for_next_split_model = []
@@ -294,8 +292,8 @@ def get_intermediate_outputs(self, q_config=None):
294292
else:
295293
ort_inputs.update({inputs_names[0]: to_numpy(inputs)})
296294
else:
295+
# skip check input length for layer-wise calibration
297296
if not self.layer_wise:
298-
# for layer-wise calibration
299297
assert len_inputs == len(inputs), "number of input tensors must align with graph inputs"
300298

301299
if isinstance(inputs, dict):
@@ -335,14 +333,16 @@ def _collect_data(ort_inputs):
335333
# per iteration in the future.
336334
if calibrator.method_name == "minmax":
337335
calibrator.collect(output)
338-
output_dicts[node_output_names[output_idx]] = [list(calibrator.calib_range)]
336+
activation_tensors_calib_range[node_output_names[output_idx]] = [
337+
list(calibrator.calib_range)
338+
]
339339
name_to_calibrator[node_output_names[output_idx]] = calibrator
340340
else:
341341
intermediate_tensor.setdefault((node_output_names[output_idx], node_name), []).append(
342342
output
343343
)
344344
elif q_config is None:
345-
output_dicts.setdefault(node_output_names[output_idx], []).append(output)
345+
activation_tensors_calib_range.setdefault(node_output_names[output_idx], []).append(output)
346346

347347
if self.layer_wise:
348348
# for layer-wise calibration
@@ -369,12 +369,94 @@ def _collect_data(ort_inputs):
369369
)
370370
calibrator = CALIBRATOR[calib_method]()
371371
calibrator.collect(datas)
372-
output_dicts.setdefault(output_name, []).append(list(calibrator.calib_range))
372+
activation_tensors_calib_range.setdefault(output_name, []).append(list(calibrator.calib_range))
373373
calibrator.clear()
374374
del calibrator
375375

376+
# set for layer-wise quant
376377
self._dataloder_for_next_split_model = ort_inputs_for_next_split_model
377378

379+
return activation_tensors_calib_range
380+
381+
def get_weight_tensors_calib_range(self):
382+
"""Get calib ranges of weight tensors.
383+
384+
Returns:
385+
dict: calib ranges
386+
"""
387+
model_nodes_names = [node.name for node in self.model.graph.node]
388+
389+
# if augmented_model is not None, it means self.white_nodes is already updated in augment_graph func
390+
# then skip update here
391+
if self.already_quantized and self.augmented_model is None:
392+
# mapping between fp32 node and int8 node
393+
new_white_nodes = []
394+
for white_node in self.white_nodes:
395+
new_white_node = white_node + "_quant"
396+
assert new_white_node in model_nodes_names, "no quantized {} in the " "graph".format(white_node)
397+
new_white_nodes.append(new_white_node)
398+
self.white_nodes = new_white_nodes
399+
400+
added_outputs = set()
401+
initializer_tensors_to_dump = []
402+
initializers = [init.name for init in self.model.graph.initializer]
403+
for node in self.model.graph.node: # pylint: disable=no-member
404+
should_be_dump = ((node.op_type in self.dump_op_types) and (node.name not in self.black_nodes)) or (
405+
node.name in self.white_nodes
406+
)
407+
if should_be_dump:
408+
for input in node.input:
409+
if (
410+
(self.already_quantized and input.replace("_dequantized", "_quantized") in initializers)
411+
or (not self.already_quantized and input in initializers)
412+
) and len(input) != 0:
413+
added_outputs.add(input)
414+
415+
for tensor in added_outputs:
416+
if tensor not in initializers:
417+
continue
418+
if self.augment_nodes:
419+
for augment_node_type in self.augment_nodes:
420+
if augment_node_type in ["DequantizeLinear"]:
421+
if not (tensor.endswith("_scale") or tensor.endswith("_zero_point")):
422+
initializer_tensors_to_dump.append(tensor)
423+
else:
424+
initializer_tensors_to_dump.append(tensor)
425+
426+
weight_tensors_calib_range = {}
427+
for initializer_tensor_name in initializer_tensors_to_dump:
428+
if self.layer_wise:
429+
self.model_wrapper.load_model_initializer_by_tensor()
430+
initializer_tensor = self.model_wrapper.get_initializer(initializer_tensor_name)
431+
432+
# double check initializer tensor is not None
433+
if initializer_tensor is None: # pragma: no cover
434+
continue
435+
436+
initializer_tensor = numpy_helper.to_array(
437+
initializer_tensor,
438+
base_dir=os.path.dirname(self.model_wrapper.model_path)
439+
if self.model_wrapper.model_path is not None
440+
else "",
441+
)
442+
calibrator = CALIBRATOR["minmax"]() # use minmax method to calibrate initializer tensors
443+
calibrator.collect(initializer_tensor)
444+
weight_tensors_calib_range[initializer_tensor_name] = [list(calibrator.calib_range)]
445+
calibrator.clear()
446+
del calibrator
447+
return weight_tensors_calib_range
448+
449+
def get_intermediate_outputs(self, q_config=None, activation_only=False, weight_only=False):
450+
"""Gather intermediate model outputs after running inference."""
451+
output_dicts = {}
452+
if not activation_only and not weight_only:
453+
output_dicts = self.get_activation_tensors_calib_range(q_config)
454+
output_dicts.update(self.get_weight_tensors_calib_range())
455+
elif weight_only:
456+
output_dicts = self.get_weight_tensors_calib_range()
457+
elif activation_only:
458+
output_dicts = self.get_activation_tensors_calib_range(q_config)
459+
378460
return list(output_dicts.keys()), output_dicts
379461

380462
def _dequantize(self, tensor, scale_tensor, zo_tensor):
@@ -472,7 +554,12 @@ def _map_calibration(self, node_output_names, output_dicts):
472554
return final_dict
473555

474556
def dump_minmax(self, q_config):
475-
"""Get min/max values of tensors."""
557+
"""Get calib ranges of tensors."""
558+
# pipeline of getting calib ranges of tensors during calibration:
559+
# 1. augment_graph(): insert activation tensors to model output
560+
# 2. get_intermediate_outputs():
561+
# 2.1 get_activation_tensors_calib_range(): get calib ranges of activation tensors using the augment graph
562+
# 2.2 get_weight_tensors_calib_range(): get calib ranges of weight tensors
476563
self.augment_graph()
477564
node_output_names, output_dicts = self.get_intermediate_outputs(q_config)
478565
return self._map_calibration(node_output_names, output_dicts)
@@ -553,15 +640,20 @@ def dump_tensor(self, activation=True, weight=False, format=None):
553640
self.already_quantized = True
554641
self.dynamically_quantized = "DynamicQuantizeLinear" in [node.op_type for node in self.model.graph.node]
555642
is_qdq = format == "qdq"
556-
self.augment_graph(activation_only=not weight, weight_only=not activation)
557-
_, output_dicts = self.get_intermediate_outputs()
643+
if activation:
644+
self.augment_graph() # add activation tensors to model output
645+
_, output_dicts = self.get_intermediate_outputs(activation_only=not weight, weight_only=not activation)
558646
iters = len(list(output_dicts.values())[-1])
559647
map_node_activation = [{} for _ in range(iters)]
560648
map_node_weight = {}
561649
self.white_nodes = [node.replace("_quant", "") for node in self.white_nodes]
562-
augmengted_wrapper = ONNXModel(self.augmented_model)
563-
map_output = augmengted_wrapper.output_name_to_node
564-
map_input = augmengted_wrapper.input_name_to_nodes
650+
651+
if activation and self.augmented_model is None:
652+
raise ValueError("augmented model should not be None when dump activation tensors.")
653+
# if activation tensors are not dumped, then use origin model wrapper
654+
model_wrapper = ONNXModel(self.augmented_model) if activation else self.model_wrapper
655+
map_output = model_wrapper.output_name_to_node
656+
map_input = model_wrapper.input_name_to_nodes
565657
model_output_names = [t.name for t in self.model.graph.output]
566658
model_input_names = [t.name for t in self.model.graph.input]
567659
model_initializer_names = [t.name for t in self.model.graph.initializer]

test/adaptor/onnxrt_adaptor/test_onnxrt_augment.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def test_augment_graph(self):
330330
attn_output_scale = generate_input_initializer([1], np.float32, "attn_output_scale")
331331
Q_zo = helper.make_tensor_value_info("attn_output_zero_point", TensorProto.INT8, [1])
332332
attn_output_zero_point = generate_input_initializer([1], np.int8, "attn_output_zero_point")
333-
Output = helper.make_tensor_value_info("output", TensorProto.INT8, [13, 7])
333+
Output = helper.make_tensor_value_info("attn_output_quantized", TensorProto.INT8, [13, 7])
334334
attention_node = onnx.helper.make_node(
335335
"QAttention",
336336
[
@@ -386,15 +386,15 @@ def test_augment_graph(self):
386386
augment.augment_nodes = ["DequantizeLinear"]
387387
augment.already_quantized = True
388388

389-
augment.augment_graph(activation_only=True, weight_only=False)
389+
augment.augment_graph()
390390
augmented_model = augment.augmented_model
391391

392392
augmented_model_node_names = [node.name for node in augmented_model.graph.node]
393393
augmented_model_outputs = [output.name for output in augmented_model.graph.output]
394-
added_node_names = ["attention_quant", "attn_output_QuantizeLinear"]
395-
added_outputs = ["input_quantized_output", "output"]
394+
added_node_names = ["attention_quant", "attn_output_QuantizeLinear", "input_quantized_DequantizeLinear"]
395+
added_outputs = ["attn_output_quantized", "input_quantized_output", "attn_output"]
396396
self.assertEqual(len(augmented_model_node_names), 3)
397-
self.assertEqual(len(augmented_model_outputs), 2)
397+
self.assertEqual(len(augmented_model_outputs), 3)
398398
for name in added_node_names:
399399
self.assertTrue(name in augmented_model_node_names)
400400
for output in added_outputs:
@@ -470,15 +470,21 @@ def test_augment_graph(self):
470470
augment = ONNXRTAugment(ONNXModel(model), data_reader, [], white_nodes=["conv"])
471471
augment.augment_nodes = ["DequantizeLinear"]
472472
augment.already_quantized = True
473-
augment.augment_graph(activation_only=True, weight_only=False)
473+
augment.augment_graph()
474474
augmented_model = augment.augmented_model
475475

476476
augmented_model_node_names = [node.name for node in augmented_model.graph.node]
477477
augmented_model_outputs = [output.name for output in augmented_model.graph.output]
478-
added_node_names = ["A_QuantizeLinear", "conv_quant", "D_DequantizeLinear", "A_quantized_DequantizeLinear"]
479-
added_outputs = ["D", "A_quantized_output"]
480-
self.assertEqual(len(augmented_model_node_names), 4)
481-
self.assertEqual(len(augmented_model_outputs), 2)
478+
added_node_names = [
479+
"A_QuantizeLinear",
480+
"conv_quant",
481+
"D_DequantizeLinear",
482+
"D_quantized_DequantizeLinear",
483+
"A_quantized_DequantizeLinear",
484+
]
485+
added_outputs = ["D", "D_quantized_output", "A_quantized_output"]
486+
self.assertEqual(len(augmented_model_node_names), 5)
487+
self.assertEqual(len(augmented_model_outputs), 3)
482488
for name in added_node_names:
483489
self.assertTrue(name in augmented_model_node_names)
484490
for output in added_outputs:

0 commit comments

Comments
 (0)