diff --git a/.azure-pipelines/scripts/install_nc.sh b/.azure-pipelines/scripts/install_nc.sh index 2e7b3ab7c26..0c0535fb840 100644 --- a/.azure-pipelines/scripts/install_nc.sh +++ b/.azure-pipelines/scripts/install_nc.sh @@ -5,15 +5,19 @@ cd /neural-compressor if [[ $1 = *"3x_pt" ]]; then python -m pip install --no-cache-dir -r requirements_pt.txt python setup.py pt bdist_wheel - pip install dist/neural_compressor*.whl + pip install dist/neural_compressor*.whl --force-reinstall elif [[ $1 = *"3x_tf" ]]; then python -m pip install --no-cache-dir -r requirements_tf.txt python setup.py tf bdist_wheel - pip install dist/neural_compressor*.whl + pip install dist/neural_compressor*.whl --force-reinstall +elif [[ $1 = *"3x_ort" ]]; then + python -m pip install --no-cache-dir -r requirements_ort.txt + python setup.py ort bdist_wheel + pip install dist/neural_compressor*.whl --force-reinstall else python -m pip install --no-cache-dir -r requirements.txt python setup.py 2x bdist_wheel - pip install dist/neural_compressor*.whl + pip install dist/neural_compressor*.whl --force-reinstall fi echo -e "\n pip list after install Neural Compressor ... " diff --git a/.azure-pipelines/scripts/ut/3x/collect_log_3x.sh b/.azure-pipelines/scripts/ut/3x/collect_log_3x.sh index 4b3087b95f4..bdb79e6d7c8 100644 --- a/.azure-pipelines/scripts/ut/3x/collect_log_3x.sh +++ b/.azure-pipelines/scripts/ut/3x/collect_log_3x.sh @@ -1,5 +1,6 @@ source /neural-compressor/.azure-pipelines/scripts/change_color.sh +set -xe pip install coverage export COVERAGE_RCFILE=/neural-compressor/.azure-pipelines/scripts/ut/3x/coverage.${1} coverage_log="/neural-compressor/log_dir/coverage_log" @@ -22,7 +23,7 @@ cd /neural-compressor git config --global --add safe.directory /neural-compressor git fetch git checkout master -echo y | pip uninstall neural-compressor +echo y | pip uninstall neural_compressor_${1} cd /neural-compressor/.azure-pipelines/scripts && bash install_nc.sh ${1} coverage erase diff --git a/.azure-pipelines/scripts/ut/3x/coverage.3x_ort b/.azure-pipelines/scripts/ut/3x/coverage.3x_ort new file mode 100644 index 00000000000..1404dccbaee --- /dev/null +++ b/.azure-pipelines/scripts/ut/3x/coverage.3x_ort @@ -0,0 +1,15 @@ +[run] +branch = True + +[report] +include = + */neural_compressor/common/* + */neural_compressor/onnxrt/* +exclude_lines = + pragma: no cover + raise NotImplementedError + raise TypeError + if self.device == "gpu": + if device == "gpu": + except ImportError: + except Exception as e: \ No newline at end of file diff --git a/.azure-pipelines/scripts/ut/3x/run_3x_ort.sh b/.azure-pipelines/scripts/ut/3x/run_3x_ort.sh new file mode 100644 index 00000000000..21f423a2a4d --- /dev/null +++ b/.azure-pipelines/scripts/ut/3x/run_3x_ort.sh @@ -0,0 +1,34 @@ +#!/bin/bash +python -c "import neural_compressor as nc" +test_case="run 3x ONNXRT" +echo "${test_case}" + +# install requirements +echo "set up UT env..." +pip install -r /neural-compressor/test/3x/onnxrt/requirements.txt +pip install coverage +pip list + +export COVERAGE_RCFILE=/neural-compressor/.azure-pipelines/scripts/ut/3x/coverage.3x_ort +inc_path=$(python -c 'import neural_compressor; print(neural_compressor.__path__[0])') +cd /neural-compressor/test || exit 1 +find ./3x/onnxrt/* -name "test*.py" | sed 's,\.\/,coverage run --source='"${inc_path}"' --append ,g' | sed 's/$/ --verbose/'> run.sh + +LOG_DIR=/neural-compressor/log_dir +mkdir -p ${LOG_DIR} +ut_log_name=${LOG_DIR}/ut_3x_ort.log + +echo "cat run.sh..." +sort run.sh -o run.sh +cat run.sh | tee ${ut_log_name} +echo "------UT start-------" +bash -x run.sh 2>&1 | tee -a ${ut_log_name} +cp .coverage ${LOG_DIR}/.coverage + +echo "------UT end -------" + +if [ $(grep -c "FAILED" ${ut_log_name}) != 0 ] || [ $(grep -c "core dumped" ${ut_log_name}) != 0 ] || [ $(grep -c "ModuleNotFoundError:" ${ut_log_name}) != 0 ] || [ $(grep -c "OK" ${ut_log_name}) == 0 ];then + echo "Find errors in UT test, please check the output..." + exit 1 +fi +echo "UT finished successfully! " \ No newline at end of file diff --git a/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh b/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh index 7453adca931..cdc93eeee2b 100644 --- a/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh +++ b/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh @@ -5,9 +5,8 @@ echo "${test_case}" # install requirements echo "set up UT env..." -pip install transformers +pip install -r /neural-compressor/test/3x/torch/requirements.txt pip install coverage -pip install pytest pip list export COVERAGE_RCFILE=/neural-compressor/.azure-pipelines/scripts/ut/3x/coverage.3x_pt diff --git a/.azure-pipelines/scripts/ut/3x/run_3x_tf.sh b/.azure-pipelines/scripts/ut/3x/run_3x_tf.sh index 87bfca17a09..ce0cad9bab2 100644 --- a/.azure-pipelines/scripts/ut/3x/run_3x_tf.sh +++ b/.azure-pipelines/scripts/ut/3x/run_3x_tf.sh @@ -5,8 +5,8 @@ echo "${test_case}" # install requirements echo "set up UT env..." +pip install -r /neural-compressor/test/3x/tensorflow/requirements.txt pip install coverage -pip install pytest pip list export COVERAGE_RCFILE=/neural-compressor/.azure-pipelines/scripts/ut/3x/coverage.3x_tf diff --git a/.azure-pipelines/ut-3x-ort.yml b/.azure-pipelines/ut-3x-ort.yml new file mode 100644 index 00000000000..6619883ef6e --- /dev/null +++ b/.azure-pipelines/ut-3x-ort.yml @@ -0,0 +1,106 @@ +trigger: none + +pr: + autoCancel: true + drafts: false + branches: + include: + - master + paths: + include: + - neural_compressor/common + - neural_compressor/onnxrt + - test/3x/onnxrt + - setup.py + - requirements_ort.txt + +pool: ICX-16C + +variables: + IMAGE_NAME: "neural-compressor" + IMAGE_TAG: "py310" + UPLOAD_PATH: $(Build.SourcesDirectory)/log_dir + DOWNLOAD_PATH: $(Build.SourcesDirectory)/log_dir + ARTIFACT_NAME: "UT_coverage_report_3x_ort" + REPO: $(Build.Repository.Uri) + +stages: + - stage: ONNXRT + displayName: Unit Test 3x ONNXRT + dependsOn: [] + jobs: + - job: + displayName: Unit Test 3x ONNXRT + steps: + - template: template/ut-template.yml + parameters: + dockerConfigName: "commonDockerConfig" + utScriptFileName: "3x/run_3x_ort" + uploadPath: $(UPLOAD_PATH) + utArtifact: "ut_coverage_3x" + + + - stage: ONNXRT_baseline + displayName: Unit Test 3x ONNXRT baseline + dependsOn: [] + jobs: + - job: + displayName: Unit Test 3x ONNXRT baseline + steps: + - template: template/ut-template.yml + parameters: + dockerConfigName: "gitCloneDockerConfig" + utScriptFileName: "3x/run_3x_ort" + uploadPath: $(UPLOAD_PATH) + utArtifact: "ut_coverage_3x_baseline" + repo: $(REPO) + + - stage: Coverage + displayName: "Coverage Combine" + pool: + vmImage: "ubuntu-latest" + dependsOn: [ONNXRT, ONNXRT_baseline] + jobs: + - job: CollectDatafiles + steps: + - script: | + if [[ ! $(docker images | grep -i ${IMAGE_NAME}:${IMAGE_TAG}) ]]; then + docker build -f ${BUILD_SOURCESDIRECTORY}/.azure-pipelines/docker/Dockerfile.devel -t ${IMAGE_NAME}:${IMAGE_TAG} . + fi + docker images | grep -i ${IMAGE_NAME} + if [[ $? -ne 0 ]]; then + echo "NO Such Repo" + exit 1 + fi + displayName: "Build develop docker image" + + - task: DownloadPipelineArtifact@2 + inputs: + artifact: + path: $(DOWNLOAD_PATH) + + - script: | + echo "--- create container ---" + docker run -d -it --name="collectLogs" -v ${BUILD_SOURCESDIRECTORY}:/neural-compressor ${IMAGE_NAME}:${IMAGE_TAG} /bin/bash + echo "--- docker ps ---" + docker ps + echo "--- collect logs ---" + docker exec collectLogs /bin/bash +x -c "cd /neural-compressor/.azure-pipelines/scripts \ + && bash install_nc.sh 3x_ort \ + && bash ut/3x/collect_log_3x.sh 3x_ort" + displayName: "Collect UT Coverage" + + - task: PublishPipelineArtifact@1 + condition: succeededOrFailed() + inputs: + targetPath: $(UPLOAD_PATH) + artifact: $(ARTIFACT_NAME) + publishLocation: "pipeline" + + - task: Bash@3 + condition: always() + inputs: + targetType: "inline" + script: | + docker exec collectLogs bash -c "rm -fr /neural-compressor/* && rm -fr /neural-compressor/.* || true" + displayName: "Docker clean up" diff --git a/.azure-pipelines/ut-3x-pt.yml b/.azure-pipelines/ut-3x-pt.yml index 3a83c1e7d09..69fc5e718e4 100644 --- a/.azure-pipelines/ut-3x-pt.yml +++ b/.azure-pipelines/ut-3x-pt.yml @@ -88,7 +88,7 @@ stages: docker exec collectLogs /bin/bash +x -c "cd /neural-compressor/.azure-pipelines/scripts \ && bash install_nc.sh 3x_pt \ && bash ut/3x/collect_log_3x.sh 3x_pt" - displayName: "collect logs" + displayName: "Collect UT Coverage" - task: PublishPipelineArtifact@1 condition: succeededOrFailed() diff --git a/.azure-pipelines/ut-3x-tf.yml b/.azure-pipelines/ut-3x-tf.yml index 1824e350786..259135b716c 100644 --- a/.azure-pipelines/ut-3x-tf.yml +++ b/.azure-pipelines/ut-3x-tf.yml @@ -88,7 +88,7 @@ stages: docker exec collectLogs /bin/bash +x -c "cd /neural-compressor/.azure-pipelines/scripts \ && bash install_nc.sh 3x_tf \ && bash ut/3x/collect_log_3x.sh 3x_tf" - displayName: "collect logs" + displayName: "Collect UT Coverage" - task: PublishPipelineArtifact@1 condition: succeededOrFailed() diff --git a/.azure-pipelines/ut-basic-no-cover.yml b/.azure-pipelines/ut-basic-no-cover.yml index a0af4d5a5b3..395c8cc4be4 100644 --- a/.azure-pipelines/ut-basic-no-cover.yml +++ b/.azure-pipelines/ut-basic-no-cover.yml @@ -19,6 +19,7 @@ pr: - neural_compressor/common - neural_compressor/torch - neural_compressor/tensorflow + - neural_compressor/onnxrt pool: ICX-16C diff --git a/.azure-pipelines/ut-basic.yml b/.azure-pipelines/ut-basic.yml index 98aec0732ab..707b073654f 100644 --- a/.azure-pipelines/ut-basic.yml +++ b/.azure-pipelines/ut-basic.yml @@ -19,6 +19,7 @@ pr: - neural_compressor/common - neural_compressor/torch - neural_compressor/tensorflow + - neural_compressor/onnxrt pool: ICX-16C @@ -257,7 +258,7 @@ stages: docker exec collectLogs /bin/bash +x -c "cd /neural-compressor/.azure-pipelines/scripts \ && bash install_nc.sh \ && bash ut/collect_log.sh" - displayName: "collect logs" + displayName: "Collect UT Coverage" - task: PublishPipelineArtifact@1 condition: succeededOrFailed() diff --git a/neural_compressor/onnxrt/__init__.py b/neural_compressor/onnxrt/__init__.py new file mode 100644 index 00000000000..50496342a50 --- /dev/null +++ b/neural_compressor/onnxrt/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neural_compressor.onnxrt.utils.utility import register_algo +from neural_compressor.onnxrt.algorithms import rtn_quantize_entry + +from neural_compressor.onnxrt.quantization import ( + RTNConfig, + get_default_rtn_config, +) diff --git a/neural_compressor/onnxrt/algorithms/__init__.py b/neural_compressor/onnxrt/algorithms/__init__.py new file mode 100644 index 00000000000..168548659a2 --- /dev/null +++ b/neural_compressor/onnxrt/algorithms/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from neural_compressor.onnxrt.algorithms.weight_only.algo_entry import rtn_quantize_entry diff --git a/neural_compressor/onnxrt/algorithms/weight_only/__init__.py b/neural_compressor/onnxrt/algorithms/weight_only/__init__.py new file mode 100644 index 00000000000..28f108cb636 --- /dev/null +++ b/neural_compressor/onnxrt/algorithms/weight_only/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/neural_compressor/onnxrt/algorithms/weight_only/algo_entry.py b/neural_compressor/onnxrt/algorithms/weight_only/algo_entry.py new file mode 100644 index 00000000000..00d4aea4cdb --- /dev/null +++ b/neural_compressor/onnxrt/algorithms/weight_only/algo_entry.py @@ -0,0 +1,40 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path +from typing import Dict, Tuple, Union + +import onnx + +from neural_compressor.common.logger import Logger +from neural_compressor.common.utility import RTN +from neural_compressor.onnxrt.quantization.config import RTNConfig +from neural_compressor.onnxrt.utils.utility import register_algo + +logger = Logger().get_logger() + + +###################### RTN Algo Entry ################################## +@register_algo(name=RTN) +def rtn_quantize_entry(model: Union[Path, str], quant_config: RTNConfig, *args, **kwargs) -> onnx.ModelProto: + """The main entry to apply rtn quantization.""" + from neural_compressor.onnxrt.algorithms.weight_only.rtn import apply_rtn_on_model + + # map config to each op + model_info = quant_config.get_model_info(model=model) + configs_mapping = quant_config.to_config_mapping(model_info=model_info) + logger.debug(configs_mapping) + model = apply_rtn_on_model(model, configs_mapping) + return model diff --git a/neural_compressor/onnxrt/algorithms/weight_only/rtn.py b/neural_compressor/onnxrt/algorithms/weight_only/rtn.py new file mode 100644 index 00000000000..cfbd7931777 --- /dev/null +++ b/neural_compressor/onnxrt/algorithms/weight_only/rtn.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 MIT HAN Lab +# This source code is licensed under the MIT license +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from pathlib import Path +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import onnx +import onnxruntime as ort +from packaging.version import Version + +from neural_compressor.onnxrt.algorithms.weight_only.utility import make_matmul_weight_only_node +from neural_compressor.onnxrt.quantization.config import RTNConfig +from neural_compressor.onnxrt.utils.onnx_model import ONNXModel +from neural_compressor.onnxrt.utils.utility import ( + ONNXRT116_VERSION, + ONNXRT1161_VERSION, + dtype_mapping, + simple_progress_bar, +) + + +def pad_tensor(weight, group_size, k_blocks): + """Pad tensor rowi so that it can be is divisible by group_size. + + Args: + weight (array): weight + group_size (int): how many elements share one scale/zp + k_blocks (int): the number of block + + Returns: + weight: paded weight + """ + if group_size == -1: + return weight + + org_w_shape = weight.shape + padded_rows = k_blocks * group_size + pad_len = padded_rows - org_w_shape[0] + + if pad_len > 0: + weight = np.pad(weight, ((0, pad_len), (0, 0)), "constant") + + return weight + + +def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0): + """Quantize tensor per group. + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + scheme (str, optional): quantization scheme. Defaults to "asym". + dtype (str, optional): data type. Defaults to "int". + ratio (float, optional): percentile of clip. Defaults to 1.0. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + data = np.reshape(data, (-1, group_size)) + if scheme == "asym" or dtype == "uint": + maxq = 2**num_bits - 1 + minq = 0 + elif scheme == "sym": + maxq = 2 ** (num_bits - 1) - 1 if num_bits != 1 else 0 + minq = -(2 ** (num_bits - 1)) if num_bits != 1 else -1 + + rmin = np.min(data, axis=1, keepdims=True) * ratio + rmax = np.max(data, axis=1, keepdims=True) * ratio + if scheme == "sym": + max_range = np.maximum(np.abs(rmin), np.abs(rmax)) + scale = np.ones(rmax.shape) + scale[max_range > 0] = np.array( + [float(i) / (maxq - minq) for i in (max_range[max_range > 0] * 2.0).flatten().tolist()] + ) + zero_point = ( + np.zeros(scale.shape) if dtype == "int" else np.ones(rmax.shape, dtype="uint8") * (1 << (num_bits - 1)) + ) + else: + scale = np.ones(rmax.shape) + scale[rmin != rmax] = np.array( + [float(i) / (maxq - minq) for i in (rmax - rmin)[rmin != rmax].flatten().tolist()] + ) + zero_point = ( + ((np.zeros(scale.shape) - rmin) / scale).round() + if dtype == "int" + else np.maximum(0, np.minimum(maxq, ((np.zeros(scale.shape) - rmin) / scale).round())).astype("uint8") + ) + return np.clip((data / scale + zero_point).round(), minq, maxq), scale, zero_point + + +def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0): + """Quant dequant tensor per group. + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + scheme (str, optional): quantization scheme. Defaults to "asym". + dtype (str, optional): data type. Defaults to "int". + ratio (float, optional): percentile of clip. Defaults to 1.0. + + Returns: + output: quant-dequant weight + """ + org_shape = data.shape + weight, scale, zp = quant_tensor(data, num_bits, group_size, scheme, dtype, ratio) + return np.reshape(scale * (weight - zp), org_shape) + + +def rtn_quantize( + model: Union[onnx.ModelProto, ONNXModel, Path, str], + weight_config: Optional[Dict[tuple, dict]] = {}, + num_bits: Optional[int] = 4, + group_size: Optional[int] = 32, + scheme: Optional[str] = "asym", + ratios: Optional[int] = {}, + accuracy_level: Optional[int] = 0, + providers: Optional[list] = ["CPUExecutionProvider"], +) -> onnx.ModelProto: + """Quantize the model with round to nearst method. + + Args: + model (Union[onnx.ModelProto, ONNXModel, Path, str]): onnx model + weight_config (Optional[Dict[tuple, dict]], optional): quantization config + For example, + weight_config = { + '(fc2, "MatMul")': + { + 'weight_dtype': 'int', + 'weight_bits': 4, + 'weight_group_size': 32, + 'weight_sym': True, + 'accuracy_level': 0 + } + }. Defaults to {}. + num_bits (Optional[int], optional): num_bits. Defaults to 4. + group_size (Optional[int], optional): how many elements share one scale/zp. Defaults to 32. + scheme (Optional[str], optional): sym or asym. Defaults to "asym". + ratios (Optional[int], optional): percentile of clip. Defaults to {}. + accuracy_level (Optional[int], optional): + accuracy level. Support 0 (unset), 1(fp32 compute type of jblas kernel), + 2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel), + 4 (int8 compute type of jblas kernel). Defaults to 0. + providers (Optional[list], optional): providers to use. Defaults to ["CPUExecutionProvider"]. + + Returns: + onnx.ModelProto: quantized ONNXModel + """ + if not isinstance(model, ONNXModel): + model = ONNXModel(model) + base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" + new_nodes = [] + remove_nodes = [] + total_num = len([i for i in model.nodes() if i.op_type in ["MatMul"]]) + curr_id = 0 + for node in model.nodes(): + if node.op_type in ["MatMul"]: + curr_id += 1 + simple_progress_bar(total_num, curr_id) + + # check op_type of node is MatMul + # check dim 1 of input is weight tensor + # check weight_type is not "fp32" + if ( + node.op_type in ["MatMul"] # check op_type of node is MatMul + and model.get_initializer(node.input[1]) is not None + and weight_config.get((node.name, node.op_type), {}).get("weight_dtype", "fp32") != "fp32" + ): + weight_tensor = model.get_initializer(node.input[1]) + weight = onnx.numpy_helper.to_array(weight_tensor, base_dir=base_dir).copy() + if len(weight.shape) != 2: + continue + + dtype = weight.dtype + if (node.name, node.op_type) in weight_config: + num_bits = weight_config[(node.name, node.op_type)].get("weight_bits", 4) + group_size = weight_config[(node.name, node.op_type)].get("weight_group_size", 32) + scheme = "sym" if weight_config[(node.name, node.op_type)].get("weight_sym", True) else "asym" + accuracy_level = weight_config[(node.name, node.op_type)].get("accuracy_level", 0) + + org_w_shape = weight.shape # ic, oc + group_size = group_size if group_size != -1 else org_w_shape[0] + + k_blocks = (org_w_shape[0] - 1) // group_size + 1 + init_share_num = model.get_initializer_share_num(node.input[1]) + + weight = pad_tensor(weight, group_size, k_blocks) + + satisfy_MatMulNBits_condition = Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4 + satisfy_MatMulFpQ4_condition = ( + Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32 + ) + if ("CUDAExecutionProvider" in providers and satisfy_MatMulNBits_condition) or ( + "CUDAExecutionProvider" not in providers + and (satisfy_MatMulFpQ4_condition or satisfy_MatMulNBits_condition) + ): # pragma: no cover + # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP + # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP + q_weight, scale, zp = quant_tensor( + weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) + ) + q_matmul_node, new_inits = make_matmul_weight_only_node( + node=node, + weight_shape=org_w_shape, + num_bits=num_bits, + group_size=group_size, + k_blocks=k_blocks, + q_weight=q_weight.astype("uint8"), + scale=scale.astype(dtype), + zero_point=zp if scheme == "asym" else None, + accuracy_level=accuracy_level, + ) + + model.add_initializers(new_inits) + remove_nodes.append(node) + new_nodes.append(q_matmul_node) + else: + q_weight = qdq_tensor(weight.T, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1)) + q_weight = np.reshape(q_weight, (org_w_shape[1], -1)) + q_weight = np.transpose(q_weight) + q_weight = q_weight[: org_w_shape[0], :].astype(dtype) + q_weight_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), + data_type=dtype_mapping[str(dtype)], + dims=weight.shape, + vals=q_weight.tobytes(), + raw=True, + ) + model.add_initializer(q_weight_tensor) + node.input[1] = q_weight_tensor.name + if init_share_num == 1: + model.remove_initializer(weight_tensor) + + model.add_nodes(new_nodes) + model.remove_nodes(remove_nodes) + model.topological_sort() + return model.model + + +def apply_rtn_on_model(model: onnx.ModelProto, quant_config: Dict[Tuple[str, callable], RTNConfig]) -> onnx.ModelProto: + if "providers" in quant_config: + providers = quant_config.pop("providers") + + # change op config to dict type + for op_name_type, op_config in quant_config.items(): + if isinstance(op_config, RTNConfig): + quant_config[op_name_type] = op_config.to_dict() + + return rtn_quantize(model, weight_config=quant_config, providers=providers) diff --git a/neural_compressor/onnxrt/algorithms/weight_only/utility.py b/neural_compressor/onnxrt/algorithms/weight_only/utility.py new file mode 100644 index 00000000000..281859a8be7 --- /dev/null +++ b/neural_compressor/onnxrt/algorithms/weight_only/utility.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 MIT HAN Lab +# This source code is licensed under the MIT license +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import struct + +import numpy as np +import onnx +import onnxruntime as ort +from packaging.version import Version + +from neural_compressor.onnxrt.utils.utility import ONNXRT1161_VERSION, dtype_mapping + + +def get_blob_size(group_size, has_zp): # pragma: no cover + """Get blob_size. + + Args: + group_size (int): how many elements share one scale/zp + has_zp (bool): whether zero_point is None + """ + if Version(ort.__version__) > ONNXRT1161_VERSION: + blob_size = group_size // 2 + elif has_zp: + blob_size = group_size // 2 + 4 + 1 + else: + blob_size = group_size // 2 + 4 + return blob_size + + +def make_matmul_weight_only_node( + node, + weight_shape, + num_bits, + group_size, + k_blocks, + q_weight, + scale, + zero_point, + accuracy_level=0, +): # pragma: no cover + """Build MatMulFpQ4 node. + + Args: + node: original matmul node + weight_shape: original weight shape + num_bits (int): num_bits + group_size (int): how many elements share one scale/zp + k_blocks (int): block number + q_weight (array): quantized weight + scale (array): scale + zero_point (array): zero point + accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32 compute type of jblas kernel), + 2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel), + 4 (int8 compute type of jblas kernel) + + Returns: + matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node + new_inits: initializers of the new node + """ + blob_size = get_blob_size(group_size, zero_point is not None) + packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8") + q_weight_name = node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)) + input_names = [node.input[0], q_weight_name] + new_inits = [] + kwargs = {} + + if Version(ort.__version__) > ONNXRT1161_VERSION: + op_type = "MatMulNBits" + + # pack quantized weight + for i in range(q_weight.shape[0]): + for k in range(0, group_size, 2): + packed[i][k // 2] = q_weight[i][k] | q_weight[i][k + 1] << 4 + packed = np.reshape(packed, (-1, k_blocks, blob_size)) + + # build scale tensor + scale = np.reshape(scale, (-1, k_blocks)) + scale_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_scale", + data_type=dtype_mapping[str(scale.dtype)], + dims=scale.shape, + vals=scale.tobytes(), + raw=True, + ) + input_names.append(scale_tensor.name) + new_inits.append(scale_tensor) + + # build zero_point tensor + if zero_point is not None: + if num_bits > 4: + packed_zp = np.reshape(zero_point, (1, -1)).astype("uint8") + else: + packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8") + for i in range(zero_point.shape[0] // k_blocks): + for j in range(k_blocks): + idx = i * k_blocks + j + zp = zero_point[idx] + packed_zp[idx // 2] = ( + ((packed_zp[idx // 2] & 0x0F) | (zp << 4)) + if (idx & 1) + else ((packed_zp[idx // 2] & 0xF0) | zp) + ) + + zp_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True + ) + input_names.append(zp_tensor.name) + new_inits.append(zp_tensor) + + # set kwargs + kwargs["K"] = weight_shape[0] + kwargs["N"] = weight_shape[1] + kwargs["bits"] = num_bits + kwargs["block_size"] = group_size + if accuracy_level > 0: + # require onnxruntime > 1.16.3 + kwargs["accuracy_level"] = accuracy_level + + else: + offset = 5 if zero_point is not None else 4 + op_type = "MatMulFpQ4" + + # pack quantized weight + for i in range(q_weight.shape[0]): + bf = struct.pack("f", scale[i]) + packed[i][0] = bf[0] + packed[i][1] = bf[1] + packed[i][2] = bf[2] + packed[i][3] = bf[3] + + if zero_point is not None: + packed[i][4] = zero_point[i] + + packed[i][offset:] = np.bitwise_or( + q_weight[i][: group_size // 2], np.left_shift(q_weight[i][group_size // 2 :], num_bits) + ) + packed = packed.reshape(-1) + + # build shape tensor + shape_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_shape", data_type=7, dims=(2,), vals=np.array(weight_shape, dtype="int64") + ) + new_inits.append(shape_tensor) + input_names.append(shape_tensor.name) + + # set kwargs + kwargs["blk_quant_type"] = 1 if zero_point is not None else 0 + + q_weight_tensor = onnx.helper.make_tensor( + name=q_weight_name, + data_type=2, + dims=packed.shape, + vals=packed.tobytes(), + raw=True, + ) + new_inits.append(q_weight_tensor) + + matmul_weight_only_node = onnx.helper.make_node( + op_type, + inputs=input_names, + outputs=node.output, + name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits), + domain="com.microsoft", + **kwargs, + ) + return matmul_weight_only_node, new_inits diff --git a/neural_compressor/onnxrt/quantization/__init__.py b/neural_compressor/onnxrt/quantization/__init__.py new file mode 100644 index 00000000000..4bd664019ac --- /dev/null +++ b/neural_compressor/onnxrt/quantization/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neural_compressor.onnxrt.quantization.quantize import _quantize +from neural_compressor.onnxrt.quantization.config import ( + RTNConfig, + get_default_rtn_config, +) diff --git a/neural_compressor/onnxrt/quantization/config.py b/neural_compressor/onnxrt/quantization/config.py new file mode 100644 index 00000000000..5ba32b55e38 --- /dev/null +++ b/neural_compressor/onnxrt/quantization/config.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from collections import OrderedDict +from enum import Enum +from pathlib import Path +from typing import Callable, List, NamedTuple, Optional, Tuple, Union + +import onnx + +from neural_compressor.common.base_config import BaseConfig, register_config +from neural_compressor.common.logger import Logger +from neural_compressor.common.utility import DEFAULT_WHITE_LIST, OP_NAME_OR_MODULE_TYPE, RTN + +logger = Logger().get_logger() + +FRAMEWORK_NAME = "onnxrt" + + +class Backend(Enum): + DEFAULT = "onnxrt_cpu" + CUDA = "onnxrt_cuda" + + +class OperatorConfig(NamedTuple): + config: BaseConfig + operators: List[Union[str, Callable]] + backend: List[Backend] + valid_func_list: List[Callable] = [] + + +######################## RNT Config ############################### + + +@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN) +class RTNConfig(BaseConfig): + """Config class for round-to-nearest weight-only quantization.""" + + supported_configs: List[OperatorConfig] = [] + node_params_list = [ + "weight_dtype", + "weight_bits", + "weight_group_size", + "weight_sym", + "act_dtype", + "accuracy_level", + ] + model_params_list = ["providers"] + params_list = node_params_list + model_params_list + name = RTN + + def __init__( + self, + weight_dtype: str = "int", + weight_bits: int = 4, + weight_group_size: int = 32, + weight_sym: bool = True, + act_dtype: str = "fp32", + accuracy_level: int = 0, + providers: list = ["CPUExecutionProvider"], + white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + ): + """Init RTN weight-only quantization config. + + Args: + weight_dtype (str): Data type for weights, default is "int". + weight_bits (int): Number of bits used to represent weights, default is 4. + weight_group_size (int): Size of weight groups, default is 32. + weight_sym (bool): Indicates whether weights are symmetric, default is True. + act_dtype (str): Data type for activations, default is "fp32". + """ + super().__init__(white_list=white_list) + self.weight_bits = weight_bits + self.weight_dtype = weight_dtype + self.weight_group_size = weight_group_size + self.weight_sym = weight_sym + self.act_dtype = act_dtype + self.accuracy_level = accuracy_level + self.providers = providers + self._post_init() + + def get_model_params_dict(self): + result = dict() + for param in self.model_params_list: + result[param] = getattr(self, param) + return result + + def to_dict(self): + return super().to_dict(params_list=self.params_list) + + @classmethod + def from_dict(cls, config_dict): + return super(RTNConfig, cls).from_dict(config_dict=config_dict) + + @classmethod + def register_supported_configs(cls) -> List[OperatorConfig]: + supported_configs = [] + linear_rtn_config = RTNConfig( + weight_dtype=["int"], + weight_bits=[4, 3, 8], + weight_group_size=[32, -1, 1, 16, 64, 128, 256, 512, 1024], + weight_sym=[True, False], + act_dtype=["fp32"], + ) + operators = ["MatMul"] + supported_configs.append(OperatorConfig(config=linear_rtn_config, operators=operators, backend=Backend.DEFAULT)) + cls.supported_configs = supported_configs + + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDict[Union[str, Callable], OrderedDict[str, BaseConfig]]: + config_mapping = OrderedDict() + if config_list is None: + config_list = [self] + for config in config_list: + # update model level setting + config_mapping.update(config.get_model_params_dict()) + + # update node level setting + global_config = config.global_config + op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() + for op_name, op_type in model_info: + if self.global_config is not None: + config_mapping[(op_name, op_type)] = global_config + if op_type in op_type_config_dict: + config_mapping[(op_name, op_type)] = op_name_config_dict[op_type] + for op_name_pattern in op_name_config_dict: + if re.match(op_name_pattern, op_name): + config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern] + return config_mapping + + @staticmethod + def get_model_info(model: Union[onnx.ModelProto, Path, str]) -> List[Tuple[str, Callable]]: + if not isinstance(model, onnx.ModelProto): + model = onnx.load(model) + white_list = ["MatMul"] + filter_result = [] + for node in model.graph.node: + if node.op_type in white_list: + pair = (node.name, node.op_type) + filter_result.append(pair) + logger.debug(f"Get model info: {filter_result}") + return filter_result + + +# TODO(Yi) run `register_supported_configs` for all registered config. +RTNConfig.register_supported_configs() + + +def get_default_rtn_config() -> RTNConfig: + """Generate the default rtn config. + + Returns: + the default rtn config. + """ + return RTNConfig() diff --git a/neural_compressor/onnxrt/quantization/quantize.py b/neural_compressor/onnxrt/quantization/quantize.py new file mode 100644 index 00000000000..976150bff45 --- /dev/null +++ b/neural_compressor/onnxrt/quantization/quantize.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Optional, Tuple + +import onnx +from onnxruntime.quantization import CalibrationDataReader + +from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry +from neural_compressor.common.logger import Logger +from neural_compressor.onnxrt.quantization.config import FRAMEWORK_NAME +from neural_compressor.onnxrt.utils.utility import algos_mapping + +logger = Logger().get_logger() + + +def need_apply(quant_config: BaseConfig, algo_name): + return quant_config.name == algo_name if hasattr(quant_config, "name") else False + + +# only for internal usage now +def _quantize( + model_input: Tuple[Path, str], + quant_config: BaseConfig, + calibration_data_reader: Optional[CalibrationDataReader] = None, +) -> onnx.ModelProto: + """The main entry to quantize a model. + + Args: + model_input (Tuple[Path, str]): Path or str to the model to quantize. + quant_config (BaseConfig): a quantization configuration. + + Returns: + onnx.ModelProto: The quantized model. + """ + registered_configs = config_registry.get_cls_configs() + if isinstance(quant_config, dict): + quant_config = ComposableConfig.from_dict(quant_config, config_registry=registered_configs[FRAMEWORK_NAME]) + logger.info(f"Parsed a config dict to construct the quantization config: {quant_config}.") + else: + assert isinstance( + quant_config, BaseConfig + ), f"Please pass a dict or config instance as the quantization configuration, but got {type(quant_config)}." + logger.info(f"Quantize model with config: \n {quant_config} \n") + + # select quantization algo according to config + for algo_name, algo_func in algos_mapping.items(): + if need_apply(quant_config, algo_name): + logger.info(f"Start to apply {algo_name} on the model.") + q_model = algo_func(model_input, quant_config, calibration_data_reader=calibration_data_reader) + return q_model diff --git a/neural_compressor/onnxrt/utils/__init__.py b/neural_compressor/onnxrt/utils/__init__.py new file mode 100644 index 00000000000..b9011c35785 --- /dev/null +++ b/neural_compressor/onnxrt/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/neural_compressor/onnxrt/utils/onnx_model.py b/neural_compressor/onnxrt/utils/onnx_model.py new file mode 100644 index 00000000000..57499316622 --- /dev/null +++ b/neural_compressor/onnxrt/utils/onnx_model.py @@ -0,0 +1,1249 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Class for ONNX model.""" + +import os +import sys +from pathlib import Path + +import onnx + +from neural_compressor.common.logger import Logger +from neural_compressor.onnxrt.utils.utility import MAXIMUM_PROTOBUF, find_by_name + +logger = Logger().get_logger() + + +class ONNXModel: + """Build ONNX model.""" + + def __init__(self, model, **kwargs): + """Initialize an ONNX model. + + Args: + model (str or ModelProto): path to onnx model or loaded ModelProto model object. + ignore_warning (bool): ignore large model warning. Default is False. + load_external_data (bool): load external data for large model. Default is True. + """ + self._model = model if not isinstance(model, str) else onnx.load(model, load_external_data=False) + self._model_path = None if not isinstance(model, str) else model + + self.check_is_large_model() + if self._is_large_model and self._model_path is None and not kwargs.get("ignore_warning", False): + logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize") + + if self._is_large_model and isinstance(model, str) and kwargs.get("load_external_data", True): + from onnx.external_data_helper import load_external_data_for_model + + load_external_data_for_model(self._model, os.path.dirname(self._model_path)) + + self._config = None + if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()): + from transformers import PretrainedConfig + + self._config = PretrainedConfig.from_pretrained(Path(model).parent.as_posix()) + + self.node_name_counter = {} + self._output_name_to_node = {} + self._input_name_to_nodes = {} + self._get_input_name_to_nodes(self._model.graph.node) + self._get_output_name_to_node(self._model.graph.node) + self._graph_info = {} + self._get_graph_info() + self._q_config = None + + def check_is_large_model(self): + """Check model > 2GB.""" + init_size = 0 + for init in self._model.graph.initializer: + # if initializer has external data location, return True + if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: + self._is_large_model = True + return + # if raise error of initializer size > 2GB, return True + try: + init_bytes = init.SerializeToString() + init_size += sys.getsizeof(init_bytes) + except Exception as e: + if "exceeds maximum protobuf size of 2GB" in str(e): + self._is_large_model = True + return + else: # pragma: no cover + raise e + if init_size > MAXIMUM_PROTOBUF: + self._is_large_model = True + return + self._is_large_model = False + + @property + def is_large_model(self): + """Check the onnx model is over 2GB.""" + return self._is_large_model + + @property + def model_path(self): + """Return model path.""" + return self._model_path + + @model_path.setter + def model_path(self, path): + """Set model path.""" + self._model_path = path + + def framework(self): + """Return framework.""" + return "onnxruntime" + + @property + def q_config(self): + """Return q_config.""" + return self._q_config + + @q_config.setter + def q_config(self, q_config): + """Set q_config.""" + self._q_config = q_config + + @property + def hf_config(self): + """Return huggingface config if model is Transformer-based.""" + return self._config + + @property + def model(self): + """Return model itself.""" + return self._model + + @model.setter + def model(self, model): + """Set model itself.""" + self._model = model + self._graph_info = {} + self._get_graph_info() + self._output_name_to_node = {} + self._input_name_to_nodes = {} + self._get_input_name_to_nodes(self._model.graph.node) + self._get_output_name_to_node(self._model.graph.node) + + def input(self): + """Return input of model.""" + return [i.name for i in self._model.graph.input] + + def output(self): + """Return output of model.""" + return [i.name for i in self._model.graph.output] + + def update(self): + """Update model info.""" + self._graph_info = {} + self._get_graph_info() + self._output_name_to_node = {} + self._input_name_to_nodes = {} + self._get_input_name_to_nodes(self._model.graph.node) + self._get_output_name_to_node(self._model.graph.node) + + @property + def graph_info(self): + """Return ORT Graph Info object holding information about backend graph.""" + return self._graph_info + + def _get_graph_info(self): + """Update graph info.""" + for node in self._model.graph.node: + self.graph_info.update({node.name: node.op_type}) + + def save(self, root): + """Save ONNX model.""" + if os.path.split(root)[0] != "" and not os.path.exists(os.path.split(root)[0]): + raise ValueError('"root" directory does not exists.') + if self.is_large_model: # pragma: no cover + from onnx.external_data_helper import load_external_data_for_model + + load_external_data_for_model(self._model, os.path.split(self._model_path)[0]) + onnx.save_model( + self._model, + root, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=root.split("/")[-1] + "_data", + size_threshold=1024, + convert_attribute=False, + ) + else: + onnx.save(self._model, root) + + if self._config is not None: + model_type = "" if not hasattr(self._config, "model_type") else getattr(self._config, "model_type") + setattr(self._config.__class__, "model_type", model_type) + output_config_file = Path(root).parent.joinpath("config.json").as_posix() + self._config.to_json_file(output_config_file, use_diff=False) + + def nodes(self): + """Return model nodes.""" + return self._model.graph.node + + def initializer(self): + """Return model initializer.""" + return self._model.graph.initializer + + def graph(self): + """Return model graph.""" + return self._model.graph + + def ir_version(self): + """Return model ir_version.""" + return self._model.ir_version + + def opset_import(self): + """Return model opset_import.""" + return self._model.opset_import + + def remove_node(self, node): + """Remove a node from model.""" + if node in self._model.graph.node: + self._model.graph.node.remove(node) + + def remove_nodes(self, nodes_to_remove): + """Remove nodes from model.""" + for node in nodes_to_remove: + self.remove_node(node) + + def add_node(self, node): + """Add a node to model.""" + self._model.graph.node.extend([node]) + + def add_nodes(self, nodes_to_add): + """Add nodes to model.""" + self._model.graph.node.extend(nodes_to_add) + + def add_initializer(self, tensor): + """Add a initializer to model.""" + if find_by_name(tensor.name, self._model.graph.initializer) is None: + self._model.graph.initializer.extend([tensor]) + + def add_initializers(self, tensors): + """Add initializers to model.""" + for tensor in tensors: + self.add_initializer(tensor) + + def get_initializer(self, name): + """Get an initializer by name.""" + for tensor in self._model.graph.initializer: + if tensor.name == name: + return tensor + return None + + def get_initializer_share_num(self, name): + """Get the number of shares of initializer.""" + num = 0 + if self.get_initializer(name) is None: + return num + + for node in self.nodes(): + if name in node.input: + num += 1 + return num + + def get_node(self, name): + """Get a node by name.""" + for node in self._model.graph.node: + if node.name == name: + return node + return None + + def remove_initializer(self, tensor): + """Remove an initializer from model.""" + if tensor in self._model.graph.initializer: + self._model.graph.initializer.remove(tensor) + + def remove_initializers(self, init_to_remove): + """Remove initializers from model.""" + for initializer in init_to_remove: + self.remove_initializer(initializer) + + def set_initializer(self, tensor, array, raw=False): + """Update initializer.""" + old_tensor = self.get_initializer(tensor) + self.remove_initializer(old_tensor) + dims = old_tensor.dims + data_type = old_tensor.data_type + new_tensor = ( + onnx.helper.make_tensor(tensor, data_type, dims, array.flatten().tolist()) + if not raw + else onnx.helper.make_tensor(tensor, data_type, dims, array.tostring(), raw=raw) + ) + self.add_initializer(new_tensor) + + @property + def input_name_to_nodes(self): + """Return input names of nodes.""" + return self._input_name_to_nodes + + def _get_input_name_to_nodes(self, nodes): + """Get input names of nodes.""" + for node in nodes: + attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(attrs) > 0: + for attr in attrs: + self._get_input_name_to_nodes(attr.g.node) + for input_name in node.input: + if len(input_name.strip()) != 0: + if input_name not in self._input_name_to_nodes: + self._input_name_to_nodes[input_name] = [node] + else: + self._input_name_to_nodes[input_name].append(node) + + @property + def output_name_to_node(self): + """Return output names of nodes.""" + return self._output_name_to_node + + def _get_output_name_to_node(self, nodes): + """Get output names of nodes.""" + for node in nodes: + attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(attrs) > 0: + for attr in attrs: + self._get_output_name_to_node(attr.g.node) + for output_name in node.output: + if len(output_name.strip()) != 0: + self._output_name_to_node[output_name] = node + + def get_siblings(self, node): + """Get siblings nodes.""" + siblings = [] + for parent in self.get_parents(node): + for child in self.get_children(parent): + if child.name != node.name: + siblings.append(child) + return siblings + + def get_children(self, node, input_name_to_nodes=None): + """Get children nodes.""" + if input_name_to_nodes is None: + input_name_to_nodes = self._input_name_to_nodes + + children = [] + for output in node.output: + if output in input_name_to_nodes: + for child in input_name_to_nodes[output]: + children.append(child) + return children + + def get_parents(self, node, output_name_to_node=None): + """Get parents nodes.""" + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + parents = [] + for input in node.input: + if input in output_name_to_node: + parents.append(output_name_to_node[input]) + return parents + + def get_parent(self, node, idx, output_name_to_node=None): + """Get parent node by idx.""" + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + if len(node.input) <= idx: + return None + + input = node.input[idx] + if input not in output_name_to_node: + return None + + return output_name_to_node[input] + + def find_node_by_name(self, node_name, new_nodes_list, graph): + """Find out node by name.""" + graph_nodes_list = list(graph.node) # deep copy + graph_nodes_list.extend(new_nodes_list) + node = find_by_name(node_name, graph_nodes_list) + return node + + def find_nodes_by_initializer(self, graph, initializer): + """Find all nodes with given initializer as an input.""" + nodes = [] + for node in graph.node: + for node_input in node.input: + if node_input == initializer.name: + nodes.append(node) + return nodes + + def get_scale_zero(self, tensor): + """Help function to get scale and zero_point.""" + if not tensor.endswith("_quantized"): + logger.debug("Find {} in the quantized graph is not quantized.".format(tensor)) + return None, None + + def _searcher(tensor_name): + """Search scale and zero point tensor recursively.""" + node = self._input_name_to_nodes[tensor_name][0] + parent = self._output_name_to_node[tensor_name] if tensor_name in self._output_name_to_node else None + direct_int8 = ["Reshape", "Transpose", "Squeeze", "Unsqueeze", "MaxPool", "Pad", "Split"] + if parent is not None and parent.op_type in direct_int8: + fp32_tensor_name = ( + parent.input[0] + .replace("_quantized", "") + .replace("_QuantizeLinear", "") + .replace("_QuantizeInput", "") + ) + elif node.op_type in ["Gather"]: # pragma: no cover + fp32_tensor_name = ( + node.output[0] + .replace("_quantized", "") + .replace("_QuantizeLinear", "") + .replace("_QuantizeInput", "") + ) + else: + fp32_tensor_name = ( + tensor_name.replace("_quantized", "").replace("_QuantizeLinear", "").replace("_QuantizeInput", "") + ) + scale = fp32_tensor_name + "_scale" + scale_tensor = self.get_initializer(scale) + zo = fp32_tensor_name + "_zero_point" + zo_tensor = self.get_initializer(zo) + + if scale_tensor is None or zo_tensor is None: + if parent is not None: + scale_tensor, zo_tensor = _searcher(parent.input[0]) + return scale_tensor, zo_tensor + + node = self._input_name_to_nodes[tensor][0] + # TODO check if scale_tensor and zero_point is needed + # for bias of qlinearconv, scale and zero_point is not needed + if (node.op_type == "QLinearConv" and tensor == node.input[-1]) or ( + node.op_type == "QGemm" and tensor == node.input[-3] + ): + return None, None + else: + scale_tensor, zo_tensor = _searcher(tensor) + assert scale_tensor, "missing scale for tensor {}".format(tensor) + assert zo_tensor, "missing zero point for tensor {}".format(tensor) + return scale_tensor, zo_tensor + + def save_model_to_file(self, output_path, use_external_data_format=False): + """Save model to external data, which is needed for model size > 2GB.""" + from onnx.external_data_helper import convert_model_to_external_data + + if use_external_data_format: + convert_model_to_external_data( + self._model, all_tensors_to_one_file=True, location=Path(output_path).name + ".data" + ) + onnx.save_model(self._model, output_path) + + @staticmethod + def replace_node_input(node, old_input_name, new_input_name): + """Replace input of a node.""" + assert isinstance(old_input_name, str) and isinstance(new_input_name, str) + for j in range(len(node.input)): + if node.input[j] == old_input_name: + node.input[j] = new_input_name + + def replace_input_of_all_nodes(self, old_input_name, new_input_name, white_optype=[], black_optype=[]): + """Replace inputs of all nodes.""" + if len(white_optype) > 0: + for node in self.model.graph.node: + if node.op_type in white_optype: + ONNXModel.replace_node_input(node, old_input_name, new_input_name) + else: + for node in self.model.graph.node: + if node.op_type not in black_optype: + ONNXModel.replace_node_input(node, old_input_name, new_input_name) + + @staticmethod + def replace_node_output(node, old_output_name, new_output_name): + """Replace output of a node.""" + assert isinstance(old_output_name, str) and isinstance(new_output_name, str) + for j in range(len(node.output)): + if node.output[j] == old_output_name: + node.output[j] = new_output_name + + def replace_output_of_all_nodes(self, old_output_name, new_output_name, white_optype=[], black_optype=[]): + """Replace outputs of all nodes.""" + if len(white_optype) > 0: + for node in self.model.graph.node: + if node.op_type in white_optype: + ONNXModel.replace_node_output(node, old_output_name, new_output_name) + else: + for node in self.model.graph.node: + if node.op_type not in black_optype: + ONNXModel.replace_node_output(node, old_output_name, new_output_name) + + def remove_unused_nodes(self): + """Remove unused nodes.""" + unused_nodes = [] + nodes = self.nodes() + for node in nodes: + if ( + node.op_type == "Constant" + and node.output[0] not in self._model.graph.output + and node.output[0] not in self._input_name_to_nodes + ): + unused_nodes.append(node) + elif ( + node.op_type == "QuantizeLinear" + and len(self.get_children(node)) == 1 + and self.get_children(node)[0].op_type == "DequantizeLinear" + and node.input[0] not in self._output_name_to_node + and self.get_children(node)[0].output[0] not in self._input_name_to_nodes + ): + unused_nodes.append(node) + unused_nodes.extend(self.get_children(node)) + else: + # remove the node if it does not serve as the input or output of any other nodes + unused = True + for output in node.output: + if output in self._input_name_to_nodes or output in self.output(): + unused = False + break + for input in node.input: + if self.get_initializer(input) is not None: + continue + elif input in self._output_name_to_node or input in self.input(): + unused = False + break + if unused: + unused_nodes.append(node) + self.remove_nodes(unused_nodes) + + ununsed_weights = [] + for w in self._model.graph.initializer: + if w.name not in self._input_name_to_nodes and w.name not in self._model.graph.output: + ununsed_weights.append(w) + # Remove from graph.input + for graph_input in self.graph().input: + if graph_input.name == w.name: + self.graph().input.remove(graph_input) + + self.remove_initializers(ununsed_weights) + self.update() + + def topological_sort(self, enable_subgraph=False): + """Topological sort the model.""" + import copy + from collections import deque + from functools import reduce + + if not enable_subgraph: + input_name_to_nodes = {} + output_name_to_node = {} + for node in self.model.graph.node: + for input_name in node.input: + if len(input_name.strip()) != 0: + if input_name not in input_name_to_nodes: + input_name_to_nodes[input_name] = [node] + else: + input_name_to_nodes[input_name].append(node) + for output_name in node.output: + if len(output_name.strip()) != 0: + output_name_to_node[output_name] = node + else: # pragma: no cover + input_name_to_nodes = self._input_name_to_nodes + output_name_to_node = self._output_name_to_node + + all_nodes = {} + q = deque() + wait = deque() + for inp in self.model.graph.input: + q.extend(input_name_to_nodes[inp.name]) + for n in self.model.graph.node: + if all([i not in output_name_to_node and i not in self.input() for i in n.input]): + q.append(n) + + while q: + n = q.popleft() + if not all([output_name_to_node[i].name in all_nodes for i in n.input if i in output_name_to_node]): + if n not in wait: + wait.append(n) + continue + + all_nodes[n.name] = n + for out in n.output: + if out in input_name_to_nodes: + q.extend([i for i in input_name_to_nodes[out] if i.name not in all_nodes and i not in q]) + if len(q) == 0 and len(wait) != 0: + q = copy.deepcopy(wait) + wait.clear() + nodes = [i[1] for i in all_nodes.items()] + assert len(list(set([n.name for n in nodes]))) == len(list(set([n.name for n in self.model.graph.node]))) + self.model.graph.ClearField("node") + self.model.graph.node.extend(nodes) + + def get_nodes_chain(self, start, stop, result_chain=[]): + """Get nodes chain with given start node and stop node.""" + from collections import deque + + from onnx import NodeProto + + # process start node list + start_node = deque() + for node in start: + if isinstance(node, str): + start_node.append(node) + elif isinstance(node, NodeProto): + start_node.append(node.name) + else: + assert False, "'get_nodes_chain' function only support list[string]" "or list[NodeProto] params" + + # process stop node list + stop_node = [] + for node in stop: + if isinstance(node, str): + stop_node.append(node) + elif isinstance(node, NodeProto): + stop_node.append(node.name) + else: + assert False, "'get_nodes_chain' function only support list[string]" "or list[NodeProto] params" + + while start_node: + node_name = start_node.popleft() + if node_name in stop_node: + continue + if node_name not in result_chain: + result_chain.append(node_name) + else: + continue + + node = find_by_name(node_name, list(self.model.graph.node)) + for parent in self.get_parents(node): + start_node.append(parent.name) + + return result_chain + + def find_split_node_for_layer_wise_quantization(self): + """Find split node for layer wise quantization.""" + # find split nodes of decoder blocks + # embed -> decoder.0 -(split_node)-> ... -(split_node)-> decoder.n -(split_node)-> norm -> head + # after split: embed -> decoder.0, + # decoder.1, + # decoder.2, + # ..., + # decoder.n, + # norm -> head + start_nodes = [] + for node in self._model.graph.node: + start_node, qkv_nodes_list = None, None + if node.op_type == "SkipLayerNormalization": + start_node = node + qkv_nodes_list = [ + self.match_parent_path( + start_node, + ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ), + ] + if node.op_type == "Add": + start_node = node + qkv_nodes_list = [ + # match base attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [0, None, 0, 0, 0], + ), + self.match_parent_path( + start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0] + ), + # match gpt attention no past structure + self.match_parent_path( + start_node, + ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], + [None, 0, 0, 0, 0, 0], + output_name_to_node=self.output_name_to_node, + return_indice=[], + ), + # match bart attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [0, None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["MatMul", "Mul", "MatMul", "Mul", "Div", "Add"], + [None, 0, None, 0, None, 0], + ), + self.match_parent_path( + start_node, + ["MatMul", "Mul", "MatMul", "SimplifiedLayerNormalization", "Add"], + [None, 0, None, 0, 0], + ), + ] + if not start_node: + continue + if not any(qkv_nodes_list): + continue + start_nodes.append(start_node) + return start_nodes + + def find_qkv_in_attention(self, find_all=False): + """Find qkv MatMul in Attention. + + Args: + find_all (bool, optional): find all qkv MatMul. Defaults to False + + Returns: + qkv (list): qkv MatMul list + """ + qkv = [] + for node in self._model.graph.node: + if node.op_type == "Attention": + qkv.append([node.name]) + continue + start_node, qkv_nodes_list = None, None + if node.op_type == "SkipLayerNormalization": + start_node = node + qkv_nodes_list = [ + self.match_parent_path( + start_node, + ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ), + ] + if node.op_type == "Add": + start_node = node + qkv_nodes_list = [ + # match base attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [0, None, 0, 0, 0], + ), + self.match_parent_path( + start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0] + ), + # match gpt attention no past structure + self.match_parent_path( + start_node, + ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], + [None, 0, 0, 0, 0, 0], + output_name_to_node=self.output_name_to_node, + return_indice=[], + ), + # match bart attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [0, None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, None, 0, 0, 0, 0], + ), + ] + if not start_node: + continue + if not any(qkv_nodes_list): + continue + qkv_nodes = [qkv for qkv in qkv_nodes_list if qkv is not None][-1] + other_inputs = [] + for input in start_node.input: + if input not in self.output_name_to_node: + continue + if input == qkv_nodes[0].output[0]: + continue + other_inputs.append(input) + if len(other_inputs) != 1: + continue + root_input = other_inputs[0] + input_name_to_nodes = self.input_name_to_nodes + children = input_name_to_nodes[root_input] + children_types = [child.op_type for child in children] + if children_types.count("MatMul") == 3: + qkv.append([child.name for child in children if child.op_type == "MatMul"]) + if not find_all: + break + return qkv + + def find_ffn_matmul(self, attention_index, attention_matmul_list, block_len): + """Find MatMul in FFN. + + Args: + attention_index (list): index of Attention + attention_matmul_list (list): list of Attention and MatMul nodes + block_len (int): block length + + Returns: + list: list of MatMul in FFN + """ + ffn_matmul = [] + for idx in range(len(attention_index)): + if idx != len(attention_index) - 1: + index = attention_index[idx + 1] + if index - 2 >= 0: + ffn_matmul.append([attention_matmul_list[index - 2], attention_matmul_list[index - 1]]) + else: + index = attention_index[idx] + if index + block_len - 1 < len(attention_matmul_list): + ffn_matmul.append( + [attention_matmul_list[index + block_len - 2], attention_matmul_list[index + block_len - 1]] + ) + return ffn_matmul + + def export(self, save_path, conf): + """Export Qlinear to QDQ model.""" + from neural_compressor.config import ONNXQlinear2QDQConfig + from neural_compressor.experimental.export import onnx_qlinear_to_qdq + + if isinstance(conf, ONNXQlinear2QDQConfig): + add_nodes, remove_nodes, inits = onnx_qlinear_to_qdq(self._model, self._input_name_to_nodes) + self.add_nodes(add_nodes) + self.remove_nodes(remove_nodes) + self.add_initializers(inits) + self.update() + self.remove_unused_nodes() + self.topological_sort() + self.save(save_path) + else: + logger.warning("Unsupported config for export, " "only ONNXQlinear2QDQConfig is supported!") + exit(0) + + def add_tensors_to_outputs(self, tensor_names): + """Add the tensors to the model outputs to gets their values. + + Args: + tensor_names: The names of tensors to be dumped. + """ + added_outputs = [] + for tensor in tensor_names: + if tensor not in self.output(): + added_tensor = onnx.helper.ValueInfoProto() + added_tensor.name = tensor + added_outputs.append(added_tensor) + self._model.graph.output.extend(added_outputs) # pylint: disable=no-member + + def remove_tensors_from_outputs(self, tensor_names): + """Remove the tensors from the model outputs. + + Args: + tensor_names: The names of tensors to be removed. + """ + removed_outputs = [] + for tensor in tensor_names: + if tensor in self.output(): + removed_outputs.append(self._model.graph.output[self.output().index(tensor)]) + for output in removed_outputs: + self._model.graph.output.remove(output) + + def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=[]): + """Find parent node based on constraints on op_type. + + Args: + node (str): current node name. + parent_op_type (str): constraint of parent node op_type. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + + Returns: + parent: The matched parent node. None if not found. + index: The input index of matched parent node. None if not found. + """ + for i, input in enumerate(node.input): + if input in output_name_to_node: + parent = output_name_to_node[input] + if parent.op_type == parent_op_type and parent not in exclude: + return parent, i + return None, None + + def match_parent( + self, + node, + parent_op_type, + input_index=None, + output_name_to_node=None, + exclude=[], + return_indice=None, + ): + """Find parent node based on constraints on op_type and index. + + Args: + node (str): current node name. + parent_op_type (str): constraint of parent node op_type. + input_index (int or None): only check the parent given input index of current node. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + return_indice (list): a list to append the input index when input_index is None. + + Returns: + parent: The matched parent node. + """ + assert node is not None + assert input_index is None or input_index >= 0 + + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + if input_index is None: + parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude) + if return_indice is not None: + return_indice.append(index) + return parent + + if input_index >= len(node.input): + return None + + parent = self.get_parent(node, input_index, output_name_to_node) + if parent is not None and parent.op_type == parent_op_type and parent not in exclude: + return parent + + return None + + def match_parent_path( + self, + node, + parent_op_types, + parent_input_index, + output_name_to_node=None, + return_indice=None, + ): + """Find a sequence of input edges based on constraints on parent op_type and index. + + Args: + node (str): current node name. + parent_op_types (str): constraint of parent node op_type of each input edge. + parent_input_index (list): constraint of input index of each input edge. + None means no constraint. + output_name_to_node (dict): dictionary with output name as key, and node as value. + return_indice (list): a list to append the input index when there is + no constraint on input index of an edge. + + Returns: + parents: a list of matched parent node. + """ + assert len(parent_input_index) == len(parent_op_types) + + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + current_node = node + matched_parents = [] + for i, op_type in enumerate(parent_op_types): + matched_parent = self.match_parent( + current_node, + op_type, + parent_input_index[i], + output_name_to_node, + exclude=[], + return_indice=return_indice, + ) + if matched_parent is None: + return None + + matched_parents.append(matched_parent) + current_node = matched_parent + + return matched_parents + + def is_smoothquant_model(self): + """Check the model is smooth quantized or not. + + Returns: + bool: the model is smooth quantized or not. + """ + for init in self.model.graph.initializer: + if "_smooth_scale" in init.name: + return True + return False + + def find_split_nodes(self): + """Find split nodes for layer-wise quantization.""" + split_nodes = self.find_split_node_for_layer_wise_quantization() + return split_nodes + + def split_model_with_node( + self, split_node_name, path_of_model_to_split, shape_infer=True, save_both_split_models=True + ): + """Split model into two parts at a given node. + + Args: + split_node_name (str): name of the node where the model is split at> + path_of_model_to_split (str): path of model to be split. + shape_infer (bool): do shape inference. Default is True. + save_both_split_models (bool): whether to save the two split models. + False means only save the first split model. + True means save both the two split models. + Default id True. + + Returns: + tuple: the first split model, the second split model + """ + # origin model : ... -> node_1 -> split_node -> node_2 -> ... + # split model 1: ... -> node_1 -> split_node + # split model 2: node_2 -> ... + + split_model_part_1 = onnx.ModelProto() + split_model_part_1.CopyFrom(self._model) + split_model_part_1.graph.ClearField("node") + + split_model_part_2 = onnx.ModelProto() + split_model_part_2.CopyFrom(self._model) + split_model_part_2.graph.ClearField("node") + + split_node_output = None + part_idx = 1 + for node in self._model.graph.node: + if part_idx == 1: + split_model_part_1.graph.node.append(node) + elif part_idx == 2: + split_model_part_2.graph.node.append(node) + + if node.name == split_node_name: + split_node_output = node.output + part_idx = 2 + + assert len(split_node_output) == 1, ( + "Only support split at node with 1 output tensor, while " + "current split node {} has {} output tensors".format(split_node_name, len(split_node_output)) + ) + split_tensor_name = split_node_output[0] + + # infer shape of the model to be split + if shape_infer: + try: + from neural_compressor.adaptor.ox_utils.util import infer_shapes + + self._model = infer_shapes(self._model, auto_merge=True, base_dir=os.path.dirname(self._model_path)) + except Exception as e: # pragma: no cover + logger.error( + "Shape infer fails for layer-wise quantization. " + "We would recommend checking the graph optimization level of your model " + "and setting it to 'DISABLE_ALL' or 'ENABLE_BASIC', " + "as this may help avoid this error." + ) + raise e + + split_tensor_type, split_tensor_shape = self._get_output_type_shape_by_tensor_name(split_tensor_name) + split_tensor = onnx.helper.make_tensor_value_info(split_tensor_name, split_tensor_type, split_tensor_shape) + + split_model_part_1 = ONNXModel(split_model_part_1, ignore_warning=True) + split_model_part_2 = ONNXModel(split_model_part_2, ignore_warning=True) + + # remove unused input & output + split_model_part_1._remove_unused_input_output() + split_model_part_2._remove_unused_input_output() + + split_model_part_1.model.graph.output.append(split_tensor) + split_model_part_2.model.graph.input.append(split_tensor) + + insert_output_for_model_1 = [] + insert_input_for_model_2 = [] + for output in split_model_part_1.output_name_to_node.keys(): + if output in split_model_part_2.input_name_to_nodes.keys(): + output_type, output_shape = self._get_output_type_shape_by_tensor_name(output) + output_tensor = onnx.helper.make_tensor_value_info(output, output_type, output_shape) + if output_tensor not in split_model_part_1.model.graph.output: + insert_output_for_model_1.append(output_tensor) + if output_tensor not in split_model_part_2.model.graph.input: + insert_input_for_model_2.append(output_tensor) + + # insert model 1 output + for output in insert_output_for_model_1: + split_model_part_1.model.graph.output.append(output) + + # insert model 2 input + for input in insert_input_for_model_2: + split_model_part_2.model.graph.input.append(input) + + # remove unused init + split_model_part_1.remove_unused_init() + split_model_part_2.remove_unused_init() + + split_model_part_1.update() + split_model_part_2.update() + + dir_of_model_to_split = os.path.dirname(path_of_model_to_split) + + split_model_part_1.load_model_initializer_by_tensor(dir_of_model_to_split) + split_model_part_1_path = os.path.join(dir_of_model_to_split, "split_model_part_1.onnx") + split_model_part_1.model_path = split_model_part_1_path + split_model_part_1._save_split_model(split_model_part_1_path) + split_model_part_1.check_is_large_model() + logger.debug("save split model part 1 to {} for layer wise quantization".format(split_model_part_1_path)) + + if save_both_split_models: + split_model_part_2.load_model_initializer_by_tensor(dir_of_model_to_split) + split_model_part_2_path = os.path.join(dir_of_model_to_split, "split_model_part_2.onnx") + split_model_part_2.model_path = split_model_part_2_path + split_model_part_2._save_split_model(split_model_part_2_path) + split_model_part_2.check_is_large_model() + logger.debug("save split model part 2 to {} for layer wise quantization".format(split_model_part_2_path)) + return split_model_part_1, split_model_part_2 + else: + return split_model_part_1, split_model_part_2 + + def _save_split_model(self, save_path): + """Save split model as external data for layer wise quantization. + + Args: + save_path (str): the path to save the split model + """ + if os.path.exists(save_path + "_data"): + os.remove(save_path + "_data") + onnx.save_model( + self._model, + save_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=save_path.split("/")[-1] + "_data", + size_threshold=1024, + convert_attribute=False, + ) + + def _get_output_type_shape_by_tensor_name(self, tensor_name): + """Get output type and shape with a tensor name. + + Args: + tensor_name (str): name of a tensor + + Returns: + tuple: output type and shape + """ + elem_type = onnx.TensorProto.FLOAT + shape = None + for output in self._model.graph.value_info: + if output.name == tensor_name: + elem_type = output.type.tensor_type.elem_type + shape = [ + dim.dim_value if dim.HasField("dim_value") else -1 for dim in output.type.tensor_type.shape.dim + ] + break + return elem_type, shape + + def _remove_unused_input_output(self): + """Remove unused input & output for split model.""" + remove_outputs = [] + remove_inputs = [] + for output in self._model.graph.output: + if output.name not in self.output_name_to_node.keys(): + remove_outputs.append(output) + + for input in self._model.graph.input: + if input.name not in self.input_name_to_nodes.keys(): + remove_inputs.append(input) + + for output in remove_outputs: + self._model.graph.output.remove(output) + for input in remove_inputs: + self._model.graph.input.remove(input) + + def remove_unused_init(self): + """Remove unused init.""" + remov_inits = [] + for init in self._model.graph.initializer: + if init.name not in self.input_name_to_nodes.keys(): + remov_inits.append(init) + self.remove_initializers(remov_inits) + + def load_model_initializer_by_tensor(self, data_path=None): + """Load model initializer by tensor. + + Args: + data_path (str, optional): the directory of saved initializer. Defaults to None. + """ + from onnx.external_data_helper import load_external_data_for_tensor + + if data_path is None: + data_path = os.path.dirname(self._model_path) + for init in self._model.graph.initializer: + if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: + load_external_data_for_tensor(init, data_path) + + def write_external_data_to_new_location(self, external_data_location="external.data", overwrite=False): + """Write external data of merged quantized model to new location to save memory. + + Args: + external_data_location (str, optional): external data location of merged quantized model. + Defaults to "external.data". + overwrite (bool, optional): if True, remove existed externa data. Defaults to False. + """ + from onnx.external_data_helper import convert_model_to_external_data, write_external_data_tensors + + if overwrite and os.path.exists(os.path.join(os.path.dirname(self._model_path), external_data_location)): + os.remove(os.path.join(os.path.dirname(self._model_path), external_data_location)) + self.load_model_initializer_by_tensor() + convert_model_to_external_data(self._model, location=external_data_location) + # TODO : if init is already saved, skip write it + write_external_data_tensors(self._model, filepath=os.path.dirname(self._model_path)) + + def merge_split_models(self, to_merge_model): + """Merge two split model into final model.""" + to_merge_model.write_external_data_to_new_location() + self.add_nodes([node for node in to_merge_model.nodes()]) + self.add_initializers([init for init in to_merge_model.initializer()]) + self.update() + + # add new output + for output in to_merge_model.graph().output: + if output.name not in self.output(): + self._model.graph.output.append(output) + + # remove unused output + remove_output = [] + for output in self._model.graph.output: + if output.name in to_merge_model.input(): + remove_output.append(output) + for output in remove_output: + self._model.graph.output.remove(output) + + # add new input + for input in to_merge_model.graph().input: + if ( + input.name not in self.input() + and input.name not in self.output() + and input.name not in self.output_name_to_node.keys() + ): + self._model.graph.input.append(input) + + def re_org_output(self, origin_output): + """Re-org output of merged model for layer-wise quantization.""" + outputs = {} + tmp_remove = [] + for output in self._model.graph.output: + outputs[output.name] = output + tmp_remove.append(output) + + for output in tmp_remove: + self._model.graph.output.remove(output) + + for out_name in origin_output: + self._model.graph.output.append(outputs[out_name]) diff --git a/neural_compressor/onnxrt/utils/utility.py b/neural_compressor/onnxrt/utils/utility.py new file mode 100644 index 00000000000..6844e28cd0b --- /dev/null +++ b/neural_compressor/onnxrt/utils/utility.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Callable, Dict, List, Tuple, Union + +import onnx +from packaging.version import Version + +from neural_compressor.common.logger import Logger + +logger = Logger().get_logger() + +ONNXRT116_VERSION = Version("1.16.0") +ONNXRT1161_VERSION = Version("1.16.1") + +# Dictionary to store a mapping between algorithm names and corresponding algo implementation(function) +algos_mapping: Dict[str, Callable] = {} + +# All constants for onnxrt +WHITE_MODULE_LIST = ["MatMul", "Conv"] + +MAXIMUM_PROTOBUF = 2147483648 + +dtype_mapping = { + "fp32": 1, + "float32": 1, + "uint8": 2, + "int8": 3, + "uint16": 4, + "int16": 5, + "int32": 6, + "int64": 7, + "string": 8, + "bool": 9, + "fp16": 10, + "float16": 10, + "double": 11, + "uint32": 12, + "uint64": 13, + "complex64": 14, + "complex128": 15, + "bf16": 16, + "bfloat16": 16, +} + + +def find_by_name(name, item_list): + """Helper function to find item by name in a list.""" + items = [] + for item in item_list: + assert hasattr(item, "name"), "{} should have a 'name' attribute defined".format(item) # pragma: no cover + if item.name == name: + items.append(item) + if len(items) > 0: + return items[0] + else: + return None + + +def simple_progress_bar(total, i): + """Progress bar for cases where tqdm can't be used.""" + progress = i / total + bar_length = 20 + bar = "#" * int(bar_length * progress) + spaces = " " * (bar_length - len(bar)) + percentage = progress * 100 + print(f"\rProgress: [{bar}{spaces}] {percentage:.2f}%", end="") + + +def register_algo(name): + """Decorator function to register algorithms in the algos_mapping dictionary. + + Usage example: + @register_algo(name=example_algo) + def example_algo(model: Union[onnx.ModelProto, Path, str], + quant_config: RTNConfig) -> onnx.ModelProto: + ... + + Args: + name (str): The name under which the algorithm function will be registered. + + Returns: + decorator: The decorator function to be used with algorithm functions. + """ + + def decorator(algo_func): + algos_mapping[name] = algo_func + return algo_func + + return decorator + + +def get_model_info( + model: Union[onnx.ModelProto, Path, str], white_op_type_list: List[Callable] +) -> List[Tuple[str, Callable]]: + if not isinstance(model, onnx.ModelProto): + model = onnx.load(model) + filter_result = [] + filter_result_set = set() + for node in model.graph.node: + if node.op_type in white_op_type_list: + pair = (node.name, node.op_type) + if pair not in filter_result_set: + filter_result_set.add(pair) + filter_result.append(pair) + logger.debug(f"Get model info: {filter_result}") + return filter_result diff --git a/requirements_ort.txt b/requirements_ort.txt new file mode 100644 index 00000000000..3a27c292e06 --- /dev/null +++ b/requirements_ort.txt @@ -0,0 +1,4 @@ +numpy +onnx +onnxruntime +onnxruntime-extensions diff --git a/setup.py b/setup.py index c0f5539e948..9d91d8ab8f0 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,8 @@ def fetch_requirements(path): "neural_compressor.torch.*", "neural_compressor.tensorflow", "neural_compressor.tensorflow.*", + "neural_compressor.onnxrt", + "neural_compressor.onnxrt.*", ], ), "package_data": {"": ["*.yaml"]}, @@ -91,6 +93,7 @@ def fetch_requirements(path): "neural_compressor.onnxrt.*", ], ), + "install_requires": fetch_requirements("requirements_ort.txt"), }, "neural_insights": { "project_name": "neural_insights", diff --git a/test/3x/onnxrt/quantization/weight_only/test_rtn.py b/test/3x/onnxrt/quantization/weight_only/test_rtn.py new file mode 100644 index 00000000000..541434796ad --- /dev/null +++ b/test/3x/onnxrt/quantization/weight_only/test_rtn.py @@ -0,0 +1,79 @@ +import os +import shutil +import unittest + +from optimum.exporters.onnx import main_export + +from neural_compressor.common.logger import Logger + +logger = Logger().get_logger() + + +def find_onnx_file(folder_path): + # return first .onnx file path in folder_path + for root, dirs, files in os.walk(folder_path): + for file in files: + if file.endswith(".onnx"): + return os.path.join(root, file) + return None + + +class TestRTNQuant(unittest.TestCase): + @classmethod + def setUpClass(self): + main_export( + "hf-internal-testing/tiny-random-gptj", + output="gptj", + ) + self.gptj = find_onnx_file("./gptj") + + @classmethod + def tearDownClass(self): + shutil.rmtree("gptj", ignore_errors=True) + + def setUp(self): + # print the test name + logger.info(f"Running ONNXRT TestRTNQuant test: {self.id()}") + + def _count_woq_matmul(self, q_model, bits=4, group_size=32): + op_names = [ + i.name + for i in q_model.graph.node + if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(bits, group_size)) + ] + return len(op_names) + + def _apply_rtn(self, quant_config): + logger.info(f"Test RTN with config {quant_config}") + from neural_compressor.onnxrt.quantization.quantize import _quantize + + fp32_model = self.gptj + qmodel = _quantize(fp32_model, quant_config) + self.assertIsNotNone(qmodel) + return qmodel + + def test_rtn(self): + from neural_compressor.onnxrt import RTNConfig + + # some tests were skipped to accelerate the CI + # TODO: check params combination. + # TODO: Add number check for group_size. + rtn_options = { + "weight_dtype": ["int"], + "weight_bits": [4, 3, 8], + "weight_group_size": [32], + "weight_sym": [True, False], + "act_dtype": ["fp32"], + } + from itertools import product + + keys = RTNConfig.params_list + for value in product(*rtn_options.values()): + d = dict(zip(keys, value)) + quant_config = RTNConfig(**d) + qmodel = self._apply_rtn(quant_config) + self.assertEqual(self._count_woq_matmul(qmodel, bits=value[1], group_size=value[2]), 30) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/3x/onnxrt/requirements.txt b/test/3x/onnxrt/requirements.txt new file mode 100644 index 00000000000..4165ba5e0a6 --- /dev/null +++ b/test/3x/onnxrt/requirements.txt @@ -0,0 +1,2 @@ +optimum +pytest diff --git a/test/3x/onnxrt/test_config.py b/test/3x/onnxrt/test_config.py new file mode 100644 index 00000000000..4b9801f76e2 --- /dev/null +++ b/test/3x/onnxrt/test_config.py @@ -0,0 +1,334 @@ +import copy +import os +import shutil +import unittest + +import numpy as np +import onnx +from optimum.exporters.onnx import main_export + +from neural_compressor.common.logger import Logger + +logger = Logger().get_logger() + + +def find_onnx_file(folder_path): + # return first .onnx file path in folder_path + for root, dirs, files in os.walk(folder_path): + for file in files: + if file.endswith(".onnx"): + return os.path.join(root, file) + return None + + +def build_simple_onnx_model(): + A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [1, 5, 5]) + C = onnx.helper.make_tensor_value_info("C", onnx.TensorProto.FLOAT, [1, 5, 2]) + D = onnx.helper.make_tensor_value_info("D", onnx.TensorProto.FLOAT, [1, 5, 2]) + H = onnx.helper.make_tensor_value_info("H", onnx.TensorProto.FLOAT, [1, 5, 2]) + + e_value = np.random.randint(2, size=(10)).astype(np.float32) + B_init = onnx.helper.make_tensor("B", onnx.TensorProto.FLOAT, [5, 2], e_value.reshape(10).tolist()) + E_init = onnx.helper.make_tensor("E", onnx.TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist()) + + matmul_node = onnx.helper.make_node("MatMul", ["A", "B"], ["C"], name="Matmul") + add = onnx.helper.make_node("Add", ["C", "E"], ["D"], name="add") + + f_value = np.random.randint(2, size=(10)).astype(np.float32) + F_init = onnx.helper.make_tensor("F", onnx.TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist()) + add2 = onnx.helper.make_node("Add", ["D", "F"], ["H"], name="add2") + + graph = onnx.helper.make_graph([matmul_node, add, add2], "test_graph_1", [A], [H], [B_init, E_init, F_init]) + model = onnx.helper.make_model(graph) + model = onnx.helper.make_model(graph, **{"opset_imports": [onnx.helper.make_opsetid("", 13)]}) + return model + + +class TestQuantizationConfig(unittest.TestCase): + @classmethod + def setUpClass(self): + main_export( + "hf-internal-testing/tiny-random-gptj", + output="gptj", + ) + self.gptj = find_onnx_file("./gptj") + + simple_onnx_model = build_simple_onnx_model() + onnx.save(simple_onnx_model, "simple_onnx_model.onnx") + self.simple_onnx_model = "simple_onnx_model.onnx" + + @classmethod + def tearDownClass(self): + shutil.rmtree("gptj", ignore_errors=True) + os.remove("simple_onnx_model.onnx") + + def setUp(self): + # print the test name + logger.info(f"Running TestQuantizationConfig test: {self.id()}") + + def _check_model_is_quantized(self, model): + node_optypes = [node.op_type for node in model.graph.node] + return "MatMulNBits" in node_optypes or "MatMulFpQ4" in node_optypes + + def _check_node_is_quantized(self, model, node_name): + for node in model.graph.node: + if (node.name == node_name or node.name == node_name + "_Q4") and node.op_type in [ + "MatMulNBits", + "MatMulFpQ4", + ]: + return True + return False + + def _count_woq_matmul(self, q_model, bits=4, group_size=32): + op_names = [ + i.name + for i in q_model.graph.node + if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(bits, group_size)) + ] + return len(op_names) + + def test_quantize_rtn_from_dict_default(self): + logger.info("test_quantize_rtn_from_dict_default") + from neural_compressor.onnxrt import get_default_rtn_config + from neural_compressor.onnxrt.quantization.quantize import _quantize + + fp32_model = self.simple_onnx_model + qmodel = _quantize(fp32_model, quant_config=get_default_rtn_config()) + self.assertIsNotNone(qmodel) + self.assertTrue(self._check_model_is_quantized(qmodel)) + + def test_quantize_rtn_from_dict_beginner(self): + from neural_compressor.onnxrt.quantization.quantize import _quantize + + quant_config = { + "rtn": { + "weight_bits": 4, + "weight_group_size": 32, + }, + } + fp32_model = self.simple_onnx_model + qmodel = _quantize(fp32_model, quant_config) + self.assertIsNotNone(qmodel) + self.assertIsNotNone(qmodel) + self.assertTrue(self._check_model_is_quantized(qmodel)) + + def test_quantize_rtn_from_class_beginner(self): + from neural_compressor.onnxrt import RTNConfig + from neural_compressor.onnxrt.quantization.quantize import _quantize + + quant_config = RTNConfig(weight_bits=4, weight_group_size=32) + fp32_model = self.simple_onnx_model + qmodel = _quantize(fp32_model, quant_config) + self.assertIsNotNone(qmodel) + + def test_quantize_rtn_fallback_from_class_beginner(self): + from neural_compressor.onnxrt import RTNConfig + from neural_compressor.onnxrt.quantization.quantize import _quantize + + fp32_config = RTNConfig(weight_dtype="fp32") + fp32_model = self.gptj + quant_config = RTNConfig( + weight_bits=4, + weight_dtype="int", + weight_sym=False, + weight_group_size=32, + ) + quant_config.set_local("/h.4/mlp/fc_out/MatMul", fp32_config) + qmodel = _quantize(fp32_model, quant_config) + self.assertIsNotNone(qmodel) + self.assertEqual(self._count_woq_matmul(qmodel), 29) + self.assertFalse(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul")) + + def test_quantize_rtn_from_dict_advance(self): + from neural_compressor.onnxrt.quantization.quantize import _quantize + + fp32_model = self.gptj + quant_config = { + "rtn": { + "global": { + "weight_bits": 4, + "weight_group_size": 32, + }, + "local": { + "/h.4/mlp/fc_out/MatMul": { + "weight_dtype": "fp32", + } + }, + } + } + qmodel = _quantize(fp32_model, quant_config) + self.assertIsNotNone(qmodel) + self.assertEqual(self._count_woq_matmul(qmodel), 29) + self.assertFalse(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul")) + + fp32_model = self.gptj + quant_config = { + "rtn": { + "global": { + "weight_bits": 4, + "weight_group_size": 32, + }, + "local": { + "/h.4/mlp/fc_out/MatMul": { + "weight_bits": 8, + "weight_group_size": 32, + } + }, + } + } + qmodel = _quantize(fp32_model, quant_config) + self.assertIsNotNone(qmodel) + for node in qmodel.graph.node: + if node.name == "/h.4/mlp/fc_out/MatMul": + self.assertTrue(node.input[1].endswith("Q8G32")) + + def test_config_white_lst(self): + from neural_compressor.onnxrt import RTNConfig + from neural_compressor.onnxrt.quantization.quantize import _quantize + + global_config = RTNConfig(weight_bits=4) + # set operator instance + fc_out_config = RTNConfig(weight_dtype="fp32", white_list=["/h.4/mlp/fc_out/MatMul"]) + # get model and quantize + fp32_model = self.gptj + qmodel = _quantize(fp32_model, quant_config=global_config + fc_out_config) + self.assertIsNotNone(qmodel) + self.assertEqual(self._count_woq_matmul(qmodel), 29) + self.assertFalse(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul")) + + def test_config_white_lst2(self): + from neural_compressor.onnxrt import RTNConfig + from neural_compressor.onnxrt.quantization.quantize import _quantize + + global_config = RTNConfig(weight_dtype="fp32") + # set operator instance + fc_out_config = RTNConfig(weight_bits=4, white_list=["/h.4/mlp/fc_out/MatMul"]) + # get model and quantize + fp32_model = self.gptj + qmodel = _quantize(fp32_model, quant_config=global_config + fc_out_config) + self.assertIsNotNone(qmodel) + self.assertEqual(self._count_woq_matmul(qmodel), 1) + onnx.save(qmodel, "qmodel.onnx") + self.assertTrue(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul")) + + def test_config_white_lst3(self): + from neural_compressor.onnxrt import RTNConfig + from neural_compressor.onnxrt.utils.utility import get_model_info + + global_config = RTNConfig(weight_bits=4) + # set operator instance + fc_out_config = RTNConfig(weight_bits=8, white_list=["/h.4/mlp/fc_out/MatMul"]) + quant_config = global_config + fc_out_config + # get model and quantize + fp32_model = self.gptj + model_info = get_model_info(fp32_model, white_op_type_list=["MatMul"]) + logger.info(quant_config) + configs_mapping = quant_config.to_config_mapping(model_info=model_info) + logger.info(configs_mapping) + self.assertTrue(configs_mapping[("/h.4/mlp/fc_out/MatMul", "MatMul")].weight_bits == 8) + self.assertTrue(configs_mapping[("/h.4/mlp/fc_in/MatMul", "MatMul")].weight_bits == 4) + + def test_config_from_dict(self): + from neural_compressor.onnxrt import RTNConfig + + quant_config = { + "rtn": { + "global": { + "weight_dtype": "int", + "weight_bits": 4, + "weight_group_size": 32, + }, + "local": { + "fc1": { + "weight_dtype": "int", + "weight_bits": 8, + } + }, + } + } + config = RTNConfig.from_dict(quant_config["rtn"]) + self.assertIsNotNone(config.local_config) + + def test_config_to_dict(self): + from neural_compressor.onnxrt import RTNConfig + + quant_config = RTNConfig(weight_bits=4) + fc_out_config = RTNConfig(weight_bits=8) + quant_config.set_local("/h.4/mlp/fc_out/MatMul", fc_out_config) + config_dict = quant_config.to_dict() + self.assertIn("global", config_dict) + self.assertIn("local", config_dict) + + def test_same_type_configs_addition(self): + from neural_compressor.onnxrt import RTNConfig + + quant_config1 = { + "rtn": { + "weight_dtype": "int", + "weight_bits": 4, + "weight_group_size": 32, + }, + } + q_config = RTNConfig.from_dict(quant_config1["rtn"]) + quant_config2 = { + "rtn": { + "global": { + "weight_bits": 8, + "weight_group_size": 32, + }, + "local": { + "/h.4/mlp/fc_out/MatMul": { + "weight_dtype": "int", + "weight_bits": 4, + } + }, + } + } + q_config2 = RTNConfig.from_dict(quant_config2["rtn"]) + q_config3 = q_config + q_config2 + q3_dict = q_config3.to_dict() + for op_name, op_config in quant_config2["rtn"]["local"].items(): + for attr, val in op_config.items(): + self.assertEqual(q3_dict["local"][op_name][attr], val) + self.assertNotEqual(q3_dict["global"]["weight_bits"], quant_config2["rtn"]["global"]["weight_bits"]) + + def test_config_mapping(self): + from neural_compressor.onnxrt import RTNConfig + from neural_compressor.onnxrt.utils.utility import get_model_info + + quant_config = RTNConfig(weight_bits=4) + # set operator instance + fc_out_config = RTNConfig(weight_bits=8) + quant_config.set_local("/h.4/mlp/fc_out/MatMul", fc_out_config) + # get model and quantize + fp32_model = self.gptj + model_info = get_model_info(fp32_model, white_op_type_list=["MatMul"]) + logger.info(quant_config) + configs_mapping = quant_config.to_config_mapping(model_info=model_info) + logger.info(configs_mapping) + self.assertTrue(configs_mapping[("/h.4/mlp/fc_out/MatMul", "MatMul")].weight_bits == 8) + self.assertTrue(configs_mapping[("/h.4/mlp/fc_in/MatMul", "MatMul")].weight_bits == 4) + # test regular matching + fc_config = RTNConfig(weight_bits=3) + quant_config.set_local("/h.[1-4]/mlp/fc_out/MatMul", fc_config) + configs_mapping = quant_config.to_config_mapping(model_info=model_info) + logger.info(configs_mapping) + self.assertTrue(configs_mapping[("/h.4/mlp/fc_out/MatMul", "MatMul")].weight_bits == 3) + self.assertTrue(configs_mapping[("/h.3/mlp/fc_out/MatMul", "MatMul")].weight_bits == 3) + self.assertTrue(configs_mapping[("/h.2/mlp/fc_out/MatMul", "MatMul")].weight_bits == 3) + self.assertTrue(configs_mapping[("/h.1/mlp/fc_out/MatMul", "MatMul")].weight_bits == 3) + + +class TestQuantConfigForAutotune(unittest.TestCase): + def test_expand_config(self): + # test the expand functionalities, the user is not aware it + from neural_compressor.onnxrt import RTNConfig + + tune_config = RTNConfig(weight_bits=[4, 8]) + expand_config_list = RTNConfig.expand(tune_config) + self.assertEqual(expand_config_list[0].weight_bits, 4) + self.assertEqual(expand_config_list[1].weight_bits, 8) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/3x/tensorflow/requirements.txt b/test/3x/tensorflow/requirements.txt new file mode 100644 index 00000000000..e079f8a6038 --- /dev/null +++ b/test/3x/tensorflow/requirements.txt @@ -0,0 +1 @@ +pytest diff --git a/test/3x/torch/requirements.txt b/test/3x/torch/requirements.txt new file mode 100644 index 00000000000..793bdfc0350 --- /dev/null +++ b/test/3x/torch/requirements.txt @@ -0,0 +1,2 @@ +pytest +transformers