Skip to content

Support MatMulNBit op for ort 1.17 #1327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 24, 2023
Merged
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,7 +1694,7 @@ def _dump_model_op_stats(self, model, tune_cfg):

dtype_set = set()
for node in model.nodes():
if node.op_type == "MatMulFpQ4":
if node.op_type in ["MatMulFpQ4", "MatMulNBits"]:
optype = "MatMul"
else:
optype = node.op_type
Expand Down
153 changes: 116 additions & 37 deletions neural_compressor/adaptor/ox_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@
ort = LazyImport("onnxruntime")
logger = logging.getLogger("neural_compressor")
ONNXRT116_VERSION = Version("1.16.0")
ONNXRT1161_VERSION = Version("1.16.1")


def get_blob_size(group_size, has_zp): # pragma: no cover
"""Get blob_size.

Args:
group_size (int): how many elements share one scale/zp
has_zp (bool): whether zero_point is None
"""
if Version(ort.__version__) > ONNXRT1161_VERSION:
blob_size = group_size // 2
elif has_zp:
blob_size = group_size // 2 + 4 + 1
else:
blob_size = group_size // 2 + 4
return blob_size


def make_matmul_weight_only_node(
Expand All @@ -54,49 +71,102 @@ def make_matmul_weight_only_node(
zero_point (array): zero point

Returns:
matmul_weight_only_node: MatMulFpQ4 node
new_inits: initializers of the MatMulFpQ4 node
matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node
new_inits: initializers of the new node
"""
if zero_point is not None:
blob_size = group_size // 2 + 4 + 1
offset = 5
else:
blob_size = group_size // 2 + 4
offset = 4

blob_size = get_blob_size(group_size, zero_point is not None)
packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
for i in range(q_weight.shape[0]):
bf = struct.pack("f", scale[i])
packed[i][0] = bf[0]
packed[i][1] = bf[1]
packed[i][2] = bf[2]
packed[i][3] = bf[3]
q_weight_name = node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size))
input_names = [node.input[0], q_weight_name]
new_inits = []
kwargs = {}

if Version(ort.__version__) > ONNXRT1161_VERSION:
op_type = "MatMulNBits"

# pack quantized weight
for i in range(q_weight.shape[0]):
for k in range(0, group_size, 2):
packed[i][k // 2] = q_weight[i][k] | q_weight[i][k + 1] << 4
packed = np.reshape(packed, (-1, k_blocks, blob_size))

# build scale tensor
scale = np.reshape(scale, (-1, k_blocks)).astype("float32")
scale_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_scale", data_type=1, dims=scale.shape, vals=scale.tobytes(), raw=True
)
input_names.append(scale_tensor.name)
new_inits.append(scale_tensor)

# build zero_point tensor
if zero_point is not None:
packed[i][4] = zero_point[i]
if num_bits > 4:
packed_zp = np.reshape(zero_point, (1, -1)).astype("uint8")
else:
packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8")
for i in range(zero_point.shape[0] // k_blocks):
for j in range(k_blocks):
idx = i * k_blocks + j
zp = zero_point[idx]
packed_zp[idx // 2] = (
((packed_zp[idx // 2] & 0x0F) | (zp << 4))
if (idx & 1)
else ((packed_zp[idx // 2] & 0xF0) | zp)
)

zp_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True
)
input_names.append(zp_tensor.name)
new_inits.append(zp_tensor)

# set kwargs
kwargs["K"] = weight_shape[0]
kwargs["N"] = weight_shape[1]
kwargs["bits"] = num_bits
kwargs["block_size"] = group_size

else:
offset = 5 if zero_point is not None else 4
op_type = "MatMulFpQ4"

# pack quantized weight
for i in range(q_weight.shape[0]):
bf = struct.pack("f", scale[i])
packed[i][0] = bf[0]
packed[i][1] = bf[1]
packed[i][2] = bf[2]
packed[i][3] = bf[3]

if zero_point is not None:
packed[i][4] = zero_point[i]

packed[i][offset:] = np.bitwise_or(
q_weight[i][: group_size // 2], np.left_shift(q_weight[i][group_size // 2 :], num_bits)
)
packed = packed.reshape(-1)

packed[i][offset:] = np.bitwise_or(
q_weight[i][: group_size // 2], np.left_shift(q_weight[i][group_size // 2 :], num_bits)
# build shape tensor
shape_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_shape", data_type=7, dims=(2,), vals=np.array(weight_shape, dtype="int64")
)
new_inits.append(shape_tensor)
input_names.append(shape_tensor.name)

# set kwargs
kwargs["blk_quant_type"] = 1 if zero_point is not None else 0

packed = packed.reshape(-1)
q_weight_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)),
name=q_weight_name,
data_type=2,
dims=packed.shape,
vals=packed.tobytes(),
raw=True,
)
shape_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_shape", data_type=7, dims=(2,), vals=np.array(weight_shape, dtype="int64")
)
input_names = [node.input[0], q_weight_tensor.name, shape_tensor.name]
new_inits = [q_weight_tensor, shape_tensor]
new_inits.append(q_weight_tensor)

kwargs = {}
kwargs["blk_quant_type"] = 1 if zero_point is not None else 0
matmul_weight_only_node = onnx.helper.make_node(
"MatMulFpQ4",
op_type,
inputs=input_names,
outputs=node.output,
name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits),
Expand Down Expand Up @@ -257,8 +327,11 @@ def rtn_quantize(

weight = pad_tensor(weight, group_size, k_blocks)

if Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32: # pragma: no cover
# currently MatMulFpQ4 only support 4 bits and 32 group_size
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
): # pragma: no cover
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
q_weight, scale, zp = quant_tensor(
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
)
Expand Down Expand Up @@ -361,9 +434,11 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
weight = weight.T * scales
weight = pad_tensor(weight, group_size, (org_w_shape[0] + group_size - 1) // group_size).T

if (
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
): # pragma: no cover
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
q_weight = qdq_tensor(weight, num_bits, group_size, scheme, "uint") / np.expand_dims(
scales, axis=-1
)
Expand Down Expand Up @@ -504,10 +579,11 @@ def apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, g
for i_s in range(10):
ratio = 1 - i_s / 100
weight = copy.deepcopy(org_weight)
if (
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
): # pragma: no cover
# currently MatMulFpQ4 only support 4 bits and 32 group_size
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
weight = qdq_tensor(weight, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1))
else:
weight = qdq_tensor(weight, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1))
Expand Down Expand Up @@ -571,9 +647,9 @@ def prepare_inputs(model, n_samples, dataloader):

if isinstance(data[0], dict):
inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()]))
elif isinstance(data[0], np.ndarray):
elif isinstance(data[0], np.ndarray): # pragma: no cover
inputs.append(dict([(name, inp) for name, inp in zip(inputs_names, [data[0]])]))
else:
else: # pragma: no cover
inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0])]))
return inputs, so

Expand Down Expand Up @@ -982,8 +1058,11 @@ def gptq_quantize(

weight_tensor = model.get_initializer(node.input[1])
init_share_num = model.get_initializer_share_num(node.input[1])
if Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32: # pragma: no cover
# currently MatMulFpQ4 only support 4 bits and 32 group_size
if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or (
Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32
): # pragma: no cover
# MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions
# MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1
org_shape = weight.shape
k_blocks = (org_shape[0] + group_size - 1) // group_size
q_weight = pad_tensor(q_weight, group_size, k_blocks)
Expand Down