Skip to content

Commit 1beb435

Browse files
Support MatMulFpQ4 for onnxruntime 1.16.0 (#1293)
Signed-off-by: Mengni Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f64833d commit 1beb435

File tree

2 files changed

+43
-74
lines changed

2 files changed

+43
-74
lines changed

neural_compressor/adaptor/onnxrt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1690,7 +1690,7 @@ def _dump_model_op_stats(self, model, tune_cfg):
16901690

16911691
dtype_set = set()
16921692
for node in model.nodes():
1693-
if node.op_type == "MatMulWithQuantWeight":
1693+
if node.op_type == "MatMulFpQ4":
16941694
optype = "MatMul"
16951695
else:
16961696
optype = node.op_type

neural_compressor/adaptor/ox_utils/weight_only.py

Lines changed: 42 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -20,63 +20,28 @@
2020
import logging
2121
import math
2222
import os
23+
import struct
2324
import sys
2425

2526
import numpy as np
2627
import onnx
2728
from onnx import helper, numpy_helper
2829
from onnx import onnx_pb as onnx_proto
30+
from packaging.version import Version
2931

3032
from neural_compressor.model.model import BaseModel
3133
from neural_compressor.model.onnx_model import ONNXModel
3234
from neural_compressor.utils.utility import LazyImport
3335

3436
ort = LazyImport("onnxruntime")
3537
logger = logging.getLogger("neural_compressor")
36-
37-
38-
WEIGHT_ONLY_OP_SUPPORTED = False
39-
40-
41-
def check_op_support_status():
42-
"""Check whether weight-only op is supported."""
43-
input_tensor = helper.make_tensor_value_info("input", 1, [1, 32])
44-
output_tensor = helper.make_tensor_value_info("output", 1, [1, 64])
45-
initializers = []
46-
# weight shape (32, 64)
47-
packed_weight = np.random.randint(0, high=16, size=(64, 1, 16), dtype="uint8")
48-
initializers.append(onnx.helper.make_tensor("weight", 2, packed_weight.shape, packed_weight.flatten().tolist()))
49-
scale = np.random.random((64, 1)).astype("float32")
50-
initializers.append(onnx.helper.make_tensor("scale", 1, scale.shape, scale.flatten().tolist()))
51-
52-
kwargs = {}
53-
kwargs["K"] = 32
54-
kwargs["N"] = 64
55-
kwargs["bits"] = 4
56-
kwargs["block_size"] = 32
57-
node = onnx.helper.make_node(
58-
"MatMulWithQuantWeight",
59-
inputs=["input", "weight", "scale"],
60-
outputs=["output"],
61-
name="test",
62-
domain="com.microsoft",
63-
**kwargs,
64-
)
65-
66-
global WEIGHT_ONLY_OP_SUPPORTED
67-
graph = helper.make_graph([node], "test", [input_tensor], [output_tensor], initializer=initializers)
68-
model = helper.make_model(graph)
69-
try:
70-
ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
71-
WEIGHT_ONLY_OP_SUPPORTED = True
72-
except:
73-
WEIGHT_ONLY_OP_SUPPORTED = False
38+
ONNXRT116_VERSION = Version("1.16.0")
7439

7540

7641
def make_matmul_weight_only_node(
7742
node, weight_shape, num_bits, group_size, k_blocks, q_weight, scale, zero_point
7843
): # pragma: no cover
79-
"""Build MatMulWithQuantWeight node.
44+
"""Build MatMulFpQ4 node.
8045
8146
Args:
8247
node: original matmul node
@@ -89,46 +54,49 @@ def make_matmul_weight_only_node(
8954
zero_point (array): zero point
9055
9156
Returns:
92-
matmul_weight_only_node: MatMulWithQuantWeight node
93-
new_inits: initializers of the MatMulWithQuantWeight node
57+
matmul_weight_only_node: MatMulFpQ4 node
58+
new_inits: initializers of the MatMulFpQ4 node
9459
"""
95-
blob_size = group_size // 2
60+
if zero_point is not None:
61+
blob_size = group_size // 2 + 4 + 1
62+
offset = 5
63+
else:
64+
blob_size = group_size // 2 + 4
65+
offset = 4
66+
9667
packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
9768
for i in range(q_weight.shape[0]):
98-
for k in range(0, group_size, 2):
99-
packed[i][k // 2] = q_weight[i][k] | q_weight[i][k + 1] << 4
69+
bf = struct.pack("f", scale[i])
70+
packed[i][0] = bf[0]
71+
packed[i][1] = bf[1]
72+
packed[i][2] = bf[2]
73+
packed[i][3] = bf[3]
10074

101-
packed = np.reshape(packed, (-1, k_blocks, blob_size))
102-
scale = np.reshape(scale, (-1, k_blocks)).astype("float32")
75+
if zero_point is not None:
76+
packed[i][4] = zero_point[i]
77+
78+
packed[i][offset:] = np.bitwise_or(
79+
q_weight[i][: group_size // 2], np.left_shift(q_weight[i][group_size // 2 :], num_bits)
80+
)
10381

82+
packed = packed.reshape(-1)
10483
q_weight_tensor = onnx.helper.make_tensor(
10584
name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)),
10685
data_type=2,
10786
dims=packed.shape,
10887
vals=packed.tobytes(),
10988
raw=True,
11089
)
111-
scale_tensor = onnx.helper.make_tensor(
112-
name=node.input[1] + "_scale", data_type=1, dims=scale.shape, vals=scale.tobytes(), raw=True
90+
shape_tensor = onnx.helper.make_tensor(
91+
name=node.input[1] + "_shape", data_type=7, dims=(2,), vals=np.array(weight_shape, dtype="int64")
11392
)
114-
input_names = [node.input[0], q_weight_tensor.name, scale_tensor.name]
115-
new_inits = [q_weight_tensor, scale_tensor]
116-
117-
if zero_point is not None:
118-
zero_point = np.reshape(zero_point, (-1, k_blocks)).astype("uint8")
119-
zp_tensor = onnx.helper.make_tensor(
120-
name=node.input[1] + "_zp", data_type=2, dims=zero_point.shape, vals=zero_point.tobytes(), raw=True
121-
)
122-
input_names.append(zp_tensor.name)
123-
new_inits.append(zp_tensor)
93+
input_names = [node.input[0], q_weight_tensor.name, shape_tensor.name]
94+
new_inits = [q_weight_tensor, shape_tensor]
12495

12596
kwargs = {}
126-
kwargs["K"] = weight_shape[0]
127-
kwargs["N"] = weight_shape[1]
128-
kwargs["bits"] = num_bits
129-
kwargs["block_size"] = group_size
97+
kwargs["blk_quant_type"] = 1 if zero_point is not None else 0
13098
matmul_weight_only_node = onnx.helper.make_node(
131-
"MatMulWithQuantWeight",
99+
"MatMulFpQ4",
132100
inputs=input_names,
133101
outputs=node.output,
134102
name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits),
@@ -260,7 +228,6 @@ def rtn_quantize(
260228
Returns:
261229
model: fake quantized ONNXModel
262230
"""
263-
check_op_support_status()
264231
model = model if isinstance(model, BaseModel) else ONNXModel(model)
265232
new_nodes = []
266233
remove_nodes = []
@@ -290,8 +257,8 @@ def rtn_quantize(
290257

291258
weight = pad_tensor(weight, group_size, k_blocks)
292259

293-
if WEIGHT_ONLY_OP_SUPPORTED and num_bits == 4 and group_size == 32: # pragma: no cover
294-
# currently MatMulWithQuantWeights only support 4 bits and 32 group_size
260+
if Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32: # pragma: no cover
261+
# currently MatMulFpQ4 only support 4 bits and 32 group_size
295262
q_weight, scale, zp = quant_tensor(
296263
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
297264
)
@@ -394,7 +361,9 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
394361
weight = weight.T * scales
395362
weight = pad_tensor(weight, group_size, (org_w_shape[0] + group_size - 1) // group_size).T
396363

397-
if WEIGHT_ONLY_OP_SUPPORTED and num_bits == 4 and group_size == 32: # pragma: no cover
364+
if (
365+
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
366+
): # pragma: no cover
398367
q_weight = qdq_tensor(weight, num_bits, group_size, scheme, "uint") / np.expand_dims(
399368
scales, axis=-1
400369
)
@@ -535,8 +504,10 @@ def apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, g
535504
for i_s in range(10):
536505
ratio = 1 - i_s / 100
537506
weight = copy.deepcopy(org_weight)
538-
if WEIGHT_ONLY_OP_SUPPORTED and num_bits == 4 and group_size == 32: # pragma: no cover
539-
# currently MatMulWithQuantWeights only support 4 bits and 32 group_size
507+
if (
508+
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
509+
): # pragma: no cover
510+
# currently MatMulFpQ4 only support 4 bits and 32 group_size
540511
weight = qdq_tensor(weight, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1))
541512
else:
542513
weight = qdq_tensor(weight, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1))
@@ -644,7 +615,6 @@ def awq_quantize(
644615
Returns:
645616
model: fake quantized ONNXModel
646617
"""
647-
check_op_support_status()
648618
model = model if isinstance(model, BaseModel) else ONNXModel(model)
649619
output_dicts = {}
650620
full_ratio = {}
@@ -918,7 +888,6 @@ def gptq_quantize(
918888
Returns:
919889
model: fake quantized ONNXModel
920890
"""
921-
check_op_support_status()
922891
model = model if isinstance(model, BaseModel) else ONNXModel(model)
923892
output_dicts = {}
924893

@@ -1013,8 +982,8 @@ def gptq_quantize(
1013982

1014983
weight_tensor = model.get_initializer(node.input[1])
1015984
init_share_num = model.get_initializer_share_num(node.input[1])
1016-
if WEIGHT_ONLY_OP_SUPPORTED and num_bits == 4 and group_size == 32: # pragma: no cover
1017-
# currently MatMulWithQuantWeights only support 4 bits and 32 group_size
985+
if Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32: # pragma: no cover
986+
# currently MatMulFpQ4 only support 4 bits and 32 group_size
1018987
org_shape = weight.shape
1019988
k_blocks = (org_shape[0] + group_size - 1) // group_size
1020989
q_weight = pad_tensor(q_weight, group_size, k_blocks)

0 commit comments

Comments
 (0)