Skip to content

Commit 67a31ba

Browse files
Support MatMulNBit op for ort 1.17 (#1327)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1beb435 commit 67a31ba

File tree

2 files changed

+117
-38
lines changed

2 files changed

+117
-38
lines changed

neural_compressor/adaptor/onnxrt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1690,7 +1690,7 @@ def _dump_model_op_stats(self, model, tune_cfg):
16901690

16911691
dtype_set = set()
16921692
for node in model.nodes():
1693-
if node.op_type == "MatMulFpQ4":
1693+
if node.op_type in ["MatMulFpQ4", "MatMulNBits"]:
16941694
optype = "MatMul"
16951695
else:
16961696
optype = node.op_type

neural_compressor/adaptor/ox_utils/weight_only.py

Lines changed: 116 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,23 @@
3636
ort = LazyImport("onnxruntime")
3737
logger = logging.getLogger("neural_compressor")
3838
ONNXRT116_VERSION = Version("1.16.0")
39+
ONNXRT1161_VERSION = Version("1.16.1")
40+
41+
42+
def get_blob_size(group_size, has_zp): # pragma: no cover
43+
"""Get blob_size.
44+
45+
Args:
46+
group_size (int): how many elements share one scale/zp
47+
has_zp (bool): whether zero_point is None
48+
"""
49+
if Version(ort.__version__) > ONNXRT1161_VERSION:
50+
blob_size = group_size // 2
51+
elif has_zp:
52+
blob_size = group_size // 2 + 4 + 1
53+
else:
54+
blob_size = group_size // 2 + 4
55+
return blob_size
3956

4057

4158
def make_matmul_weight_only_node(
@@ -54,49 +71,102 @@ def make_matmul_weight_only_node(
5471
zero_point (array): zero point
5572
5673
Returns:
57-
matmul_weight_only_node: MatMulFpQ4 node
58-
new_inits: initializers of the MatMulFpQ4 node
74+
matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node
75+
new_inits: initializers of the new node
5976
"""
60-
if zero_point is not None:
61-
blob_size = group_size // 2 + 4 + 1
62-
offset = 5
63-
else:
64-
blob_size = group_size // 2 + 4
65-
offset = 4
66-
77+
blob_size = get_blob_size(group_size, zero_point is not None)
6778
packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
68-
for i in range(q_weight.shape[0]):
69-
bf = struct.pack("f", scale[i])
70-
packed[i][0] = bf[0]
71-
packed[i][1] = bf[1]
72-
packed[i][2] = bf[2]
73-
packed[i][3] = bf[3]
79+
q_weight_name = node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size))
80+
input_names = [node.input[0], q_weight_name]
81+
new_inits = []
82+
kwargs = {}
83+
84+
if Version(ort.__version__) > ONNXRT1161_VERSION:
85+
op_type = "MatMulNBits"
86+
87+
# pack quantized weight
88+
for i in range(q_weight.shape[0]):
89+
for k in range(0, group_size, 2):
90+
packed[i][k // 2] = q_weight[i][k] | q_weight[i][k + 1] << 4
91+
packed = np.reshape(packed, (-1, k_blocks, blob_size))
7492

93+
# build scale tensor
94+
scale = np.reshape(scale, (-1, k_blocks)).astype("float32")
95+
scale_tensor = onnx.helper.make_tensor(
96+
name=node.input[1] + "_scale", data_type=1, dims=scale.shape, vals=scale.tobytes(), raw=True
97+
)
98+
input_names.append(scale_tensor.name)
99+
new_inits.append(scale_tensor)
100+
101+
# build zero_point tensor
75102
if zero_point is not None:
76-
packed[i][4] = zero_point[i]
103+
if num_bits > 4:
104+
packed_zp = np.reshape(zero_point, (1, -1)).astype("uint8")
105+
else:
106+
packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8")
107+
for i in range(zero_point.shape[0] // k_blocks):
108+
for j in range(k_blocks):
109+
idx = i * k_blocks + j
110+
zp = zero_point[idx]
111+
packed_zp[idx // 2] = (
112+
((packed_zp[idx // 2] & 0x0F) | (zp << 4))
113+
if (idx & 1)
114+
else ((packed_zp[idx // 2] & 0xF0) | zp)
115+
)
116+
117+
zp_tensor = onnx.helper.make_tensor(
118+
name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True
119+
)
120+
input_names.append(zp_tensor.name)
121+
new_inits.append(zp_tensor)
122+
123+
# set kwargs
124+
kwargs["K"] = weight_shape[0]
125+
kwargs["N"] = weight_shape[1]
126+
kwargs["bits"] = num_bits
127+
kwargs["block_size"] = group_size
128+
129+
else:
130+
offset = 5 if zero_point is not None else 4
131+
op_type = "MatMulFpQ4"
132+
133+
# pack quantized weight
134+
for i in range(q_weight.shape[0]):
135+
bf = struct.pack("f", scale[i])
136+
packed[i][0] = bf[0]
137+
packed[i][1] = bf[1]
138+
packed[i][2] = bf[2]
139+
packed[i][3] = bf[3]
140+
141+
if zero_point is not None:
142+
packed[i][4] = zero_point[i]
143+
144+
packed[i][offset:] = np.bitwise_or(
145+
q_weight[i][: group_size // 2], np.left_shift(q_weight[i][group_size // 2 :], num_bits)
146+
)
147+
packed = packed.reshape(-1)
77148

78-
packed[i][offset:] = np.bitwise_or(
79-
q_weight[i][: group_size // 2], np.left_shift(q_weight[i][group_size // 2 :], num_bits)
149+
# build shape tensor
150+
shape_tensor = onnx.helper.make_tensor(
151+
name=node.input[1] + "_shape", data_type=7, dims=(2,), vals=np.array(weight_shape, dtype="int64")
80152
)
153+
new_inits.append(shape_tensor)
154+
input_names.append(shape_tensor.name)
155+
156+
# set kwargs
157+
kwargs["blk_quant_type"] = 1 if zero_point is not None else 0
81158

82-
packed = packed.reshape(-1)
83159
q_weight_tensor = onnx.helper.make_tensor(
84-
name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)),
160+
name=q_weight_name,
85161
data_type=2,
86162
dims=packed.shape,
87163
vals=packed.tobytes(),
88164
raw=True,
89165
)
90-
shape_tensor = onnx.helper.make_tensor(
91-
name=node.input[1] + "_shape", data_type=7, dims=(2,), vals=np.array(weight_shape, dtype="int64")
92-
)
93-
input_names = [node.input[0], q_weight_tensor.name, shape_tensor.name]
94-
new_inits = [q_weight_tensor, shape_tensor]
166+
new_inits.append(q_weight_tensor)
95167

96-
kwargs = {}
97-
kwargs["blk_quant_type"] = 1 if zero_point is not None else 0
98168
matmul_weight_only_node = onnx.helper.make_node(
99-
"MatMulFpQ4",
169+
op_type,
100170
inputs=input_names,
101171
outputs=node.output,
102172
name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits),
@@ -257,8 +327,11 @@ def rtn_quantize(
257327

258328
weight = pad_tensor(weight, group_size, k_blocks)
259329

260-
if Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32: # pragma: no cover
261-
# currently MatMulFpQ4 only support 4 bits and 32 group_size
330+
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
331+
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
332+
): # pragma: no cover
333+
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
334+
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
262335
q_weight, scale, zp = quant_tensor(
263336
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
264337
)
@@ -361,9 +434,11 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
361434
weight = weight.T * scales
362435
weight = pad_tensor(weight, group_size, (org_w_shape[0] + group_size - 1) // group_size).T
363436

364-
if (
437+
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
365438
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
366439
): # pragma: no cover
440+
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
441+
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
367442
q_weight = qdq_tensor(weight, num_bits, group_size, scheme, "uint") / np.expand_dims(
368443
scales, axis=-1
369444
)
@@ -504,10 +579,11 @@ def apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, g
504579
for i_s in range(10):
505580
ratio = 1 - i_s / 100
506581
weight = copy.deepcopy(org_weight)
507-
if (
582+
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
508583
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
509584
): # pragma: no cover
510-
# currently MatMulFpQ4 only support 4 bits and 32 group_size
585+
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
586+
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
511587
weight = qdq_tensor(weight, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1))
512588
else:
513589
weight = qdq_tensor(weight, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1))
@@ -571,9 +647,9 @@ def prepare_inputs(model, n_samples, dataloader):
571647

572648
if isinstance(data[0], dict):
573649
inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()]))
574-
elif isinstance(data[0], np.ndarray):
650+
elif isinstance(data[0], np.ndarray): # pragma: no cover
575651
inputs.append(dict([(name, inp) for name, inp in zip(inputs_names, [data[0]])]))
576-
else:
652+
else: # pragma: no cover
577653
inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0])]))
578654
return inputs, so
579655

@@ -982,8 +1058,11 @@ def gptq_quantize(
9821058

9831059
weight_tensor = model.get_initializer(node.input[1])
9841060
init_share_num = model.get_initializer_share_num(node.input[1])
985-
if Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32: # pragma: no cover
986-
# currently MatMulFpQ4 only support 4 bits and 32 group_size
1061+
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
1062+
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
1063+
): # pragma: no cover
1064+
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
1065+
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
9871066
org_shape = weight.shape
9881067
k_blocks = (org_shape[0] + group_size - 1) // group_size
9891068
q_weight = pad_tensor(q_weight, group_size, k_blocks)

0 commit comments

Comments
 (0)