diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 7244b6f7a..34d466be4 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -61,8 +61,7 @@ env: PYTORCH_CUDA_ALLOC_CONF: 'expandable_segments:True' MAX_JOBS: 8 RUNNER: 10.0.13.31 - TRANSFORMERS_DIFF_TESTS: "models/test_internlm.py,models/test_internlm2_5.py,models/test_xverse.py" - TORCH_2_5_TESTS: "test_evalplus.py,test_perplexity.py,test_q4_ipex.py,test_ipex_xpu.py,test_save_loaded_quantized_model.py,test_quant_formats.py,models/test_hymba.py" + LEGACY_TESTS: "models/test_internlm.py,models/test_internlm2_5.py,models/test_xverse.py" IGNORED_TEST_FILES: "test_tgi.py,test_gptneox.py,models/test_mixtral.py,models/test_phi_3_moe.py" GPTQMODEL_FORCE_BUILD: 1 repo: ${{ github.event.inputs.repo || github.repository }} @@ -139,7 +138,7 @@ jobs: import os import re - TRANSFORMERS_DIFF_TESTS = '${TRANSFORMERS_DIFF_TESTS}' + LEGACY_TESTS = '${LEGACY_TESTS}' IGNORED_TEST_FILES = '${IGNORED_TEST_FILES}' TEST_NAMES='${{ github.event.inputs.test_names }}' @@ -147,7 +146,7 @@ jobs: input_test_files_list = [f.strip().removesuffix('.py') for f in TEST_NAMES.split(',') if f.strip()] - transformers_test_files = [f.strip().removesuffix('.py') for f in f'{TRANSFORMERS_DIFF_TESTS}'.split(',') if f.strip()] + transformers_test_files = [f.strip().removesuffix('.py') for f in f'{LEGACY_TESTS}'.split(',') if f.strip()] transformers_test_files = [f for f in transformers_test_files if not input_test_files_list or f in input_test_files_list] all_tests = [f.removesuffix('.py') for f in os.listdir('tests/') if f.startswith('test_') and f.endswith('.py') and f.strip().removesuffix('py') not in f'{IGNORED_TEST_FILES}'] @@ -190,8 +189,8 @@ jobs: echo "Conditions:" echo "will build run: ${{ github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.torch-files != '[]' && needs.list-test-files.outputs.transformers-files != '[]' && !(needs.list-test-files.outputs.m4-files == '[]' && needs.list-test-files.outputs.m4-files == '[]') }}" - echo "will transformers_diff run: ${{ (needs.build.result == 'success' || github.event.inputs.artifact_id != '') && github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.transformers-files != '[]' }}" - echo "will torch2_5 run: ${{ (needs.build.result == 'success' || github.event.inputs.artifact_id != '') && github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.torch-files != '[]' }}" + echo "will legacy run: ${{ (needs.build.result == 'success' || github.event.inputs.artifact_id != '') && github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.transformers-files != '[]' }}" + echo "will torch run: ${{ (needs.build.result == 'success' || github.event.inputs.artifact_id != '') && github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.torch-files != '[]' }}" echo "will m4 run: ${{ (github.event.inputs.test_names == '' || contains(github.event.inputs.test_names, 'apple') || contains(github.event.inputs.test_names, 'mlx') ) && (needs.list-test-files.outputs.m4-files != '' || needs.list-test-files.outputs.m4-files != '[]') }}" build: @@ -201,7 +200,13 @@ jobs: - list-test-files if: github.event.inputs.m4-only != 'true' && (needs.list-test-files.outputs.torch-files != '[]' || needs.list-test-files.outputs.transformers-files != '[]') container: - image: ${{ needs.check-vm.outputs.ip }}:5000/modelcloud/gptqmodel:github-ci-v5 + image: ${{ needs.check-vm.outputs.ip }}:5000/modelcloud/gptqmodel:github-ci-v7 + options: --device /dev/dri --ipc=host --runtime=nvidia --gpus all + volumes: + - /dev/dri/by-path:/dev/dri/by-path + - /home/ci/models:/monster/data/model + - /home/ci/models/huggingface:/github/home/.cache/huggingface + steps: - name: Checkout Codes uses: actions/checkout@v4 @@ -286,7 +291,7 @@ jobs: if: always() run: pip cache purge && uv cache clean && rm -rf ./* ./.* - transformers_diff: + legacy: needs: - build - list-test-files @@ -294,7 +299,7 @@ jobs: runs-on: [ self-hosted, xeon5 ] if: always() && !cancelled() && (needs.build.result == 'success' || github.event.inputs.artifact_id != '') && github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.transformers-files != '[]' container: - image: ${{ needs.check-vm.outputs.ip }}:5000/modelcloud/gptqmodel:github-ci-v5 + image: ${{ needs.check-vm.outputs.ip }}:5000/modelcloud/gptqmodel:github-ci-v7 volumes: - /home/ci/models:/monster/data/model - /home/ci/models/huggingface:/github/home/.cache/huggingface @@ -383,7 +388,7 @@ jobs: - name: Install wheel run: | - uv pip install git+https://github.com/ModelCloud/Tokenicer -U + uv pip install colorlog git+https://github.com/ModelCloud/Tokenicer -U echo "===== install optimum bitblas parameterized uvicorn =====" uv pip install optimum bitblas==0.0.1.dev13 parameterized uvicorn -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple echo "===== install dist/whl =====" @@ -407,10 +412,10 @@ jobs: gpu_id=-1 while [ "$gpu_id" -lt 0 ]; do - gpu_id=$(curl -s "http://${{ needs.check-vm.outputs.ip }}/gpu/get?id=${{ github.run_id }}×tamp=$timestamp&test=${{ matrix.test_script }}&runner=${RUNNER_NAME}") + gpu_id=$(curl -s "http://${{ needs.check-vm.outputs.ip }}/gpu/get?id=${{ github.run_id }}×tamp=$timestamp&test=${{ matrix.test_script }}&runner=${RUNNER_NAME}&exclusive=${{ github.event.inputs.exclusive-gpu }}") if [ "$gpu_id" -lt 0 ]; then - echo "http://${{ needs.check-vm.outputs.ip }}/gpu/get?id=${{ github.run_id }}×tamp=$timestamp&test=${{ matrix.test_script }}&runner=${RUNNER_NAME} returned $gpu_id" + echo "http://${{ needs.check-vm.outputs.ip }}/gpu/get?id=${{ github.run_id }}×tamp=$timestamp&test=${{ matrix.test_script }}&runner=${RUNNER_NAME}&exclusive=${{ github.event.inputs.exclusive-gpu }} returned $gpu_id" echo "No available GPU, waiting 5 seconds..." sleep 5 else @@ -441,7 +446,7 @@ jobs: if: always() run: pip cache purge && uv cache clean && rm -rf ./* ./.* - torch2_5: + torch: needs: - build - list-test-files @@ -449,7 +454,7 @@ jobs: runs-on: [ self-hosted, xeon5 ] if: always() && !cancelled() && (needs.build.result == 'success' || github.event.inputs.artifact_id != '') && github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.torch-files != '[]' container: - image: ${{ needs.check-vm.outputs.ip }}:5000/modelcloud/gptqmodel:github-ci-v5 + image: ${{ needs.check-vm.outputs.ip }}:5000/modelcloud/gptqmodel:github-ci-v7 options: --device /dev/dri --ipc=host --runtime=nvidia --gpus all volumes: - /dev/dri/by-path:/dev/dri/by-path @@ -541,35 +546,58 @@ jobs: - name: Install wheel run: | - if [ "${{ matrix.test_script }}" == "test_quant_formats" ] || [ "${{ matrix.test_script }}" == "test_perplexity" ]; then - echo "===== install auto_round =====" - uv pip install auto_round -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple + uv pip install -U transformers colorlog + if [ "${{ matrix.test_script }}" == "test_quant_formats" ] || [ "${{ matrix.test_script }}" == "test_perplexity" ] || [ "${{ matrix.test_script }}" == "test_q4_bitblas" ]; then + echo "===== install auto_round bitblas==0.0.1.dev13 =====" + uv pip install auto_round bitblas==0.0.1.dev13 -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple fi + if [ "${{ matrix.test_script }}" == "models/test_cohere2" ] || [ "${{ matrix.test_script }}" == "models/test_gemma" ]; then echo "===== install transformers from git =====" - uv pip install -U git+https://github.com/huggingface/transformers.git -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple + uv pip install -U transformers -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple fi + if [[ "${{ matrix.test_script }}" == *xpu* ]]; then + echo "===== switching to xpu env =====" source /etc/profile.d/pyenv.sh && pyenv activate xpu + uv pip install colorlog + fi + + if [[ "${{ matrix.test_script }}" == "test_sglang.py" ]]; then + uv pip install transformers==4.48.3 + fi + + if [[ "${{ matrix.test_script }}" == *ipex* ]] && [[ "${{ matrix.test_script }}" != *xpu* ]]; then + uv pip uninstall torchvision torch flash_attn # fix ipex can't be used with torch+cu126 + uv pip install torchvision torch + uv pip install -U intel_extension_for_pytorch -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple fi if [[ "${{ matrix.test_script }}" == *"mlx"* ]]; then uv pip install mlx_lm --no-build-isolation -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple fi + if [[ "${{ matrix.test_script }}" == "test_modelscope" ]]; then + echo "===== installing modelscope =====" uv pip install modelscope --no-build-isolation -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple fi - echo "===== install dist/whl =====" uv pip install git+https://github.com/ModelCloud/Tokenicer -U - uv pip install dist/*.whl -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple + + # ipex doesn't need to compile kernels. xpu can't install cuda package + if [[ "${{ matrix.test_script }}" != *ipex* && "${{ matrix.test_script }}" != *xpu* ]]; then + echo "===== install dist/whl =====" + uv pip install dist/*.whl -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple + else + echo "===== install with local files for xpu env =====" + export CUDA_VISIBLE_DEVICES="" + unset TORCH_CUDA_ARCH_LIST + uv pip install . --no-build-isolation + fi if [ "${{ matrix.test_script }}" == "test_transformers" ]; then echo "===== install optimum from git =====" uv pip install -U git+https://github.com/huggingface/optimum.git -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} - echo "===== install transformers from git =====" - uv pip install -U git+https://github.com/huggingface/transformers.git -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} - uv pip install torch==2.5.1 # fix optimum will install torch 2.6.0 fi if [[ "${{ matrix.test_script }}" == "test_sglang" ]]; then @@ -577,16 +605,16 @@ jobs: fi - name: Find suitable GPU - if: ${{ !contains(matrix.test_script, 'ipex') && !cancelled() }} + if: ${{ !contains(matrix.test_script, 'ipex') && !contains(matrix.test_script, 'xpu') && !cancelled() }} run: | timestamp=$(date +%s%3N) gpu_id=-1 while [ "$gpu_id" -lt 0 ]; do - gpu_id=$(curl -s "http://${{ needs.check-vm.outputs.ip }}/gpu/get?id=${{ github.run_id }}×tamp=$timestamp&test=${{ matrix.test_script }}&runner=${RUNNER_NAME}") + gpu_id=$(curl -s "http://${{ needs.check-vm.outputs.ip }}/gpu/get?id=${{ github.run_id }}×tamp=$timestamp&test=${{ matrix.test_script }}&runner=${RUNNER_NAME}&exclusive=${{ github.event.inputs.exclusive-gpu }}") if [ "$gpu_id" -lt 0 ]; then - echo "http://${{ needs.check-vm.outputs.ip }}/gpu/get?id=${{ github.run_id }}×tamp=$timestamp&test=${{ matrix.test_script }}&runner=${RUNNER_NAME} returned $gpu_id" + echo "http://${{ needs.check-vm.outputs.ip }}/gpu/get?id=${{ github.run_id }}×tamp=$timestamp&test=${{ matrix.test_script }}&runner=${RUNNER_NAME}&exclusive=${{ github.event.inputs.exclusive-gpu }} returned $gpu_id" echo "No available GPU, waiting 5 seconds..." sleep 5 else @@ -617,12 +645,14 @@ jobs: curl "http://${{ needs.check-vm.outputs.ip }}/gpu/log_test_vram?id=${{ github.run_id }}&gpu=${{ env.CUDA_VISIBLE_DEVICES }}&range=$execution_time&unit=second&test=${{ matrix.test_script }}" - name: Release GPU - if: always() && !contains(matrix.test_script, 'ipex') + if: always() && !contains(matrix.test_script, 'ipex') && !contains(matrix.test_script, 'xpu') run: curl -X GET "http://${{ needs.check-vm.outputs.ip }}/gpu/release?id=${{ github.run_id }}&gpu=${{ env.CUDA_VISIBLE_DEVICES }}×tamp=${{ env.STEP_TIMESTAMP }}&test=${{ matrix.test_script }}&runner=${RUNNER_NAME}" - + - name: Clean cache if: always() - run: pip cache purge && uv cache clean && rm -rf ./* ./.* + run: | + rm ~/.cache/evalplus/*pkl || true + pip cache purge && uv cache clean && rm -rf ./* ./.* show-statistics: runs-on: [ self-hosted, xeon5 ] @@ -630,8 +660,8 @@ jobs: container: image: modelcloud/gptqmodel:alpine-ci-v1 needs: - - transformers_diff - - torch2_5 + - legacy + - torch steps: - name: Print statistics run: curl "http://10.0.14.248/gpu/get_vram_logs?id=${{ github.run_id }}" diff --git a/README.md b/README.md index 6288dc4c1..8cb350678 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ ## News * 02/12/2025 [1.9.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.9.0): ⚡ Offload `tokenizer` fixes to [Toke(n)icer](https://github.com/modelcloud/tokenicer) pkg. Optimized `lm_head` quant time and vram usage. - Optimized `DeekSeek v3/R1` model quant vram usage. Fixed `Optimum` compat regresion in `v1.8.1`. 3x speed-up for `Torch` kernel when using Pytorch >= 2.5.0 with `model.compile()`. New `calibration_dataset_concat_size` option to enable calibration data `concat` mode to mimic original GPTQ data packing strategy which may improve quant speed and accuracy for datasets like `wikitext2`. + Optimized `DeekSeek v3/R1` model quant vram usage. Fixed `Optimum` compat regresion in `v1.8.1`. 3x speed-up for `Torch` kernel when using Pytorch >= 2.5.0 with `model.optimize()`. New `calibration_dataset_concat_size` option to enable calibration data `concat` mode to mimic original GPTQ data packing strategy which may improve quant speed and accuracy for datasets like `wikitext2`. * 02/08/2025 [1.8.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.8.1): ⚡ `DeekSeek v3/R1` model support. New flexible weight `packing`: allow quantized weights to be packed to `[int32, int16, int8]` dtypes. `Triton` and `Torch` kernels supports full range of new `QuantizeConfig.pack_dtype`. New `auto_gc: bool` control in `quantize()` which can reduce quantization time for small model with no chance of oom. diff --git a/examples/benchmark/generation_speed.py b/examples/benchmark/generation_speed.py index add850be4..ad7eaea4c 100644 --- a/examples/benchmark/generation_speed.py +++ b/examples/benchmark/generation_speed.py @@ -195,8 +195,8 @@ def load_model_tokenizer( def benchmark_generation_speed(model, tokenizer, examples, generation_config): generation_time_list = [] num_generated_tokens_list = [] - progress_bar = ProgressBar(examples) - for example in progress_bar: + pb = ProgressBar(examples) + for example in pb: input_ids = example["input_ids"].to(model.device) start = time.time() @@ -217,7 +217,7 @@ def benchmark_generation_speed(model, tokenizer, examples, generation_config): ) num_generated_tokens_list.append(num_generated_tokens) - progress_bar.set_postfix( + pb.set_postfix( num_tokens=num_generated_tokens_list[-1], time=generation_time_list[-1], speed=f"{num_generated_tokens_list[-1] / generation_time_list[-1]:.3f} tokens/s", diff --git a/format/format.sh b/format/format.sh index 516900e78..a0d7769bc 100755 --- a/format/format.sh +++ b/format/format.sh @@ -3,7 +3,7 @@ cd "$(dirname "$0")" || exit # force ruff/isort to be same version as setup.py -pip install -U ruff==0.9.5 isort==6.0.0 +pip install -U gptqmodel["quality"] ruff check ../gptqmodel/models ../gptqmodel/nn_modules ../gptqmodel/quantization ../gptqmodel/utils ../gptqmodel/__init__.py ../examples ../tests ../setup.py --fix --unsafe-fixes ruff_status=$? diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index f015202a9..4a13698b4 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -14,13 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + from .models import GPTQModel, get_best_device from .quantization import BaseQuantizeConfig, QuantizeConfig from .utils import BACKEND from .utils.exllama import exllama_set_max_input_length from .version import __version__ -import os if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']: try: from modelscope.utils.hf_util.patcher import patch_hub diff --git a/gptqmodel/adapter/__init__.py b/gptqmodel/adapter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/adapter/adapter.py b/gptqmodel/adapter/adapter.py new file mode 100644 index 000000000..5791c6948 --- /dev/null +++ b/gptqmodel/adapter/adapter.py @@ -0,0 +1,207 @@ +import os +from dataclasses import dataclass, field +from typing import Dict, List, Union +from urllib.parse import urlparse + +import safetensors +import torch +from gptqmodel.utils.logger import setup_logger +from gptqmodel.utils.torch import torch_compile + +logger = setup_logger() +LORA_MERGED_WEIGHT_PATHS = [None, ""] + +# TODO FIX ME: cache of adapter tensors loaded from disk +adapter_load_cache = None + +class Adapter(): + def __init__(self, rank: int, path: str = None): + self.rank = rank + self.path = path.lower().strip() if isinstance(path, str) else path + + def validate_path(self, local_only=False): + if not self.path or not isinstance(self.path, str): + raise ValueError("Adapter: `path` str is required.") + + if local_only: + if self.path.startswith("http"): + raise ValueError(f"Adapter: `path` str in this context must be a local os path: actual = `{self.path}`.") + + # override me + def apply(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor: + pass + + # override me + def post_init(self, weight_key: str, device: torch.device, **kwargs): + pass + + # override me + def optimize(self): + pass + + # override me + @classmethod + def name(cls) -> List[str]: + pass + + # override me + @classmethod + def parameter_keys(cls) -> [str]: # name of tensors/parameters in attribute key name + pass + + +@dataclass +class Lora(Adapter): + def __init__(self, rank: int, path: str = None, lora_A: torch.Tensor = None, lora_B: torch.Tensor = None): + super().__init__(rank, path) + + self.lora_A = lora_A + self.lora_B = lora_B + + @classmethod + def name(cls) -> str: + return "lora" + + @classmethod + def parameter_keys(cls) -> List[str]: + return ["lora_A", "lora_B"] + + def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False): + pass + #logger.info("Adapter: optimize (compile)") + #self.apply = torch_compile(self.apply, backend=backend, mode=mode, fullgraph=fullgraph) + + def apply(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor: + # original code + # out = out + ((x @ self.lora_A) @ self.lora_B) + + # fix batch for lora + # Some kernels do not reshape x, such as marlin / exllama / exllamav2. + # out.dim() > x.dim() is used to exclude these kernels without additional processing + if out.dim() > x.dim() and out.shape[0] > 1: + out_orgi_shape = out.shape + out = out.view(-1, out.shape[-1]) + out.add_((x @ self.lora_A) @ self.lora_B) + out = out.view(out_orgi_shape) + return out + else: + return out.add_((x @ self.lora_A) @ self.lora_B) + + def post_init(self, weight_key: str, device:torch.device, lora_A: torch.Tensor=None, lora_B: torch.Tensor=None): + # self.register_buffer("lora_A", lora_A) + # self.register_buffer("lora_B", lora_B) + + # we need since lora A/B weights may be merged into model tensors and not separate + if lora_A is not None and lora_B is not None: + # print(f"Adapter has preloaded lora_A and lora_B") + self.lora_A, self.lora_B = lora_A, lora_B + return + + global adapter_load_cache + if adapter_load_cache is None: + if os.path.isfile(self.path): + lora_path = self.path + logger.info(f"Adapter: Loading `{self.path}` tensors from disk") # {adapter_load_cache} + elif self.path.startswith("http"): + from huggingface_hub import hf_hub_download + result = self.parse_url(self.path) + if len(result) == 3: + logger.info(f"Adapter: Downloading adapter weights from hf repo: `{result[0]}` revision: `{result[1]}` file: `{result[2]}`") + lora_path = hf_hub_download(repo_id=result[0], revision=result[1], filename=result[2]) + elif len(result) == 1: + logger.info(f"Adapter: Downloading adapter weights from uri = `{self.path}`") + import requests + response = requests.get(self.path, stream=True) + lora_path = "lora.safetensors" + with open(lora_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + else: + raise Exception(f"Adapter: Lora path is invalid: `{self.path}`") + else: + from huggingface_hub import HfApi, hf_hub_download + files = [f for f in HfApi().list_repo_files(self.path) if f in ["lora.safetensors", "eora_test.safetensors"]] + + if files: + lora_path = hf_hub_download(repo_id=self.path, filename=files[0]) + # print(f"Adapter tensors loaded from `{self.path}`") + else: + raise Exception(f"Adapter: There's no lora.safetensors or eora_test.safetensors on repo `{self.path}`") + + adapter_load_cache = safetensors.torch.load_file(lora_path) + + weight_key = weight_key.lower() + + # hack for HF Auto compat + if not f"{weight_key}.lora_A.weight" in adapter_load_cache: + weight_key = weight_key.removeprefix("model.") + + #print(f"loaded lora weight keys: {adapter_load_cache.keys()}") + lora_A = adapter_load_cache.pop(f"{weight_key}.lora_A.weight").T + lora_B = adapter_load_cache.pop(f"{weight_key}.lora_B.weight").T + + # since loder cache is singleton, we need to reset to None to ci loop tests can pass + if len(adapter_load_cache) == 0: + adapter_load_cache = None + + # print(f"Adapter: {self.name()}, loaded lora_A shape: {lora_A.shape}") + # print(f"Adapter: {self.name()}, loaded lora_B shape: {lora_B.shape}") + if lora_A.dtype != torch.float16 or lora_A.dtype != torch.float16: + logger.warn(f"Adapter: `lora_A` and `lora_B` tensors should be of dtype = `torch.float16`: actual = `[{lora_A.dtype}, {lora_A.dtype}]`.") + + self.lora_A = lora_A.to(device=device, dtype=torch.float16) + self.lora_B = lora_B.to(device=device, dtype=torch.float16) + + #print(f"Adapter: lora_A {lora_A.shape}: `{lora_B}`") + #print(f"Adapter: lora_B {lora_B.shape}: `{lora_B}`") + + def parse_url(self, url: str): + parsed_url = urlparse(url) + + if parsed_url.netloc.endswith("huggingface.co") or parsed_url.netloc.endswith("hf.co"): + parts = parsed_url.path.strip("/").split("/") + + if "blob" in parts: + idx = parts.index("blob") + repo_id = "/".join(parts[:idx]) + rev = parts[idx + 1] + filename = parts[idx + 2].split("?")[0] # remove ?download=true + return [repo_id, rev, filename] + else: + return [url] + return [] + + def to_dict(self): + return { + "name": self.name(), + "path": self.path, + "rank": self.rank + } + +ADAPTER_MAPPING = {Lora.name(): Lora} + +# accept both Adapter cls instance or Dict() +def normalize_adapter(adapter: Union[Dict, Adapter]): + if adapter is None: + return None + + if isinstance(adapter, Adapter): + return adapter + + if not isinstance(adapter, Dict): + raise ValueError("Adapter: Invalid adapter config: `adapter`.") + + adapter_type = adapter.pop("name", None) + if adapter_type is None: + raise ValueError(f"Adapter: Invalid adapter class `{adapter_type}`: expected = `{ADAPTER_MAPPING}`.") + + adapterCls = ADAPTER_MAPPING.get(adapter_type) + if adapterCls is None: + raise ValueError(f"Adapter: Compatible adapters include `{ADAPTER_MAPPING.keys()}`: actual `{(adapter_type)}`.") + + try: + adapterInstance = adapterCls(**adapter) + except Exception: + raise ValueError(f"Adapter: Invalid adapter config: `{adapter}`.") + + return adapterInstance diff --git a/gptqmodel/eora/__init__.py b/gptqmodel/eora/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/eora/eora.py b/gptqmodel/eora/eora.py new file mode 100644 index 000000000..3fc6d385b --- /dev/null +++ b/gptqmodel/eora/eora.py @@ -0,0 +1,91 @@ +# Copyright 2024-2025 NVIDIA CORPORATION +# EoRA arXiv: https://arxiv.org/abs/2410.21271 +# EoRA Official Repo: https://github.com/NVlabs/EoRA + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple + +import torch +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.utils.logger import setup_logger +from torch import Tensor + +logger = setup_logger() + +def eora_process_input(input: Tensor, name: str, eigen_scaling_diag_matrix: Dict[str, torch.dtype], sample_size: int): + inp = input[0].to(dtype=torch.float32) + if inp.dim() == 2: + inp = inp.unsqueeze(0) + + tmp = inp.shape[0] + adds = torch.matmul(inp.transpose(1, 2), inp) + adds_sum = torch.sum(adds, dim=0) + + ## Adding tmp to denominator is only for mathmatical stability + eigen_scaling_diag_matrix[name] *= sample_size / (sample_size + tmp) + eigen_scaling_diag_matrix[name] += adds_sum / sample_size + + del inp, tmp, adds, adds_sum + +def eora_compute_lora( + device: torch.device, + w_wq_delta: Tensor, # need the w (original weight) and wq (quantized qweight) delta in float32 + module: NamedModule, + eigen_scaling_diag_matrix: torch.dtype, + rank: int) -> Tuple[Tensor, Tensor]: + + assert w_wq_delta.dtype == torch.float32 + + # save this later for SVD + raw_scaling_diag_matrix = eigen_scaling_diag_matrix.to(dtype=torch.float64, device=device) + + L, Q = torch.linalg.eigh(raw_scaling_diag_matrix) + if (L < 0).any(): + ## When expanding the calibration data size for EoRA, I suggest maintaining the balance by allocating 50% to general input (C4) and the remaining 50% to downstream task data. + logger.warn(f"Found negative eigenvalues in `{module.name}`. Please increase your calibration data set for EoRA.") + minimum = torch.min(L[L > 0]) + L[L < 0] = minimum + + sqrtEigenvalues = torch.sqrt(L) + scaling_diag_matrix = Q @ torch.diag(sqrtEigenvalues) + + try: + scaling_matrix_inv = torch.linalg.inv(scaling_diag_matrix) + except Exception: + logger.warn("`scaling_diag_matrix` is not full rank!") # TODO: assert? + scaling_diag_matrix += 1e-6 * torch.eye(scaling_diag_matrix.shape[0]).to(device) + scaling_matrix_inv = torch.linalg.inv(scaling_diag_matrix) + + scaling_diag_matrix = scaling_diag_matrix.to(dtype=torch.float32) + scaling_matrix_inv = scaling_matrix_inv.to(dtype=torch.float32) + + delta_scale = torch.matmul(w_wq_delta, scaling_diag_matrix) + + U, S, V = torch.linalg.svd(delta_scale, full_matrices=False) + lowrank_r = rank + truc_s = S[:lowrank_r] + truc_u = U[:, :lowrank_r] + truc_v = torch.matmul(V[:lowrank_r, :], scaling_matrix_inv) + truc_sigma = torch.diag(truc_s) + + sqrtS = torch.sqrt(truc_sigma) + B = torch.matmul(truc_u, sqrtS).to(dtype=torch.float16) + A = torch.matmul(sqrtS, truc_v).to(dtype=torch.float16) + + + del L, Q, U, S, V, + del w_wq_delta, raw_scaling_diag_matrix, sqrtEigenvalues, scaling_diag_matrix, scaling_matrix_inv, delta_scale + del truc_s, truc_u, truc_v, truc_sigma, sqrtS + + return A, B diff --git a/gptqmodel/looper/__init__.py b/gptqmodel/looper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/looper/dequantize_processor.py b/gptqmodel/looper/dequantize_processor.py new file mode 100644 index 000000000..9540627b5 --- /dev/null +++ b/gptqmodel/looper/dequantize_processor.py @@ -0,0 +1,63 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +from gptqmodel.looper.loop_processor import LoopProcessor +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear +from gptqmodel.utils.logger import setup_logger +from gptqmodel.utils.torch import torch_compile + +logger = setup_logger() + +class DequantizeProcessor(LoopProcessor): + def __init__(self, quantized_modules: Dict[str, TorchQuantLinear]): + super().__init__(tokenizer=None, qcfg=None, calibration_dataset=None, calibration_dataset_concat_size=None, + prepare_dataset_func=None, batch_size=1, + logger_board="", require_fwd=True) + + self.quantized_modules = quantized_modules + + def set_calibration_dataset(self, calibration_dataset): + self.calibration_dataset = None + self.num_batches = 0 + + # de-quantize weights + def process(self, module: NamedModule): + device = module.weight.device + w = module.weight.data + + # TODO fix num_itr param..need to calculate this before dequant + m = self.quantized_modules.pop(module.full_name) + m.dequantize_weight = torch_compile(m.dequantize_weight) + wq = m.dequantize_weight().T.to(device=device) + + module.state.update({ + "w": w, + "wq": wq, + }) + + def submodule_finalize(self, module: NamedModule): + module.state.pop("w", None) # no need for these weights now + module.state.pop("wq", None) # no need for these weights now + + def verify_calibration_dataset(self, processor_index: int) -> bool: + return False + + @classmethod + def name(cls) -> str: + return "de-quantize" diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py new file mode 100644 index 000000000..337a4adec --- /dev/null +++ b/gptqmodel/looper/eora_processor.py @@ -0,0 +1,229 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import time +from typing import Callable, Dict, Optional, Tuple + +import torch +from gptqmodel.adapter.adapter import Lora +from gptqmodel.eora.eora import eora_compute_lora, eora_process_input +from gptqmodel.looper.loop_processor import LoopProcessor +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.models import BaseGPTQModel +from gptqmodel.models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, + PROCESS_LOG_MODULE, PROCESS_LOG_NAME, PROCESS_LOG_TIME) +from gptqmodel.quantization.config import QuantizeConfig +from gptqmodel.quantization.gptq import CPU +from gptqmodel.utils.logger import setup_logger +from gptqmodel.utils.model import move_to +from gptqmodel.utils.torch import torch_compile, torch_sync +from torch.nn import Module + +logger = setup_logger() + + +class EoraProcessor(LoopProcessor): + def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration_dataset, prepare_dataset_func, + calibration_dataset_concat_size: Optional[int], batch_size: int, + logger_board: str = "", require_fwd: bool = True, + ): + super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration_dataset=calibration_dataset, + calibration_dataset_concat_size=calibration_dataset_concat_size, + prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, + logger_board=logger_board, require_fwd=require_fwd) + + # dict: key is module name, value is the accumulated eigen_scaling_diag_matrix + self.eigen_scaling_diag_matrix: Dict[str, torch.float32] = {} + + + # Increase the dynamo cache size limit, default of 8 is too low + if torch._dynamo.config.cache_size_limit < 64: + torch._dynamo.config.cache_size_limit = 64 + + # needed by eora + # torch._dynamo.config.capture_scalar_outputs = True + + self.eora_compute_lora = torch_compile(eora_compute_lora) + self.eora_process_input = torch_compile(eora_process_input) + + # self.eora_compute_lora = eora_compute_lora + # self.eora_process_input = eora_process_input + + def log_plotly(self): + task = self.logger_task + if task is not None: + from gptqmodel.utils.plotly import create_plotly + x = list(range(self.layer_count)) + gpu_fig = create_plotly(x=x, y=self.gpu_memorys, xaxis_title="layer", yaxis_title="GPU usage (GB)") + cpu_fig = create_plotly(x=x, y=self.cpu_memorys, xaxis_title="layer", yaxis_title="CPU usage (GB)") + time_fig = create_plotly(x=self.module_names, y=self.durations, xaxis_title="layer", yaxis_title="time") + task.get_logger().report_plotly('GPU Memory', 'GPU Memory', gpu_fig) + task.get_logger().report_plotly('CPU Memory', 'CPU Memory', cpu_fig) + task.get_logger().report_plotly('quant_time', 'quant_time', time_fig) + + def set_calibration_dataset(self, calibration_dataset): + self.calibration_dataset = calibration_dataset + self.num_batches = len(calibration_dataset) + + def preprocess(self, module: NamedModule, **kwargs): + # entire module is skipped + if self.qcfg.dynamic_get(layer_name=module.full_name) == False: + module.adapter_cfg = None # hack + return + + adapter_cfg = copy.deepcopy(self.qcfg.adapter) + + # dynamic overrides + if self.qcfg.dynamic is not None: + adapter_cfg.adapter = self.qcfg.dynamic_get(module.full_name, "adapter", adapter_cfg) + + # hack store property inside module + module.adapter_cfg = adapter_cfg + + self.eigen_scaling_diag_matrix[module.name] = 0 # torch.tensor(0.0, dtype=torch.float32) + + return + + def is_skipped(self, module: NamedModule) -> bool: + # dynamic override removed eora processing for this module + return module.adapter_cfg in [None, {}] + + def preprocess_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: + def tmp(_, input: Tuple[torch.Tensor, ...], output: torch.Tensor): + self.eora_process_input( + input=input, + name=name, + eigen_scaling_diag_matrix=self.eigen_scaling_diag_matrix, + sample_size=self.num_batches + ) + return tmp + + def process(self, module: NamedModule): + assert isinstance(module.adapter_cfg, Lora) + + self.pb.info(f"EoRA gen: {module.name} in layer {module.layer_index} of {self.layer_count - 1}") + + start = time.time() + + eigen_scaling_diag_matrix = self.eigen_scaling_diag_matrix[module.name] + + w: torch.Tensor = module.state.pop("w") + w_device = w.device # TODO clear up device situation between w and wq + wq: torch.Tensor = module.state["wq"] + + # print(f"types: w = `{w.dtype}`, device = `{w.device}`, wq = `{wq.dtype}`, device = `{wq.device}`") + if w.dtype != torch.float16: + w_wq_delta = w.to(dtype=torch.float32) - wq # wq is float16 + else: + w_wq_delta = w - wq + + assert w_wq_delta.dtype == torch.float32 + + # print(f"types: w_q_delta = `{w_wq_delta.dtype}`, device = `{w_wq_delta.device}`") + del w + + A, B = self.eora_compute_lora( + device=w_device, + w_wq_delta=w_wq_delta, + module=module, + eigen_scaling_diag_matrix=eigen_scaling_diag_matrix, + rank=module.adapter_cfg.rank + ) + + # wq with A/B applied + computed_wq = wq + (B @ A) + + module.state.update({ + "wq": move_to(wq, device=CPU, stream=self.stream), + }) + + # override module weight with computed weight with B@A delta + module.weight.data = computed_wq.to(dtype=module.weight.data.dtype) + + # for assert weight + # module.state.update({ + # "wq_ab": move_to(computed_wq.to(dtype=module.weight.data.dtype), device=CPU, stream=self.stream), + # }) + + # lowrank_dict[f'{layer_name}.lora_A.weight'] = A.cpu().to(dtype=torch.float16) + # lowrank_dict[f'{layer_name}.lora_B.weight'] = B.cpu().to(dtype=torch.float16) + + duration = time.time() - start + self.durations.append(duration) + self.module_names.append(f"layer-{module.layer_index}-{module.name}") + + stat = { + PROCESS_LOG_NAME: self.name(), + PROCESS_LOG_LAYER: module.layer_index, + PROCESS_LOG_MODULE: module.name, + PROCESS_LOG_TIME: f"{duration:.3f}", + PROCESS_LOG_FWD_TIME: f"{self.fwd_time:.3f}" + } + + if self.qcfg.dynamic is not None: + stat["dynamic"] = self.qcfg.dynamic_get(layer_name=module.full_name) + + self.log.append(stat) + logger.info(stat) + + # logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}") + self.result_save(module.full_name, { + "lora_A.weight": move_to(A.to(dtype=torch.float16), device=CPU, stream=self.stream), + "lora_B.weight": move_to(B.to(dtype=torch.float16), device=CPU, stream=self.stream), + }) + + # eora = Lora(rank=module.adapter_cfg.rank, lora_A=A, lora_B=B) + # + # module.state.update({ + # "adapter": eora, + # }) + + def submodule_finalize(self, module: NamedModule): + pass + # adapter: Lora = module.state.pop("adapter") + # + # # logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}") + # self.result_save(module.full_name, { + # "lora_A.weight": move_to(adapter.lora_A.to(dtype=torch.float16), device=CPU, stream=self.stream), + # # A.to(dtype=torch.float16, device=CPU), + # "lora_B.weight": move_to(adapter.lora_B.to(dtype=torch.float16), device=CPU, stream=self.stream), + # # B.to(dtype=torch.float16, device=CPU), + # }) + + def finalize(self, model: BaseGPTQModel, **kwargs): + # block for streams + if self.stream: + torch_sync() + + del self.eigen_scaling_diag_matrix + + # hack: store loras into model until `save()` is called + model.lora_results = self.results() + + super().finalize(model=model, **kwargs) + + def verify_calibration_dataset(self, processor_index: int) -> bool: + if self.calibration_dataset is None: + if processor_index == 0: + raise ValueError("EoraProcessor's calibration_dataset must be provided.") + else: + return False + return True + + @classmethod + def name(cls) -> str: + return "eora" diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py new file mode 100644 index 000000000..dc5bca773 --- /dev/null +++ b/gptqmodel/looper/gptq_processor.py @@ -0,0 +1,226 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Callable, Optional, Tuple + +import torch +from gptqmodel import QuantizeConfig +from gptqmodel.looper.loop_processor import LoopProcessor +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.models import BaseGPTQModel +from gptqmodel.models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, + PROCESS_LOG_NAME, PROCESS_LOG_TIME, QUANT_LOG_DAMP, QUANT_LOG_LOSS) +from gptqmodel.quantization import GPTQ +from gptqmodel.quantization.gptq import CPU +from gptqmodel.utils.logger import setup_logger +from gptqmodel.utils.model import move_to, pack_model +from gptqmodel.utils.torch import torch_sync +from torch.nn import Module + +logger = setup_logger() + +class GPTQProcessor(LoopProcessor): + def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration_dataset, prepare_dataset_func, + calibration_dataset_concat_size: Optional[int], batch_size: int, + logger_board: str = "", require_fwd: bool = True, retain_w: bool = False): + + super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration_dataset=calibration_dataset, + calibration_dataset_concat_size=calibration_dataset_concat_size, + prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, + logger_board=logger_board, require_fwd=require_fwd) + + self.retain_w = retain_w + self.avg_losses = [] + + def log_plotly(self): + task = self.logger_task + if task is not None: + from gptqmodel.utils.plotly import create_plotly + x = list(range(self.layer_count)) + gpu_fig = create_plotly(x=x, y=self.gpu_memorys, xaxis_title="layer", yaxis_title="GPU usage (GB)") + cpu_fig = create_plotly(x=x, y=self.cpu_memorys, xaxis_title="layer", yaxis_title="CPU usage (GB)") + loss_fig = create_plotly(x=self.module_names, y=self.avg_losses, xaxis_title="layer", yaxis_title="loss") + time_fig = create_plotly(x=self.module_names, y=self.durations, xaxis_title="layer", yaxis_title="time") + task.get_logger().report_plotly('GPU Memory', 'GPU Memory', gpu_fig) + task.get_logger().report_plotly('CPU Memory', 'CPU Memory', cpu_fig) + task.get_logger().report_plotly('avg_loss', 'avg_loss', loss_fig) + task.get_logger().report_plotly('quant_time', 'quant_time', time_fig) + + def set_calibration_dataset(self, calibration_dataset): + raise NotImplementedError("GPTQProcessor's calibration_dataset cannot be modified") + + def preprocess(self, module: NamedModule, buffered_fwd: bool): + # entire module is skipped + if self.qcfg.dynamic_get(layer_name=module.full_name) == False: + return + + qcfg_clone = copy.deepcopy(self.qcfg) + + # dynamic overrides + if self.qcfg.dynamic is not None: + qcfg_clone.bits = self.qcfg.dynamic_get(module.full_name, "bits", qcfg_clone.bits) + qcfg_clone.sym = self.qcfg.dynamic_get(module.full_name, "sym", qcfg_clone.sym) + qcfg_clone.mse = self.qcfg.dynamic_get(module.full_name, "mse", qcfg_clone.mse) + + qcfg_clone.group_size = self.qcfg.dynamic_get(module.full_name, "group_size", qcfg_clone.group_size) + qcfg_clone.desc_act = self.qcfg.dynamic_get(module.full_name, "desc_act", qcfg_clone.desc_act) + qcfg_clone.damp_percent = self.qcfg.dynamic_get(module.full_name, "damp_percent", qcfg_clone.damp_percent) + qcfg_clone.static_groups = self.qcfg.dynamic_get(module.full_name, "static_groups", qcfg_clone.static_groups) + + tmp = GPTQ(module=module, qcfg=qcfg_clone) + + # models like DeepSeek v3/r1 has > 256 $ of sub-modules per layer + # use buffered mode go vram don't explode: gptq needs to store fwd inputs per each layer fwd + # all sub-modules within a single layer needs to store all the inputs. + # deepseek has massive # of sub-modules per layer, causing vram pressure + # buffered mode is slower due to gpu<->cpu movement + if buffered_fwd: # TODO tweak this number for masive MoE + logger.info(f"Experimental: enabling fwd buffered mode for: `{module.name}`") + tmp.fwd_inputs_buffered = True + + tmp.quantizer.configure( + perchannel=True, + ) + self.tasks[module.name] = tmp + + def is_skipped(self, module: NamedModule) -> bool: + # gptq has no dynamic method of full override (removal) + t = self.tasks.get(module.name, False) + if t == False: + return True + else: + return False + + def preprocess_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: + def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): + # gptq is mutable. + g = self.tasks[name] # noqa: F821 + g.add_batch(inp[0].data, out.data) # noqa: F821 + return tmp + + def process(self, module: NamedModule): + self.pb.info(f"Quantizing {module.name} in layer {module.layer_index} of {self.layer_count - 1}") + gptq = self.tasks + + # logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}") + ## Need to return the quantized_weight for offloading + g = gptq[module.name] + # TODO FIX ME, quantize does NOT need to pass any args! Check HF compat! + wq, scale, zero, g_idx, duration, avg_loss, damp_percent = g.quantize() + ## Assign the quantized weight to the weight + #gptq[name].layer.weight.data = q_full_weight.to(device=gptq[name].device) + + ## Offload the quantized weight to CPU for EoRA + #quantized_weights['model.layers.%d.%s' % (module_index, name)] = q_full_weights.cpu() + + # if task is not None: + # task.get_logger().report_scalar( + # title='Quantization Loss', + # series=f'layer_{module_index}_loss', + # value=avg_loss, + # iteration=name_index, + # ) + # + # task.get_logger().report_scalar( + # title='Quantization Time', + # series=f'layer_{module_index}_time', + # value=duration, + # iteration=name_index, + # ) + self.durations.append(duration) + self.avg_losses.append(avg_loss) + self.module_names.append(f"layer-{module.layer_index}-{module.name}") + + stat = { + PROCESS_LOG_NAME: self.name(), + PROCESS_LOG_LAYER: module.layer_index, + PROCESS_LOG_MODULE: module.name, + QUANT_LOG_LOSS: f"{avg_loss:.5f}", + QUANT_LOG_DAMP: f"{damp_percent:.5f}", + PROCESS_LOG_TIME: f"{duration:.3f}", + PROCESS_LOG_FWD_TIME: f"{self.fwd_time:.3f}", + } + + if self.qcfg.dynamic is not None: + stat["dynamic"] = self.qcfg.dynamic_get(layer_name=module.full_name) + + self.log.append(stat) + logger.info(stat) + + self.result_save(module.full_name, { + "scale": move_to(scale, device=CPU, stream=self.stream), + "zero": move_to(zero, device=CPU, stream=self.stream), + "g_idx": move_to(g_idx, device=CPU, stream=self.stream), + }) + + if self.retain_w: + # original weights + w = module.weight.data + module.state.update({ + "w": w, # bf16/fp16, non-quantized native weight + }) + + gptq[module.name].free() + + # logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}") + module.state.update({ + "wq": wq, # fp16, quantized weight but not int4 (packed qweight) + }) + + # prepare for module.forward post generate + module.weight.data = wq + + # submodule_finalized is called in reverse after all next sequential processes are called + def submodule_finalize(self, module: NamedModule): + # generate complete, safe to move to cpu + module.weight.data = move_to(module.state.pop("wq"), device=CPU, stream=self.stream) # large weights is slow to init on cpu + module.state.pop("w", None) # no need for original weights now + + def finalize(self, model: BaseGPTQModel, **kwargs): + # block for streams + if self.stream: + torch_sync() + + backend = kwargs.pop("backend") + model.qlinear_kernel = pack_model( + model=model.model, + quant_result=self.results(), + bits=self.qcfg.bits, + group_size=self.qcfg.group_size, + backend=backend, + desc_act=self.qcfg.desc_act, + format=self.qcfg.format, + lm_head_name=model.lm_head, + dynamic=self.qcfg.dynamic, + parallel_packing=self.qcfg.parallel_packing, + pack_dtype=self.qcfg.pack_dtype, + ) + + # set quantized state + model.quantized = True + + super().finalize(model=model, **kwargs) + + def verify_calibration_dataset(self, processor_index: int) -> bool: + if self.calibration_dataset is None: + raise ValueError("GPTQProcessor's calibration_dataset must be provided.") + else: + return True + + @classmethod + def name(cls) -> str: + return "gptq" diff --git a/gptqmodel/looper/input_cache.py b/gptqmodel/looper/input_cache.py new file mode 100644 index 000000000..444e3e0c3 --- /dev/null +++ b/gptqmodel/looper/input_cache.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Dict, List + +import torch + + +@dataclass +class InputCache: + layer_inputs: List[List[torch.Tensor]] + layer_input_kwargs: List[Dict[str, torch.Tensor]] + position_ids: List[torch.Tensor] + attention_masks: List[torch.Tensor] diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py new file mode 100644 index 000000000..fc4a0e860 --- /dev/null +++ b/gptqmodel/looper/loop_processor.py @@ -0,0 +1,218 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from gptqmodel.looper.input_cache import InputCache +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.models import BaseGPTQModel +from gptqmodel.models._const import CALIBRATION_DATASET_CONCAT_CHAR +from gptqmodel.quantization.config import QuantizeConfig +from gptqmodel.utils.data import collate_data +from gptqmodel.utils.device import get_cpu_usage_memory, get_gpu_usage_memory +from gptqmodel.utils.logger import setup_logger +from torch import Tensor +from torch.nn import Module + +logger = setup_logger() + + +# LoopProcessor is a singleton(), not per module instance +class LoopProcessor: + def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration_dataset, prepare_dataset_func, + calibration_dataset_concat_size: Optional[int], batch_size: int, + logger_board: str = "", require_fwd: bool = True): + + # result is total collection of all module results mapped by module.full_name + self._results: Dict[str, Any] = {} + + # toggle to enable stream from gpu to cpu + self.stream = False + + self.tokenizer = tokenizer + self.qcfg = qcfg + + # if processor require fwd generate and hooks, set this to true + # looper should bypass generate + hooks if this is false + self.require_fwd = require_fwd + + self.inputs_cache: InputCache = InputCache(None, None, None, None) + self.tasks = {} + + self.pb = None + self.logger_task = None + self.fwd_time = None + self.layer_count = None + + # logging + self.log = [] + self.logger_board = logger_board + self.gpu_memorys = [] + self.cpu_memorys = [] + self.durations = [] + self.module_names = [] + + if self.logger_board == "clearml": + try: + from clearml import Task + from random_word import RandomWords + + from ..utils.plotly import create_plotly + except ImportError as _: + raise ImportError( + "The logger_board is set to 'clearml', but required dependencies are missing. " + "Please install them by running: pip install gptqmodel[logger]" + ) + self.logger_task = Task.init(project_name='GPTQModel', + task_name=f'{self.__class__.__name__}-{RandomWords().get_random_word()}', + task_type=Task.TaskTypes.optimizer) + else: + self.logger_task = None + + + # prepare dataset + if calibration_dataset is not None: + if len(calibration_dataset) == 0: + raise ValueError("Calibration dataset must not be empty.") + + min_calibration_dataset_size = 256 + min_calibration_dataset_input_ids_avg_length = 256 + if len(calibration_dataset) < min_calibration_dataset_size: + logger.warning(f"Calibration dataset size should be more than {min_calibration_dataset_size}. " + f"Current: {len(calibration_dataset)}.") + + calibration_dataset = prepare_dataset_func(calibration_dataset=calibration_dataset, + calibration_dataset_concat_size=calibration_dataset_concat_size, + batch_size=batch_size) + + # Calculate the average length of the average input_ids + total_input_ids_length = 0 + max_input_id_length = 0 + for row in calibration_dataset: + input_ids = row["input_ids"] + if isinstance(input_ids, torch.Tensor): + if input_ids.dim() <= 2: + input_ids_length = input_ids.shape[-1] + else: + raise ValueError( + "Expected a 1-dimensional tensor or 2-dimensional tensor for 'input_ids', but got a tensor with {0} dimensions.".format( + input_ids.dim())) + else: + input_ids_length = len(input_ids) + + if input_ids_length > max_input_id_length: + max_input_id_length = input_ids_length + total_input_ids_length += input_ids_length + avg = total_input_ids_length / len(calibration_dataset) + + if avg < min_calibration_dataset_input_ids_avg_length: + logger.warning(f"The average length of input_ids of calibration_dataset should be greater than " + f"{min_calibration_dataset_input_ids_avg_length}: actual avg: {avg}.") + + self.num_batches = len(calibration_dataset) + + self.calibration_dataset = calibration_dataset + + def result_save(self, key: str, value: Any): + assert self.result_get(key) is None, f"key: {key} already exists in `self.result`" + self._results[key] = value + + def result_get(self, key: str, default: Any = None) -> Any: + return self._results.get(key, default) + + def results(self): + return self._results + + def collect_memory_info(self, layer_index: int): + if self.logger_task is not None: + gpu_memory = get_gpu_usage_memory() + cpu_memory = get_cpu_usage_memory() + self.logger_task.get_logger().report_scalar( + title='GPU Memory', + series='GPU Memory', + value=gpu_memory, + iteration=layer_index, + ) + + self.logger_task.get_logger().report_scalar( + title='CPU Memory', + series='CPU Memory', + value=cpu_memory, + iteration=layer_index, + ) + self.gpu_memorys.append(gpu_memory) + self.cpu_memorys.append(cpu_memory) + + def log_plotly(self): + pass + + def set_calibration_dataset(self, calibration_dataset): + pass + + def set_fwd_time(self, fwd_time: float): + self.fwd_time = fwd_time + + # called first + def preprocess(self, module: NamedModule, **kwargs): + pass + + # after preproces, this process may be skipped due to dynamic override (lora adapter = None) + def is_skipped(self, module: NamedModule) -> bool: + pass + + def receive_input_cache(self, input_cache: InputCache): + self.inputs_cache = input_cache + + # called after every module generate + # may be called multiple times due to batch + def receive_layer_inputs(self, layer_inputs: List[List[Tensor]]): + self.inputs_cache.layer_inputs = layer_inputs + + def clear_cache_data(self): + self.tasks = {} + self.inputs_cache.layer_inputs = [] + + def preprocess_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: + pass + + # do work and return processor.self state which will updated/merged + def process(self, module: NamedModule): + pass + + # last step, after all loop processor is called + # submodule_finalize is called in reverse after all next sequential processes are called + def submodule_finalize(self, module: NamedModule): + pass + + # last step, after all loop processor is called + # finalize is called in reverse after all next sequential processes are called + def finalize(self, model: BaseGPTQModel, **kwargs): + del self.inputs_cache + del self._results + + def release_calibration_dataset(self): + del self.calibration_dataset + + def number_batches(self) -> int: + return self.num_batches + + def verify_calibration_dataset(self, processor_index: int) -> bool: + pass + + @classmethod + def name(cls) -> str: + pass diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py new file mode 100644 index 000000000..123e88ffc --- /dev/null +++ b/gptqmodel/looper/module_looper.py @@ -0,0 +1,441 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import time +from typing import List + +import torch +from gptqmodel.looper.dequantize_processor import DequantizeProcessor +from gptqmodel.looper.eora_processor import EoraProcessor +from gptqmodel.looper.gptq_processor import GPTQProcessor +from gptqmodel.looper.input_cache import InputCache +from gptqmodel.looper.loop_processor import LoopProcessor +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.models import BaseGPTQModel +from gptqmodel.models._const import SUPPORTS_MODULE_TYPES +from gptqmodel.nn_modules.hooked_linear import replace_linear_with_hooked_linear +from gptqmodel.quantization.gptq import CPU +from gptqmodel.utils.logger import setup_logger +from gptqmodel.utils.model import (find_modules, get_device, get_module, get_module_by_name_prefix, + get_moe_layer_modules, move_to, nested_move_to) +from gptqmodel.utils.progress import ProgressBar +from gptqmodel.utils.torch import torch_empty_cache + +logger = setup_logger() + + +class ModuleLooper(): + def __init__(self, model: BaseGPTQModel, processors: List[LoopProcessor]): + self.processors = processors + self.gptq_model = model + + def cache_inputs(self, layers, auto_gc, calibration_data, calibration_enable_gpu_cache): + layer_inputs = [] + attention_masks = [] + position_ids = [] + layer_input_kwargs = [] + + cur_layer_device = get_device(layers[0]) + data_device = cur_layer_device if calibration_enable_gpu_cache else CPU + + # TODO HookLinear add register_forward_pre_hook() + def store_input_hook(_, args, kwargs): + # Positional arguments. + layer_input = [] + for inp in args: + layer_input.append(move_to(inp, device=data_device)) + if len(layer_input) == 0: + # Some models put hidden_states in kwargs instead of args. + # For example, gptj ... + if kwargs.get("hidden_states") is not None: + layer_input.append(move_to(kwargs["hidden_states"], device=data_device)) + + layer_inputs.append(layer_input) + + # Keyword arguments. + if kwargs.get("attention_mask") is not None: + attention_masks.append(kwargs["attention_mask"].to(device=data_device)) + else: + attention_masks.append(None) + + pos_ids = kwargs.get("position_ids", None) + if pos_ids is not None: + position_ids.append(move_to(pos_ids, device=data_device)) + one_kwargs = {} + for (k, v) in kwargs.items(): # make sure other arguments also be captured + if k not in ["hidden_states", "attention_mask", "position_ids"]: + one_kwargs[k] = nested_move_to(v, device=data_device) + layer_input_kwargs.append(one_kwargs) + + raise ValueError + + # move layer to target device + layers[0] = layers[0].to(self.gptq_model.quantize_config.device) + ori_outside_layer_module_devices = {} + for module_name in self.gptq_model.base_modules: + module = get_module_by_name_prefix(self.gptq_model.model, module_name) + + if module is None: + continue + + ori_outside_layer_module_devices[module_name] = get_device(module) + if module is not None: + move_to(module, cur_layer_device) + # TODO: make this optional, backporting https://github.com/huggingface/optimum/blob/main/optimum/gptq/quantizer.py + handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) + is_ovis = self.gptq_model.__class__.__name__ == "OvisGPTQ" + self.gptq_model.pre_quantize_generate_hook_start() + for example in calibration_data: + for k, v in example.items(): + data_device = self.gptq_model.quantize_config.device if k == "pixel_values" else cur_layer_device + if isinstance(v, list): + for index in range(len(v)): + if len(v[index].shape) == 1: + v[index] = v[index].unsqueeze(0) + v[index] = move_to(v[index].to(self.gptq_model.model.visual_tokenizer.dtype) if is_ovis else v[index], + device=data_device) + else: + if len(v.shape) == 1: + v = v.unsqueeze(0) + example[k] = move_to(v, device=data_device) + try: + if is_ovis: + self.gptq_model.generate(inputs=example.pop("input_ids"), max_new_tokens=1024, **example) + else: + self.gptq_model.model(**example) + except ValueError: + pass + self.gptq_model.pre_quantize_generate_hook_end() + handle.remove() + move_to(layers[0], device=CPU) + for module_name in self.gptq_model.base_modules: + module = get_module_by_name_prefix(self.gptq_model.model, module_name) + if module is not None: + move_to(module, device=ori_outside_layer_module_devices[module_name]) + if auto_gc: + torch_empty_cache() + return InputCache(layer_inputs=layer_inputs, layer_input_kwargs=layer_input_kwargs, position_ids=position_ids, + attention_masks=attention_masks) + + @torch.no_grad() + def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=False, **kwargs): + if self.gptq_model.quantize_config.lm_head: + if self.gptq_model.model.config.tie_word_embeddings and hasattr(self.gptq_model.model.model, "_tied_weights_keys"): + tied_keys = self.gptq_model.model._tied_weights_keys + for item in tied_keys: + if self.gptq_model.lm_head in item: + raise NotImplementedError("quantization of `lm_head` layer with `tied_weights=True` model state is not supported. Please check model has `tied_weights=False`.") + + lm_head_module = get_module(self.gptq_model.model, key=self.gptq_model.lm_head) + if get_module(self.gptq_model.model, key=self.gptq_model.lm_head) is None: + raise ValueError(f"could not find layer {self.gptq_model.lm_head} in the model, exit...") + + if not isinstance(lm_head_module, tuple(SUPPORTS_MODULE_TYPES)): + raise NotImplementedError(f"This type({type(lm_head_module)}) of lm_head quantization is currently not " + f"supported. SUPPORTS_MODULE_TYPES is {SUPPORTS_MODULE_TYPES}") + + lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": True, "desc_act": False, "mse": 2.4} + if self.gptq_model.quantize_config.dynamic is None: + self.gptq_model.quantize_config.dynamic = {self.gptq_model.lm_head: lm_head_quant_config} + elif self.gptq_model.quantize_config.dynamic_get(self.gptq_model.lm_head, default_value=None) is None: + self.gptq_model.quantize_config.dynamic[self.gptq_model.lm_head] = lm_head_quant_config + + forward_pass_use_cache = self.gptq_model.model.config.use_cache if hasattr(self.gptq_model.model.config, "use_cache") else False + self.gptq_model.model.config.use_cache = False + + layers = get_module_by_name_prefix(self.gptq_model.model, self.gptq_model.layers_node) + + for p_index, processor in enumerate(self.processors): + if not processor.verify_calibration_dataset(p_index): + if isinstance(processor, EoraProcessor): + prev_processor = self.processors[p_index - 1] + processor.set_calibration_dataset(prev_processor.calibration_dataset) + # If calibration_dataset is None or Empty, the input_cache of the previous processor is used. + processor.receive_input_cache(copy.copy(prev_processor.inputs_cache)) + elif isinstance(processor, DequantizeProcessor): + # DequantizeProcessor does not perform any operations on dataset. + processor.set_calibration_dataset([]) + processor.receive_input_cache(InputCache([], [], [], [])) + + continue + + input_cache = self.cache_inputs(layers=layers, auto_gc=auto_gc, + calibration_data=processor.calibration_dataset, + calibration_enable_gpu_cache=calibration_enable_gpu_cache) + processor.receive_input_cache(input_cache) + + # release calibration_dataset + for processor in self.processors: + processor.release_calibration_dataset() + + layer_modules = self.gptq_model.layer_modules + + if not self.gptq_model.quantize_config.true_sequential: + layer_modules = [sum(layer_modules, [])] + + # dynamic expert layer index for model defs + if self.gptq_model.dynamic_expert_index is not None: + num_experts = getattr(self.gptq_model.model.config, self.gptq_model.dynamic_expert_index) + layer_modules = get_moe_layer_modules(layer_modules=self.gptq_model.layer_modules, + num_experts=num_experts) + + layer_count = len(layers) + quant_modules_pb = ProgressBar(range(layer_count + 1 if self.gptq_model.quantize_config.lm_head else layer_count)) + + for processor in self.processors: + processor.layer_count = layer_count + processor.pb = quant_modules_pb + + shared_kv_cache_dict = {} + + # replace linear with hooked linear + replace_linear_with_hooked_linear(self.gptq_model.model) + + for layer_index in quant_modules_pb: + is_lm_head_module = layer_index >= layer_count + + if is_lm_head_module: + quant_modules_pb.info("Quantizing lm_head") + module = get_module(self.gptq_model.model, key=self.gptq_model.lm_head) + layer_inputs = self.gptq_model.lm_head_pre_quantize_generate_hook(layer_inputs) + else: + quant_modules_pb.info(f"Quantizing layer {layer_index} of {layer_count - 1}") + module = layers[layer_index] + + if module.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower(): + # TODO FIXME: currently we not support quantizing cross attention layer (pixel_values) + continue + + self.gptq_model.pre_quantize(module) + + cur_layer_device = get_device(module) + full = find_modules(module, name=self.gptq_model.lm_head if is_lm_head_module else "") + modules = [[self.gptq_model.lm_head]] if is_lm_head_module else layer_modules + + for p_index, processor in enumerate(self.processors): + processor.collect_memory_info(layer_index) + + layer_inputs = processor.inputs_cache.layer_inputs + layer_input_kwargs = processor.inputs_cache.layer_input_kwargs + position_ids = processor.inputs_cache.position_ids + attention_masks = processor.inputs_cache.attention_masks + + processed_subset = {} + for index, names in enumerate(modules): + subset = {} + for n in names: + if n in full: + subset[n] = full[n] + # some modules have layer_modules that are dynamic based on config + # ref: deepseek v2/v3/r1 + elif self.gptq_model.layer_modules_strict: + raise ValueError(f"layer module item `{n}` not found in model, please check your model config.") + + + skipped_modules = [] + + for name in subset: + layer_name = self.gptq_model.lm_head if is_lm_head_module else f"{self.gptq_model.layers_node}.{layer_index}.{name}" + + # gptq task is created and stored inside processor + if not isinstance(subset[name], NamedModule): + named_module = NamedModule(subset[name], name=name, full_name=layer_name, + layer_index=layer_index) + if isinstance(processor, EoraProcessor): + named_module.state.update({ + "wq": processor.quantized_weights[layer_name], + }) + # TODO processor.release_quantized_weights() + + subset[name] = named_module + full[name] = named_module + + processor.preprocess(subset[name], buffered_fwd=buffered_fwd) + # some modules are skipped + if processor.is_skipped(subset[name]): + skipped_modules.append(name) + + for name in skipped_modules: + subset.pop(name) + + if len(subset) == 0: + continue + + handle = [] + for name in subset: + if hasattr(subset[name], 'forward_hook'): + subset[name].forward_hook = processor.preprocess_fwd_hook(name) + else: + # TODO FIXME: do we even need to hook into modules that are not quantizable? + assert (f"forward_hook missing for module name: `{name}`, layer name: {layer_name}") + handle.append(subset[name].register_forward_hook(processor.preprocess_fwd_hook(name))) + + # logger.info(f"layer-{i}: Begin Forward() Pass") + fwd_start = time.time() + for j in range(processor.num_batches): + layer_input = [] + for k, layer_inp in enumerate(layer_inputs[j]): + layer_input.append(move_to(layer_inp, device=cur_layer_device)) + + mask = attention_masks[j] + layer_attention_mask = mask if mask is None else move_to(mask, device=cur_layer_device) + + additional_layer_inputs = {"attention_mask": layer_attention_mask} + layer_position_ids = ( + None if not position_ids else move_to(position_ids[j], device=cur_layer_device) + ) + if layer_position_ids is not None: + additional_layer_inputs["position_ids"] = layer_position_ids + for k, v in layer_input_kwargs[j].items(): + additional_layer_inputs[k] = nested_move_to(v, device=cur_layer_device) + + # reuse_kv is a flag to reuse the kv cache, only for the hamba model + if hasattr(module, "reuse_kv"): + if module.reuse_kv: + additional_layer_inputs["kv_last_layer"] = shared_kv_cache_dict.get( + layer_index - 1) + + layer_output = module(*layer_input) if is_lm_head_module else module(*layer_input, + **additional_layer_inputs) + if shared_kv_cache_dict.get(layer_index) is None: + shared_kv_cache_dict[layer_index] = layer_output[-1] + else: + module(*layer_input) if is_lm_head_module else module(*layer_input, + **additional_layer_inputs) + + del layer_input + del additional_layer_inputs + + fwd_end = time.time() + fwd_time = fwd_end - fwd_start + + processor.set_fwd_time(fwd_time) + + for h in handle: + h.remove() + + for name in subset: + if hasattr(subset[name], 'forward_hook'): + subset[name].forward_hook = None + + for name_index, name in enumerate(subset): + m = subset[name] + processor.process(module=m) + processed_subset[name] = m + + if index == len(layer_modules) - 1: + if auto_gc: + torch_empty_cache() + + is_last_module = layer_index == len(quant_modules_pb) - 1 + layer_outputs = [] + if not is_last_module: + for j in range(processor.num_batches): + # assert weight + # if isinstance(processor, EoraProcessor): + # for names in modules: + # if n in names: + # assert torch.equal(full[n].weight.data.cpu(), processed_subset[n].state["wq_ab"]) + # assert not torch.equal(full[n].weight.data.cpu(), processed_subset[n].state["wq"]) + # assert not torch.equal(processed_subset[n].state["wq_ab"], processed_subset[n].state["wq"]) + # full[n].weight.data.cuda() + + layer_input = [] + for k, layer_inp in enumerate(layer_inputs[j]): + layer_input.append(move_to(layer_inp, device=cur_layer_device)) + + mask = attention_masks[j] + layer_attention_mask = mask if mask is None else move_to(mask, device=cur_layer_device) + + additional_layer_inputs = {"attention_mask": layer_attention_mask} + layer_position_ids = None if not position_ids else move_to(position_ids[j], device=cur_layer_device) + if layer_position_ids is not None: + additional_layer_inputs["position_ids"] = layer_position_ids + for k, v in layer_input_kwargs[j].items(): + additional_layer_inputs[k] = nested_move_to(v, device=cur_layer_device) + + if hasattr(module, "reuse_kv"): + if module.reuse_kv: + additional_layer_inputs["kv_last_layer"] = shared_kv_cache_dict.get(layer_index - 1) + + layer_output = move_to( + module(*layer_input)[0] if is_lm_head_module else + module(*layer_input, **additional_layer_inputs)[0], + device=cur_layer_device if calibration_enable_gpu_cache else CPU, + ) + layer_outputs.append([layer_output]) + + del layer_input + del additional_layer_inputs + if processor.num_batches > 1 and j == processor.num_batches - 1: + if auto_gc: + torch_empty_cache() + + # TODO move to processor? + if p_index == len(self.processors) - 1: + if not is_lm_head_module: + layers[layer_index] = self.gptq_model.post_quantize(module) + else: + self.gptq_model.post_quantize(module) + + processor.clear_cache_data() + + processor.receive_layer_inputs(layer_outputs) + + # if last processor, we need to call finalize in reverse + if p_index == len(self.processors) - 1: + for reverse_p in reversed(self.processors): + for name in processed_subset: + reverse_p.submodule_finalize(processed_subset[name]) + del module + + if auto_gc: + torch_empty_cache() + + total_log = {} + + for reverse_p in reversed(self.processors): + if isinstance(reverse_p, GPTQProcessor): + pass + #logger.info(f"Quantization summary:\n{reverse_p.log}") + elif isinstance(reverse_p, EoraProcessor): + pass + #logger.info(f"Eora summary:\n{reverse_p.log}") + elif isinstance(reverse_p, DequantizeProcessor): + # ignore log + pass + else: + logger.info(f"{reverse_p.name()} summary:\n{reverse_p.log}") + + processor_name = reverse_p.name() + total_log[processor_name] = reverse_p.log + if processor_name == "gptq": + self.gptq_model.quant_log = reverse_p.log + + for module_log in reverse_p.log: + logger.info(module_log) + reverse_p.log_plotly() + + reverse_p.finalize(model=self.gptq_model, **kwargs) + + + self.gptq_model.model.config.use_cache = forward_pass_use_cache + + + if auto_gc: + torch_empty_cache() + + return total_log diff --git a/gptqmodel/looper/named_module.py b/gptqmodel/looper/named_module.py new file mode 100644 index 000000000..bc49d525f --- /dev/null +++ b/gptqmodel/looper/named_module.py @@ -0,0 +1,76 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict + +import torch +import transformers +from torch import nn + + +class NamedModule(torch.nn.Module): + def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_index: int) -> None: + super().__init__() + + self.module = module # wrapped module + self.name = name # module name + self.full_name = full_name # module full name (path) within model + self.layer_index = layer_index # layerid in a repeating layer, if in outside layer, this info may be fake + + # persistent work state forLoopProcessors + # store all `processed()` work state/data/result here + self.state = {} + + # print(f"NamedModule init: name: `{name}, full-name: `{full_name}`") + + # store original in/out features since weight.data will changed later on + if isinstance(module, nn.Linear): + in_features = module.in_features + out_features = module.out_features + elif isinstance(module, nn.Conv2d): + in_features = module.in_channels + out_features = module.out_channels + elif isinstance(module, transformers.pytorch_utils.Conv1D): + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + else: + raise NotImplementedError(f"Unsupported module.module type: `{type(module)}`") + + self.state.update({ + "in_features": in_features, + "out_features": out_features, + }) + + # return stats for mo + # def stats(self) -> Dict[str, float]: + # # -1 means no stats have yet to gathered for the stat property + # return { + # STAT_GPTQ_DURATION: self.state.get(STAT_GPTQ_DURATION, -1), + # STAT_GPTQ_AVG_LOSS: self.state.get(STAT_GPTQ_AVG_LOSS, -1), + # STAT_GPTQ_DAMP_PERCENT: self.state.get(STAT_GPTQ_DAMP_PERCENT, -1), + # STAT_GPTQ_FWD_TIME: self.state.get(STAT_GPTQ_FWD_TIME, -1), + # } + + # getattr is only called if python cannot find attr for `self` + def __getattr__(self, name: str): + return getattr(self.module, name) + + # setattr is always called by python even if attr exists in `self` + def __setattr__(self, name: str, value: Any) -> None: + if name in ["module", "name", "full_name", "layer_index", "state"]: + self.__dict__[name] = value + else: + self.module.__dict__[name] = value diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 9a21df5d7..d40b831b2 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -18,16 +18,13 @@ import os -from lm_eval.utils import make_table -from tokenicer import Tokenicer - if not os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None): os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 'expandable_segments:True' print("ENV: Auto setting PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' for memory saving.") if not os.environ.get("CUDA_DEVICE_ORDER", None): os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' - print("ENV: Auto setting CUDA_DEVICE_ORDER=PCI_BUS_ID for compatibililty.") + print("ENV: Auto setting CUDA_DEVICE_ORDER=PCI_BUS_ID for correctness.") import sys # noqa: E402 @@ -38,18 +35,24 @@ import os.path # noqa: E402 import random # noqa: E402 from os.path import isdir, join # noqa: E402 -from typing import Dict, List, Optional, Union # noqa: E402 +from typing import Any, Dict, List, Optional, Type, Union # noqa: E402 import numpy # noqa: E402 import torch # noqa: E402 +from gptqmodel.adapter.adapter import Adapter, Lora, normalize_adapter # noqa: E402 from huggingface_hub import list_repo_files # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 +from tokenicer import Tokenicer # noqa: E402 from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizerBase # noqa: E402 +from ..nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 from ..quantization import QUANT_CONFIG_FILENAME # noqa: E402 +from ..quantization.gptq import CPU # noqa: E402 from ..utils import BACKEND # noqa: E402 from ..utils.eval import EVAL # noqa: E402 from ..utils.logger import setup_logger # noqa: E402 -from ..utils.model import check_and_get_model_type # noqa: E402 +from ..utils.model import check_and_get_model_type, find_modules # noqa: E402 +from ..utils.torch import torch_empty_cache # noqa: E402 from .base import BaseGPTQModel, QuantizeConfig # noqa: E402 from .definitions.baichuan import BaiChuanGPTQ # noqa: E402 from .definitions.bloom import BloomGPTQ # noqa: E402 @@ -165,6 +168,7 @@ } + class GPTQModel: def __init__(self): raise EnvironmentError( @@ -185,15 +189,20 @@ def load( verify_hash: Optional[Union[str, List[str]]] = None, **kwargs, ): + # normalize config to cfg instance + if isinstance(quantize_config, Dict): + quantize_config = QuantizeConfig(**quantize_config) + if isinstance(backend, str): backend = BACKEND(backend) - if backend == BACKEND.VLLM: - from ..integration.integration_vllm import patch_vllm - patch_vllm() + # if backend == BACKEND.VLLM: + # from ..integration.integration_vllm import patch_vllm + # patch_vllm() is_quantized = False - if hasattr(AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code), "quantization_config"): + if hasattr(AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code), + "quantization_config"): is_quantized = True else: for name in [QUANT_CONFIG_FILENAME, "quant_config.json"]: @@ -237,14 +246,16 @@ def from_pretrained( trust_remote_code: bool = False, **model_init_kwargs, ) -> BaseGPTQModel: - if hasattr(AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code), "quantization_config"): + if hasattr(AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code), + "quantization_config"): logger.warning("Model is already quantized, will use `from_quantized` to load quantized model.\n" "If you want to quantize the model, please pass un_quantized model path or id, and use " "`from_pretrained` with `quantize_config`.") return cls.from_quantized(model_id_or_path, trust_remote_code=trust_remote_code) if quantize_config and quantize_config.dynamic: - logger.warning("GPTQModel's per-module `dynamic` quantization feature is currently not upstreamed to hf/vllm/sglang. If you're using vllm, you need to install this PR: https://github.com/vllm-project/vllm/pull/7086") + logger.warning( + "GPTQModel's per-module `dynamic` quantization feature is currently not upstreamed to hf/vllm/sglang. If you're using vllm, you need to install this PR: https://github.com/vllm-project/vllm/pull/7086") model_type = check_and_get_model_type(model_id_or_path, trust_remote_code) return MODEL_MAP[model_type].from_pretrained( @@ -261,6 +272,7 @@ def from_quantized( device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None, device: Optional[Union[str, int]] = None, backend: Union[str, BACKEND] = BACKEND.AUTO, + adapter: Optional[Adapter | Dict] = None, trust_remote_code: bool = False, # verify weight files matches predefined hash during loading # usage: hash_format:hash_value, example: md5:ugkdh232 @@ -268,6 +280,10 @@ def from_quantized( verify_hash: Optional[Union[str, List[str]]] = None, **kwargs, ) -> BaseGPTQModel: + # normalize adapter to instance + adapter = normalize_adapter(adapter) + + print(f"from_quantized: adapter: {adapter}") model_type = check_and_get_model_type(model_id_or_path, trust_remote_code) if isinstance(backend, str): @@ -280,6 +296,7 @@ def from_quantized( backend=backend, trust_remote_code=trust_remote_code, verify_hash=verify_hash, + adapter=adapter, **kwargs, ) @@ -287,15 +304,16 @@ def from_quantized( def eval( cls, model_or_id_or_path: str=None, - tokenizer: PreTrainedTokenizerBase=None, - tasks: Union[List[EVAL.LM_EVAL], List[EVAL.EVALPLUS]] = None, # set to None to fix mutable warning - framework: EVAL = EVAL.LM_EVAL, - batch_size: int = 1, + tokenizer: Union[PreTrainedTokenizerBase, Tokenicer]=None, + tasks: Union[EVAL.LM_EVAL, EVAL.EVALPLUS, List[EVAL.LM_EVAL], List[EVAL.EVALPLUS]] = None, # set to None to fix mutable warning + framework: Union[Type[EVAL.LM_EVAL],Type[EVAL.EVALPLUS]] = EVAL.LM_EVAL, + batch_size: Union[int, str] = 1, trust_remote_code: bool = False, output_path: Optional[str] = None, - backend: str = 'gptqmodel', + llm_backend: str = 'gptqmodel', + backend: BACKEND = BACKEND.AUTO, # gptqmodel arg only random_seed: int = 1234, # only for framework=EVAL.LM_EVAL backend=vllm - model_args: Dict = None, # only for framework=EVAL.LM_EVAL backend=vllm + model_args: Dict[str, Any] = None, # only for framework=EVAL.LM_EVAL backend=vllm **args ): if model_args is None: @@ -306,17 +324,20 @@ def eval( else: tasks = [EVAL.EVALPLUS.HUMAN] + elif not isinstance(tasks, List): + tasks = [tasks] + if framework is None: - raise ValueError("eval parameter: `framework` cannot be set to None") + raise ValueError("Eval parameter: `framework` cannot be set to None") if not isinstance(tasks, list): - raise ValueError("eval parameter: `tasks` must be of List type") + raise ValueError("Eval parameter: `tasks` must be of List type") - if backend not in ['gptqmodel', 'vllm']: - raise ValueError('Eval framework support backend: [gptqmodel, vllm]') + if llm_backend not in ['gptqmodel', 'vllm']: + raise ValueError('Eval framework support llm_backend: [gptqmodel, vllm]') if isinstance(model_or_id_or_path, str): - model = None + model = GPTQModel.load(model_id_or_path=model_or_id_or_path, backend=backend) model_id_or_path = model_or_id_or_path elif isinstance(model_or_id_or_path, BaseGPTQModel) or isinstance(model_or_id_or_path, PreTrainedModel): model = model_or_id_or_path @@ -328,34 +349,36 @@ def eval( if isinstance(model, BaseGPTQModel): tokenizer = model.tokenizer elif isinstance(model, PreTrainedModel) or model_id_or_path.strip(): - tokenizer = Tokenicer.load(model_id_or_path).tokenizer # lm-eval checks if tokenizer's type is PretrainedTokenizer + tokenizer = Tokenicer.load(model_id_or_path) if tokenizer is None: raise ValueError("Tokenizer: Auto-loading of tokenizer failed with `model_or_id_or_path`. Please pass in `tokenizer` as argument.") + if backend=="gptqmodel": # vllm loads tokenizer - model_args["tokenizer"] = tokenizer + if isinstance(tokenizer, Tokenicer): + model_args["tokenizer"] = tokenizer.tokenizer # lm-eval checks if tokenizer's type is PretrainedTokenizer + else: + model_args["tokenizer"] = tokenizer if framework == EVAL.LM_EVAL: for task in tasks: if task not in EVAL.get_task_enums(): - raise ValueError(f"lm_eval support tasks: {EVAL.get_all_tasks_string()}") + raise ValueError(f"Eval.lm_eval supported `tasks`: `{EVAL.get_all_tasks_string()}`, actual = `{task}`") - model_name = "hf" if backend == "gptqmodel" else backend + model_name = "hf" if llm_backend == "gptqmodel" else llm_backend - if backend == "gptqmodel": + if llm_backend == "gptqmodel": model_args["gptqmodel"] = True model_args["pretrained"] = model_id_or_path try: from lm_eval import simple_evaluate - from lm_eval.loggers import EvaluationTracker, WandbLogger from lm_eval.models.huggingface import HFLM - from lm_eval.utils import handle_non_serializable except BaseException: raise ValueError("lm_eval is not installed. Please install via `pip install gptqmodel[eval]`.") - if backend == "gptqmodel" and model is not None: + if llm_backend == "gptqmodel" and model is not None: model_name = HFLM( pretrained=model, batch_size=batch_size, @@ -399,9 +422,10 @@ def eval( batch=batch_size, trust_remote_code=trust_remote_code, output_file=output_path, - backend=backend + backend=llm_backend ) - results[task.value] = {"base tests": base_formatted, "base + extra tests": plus_formatted, "results_path": result_path} + results[task.value] = {"base tests": base_formatted, "base + extra tests": plus_formatted, + "results_path": result_path} print('--------evalplus Eval Result---------') evalplus_make_table(results) print('--------evalplus Result End---------') @@ -428,9 +452,11 @@ def export(model_id_or_path: str, target_path: str, format: str, trust_remote_co from ..utils.mlx import convert_gptq_to_mlx_weights except ImportError: - raise ValueError("MLX not installed. Please install via `pip install gptqmodel[mlx] --no-build-isolation`.") + raise ValueError( + "MLX not installed. Please install via `pip install gptqmodel[mlx] --no-build-isolation`.") - mlx_weights, mlx_config = convert_gptq_to_mlx_weights(model_id_or_path, gptq_model, gptq_config) + mlx_weights, mlx_config = convert_gptq_to_mlx_weights(model_id_or_path, gptq_model, gptq_config, + gptq_model.lm_head) save_weights(target_path, mlx_weights, donate_weights=True) @@ -474,4 +500,59 @@ def push_to_hub(repo_id: str, folder_path=quantized_path, repo_id=repo_id, repo_type=repo_type, - ) \ No newline at end of file + ) + + class adapter: + @classmethod + def generate( + cls, + # eora adapter generation needs config Lora(rank=1, path='lora.safetensors') + adapter: Adapter, + model_id_or_path: str, # native model + quantized_model_id_or_path: str, # gptqmodel quantized model + calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]], + calibration_dataset_concat_size: Optional[int] = None, + batch_size: Optional[int] = 1, + calibration_enable_gpu_cache: Optional[bool] = True, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + logger_board: Optional[str] = None, + # Experimental: enables the buffering of fwd inputs to cpu, slower than non-buffered, may reduce vram usage + buffered_fwd: bool = False, + # torch/cuda GC is auto enabled to reduce vram usage: disable to for small models or you know there is no possibility of oom due to vram to accelerate quantization + auto_gc: bool = True, + ): + if not adapter or not isinstance(adapter, Lora): + raise ValueError(f"Adapter: expected `adapter` type to be `Lora`: actual = `{adapter}`.") + + adapter.validate_path(local_only=True) + + quantized_model = GPTQModel.load( + model_id_or_path=quantized_model_id_or_path, + backend=BACKEND.TORCH, + device=CPU, + ) + + qcfg = quantized_model.quantize_config + qModules: Dict[str, TorchQuantLinear] = find_modules(module=quantized_model.model, layers=[TorchQuantLinear]) + # for name, module in qModules.items(): + # quantized_weights[name] = module.dequantize_weight() + del quantized_model + torch_empty_cache() + + model = GPTQModel.load( + model_id_or_path=model_id_or_path, + quantize_config=qcfg, + backend=BACKEND.TORCH) + + model._eora_generate( + adapter=adapter, + quantized_modules=qModules, + calibration_dataset=calibration_dataset, + calibration_dataset_concat_size=calibration_dataset_concat_size, + batch_size=batch_size, + calibration_enable_gpu_cache=calibration_enable_gpu_cache, + tokenizer=tokenizer, + logger_board=logger_board, + buffered_fwd=buffered_fwd, + auto_gc=auto_gc) + return diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 0ad73b08a..1e44a7381 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -16,10 +16,11 @@ from __future__ import annotations +import copy import json import os import time -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch._dynamo @@ -27,29 +28,33 @@ from packaging import version from packaging.version import Version from tokenicer import Tokenicer -from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase, modeling_utils, ProcessorMixin, \ - AutoProcessor +from transformers import (AutoModelForCausalLM, AutoProcessor, PreTrainedModel, + PreTrainedTokenizerBase, ProcessorMixin, modeling_utils) +from ..adapter.adapter import Adapter from ..nn_modules.hooked_linear import replace_linear_with_hooked_linear from ..nn_modules.qlinear import BaseQuantLinear +from ..nn_modules.qlinear.torch import TorchQuantLinear from ..quantization import GPTQ, QuantizeConfig from ..quantization.config import FORMAT, QUANTIZE_BLACK_LIST, AutoRoundQuantizeConfig from ..utils.backend import BACKEND from ..utils.data import collate_data from ..utils.device import get_cpu_usage_memory, get_gpu_usage_memory +from ..utils.hf import autofix_hf_model_config from ..utils.importer import select_quant_linear from ..utils.logger import setup_logger from ..utils.model import (MODALITY, check_to_quantized, find_modules, get_device, get_module, get_module_by_name_prefix, get_moe_layer_modules, move_to, nested_move_to, pack_model) from ..utils.progress import ProgressBar -from ..utils.torch import torch_empty_cache +from ..utils.torch import torch_compile, torch_empty_cache from ._const import CALIBRATION_DATASET_CONCAT_CHAR, CPU, DEFAULT_MAX_SHARD_SIZE, DEVICE, SUPPORTS_MODULE_TYPES from .loader import ModelLoader -from .writer import (QUANT_LOG_DAMP, QUANT_LOG_FWD_TIME, QUANT_LOG_LAYER, - QUANT_LOG_LOSS, QUANT_LOG_MODULE, QUANT_LOG_TIME, ModelWriter) +from .writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, + PROCESS_LOG_TIME, QUANT_LOG_DAMP, QUANT_LOG_LOSS, ModelWriter) # pytorch 2.6.0 fixes many compilation errors -PYTORCH_MIN_VERFSION_WITH_COMPILE = Version("2.6.0") +TORCH_MIN_VERSION_STR = '2.6.0' +PYTORCH_MIN_VERSION_WITH_COMPILE = Version(TORCH_MIN_VERSION_STR) def check_support_param_buffer_assignment(*args, **kwargs): return False @@ -77,6 +82,10 @@ class BaseGPTQModel(nn.Module): # for each repeating layer there are multiple modules within each layer layer_modules: List[List[str]] = None + # Strict=True -> all layer_modules must exists in model + # Some models (deepseek2-lite) dynamically create lora modules based on config.rank + layer_modules_strict = True + pre_lm_head_norm_module: str = None # some models require trust_remove_code = True (dbrx_converted) @@ -128,6 +137,7 @@ def __init__( super().__init__() self.model = model + self.compiled = False # set to True while compile() is triggered successfully self.quantized = quantized self.load_quantized_model = load_quantized_model @@ -139,10 +149,14 @@ def __init__( f"Unsupported `tokenizer` type: Expected `PreTrainedTokenizerBase`, actual = `{type(tokenizer)}`.") self.model.tokenizer = self.tokenizer.tokenizer # helpful for CI tests else: - self.tokenizer = tokenizer - self.model.tokenizer = tokenizer # helpful for CI tests + self.tokenizer = tokenizer # TODO none? + self.model.tokenizer = tokenizer # helpful for CI tests # TODO none? + + # auto-fix model config erors + if isinstance(self.model, PreTrainedModel): + autofix_hf_model_config(self.model, path=model_local_path) + self.quantize_config = quantize_config - self.config = self.model.config if hasattr(self.model, "config") else None # compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion self.qlinear_kernel = qlinear_kernel @@ -159,6 +173,22 @@ def __init__( if self.require_monkeypatch: self.monkey_patch() + # hack: circular import + from ..adapter.adapter import Lora + + # check adapter load and print info so users knows lora(s) are applied + if isinstance(self.quantize_config.adapter, Lora): + loaded_loras = 0 + qmodules = find_modules(self.model, layers=[BaseQuantLinear]) + for name, m in qmodules.items(): + if all(hasattr(m.adapter, name) for name in Lora.parameter_keys()): + loaded_loras += 1 + + logger.info(f"Adapter: `{loaded_loras}` EoRA/Lora adapters loaded for `{len(qmodules)}` modules.") + + # print kernel info: + logger.info(f"Kernel: loaded -> `[{', '.join(cls.__name__ for cls in self.kernels())}]`") + def prepare_dataset( self, calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[List[int]]], @@ -284,7 +314,6 @@ def _convert_tensor_to_list(tensor): return new_calibration_dataset_batched - @torch.no_grad() def quantize( self, calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]], @@ -299,7 +328,185 @@ def quantize( buffered_fwd: bool = False, # torch/cuda GC is auto enabled to reduce vram usage: disable to for small models or you know there is no possibility of oom due to vram to accelerate quantization auto_gc: bool = True, - ) -> List[Dict[str, str]]: + # eora adapter generation needs config Lora(rank=1, path='lora.safetensors') + adapter: Adapter = None, + adapter_calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]] = None, + ) -> Dict[str, List[Dict[str, str]]]: + if self.quantized: + raise EnvironmentError("quantize() is called a model that is already quantized") + + if self.quantize_config.quant_method in QUANTIZE_BLACK_LIST: + raise ValueError( + f"Unsupported quantization operation for quant method: {self.quantize_config.quant_method}" + ) + + if backend == BACKEND.IPEX: + self.quantize_config.format = FORMAT.IPEX + + if self.quantize_config.format == FORMAT.MARLIN: + raise ValueError( + "FORMAT.MARLIN is deprecated for quantization. Please switch to FORMAT.GPTQ. GPTQMOdel will auto-use Marlin kernel for accelerated inference for FORMAT.GPTQ." + ) + + # Validate quant linear before quantization starts + _ = select_quant_linear( + bits=self.quantize_config.bits, + dynamic=self.quantize_config.dynamic, + group_size=self.quantize_config.group_size, + desc_act=self.quantize_config.desc_act, + sym=self.quantize_config.sym, + backend=backend, + device=DEVICE(self.quantize_config.device), + pack=True, + format=self.quantize_config.format, + pack_dtype=self.quantize_config.pack_dtype, + ) + + # Use the provided tokenizer if one is passed to quantize() + if tokenizer is not None: + if isinstance(tokenizer, PreTrainedTokenizerBase): + # TODO FIX ME...this is a bug + self.tokenizer = Tokenicer.load(tokenizer, trust_remote_code=self.trust_remote_code) + else: + raise ValueError( + f"Unsupported `tokenizer` type: Expected `PreTrainedTokenizerBase`, actual = `{type(tokenizer)}`.") + + if self.quantize_config.format == FORMAT.BITBLAS: + from ..nn_modules.qlinear.bitblas import BITBLAS_AVAILABLE, BITBLAS_INSTALL_HINT + if BITBLAS_AVAILABLE is False: + raise ValueError(BITBLAS_INSTALL_HINT) + + # overwrite quantize_config.adapter + if adapter is not None: + self.quantize_config.adapter = adapter + + from gptqmodel.adapter.adapter import Lora + from gptqmodel.looper.eora_processor import EoraProcessor + from gptqmodel.looper.gptq_processor import GPTQProcessor + from gptqmodel.looper.module_looper import ModuleLooper + + # has lora process + needs_lora = isinstance(self.quantize_config.adapter, Lora) + + # init processor with default GPTQ processor + processors = [ + GPTQProcessor( + tokenizer=self.tokenizer, + qcfg=self.quantize_config, + calibration_dataset=calibration_dataset, + prepare_dataset_func=self.prepare_dataset, + calibration_dataset_concat_size=calibration_dataset_concat_size, + batch_size=batch_size, + logger_board=logger_board, + retain_w=needs_lora, # eora needs original w + ) + ] + + # Append EoRA processor for lora adapter + if needs_lora: + processors.append( + EoraProcessor( + tokenizer=self.tokenizer, + qcfg=self.quantize_config, + calibration_dataset=adapter_calibration_dataset, + prepare_dataset_func=self.prepare_dataset, + calibration_dataset_concat_size=calibration_dataset_concat_size, + batch_size=batch_size, + logger_board=logger_board, + ) + ) + + # prepare processor worker (looper) + module_looper = ModuleLooper(self, processors=processors) + + return module_looper.loop( + calibration_enable_gpu_cache=calibration_enable_gpu_cache, + buffered_fwd=buffered_fwd, + auto_gc=auto_gc, + backend=backend, + ) + + def _eora_generate( + self, + # eora adapter generation needs config Lora(rank=1, path='lora.safetensors') + adapter: Adapter, + quantized_modules: Dict[str, TorchQuantLinear], + calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]], + calibration_dataset_concat_size: Optional[int] = None, + batch_size: int = 1, + calibration_enable_gpu_cache: bool = True, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + logger_board: Optional[str] = None, + # Experimental: enables the buffering of fwd inputs to cpu, slower than non-buffered, may reduce vram usage + buffered_fwd: bool = False, + # torch/cuda GC is auto enabled to reduce vram usage: disable to for small models or you know there is no possibility of oom due to vram to accelerate quantization + auto_gc: bool = True, + ): + if self.quantized: + raise EnvironmentError("eora_generate() is called a model that is already quantized") + + # Use the provided tokenizer if one is passed to quantize() + if tokenizer is not None: + if isinstance(tokenizer, PreTrainedTokenizerBase): + # TODO FIX ME...this is a bug + self.tokenizer = Tokenicer.load(tokenizer, trust_remote_code=self.trust_remote_code) + else: + raise ValueError( + f"Unsupported `tokenizer` type: Expected `PreTrainedTokenizerBase`, actual = `{type(tokenizer)}`.") + + from gptqmodel.adapter.adapter import Lora + from gptqmodel.looper.dequantize_processor import DequantizeProcessor + from gptqmodel.looper.eora_processor import EoraProcessor + from gptqmodel.looper.module_looper import ModuleLooper + + self.quantize_config.adapter = adapter + + assert isinstance(self.quantize_config.adapter, Lora) + + # init processor with EoRA processor + processors = [ + DequantizeProcessor( + quantized_modules=quantized_modules, + ), + EoraProcessor( + tokenizer=self.tokenizer, + qcfg=self.quantize_config, + calibration_dataset=calibration_dataset, + prepare_dataset_func=self.prepare_dataset, + calibration_dataset_concat_size=calibration_dataset_concat_size, + batch_size=batch_size, + logger_board=logger_board, + ), + ] + + # prepare processor worker (looper) + module_looper = ModuleLooper(model=self, processors=processors) + + module_looper.loop( + calibration_enable_gpu_cache=calibration_enable_gpu_cache, + buffered_fwd=buffered_fwd, + auto_gc=auto_gc, + ) + + self.eora_save(eora_path=adapter.path) + return + + @torch.no_grad() + def quantize_old( + self, + calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]], + # Setting a fixed calibration_dataset_concat_size may improve the performance of the quantized model. + calibration_dataset_concat_size: Optional[int] = None, + batch_size: int = 1, + calibration_enable_gpu_cache: bool = True, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + logger_board: Optional[str] = None, + backend: Optional[BACKEND] = BACKEND.AUTO, + # Experimental: enables the buffering of fwd inputs to cpu, slower than non-buffered, may reduce vram usage + buffered_fwd: bool = False, + # torch/cuda GC is auto enabled to reduce vram usage: disable to for small models or you know there is no possibility of oom due to vram to accelerate quantization + auto_gc: bool = True, + ) -> Tuple[List[Dict[str, str]], Dict[str, torch.Tensor]]: if self.quantized: raise EnvironmentError("quantize() is called a model that is already quantized") @@ -467,7 +674,7 @@ def collate_batch(batch): self.qlinear_kernel = pack_model( model=self.model, - quantizers=quantizers, + quant_result=quantizers, bits=self.quantize_config.bits, dynamic=self.quantize_config.dynamic, group_size=self.quantize_config.group_size, @@ -483,7 +690,7 @@ def collate_batch(batch): return if self.quantize_config.lm_head: - if self.model.config.tie_word_embeddings and hasattr(self.model.model, "_tied_weights_keys"): + if self.model.config.tie_word_embeddings and hasattr(self.model, "_tied_weights_keys"): tied_keys = self.model._tied_weights_keys for item in tied_keys: if self.lm_head in item: @@ -524,34 +731,34 @@ def store_input_hook(_, args, kwargs): # Positional arguments. layer_input = [] for inp in args: - layer_input.append(move_to(inp, data_device)) + layer_input.append(move_to(inp, device=data_device)) if len(layer_input) == 0: # Some models put hidden_states in kwargs instead of args. # For example, gptj ... if kwargs.get("hidden_states") is not None: - layer_input.append(move_to(kwargs["hidden_states"], data_device)) + layer_input.append(move_to(kwargs["hidden_states"], device=data_device)) layer_inputs.append(layer_input) # Keyword arguments. if kwargs.get("attention_mask") is not None: - attention_masks.append(kwargs["attention_mask"].to(data_device)) + attention_masks.append(kwargs["attention_mask"].to(device=data_device)) else: attention_masks.append(None) pos_ids = kwargs.get("position_ids", None) if pos_ids is not None: - position_ids.append(move_to(pos_ids, data_device)) + position_ids.append(move_to(pos_ids, device=data_device)) one_kwargs = {} for (k, v) in kwargs.items(): # make sure other arguments also be captured if k not in ["hidden_states", "attention_mask", "position_ids"]: - one_kwargs[k] = nested_move_to(v, data_device) + one_kwargs[k] = nested_move_to(v, device=data_device) layer_input_kwargs.append(one_kwargs) raise ValueError # move layer to target device - layers[0] = layers[0].to(self.quantize_config.device) + layers[0] = layers[0].to(device=self.quantize_config.device) ori_outside_layer_module_devices = {} for module_name in self.base_modules: @@ -575,7 +782,7 @@ def store_input_hook(_, args, kwargs): for module_index in range(len(v)): if len(v[module_index].shape) == 1: v[module_index] = v[module_index].unsqueeze(0) - v[module_index] = move_to(v[module_index].to(torch.bfloat16) if is_ovis else v[module_index], data_device) + v[module_index] = move_to(v[module_index].to(self.model.visual_tokenizer.dtype) if is_ovis else v[module_index], data_device) else: if len(v.shape) == 1: v = v.unsqueeze(0) @@ -625,14 +832,15 @@ def store_input_hook(_, args, kwargs): # replace linear with hooked linear replace_linear_with_hooked_linear(self.model) + quantized_weights = {} for module_index in quant_modules_pb: is_lm_head_module = module_index >= layer_count if is_lm_head_module: - quant_modules_pb.set_description("Quantizing lm_head") + quant_modules_pb.info("Quantizing lm_head") module = get_module(self.model, key=self.lm_head) layer_inputs = self.lm_head_pre_quantize_generate_hook(layer_inputs) else: - quant_modules_pb.set_description(f"Quantizing layer {module_index} of {layer_count - 1}") + quant_modules_pb.info(f"Quantizing layer {module_index} of {layer_count - 1}") module = layers[module_index] if module.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower(): @@ -667,9 +875,7 @@ def store_input_hook(_, args, kwargs): skipped_modules = [] gptq = {} for name in subset: - bits = self.quantize_config.bits - sym = self.quantize_config.sym - mse = self.quantize_config.mse + qcfg_clone = copy.deepcopy(self.quantize_config) # dynamic overrides if self.quantize_config.dynamic is not None: @@ -681,11 +887,15 @@ def store_input_hook(_, args, kwargs): skipped_modules.append(name) continue - bits = self.quantize_config.dynamic_get(layer_name, "bits", bits) - sym = self.quantize_config.dynamic_get(layer_name, "sym", sym) - mse = self.quantize_config.dynamic_get(layer_name, "mse", mse) + qcfg_clone.bits = self.quantize_config.dynamic_get(layer_name, "bits", qcfg_clone.bits) + qcfg_clone.sym = self.quantize_config.dynamic_get(layer_name, "sym", qcfg_clone.sym) + qcfg_clone.mse = self.quantize_config.dynamic_get(layer_name, "mse", qcfg_clone.mse) + qcfg_clone.group_size = self.quantize_config.dynamic_get(layer_name, "group_size", qcfg_clone.group_size) + qcfg_clone.desc_act = self.quantize_config.dynamic_get(layer_name, "desc_act", qcfg_clone.desc_act) + qcfg_clone.damp_percent = self.quantize_config.dynamic_get(layer_name, "damp_percent", qcfg_clone.damp_percent) + qcfg_clone.static_groups = self.quantize_config.dynamic_get(layer_name, "static_groups", qcfg_clone.static_groups) - tmp = GPTQ(subset[name]) + tmp = GPTQ(module=subset[name], qcfg=qcfg_clone) gptq[name] = tmp # models like DeepSeek v3/r1 has > 256 $ of sub-modules per layer @@ -698,10 +908,7 @@ def store_input_hook(_, args, kwargs): tmp.fwd_inputs_buffered = True tmp.quantizer.configure( - bits, perchannel=True, - sym=sym, - mse=mse, ) for name in skipped_modules: @@ -774,28 +981,18 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): for name_index, name in enumerate(subset): layer_name = self.lm_head if is_lm_head_module else f"{self.layers_node}.{module_index}.{name}" - quant_modules_pb.set_description(f"Quantizing {name} in layer {module_index} of {layer_count - 1}") + quant_modules_pb.info(f"Quantizing {name} in layer {module_index} of {layer_count - 1}") - group_size = self.quantize_config.group_size - desc_act = self.quantize_config.desc_act - damp_percent = self.quantize_config.damp_percent - static_groups = self.quantize_config.static_groups + # logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}") + ## Need to return the quantized_weight for offloading + quantized_weight, scale, zero, g_idx, duration, avg_loss, damp_percent = gptq[name].quantize() - # dynamic overrides - if self.quantize_config.dynamic is not None: - group_size = self.quantize_config.dynamic_get(layer_name, "group_size", group_size) - desc_act = self.quantize_config.dynamic_get(layer_name, "desc_act", desc_act) - damp_percent = self.quantize_config.dynamic_get(layer_name, "damp_percent", damp_percent) - static_groups = self.quantize_config.dynamic_get(layer_name, "static_groups", static_groups) + ## Assign the quantized weight to the weight + gptq[name].module.weight.data = quantized_weight.to(device=gptq[name].device) + ## Offload the quantized weight to CPU for EoRA + quantized_weights['model.layers.%d.%s' % (module_index, name)] = quantized_weight.cpu() - # logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}") - scale, zero, g_idx, duration, avg_loss, damp_percent = gptq[name].quantize( - percdamp=damp_percent, - group_size=group_size, - actorder=desc_act, - static_groups=static_groups, - ) if task is not None: task.get_logger().report_scalar( title='Quantization Loss', @@ -814,8 +1011,8 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): avg_losses.append(avg_loss) module_names.append(f"layer-{module_index}-{name}") - stat = {QUANT_LOG_LAYER: module_index, QUANT_LOG_MODULE: name, QUANT_LOG_LOSS: f"{avg_loss:.5f}", - QUANT_LOG_DAMP: f"{damp_percent:.5f}", QUANT_LOG_TIME: f"{duration:.3f}", QUANT_LOG_FWD_TIME: f"{fwd_time:.3f}"} + stat = {PROCESS_LOG_LAYER: module_index, PROCESS_LOG_MODULE: name, QUANT_LOG_LOSS: f"{avg_loss:.5f}", + QUANT_LOG_DAMP: f"{damp_percent:.5f}", PROCESS_LOG_TIME: f"{duration:.3f}", PROCESS_LOG_FWD_TIME: f"{fwd_time:.3f}"} if self.quantize_config.dynamic is not None: stat["dynamic"] = self.quantize_config.dynamic_get(layer_name=layer_name) @@ -899,7 +1096,7 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): self.qlinear_kernel = pack_model( model=self.model, - quantizers=quantizers, + quant_result=quantizers, bits=self.quantize_config.bits, group_size=self.quantize_config.group_size, backend=backend, @@ -917,7 +1114,8 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): if auto_gc: torch_empty_cache() - return self.quant_log + ## need to return quantized_weight for EoRA + return self.quant_log, quantized_weights def to(self, device: Union[str, torch.device]): if hasattr(self.model, "to"): @@ -931,8 +1129,15 @@ def forward(self, *args, **kwargs): def generate(self, inputs=None, **kwargs): with torch.inference_mode(): + # fix hf generate not applying correct pad token + pad_token_id = kwargs.get("pad_token_id", None) + if pad_token_id is None and self.tokenizer: + kwargs["pad_token_id"] = self.tokenizer.pad_token_id + if isinstance(inputs, str) or (isinstance(inputs, list) and all(isinstance(x, str) for x in inputs)): - inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.model.device) + if self.tokenizer is None: + raise ValueError("You passed in an `input` to `generate()` of type `str` but model is missing `model.tokenizer`. Please set `model.tokenizer = my_tokenizer`.") + inputs = self.tokenizer(inputs, return_tensors="pt", padding=True, padding_side="left").to(self.model.device) return self.model.generate(**inputs, **kwargs) return self.model.generate(inputs=inputs, **kwargs) @@ -957,13 +1162,19 @@ def save( safetensors_metadata: Optional[Dict[str, str]] = None, max_shard_size: Optional[Union[int, str]] = DEFAULT_MAX_SHARD_SIZE, meta_quantizer: Optional[str] = None, + eora_path: Optional[str] = None, **kwargs, ): if self.quantized: # Safetensors is unable to save tied weights, so we untie them here. Reference: https://github.com/huggingface/safetensors/issues/202 #untie_weights(self.model) - self.save_quantized(save_dir, safetensors_metadata, max_shard_size, meta_quantizer) + self.save_quantized( + save_dir=save_dir, + safetensors_metadata=safetensors_metadata, + max_shard_size=max_shard_size, + meta_quantizer=meta_quantizer, + eora_path=eora_path) # overwrite quant_override_files for name, value in self.quant_override_files.items(): @@ -974,39 +1185,59 @@ def save( else: f.write(json.dumps(value)) else: - self.save_pretrained(save_dir, **kwargs) + self.save_pretrained(save_dir=save_dir, **kwargs) + + + # returns all the loaded qlinear types, returns empty [] if non-found + def kernels(self) -> List[Type[BaseQuantLinear]]: + if not isinstance(self.model, nn.Module): + return [] + loaded_kernels = set() + modules = find_modules(self.model, layers=[BaseQuantLinear]) + for k, v in modules.items(): + loaded_kernels.add(v.__class__) + + return list(loaded_kernels) - def compile(self, backend="inductor", mode="max-autotune"): + def compile(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False): + logger.warn("Deprecation: `model.compile()` is deprecated. Please use `model.optimize()` instead.") + return self.optimize(backend=backend, mode=mode, fullgraph=fullgraph) + + def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False): if not self.quantized: logger.warning("model is not quantized, skip compiling...") return self - if Version(torch.__version__) < PYTORCH_MIN_VERFSION_WITH_COMPILE: + if Version(torch.__version__) < PYTORCH_MIN_VERSION_WITH_COMPILE: self.compiled = False - logger.warning("To use compile(), you need to have torch version >= 2.5.1, please upgrade it by `pip install torch -U`") + logger.warning(f"To use compile(), you need to have torch version >= {TORCH_MIN_VERSION_STR}, please " + f"upgrade it by `pip install torch -U`") return self + # needed by eora + # torch._dynamo.config.capture_scalar_outputs = True + + logger.info(f"Compiling qlinear modules with backend: `{backend}`, mode: `{mode}`") + modules = find_modules(self.model, layers=[BaseQuantLinear]) + for name in modules.keys(): + modules[name].optimize(fullgraph=False, backend=backend, mode=mode) + # supress errors until PyTorch fixed: https://github.com/pytorch/pytorch/issues/132635 - #torch._dynamo.config.suppress_errors = True + # torch._dynamo.config.suppress_errors = True logger.info(f"Compiling model with backend: `{backend}`, mode: `{mode}`") - try: - self.model = torch.compile(self.model, fullgraph=True, backend=backend, mode=mode) - self.compiled = True - except Exception as e: - logger.info(f"Compiling model again with `fullgraph=False`; `full-graph=True` compile failed: {e}") - try: - self.model = torch.compile(self.model, fullgraph=False, backend=backend, mode=mode) - self.compiled = True - except Exception as e: - self.compiled = False - logger.info(f"Compiling model failed: running model in non-compiled mode. {e}") - - # trigger kernel compilation hooks - if self.compiled: - modules = find_modules(self.model, layers=[BaseQuantLinear]) - for name in modules.keys(): - modules[name].compile() + self.model = torch_compile(self.model, fullgraph=fullgraph, backend=backend, mode=mode) + + #trigger kernel compilation hooks + # if self.compiled: + # modules = find_modules(self.model, layers=[BaseQuantLinear]) + # for name in modules.keys(): + # modules[name].optimize(fullgraph=False, backend=backend, mode=mode) + + # logger.info(f"Compiling qlinear modules with backend: `{backend}`, mode: `{mode}`") + # modules = find_modules(self.model, layers=[BaseQuantLinear]) + # for name in modules.keys(): + # modules[name].optimize(fullgraph=False, backend=backend, mode=mode) return self @@ -1046,11 +1277,11 @@ def lm_head_pre_quantize_generate_hook(self, inputs: List[List[torch.tensor]]) - def pre_quantize(self, module: nn.Module) -> nn.Module: if get_device(module) == CPU and self.quantize_config.device != CPU: - return move_to(module, self.quantize_config.device) + return move_to(module, device=self.quantize_config.device) return module def post_quantize(self, module: nn.Module) -> nn.Module: - return move_to(module, CPU) + return move_to(module, device=CPU) def __getattr__(self, item): try: diff --git a/gptqmodel/models/definitions/deepseek_v2.py b/gptqmodel/models/definitions/deepseek_v2.py index 1a48503b7..4c10ed4e1 100644 --- a/gptqmodel/models/definitions/deepseek_v2.py +++ b/gptqmodel/models/definitions/deepseek_v2.py @@ -33,15 +33,22 @@ class DeepSeekV2GPTQ(BaseGPTQModel): layers_node = "model.layers" layer_type = "DeepseekV2DecoderLayer" + # DeepSeek V2-Lite uses dynamic modules based on lora(rank): + # https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py#L712 + layer_modules_strict = False + # DeepSeek-V2 uses 160 experts, v2-lite is auto-switched during __init__ layer_modules = [ # DeepSeek-V2 and DeepSeek-V2-Lite use same model_type, but different self_attn # so we provide different layer_modules usage. # DeepSeek-V2-Lite usage - ["self_attn.q_proj", "self_attn.kv_a_proj_with_mqa", "self_attn.kv_b_proj"], + #["self_attn.q_proj", "self_attn.kv_a_proj_with_mqa", "self_attn.kv_b_proj"], # DeepSeek-V2 usage, included in layer 0-59 - ["self_attn.q_a_proj", "self_attn.q_b_proj", "self_attn.kv_a_proj_with_mqa", "self_attn.kv_b_proj"], + #["self_attn.q_a_proj", "self_attn.q_b_proj", "self_attn.kv_a_proj_with_mqa", "self_attn.kv_b_proj"], + + # merged v2-lite and v2 + ["self_attn.q_a_proj", "self_attn.q_b_proj", "self_attn.q_proj", "self_attn.kv_a_proj_with_mqa", "self_attn.kv_b_proj"], ["self_attn.o_proj"], diff --git a/gptqmodel/models/definitions/deepseek_v3.py b/gptqmodel/models/definitions/deepseek_v3.py index 768505391..0d32227e7 100644 --- a/gptqmodel/models/definitions/deepseek_v3.py +++ b/gptqmodel/models/definitions/deepseek_v3.py @@ -34,6 +34,9 @@ class DeepSeekV3GPTQ(BaseGPTQModel): layers_node = "model.layers" layer_type = "DeepseekV3DecoderLayer" + # DeepSeek V3 uses dynamic modules based on lora(rank): + layer_modules_strict = False + layer_modules = [ ["self_attn.q_a_proj", "self_attn.q_b_proj", "self_attn.kv_a_proj_with_mqa", "self_attn.kv_b_proj"], diff --git a/gptqmodel/models/definitions/ovis.py b/gptqmodel/models/definitions/ovis.py index b99cb4aa7..9d2a5f1e9 100644 --- a/gptqmodel/models/definitions/ovis.py +++ b/gptqmodel/models/definitions/ovis.py @@ -40,17 +40,27 @@ class OvisGPTQ(BaseGPTQModel): ["mlp.down_proj"], ] + require_monkeypatch = True + modality = [MODALITY.IMAGE_TO_TEXT] IGNORE_ID = -100 + def monkey_patch(self): + # From config.json, we know that visual_tokenizer.dtype is float32 and text model.confi.dtype is bfloat16. + # But before transformers<4.49.0, the dtype returned by AutoModel.from_config(config.visual_tokenizer_config) + # is bfloat16. This should be a bug, but OVIS generate() unexpectedly works properly. + # This bug was fixed in transformers 4.49.0. So visual_tokenizer needs to be converted to model.config.dtype + self.model.visual_tokenizer = self.model.visual_tokenizer.to(dtype=self.model.llm.dtype) + self.model.vte = self.model.vte.to(dtype=self.model.llm.dtype) + def pre_quantize_generate_hook_start(self): - self.model.visual_tokenizer = move_to(self.model.visual_tokenizer, self.quantize_config.device) - self.model.vte = move_to(self.model.vte, self.quantize_config.device) + self.model.visual_tokenizer = move_to(self.model.visual_tokenizer, device=self.quantize_config.device) + self.model.vte = move_to(self.model.vte, device=self.quantize_config.device) def pre_quantize_generate_hook_end(self): - self.model.visual_tokenizer = move_to(self.model.visual_tokenizer, CPU) - self.model.vte = move_to(self.model.vte, CPU) + self.model.visual_tokenizer = move_to(self.model.visual_tokenizer, device=CPU) + self.model.vte = move_to(self.model.vte, device=CPU) def preprocess_dataset(self, sample: Dict) -> Dict: text_max_length = 832 diff --git a/gptqmodel/models/definitions/qwen2_vl.py b/gptqmodel/models/definitions/qwen2_vl.py index ef32abd0b..14c58dc18 100644 --- a/gptqmodel/models/definitions/qwen2_vl.py +++ b/gptqmodel/models/definitions/qwen2_vl.py @@ -81,10 +81,10 @@ class Qwen2VLGPTQ(BaseGPTQModel): } def pre_quantize_generate_hook_start(self): - self.model.visual = move_to(self.model.visual, self.quantize_config.device) + self.model.visual = move_to(self.model.visual, device=self.quantize_config.device) def pre_quantize_generate_hook_end(self): - self.model.visual = move_to(self.model.visual, CPU) + self.model.visual = move_to(self.model.visual, device=CPU) @staticmethod def process_vision_info( diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index a9ad0398e..de39ed66e 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -23,6 +23,7 @@ import torch import transformers + if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']: try: from modelscope import snapshot_download @@ -30,6 +31,9 @@ raise ModuleNotFoundError("env `GPTQMODEL_USE_MODELSCOPE` used but modelscope pkg is not found: please install with `pip install modelscope`.") else: from huggingface_hub import snapshot_download + +from gptqmodel.adapter.adapter import Adapter +from huggingface_hub import snapshot_download from packaging.version import InvalidVersion, Version from transformers import AutoConfig, AutoTokenizer, PretrainedConfig from transformers.modeling_utils import no_init_weights @@ -189,7 +193,7 @@ def skip(*args, **kwargs): model.seqlen = model_config[key] break else: - logger.warning("can't get model's sequence length from model config, will set to 4096.") + logger.warning("Model: can't get model's sequence length from model config, will set to 4096.") model.seqlen = 4096 model.eval() @@ -213,6 +217,7 @@ def from_quantized( device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, device: Optional[Union[str, int]] = None, backend: Union[str, BACKEND] = BACKEND.AUTO, + adapter: Optional[Adapter] = None, torch_dtype: [str | torch.dtype] = "auto", trust_remote_code: bool = False, verify_hash: Optional[Union[str, List[str]]] = None, @@ -291,6 +296,10 @@ def from_quantized( qcfg = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs) + # inject adapter into qcfg + if adapter is not None: + qcfg.adapter = adapter + qcfg.calculate_bits_per_weight() if backend == BACKEND.VLLM or backend == BACKEND.SGLANG: @@ -403,8 +412,17 @@ def skip(*args, **kwargs): init_contexts = [no_init_weights()] with ContextManagers(init_contexts): + if config.architectures: + model_class = getattr(transformers, config.architectures[0], None) + if model_class is not None and hasattr(model_class, "_supports_flash_attn_2"): + supports_flash_attn = model_class._supports_flash_attn_2 + else: + supports_flash_attn = None + else: + supports_flash_attn = None + args = {} - if device in [DEVICE.CUDA, DEVICE.ROCM]: + if supports_flash_attn and device in [DEVICE.CUDA, DEVICE.ROCM]: if ATTN_IMPLEMENTATION in kwargs: args[ATTN_IMPLEMENTATION] = kwargs.pop(ATTN_IMPLEMENTATION, None) if USE_FLASH_ATTENTION_2 in kwargs: @@ -439,25 +457,20 @@ def skip(*args, **kwargs): if any(name.startswith(ignore_module) for ignore_module in ignore_modules) or all( not name.endswith(ignore_module) for sublist in cls.layer_modules for ignore_module in sublist ): - # log non-lm-head quantizerd modules only + # log non-lm-head quantized modules only if name is not cls.lm_head: logger.info(f"The layer {name} is not quantized.") del modules[name] preload_qlinear_kernel = make_quant( model, - modules, - qcfg.bits, - qcfg.group_size, + quant_result=modules, + qcfg=qcfg, backend=backend, - format=qcfg.format, lm_head_name=cls.lm_head, - desc_act=qcfg.desc_act, - sym=qcfg.sym, - dynamic=qcfg.dynamic, device=device, - pack_dtype=qcfg.pack_dtype, ) + if preload_qlinear_kernel == IPEXQuantLinear: qcfg.runtime_format = FORMAT.IPEX @@ -476,17 +489,17 @@ def skip(*args, **kwargs): # validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase if not qcfg.sym and not qcfg.is_quantized_by_v2(): raise ValueError( - f"Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}" + f"Format: Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}" ) t = time.time() - logger.info(f"Converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.") + logger.info(f"Format: Converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.") model = convert_gptq_v1_to_v2_format( model, cfg=qcfg, qlinear_kernel=preload_qlinear_kernel, ) - logger.info(f"Conversion complete: {time.time() - t}s") + logger.info(f"Format: Conversion complete: {time.time() - t}s") load_checkpoint_in_model = False qcfg.runtime_format = FORMAT.GPTQ_V2 @@ -495,11 +508,11 @@ def skip(*args, **kwargs): preload_qlinear_kernel == ExllamaV2QuantLinear or qcfg.format == FORMAT.MARLIN): if is_sharded: raise ValueError( - "The loading of sharded checkpoints with Marlin is currently not supported." + "Format: The loading of sharded checkpoints with Marlin is currently not supported." ) if not _validate_marlin_device_support(): raise ValueError( - f'Marlin kernel does not support this gpu with compute capability of `{torch.cuda.get_device_capability()}`. Please do not use `back=BACKEND.MARLIN`.' + f'Kernel: Marlin kernel does not support this gpu with compute capability of `{torch.cuda.get_device_capability()}`. Please do not use `back=BACKEND.MARLIN`.' ) # Validate the model can run in Marlin. @@ -600,7 +613,7 @@ def skip(*args, **kwargs): ) with tempfile.TemporaryDirectory() as temp_dir: - mlx_weights, mlx_config = convert_gptq_to_mlx_weights(model_id_or_path, model, qcfg.to_dict()) + mlx_weights, mlx_config = convert_gptq_to_mlx_weights(model_id_or_path, model, qcfg.to_dict(), cls.lm_head) save_weights(temp_dir, mlx_weights, donate_weights=True) save_config(mlx_config, config_path=temp_dir + "/config.json") diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index d350de3e3..ee2e88d7d 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -28,6 +28,7 @@ import transformers from huggingface_hub import split_torch_state_dict_into_shards from huggingface_hub.constants import SAFETENSORS_WEIGHTS_FILE_PATTERN +from safetensors.torch import save_file from safetensors.torch import save_file as safe_save from transformers import AutoConfig, PreTrainedTokenizerFast, ProcessorMixin from transformers.modeling_utils import no_init_weights @@ -48,15 +49,17 @@ logger = setup_logger() -QUANT_LOG_LAYER = "layer" -QUANT_LOG_MODULE = "module" +PROCESS_LOG_NAME = "process" +PROCESS_LOG_LAYER = "layer" +PROCESS_LOG_MODULE = "module" QUANT_LOG_LOSS = "loss" QUANT_LOG_DAMP = "damp" -QUANT_LOG_TIME = "time" -QUANT_LOG_FWD_TIME = "fwd_time" +PROCESS_LOG_TIME = "time" +PROCESS_LOG_FWD_TIME = "fwd_time" -def ModelWriter(cls): +EORA_DEFAULT_FILE = "eora.safetensors" +def ModelWriter(cls): def save_pretrained( self, save_dir: str, @@ -67,12 +70,46 @@ def save_pretrained( cls.save_pretrained = save_pretrained + def eora_save(self, eora_path: str): + # save lora tensors + if hasattr(self, 'lora_results'): # hack: TODO + weights = {} + + # convert the dict into safetensors compatible dict + for key, d in self.lora_results.items(): + # must normalize key since HF can load weights as `model.` or not based on what AutoModel is used + key = key.lower().removeprefix("model.") + for lora_key, lora_weight in d.items(): + if isinstance(lora_weight, torch.Tensor): + weights[f"{key}.{lora_key}"] = lora_weight + logger.info(f"lora weight: `{key}.{lora_key}`") + + # then lora_path from `save()` then lora.path + eora_path = eora_path if eora_path else self.quantize_config.adapter.path + + if not eora_path: + raise ValueError(f"Invalid EoRA lora path: actual = `{eora_path}`") + + is_file = eora_path.endswith(".safetensors") + + if not is_file: + eora_path = f"{eora_path}/eora.safetensors" + + logger.info(f"Found EoRA lora weights: saving to {eora_path}") + + os.makedirs(os.path.dirname(eora_path), exist_ok=True) + + save_file(tensors=weights, filename=eora_path, metadata={"format": "pt"}) + + cls.eora_save = eora_save + def save_quantized( self, save_dir: str, safetensors_metadata: Optional[Dict[str, str]] = None, max_shard_size: Optional[Union[int, str]] = DEFAULT_MAX_SHARD_SIZE, meta_quantizer: Optional[str] = None, + eora_path: Optional[str] = None, ): """save quantized model and configs to local disk""" os.makedirs(save_dir, exist_ok=True) @@ -80,9 +117,9 @@ def save_quantized( if self.quant_log: with open(os.path.join(save_dir, "quant_log.csv"), mode='w', newline='') as file: w = csv.writer(file) - w.writerow([QUANT_LOG_LAYER, QUANT_LOG_MODULE, QUANT_LOG_LOSS, QUANT_LOG_DAMP, QUANT_LOG_TIME]) - w.writerows([[entry.get(QUANT_LOG_LAYER), entry.get(QUANT_LOG_MODULE), entry.get(QUANT_LOG_LOSS), - entry.get(QUANT_LOG_DAMP), entry.get(QUANT_LOG_TIME)] for entry in self.quant_log]) + w.writerow([PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, QUANT_LOG_LOSS, QUANT_LOG_DAMP, PROCESS_LOG_TIME]) + w.writerows([[entry.get(PROCESS_LOG_LAYER), entry.get(PROCESS_LOG_MODULE), entry.get(QUANT_LOG_LOSS), + entry.get(QUANT_LOG_DAMP), entry.get(PROCESS_LOG_TIME)] for entry in self.quant_log]) pre_quantized_size_mb = get_model_files_size(self.model_local_path) pre_quantized_size_gb = pre_quantized_size_mb / 1024 @@ -130,7 +167,6 @@ def save_quantized( value=self.quantize_config.mse ) - # The config, quantize_config and model may be edited in place in save_quantized. config = copy.deepcopy(self.model.config) quantize_config = copy.deepcopy(self.quantize_config) @@ -180,10 +216,32 @@ def save_quantized( self.model.config = config # Save model config, including generation_config - # use empty state_dict hack to bypass saving weights - self.model.save_pretrained(save_dir, state_dict={}) + # Use empty state_dict hack to bypass saving weights + self.model.save_pretrained(save_dir, state_dict={}, is_main_process=True) + + # Save `quantize_config.json` quantize_config.save_pretrained(save_dir) + def debug_saved_config(path): + # List all files in the directory + files = os.listdir(path) + print("Files in directory:") + for file in files: + print(file) + + config_file_paths = ["generation_config.json", "config.json"] + for file_name in config_file_paths: + full_path = os.path.join(path, file_name) + if os.path.isfile(full_path): + print(f"Content of saved `{file_name}`:") + with open(full_path, 'r') as config_file: + config_data = json.load(config_file) + print(json.dumps(config_data, indent=4)) + else: + print(f"`{file_name}` does not exist in the directory.") + + debug_saved_config(save_dir) + # Save processor related config files. For example: preprocessor_config.json, chat_template.json if hasattr(self,"processor") and isinstance(self.processor, ProcessorMixin): self.processor.save_pretrained(save_dir) @@ -309,6 +367,9 @@ def save_quantized( content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) + # save lora + eora_save(self, eora_path=eora_path) + # If the saved model is a loaded quantized model, do not calculate the size diff. if not self.load_quantized_model: total_size_gb = total_size_mb / 1024 @@ -319,7 +380,6 @@ def save_quantized( logger.info(f"Quantized model size: {total_size_mb:.2f}MB, {total_size_gb:.2f}GB") logger.info(f"Size difference: {size_diff_mb:.2f}MB, {size_diff_gb:.2f}GB - {percent_diff:.2f}%") - # need to copy .py files for model/tokenizers not yet merged to HF transformers if self.trust_remote_code: copy_py_files(save_dir, model_id_or_path=self.model_local_path) @@ -384,15 +444,11 @@ def skip(*args, **kwargs): make_quant( model, - modules, - qcfg.bits, - qcfg.group_size, + quant_result=modules, + qcfg=qcfg, backend=BACKEND.AUTO, - format=qcfg.format, lm_head_name=cls.lm_head, - desc_act=qcfg.desc_act, pack=True, - pack_dtype=qcfg.pack_dtype, ) load_checkpoint_in_model_then_tie_weights( diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 3cffb7a0a..9f94f9488 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import math import sys from typing import List, Optional, Tuple @@ -21,6 +22,7 @@ import torch as t # conflict with torch.py import torch.nn as nn import transformers +from gptqmodel.adapter.adapter import LORA_MERGED_WEIGHT_PATHS, Adapter from ...models._const import DEVICE, PLATFORM @@ -37,6 +39,7 @@ class BaseQuantLinear(nn.Module): SUPPORTS_OUT_FEATURES_DIVISIBLE_BY: List[int] = None SUPPORTS_PACK_DTYPES: List[t.dtype] = None + SUPPORTS_ADAPTERS: List[Adapter] = None SUPPORTS_DEVICES: List[DEVICE] = None SUPPORTS_PLATFORM: List[PLATFORM] = None @@ -49,12 +52,16 @@ def __init__(self, out_features: int, bias: bool, pack_dtype: t.dtype, + adapter: Adapter, + name: str = None, register_buffers: bool = False, register_buffers_in_features: int = None, register_buffers_out_features: int = None, **kwargs): super().__init__() - + if name is None: + name = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + self.name = name # full path module name in model weights self.in_features = in_features self.out_features = out_features self.group_size = group_size if group_size != -1 else in_features @@ -63,7 +70,11 @@ def __init__(self, self.pack_dtype = pack_dtype self.maxq = 2 ** self.bits - 1 self.pack_dtype = pack_dtype + # we need to clone the adapter since passed in adapter may be shared + # adapter tensors are lodaed inside adapter so they must be unique per module + self.adapter = copy.deepcopy(adapter) + self.optimized = False if self.pack_dtype == t.int8: self.pack_dtype_bits = 8 @@ -126,6 +137,47 @@ def __init__(self, else: self.bias = None + # load adapter if any + if adapter is not None: + if adapter.path in LORA_MERGED_WEIGHT_PATHS: + print(f"Adapter (merged weights) lazy init: {self.adapter.name()}: {self.adapter}, module: {self.name}") + + # pre allocate buffers so accelerate can auto-bind merged weights in same tensor file as model + self.register_buffer( + "lora_A", + t.zeros((in_features, adapter.rank), dtype=t.float16), + ) + + self.register_buffer( + "lora_B", + t.zeros((adapter.rank, out_features), dtype=t.float16), + ) + else: + pass + # print(f"Adapter lazy init: {self.adapter.name()}: {self.adapter}, module: {self.name}") + + # TDOO: allow merged lora weights exist in gptq model safetensor file for direct loading + # EoRA need to preallocate buffers for Lora_A and B weights so HF can load + # self.register_buffer( + # "lora_A", + # torch.zeros((in_features, 128), dtype=torch.float16), # <-- EoRA lora_A shape needs to be calculated using pass in_features/out_features or other eora_test math + # ) + # + # # EoRA need to preallocate buffers for Lora_A and B weights so HF can load + # self.register_buffer( + # "lora_B", + # torch.zeros((128, out_features), dtype=torch.float16), # <-- EoRA lora_A shape needs to be calculated using pass in_features/out_features or other eora_test math + # ) + + # override me, to perform post-weight load to device init + def post_init(self): + if self.adapter is not None: + self.adapter.post_init( + weight_key=self.name, + device=self.qweight.device, + lora_A=getattr(self, "lora_A", None), + lora_B=getattr(self, "lora_B", None)) + @classmethod # custom quant linear class can override this and add custom checks def validate( @@ -139,11 +191,13 @@ def validate( pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, - trainable:Optional[bool]=None) -> Tuple[ + trainable:Optional[bool]=None, + adapter:Optional[Adapter]=None, + ) -> Tuple[ bool, Optional[Exception]]: return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, - in_features=in_features, out_features=out_features, pack_dtype=pack_dtype, - dynamic=dynamic, device=device, trainable=trainable) + in_features=in_features, out_features=out_features, pack_dtype=pack_dtype, + dynamic=dynamic, device=device, trainable=trainable, adapter=adapter) @classmethod # internal method and should not be overriden @@ -175,14 +229,21 @@ def verify_supports_params(cls): for name, value in child_supports_variables: if not name.startswith("SUPPORTS") or callable(value): continue - if value is None or (isinstance(value, list) and not value): - raise ValueError(f"{cls.__name__}.{name} cannot be None or an empty list.") + if value is None: + raise ValueError(f"{cls.__name__}.{name} cannot be None.") + + # if isinstance(value, list) and not value: + # raise ValueError(f"{cls.__name__}.{name} cannot be an empty list.") @classmethod def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, in_features:int=None, - out_features:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[bool, Optional[Exception]]: + out_features:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None, adapter:Optional[Adapter]=None) -> Tuple[bool, Optional[Exception]]: cls.verify_supports_params() + if adapter is not None and adapter.__class__ not in cls.SUPPORTS_ADAPTERS: + err = f"{cls} does not support adapter: {adapter}" + return False, NotImplementedError(err) + if pack_dtype not in cls.SUPPORTS_PACK_DTYPES: err = f"{cls} does not support `pack_dtype`: {pack_dtype}" return False, NotImplementedError(err) @@ -276,61 +337,134 @@ def validate_device(cls, device: DEVICE): if device not in cls.SUPPORTS_DEVICES: raise NotImplementedError(f"{cls} only supports `{cls.SUPPORTS_DEVICES}`: actual device = `{device}`") - # override me, to perform post-weight load to device init - def post_init(self): - pass - + # use optimize so we don't override native module.compile() # override me, to perform any torch.compile logic on the kernel pre forward - def compile(self): + def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False): + self.optimized = True pass class PackableQuantLinear(BaseQuantLinear): - def pack(self, linear, scales, zeros, g_idx=None): + def post_init(self, **kwargs): + super().post_init(**kwargs) + + if self.bits in [2, 4, 8]: + wf = t.tensor(list(range(0, self.pack_dtype_bits, self.bits)), dtype=t.int32).unsqueeze(0).to( + device=self.g_idx.device) + elif self.bits == 3: + wf = t.tensor( + [ + [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], + [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], + [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0], + ], + dtype=t.int32, + ).reshape(1, 3, 12).to(device=self.g_idx.device) + + # self.register_buffer("wf_unsqueeze_zero", wf.unsqueeze(0).to(device=self.g_idx.device)) + # self.register_buffer("wf_unsqueeze_neg_one", wf.unsqueeze(-1).to(device=self.g_idx.device)) + # + self.wf_unsqueeze_zero = wf.unsqueeze(0).to(device=self.g_idx.device) + self.wf_unsqueeze_neg_one = wf.unsqueeze(-1).to(device=self.g_idx.device) + + def dequantize_weight(self, num_itr: int = 1): + if self.bits in [2, 4, 8]: + zeros = t.bitwise_right_shift( + t.unsqueeze(self.qzeros, 2).expand(-1, -1, self.pack_factor), + self.wf_unsqueeze_zero # self.wf.unsqueeze(0), + ).to(self.dequant_dtype) + zeros = t.bitwise_and(zeros, self.maxq).reshape(self.scales.shape) + + weight = t.bitwise_and( + t.bitwise_right_shift( + t.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1), + self.wf_unsqueeze_neg_one # self.wf.unsqueeze(-1) + ).to(self.dequant_dtype), + self.maxq + ) + elif self.bits == 3: + zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1).expand( + -1, -1, -1, 12 + ) + zeros = zeros >> self.wf_unsqueeze_zero # self.wf.unsqueeze(0) + zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4) + zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6) + zeros = zeros & 0x7 + zeros = t.cat( + [zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], + dim=2, + ).reshape(self.scales.shape) + + weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand( + -1, -1, 12, -1 + ) + weight = (weight >> self.wf_unsqueeze_neg_one) & 0x7 # self.wf.unsqueeze(-1) + weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4) + weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6) + weight = weight & 0x7 + weight = t.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1) + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + if num_itr == 1: + weights = self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()]) + else: + num_dim = self.g_idx.shape[0] // num_itr + weights = [] + for i in range(num_itr): + scale_i = self.scales[:, i * num_dim: (i + 1) * num_dim] + weight_i = weight[:, i * num_dim: (i + 1) * num_dim] + zeros_i = zeros[:, i * num_dim: (i + 1) * num_dim] + g_idx_i = self.g_idx[i * num_dim: (i + 1) * num_dim].long() + weights.append(scale_i[g_idx_i] * (weight_i - zeros_i[g_idx_i])) + weights = t.cat(weights, dim=1) + + return weights + + def pack(self, linear: nn.Module, scales: t.Tensor, zeros: t.Tensor, g_idx: t.Tensor=None): W = linear.weight.data.clone() if isinstance(linear, nn.Conv2d): W = W.flatten(1) if isinstance(linear, transformers.pytorch_utils.Conv1D): - W = W.t() + W = W.T self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() + scales = scales.T.contiguous() + zeros = zeros.T.contiguous() scale_zeros = zeros * scales self.scales = scales.clone().to(dtype=t.float16) if linear.bias is not None: self.bias = linear.bias.clone().to(dtype=t.float16) - intweight = t.round((W + scale_zeros[self.g_idx].T) / scales[self.g_idx].T).to(t.int32) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(self.pack_np_math_dtype) + int_weight = t.round((W + scale_zeros[self.g_idx].T) / scales[self.g_idx].T).to(t.int32) + int_weight = int_weight.T.contiguous() + int_weight = int_weight.numpy().astype(self.pack_np_math_dtype) - qweight = np.zeros((intweight.shape[0] // self.pack_dtype_bits * self.bits, intweight.shape[1]), + qweight = np.zeros((int_weight.shape[0] // self.pack_dtype_bits * self.bits, int_weight.shape[1]), dtype=self.pack_np_math_dtype) if self.bits in [2, 4, 8]: for row in range(qweight.shape[0]): for j in range(self.pack_factor): - qweight[row] |= intweight[row * self.pack_factor + j] << (self.bits * j) + qweight[row] |= int_weight[row * self.pack_factor + j] << (self.bits * j) elif self.bits == 3: i = 0 row = 0 while row < qweight.shape[0]: for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i)) + qweight[row] |= int_weight[j] << (3 * (j - i)) i += 10 - qweight[row] |= intweight[i] << 30 + qweight[row] |= int_weight[i] << 30 row += 1 - qweight[row] |= (intweight[i] >> 2) & 1 + qweight[row] |= (int_weight[i] >> 2) & 1 i += 1 for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i) + 1) + qweight[row] |= int_weight[j] << (3 * (j - i) + 1) i += 10 - qweight[row] |= intweight[i] << 31 + qweight[row] |= int_weight[i] << 31 row += 1 - qweight[row] |= (intweight[i] >> 1) & 0x3 + qweight[row] |= (int_weight[i] >> 1) & 0x3 i += 1 for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i) + 2) + qweight[row] |= int_weight[j] << (3 * (j - i) + 2) i += 10 row += 1 @@ -367,4 +501,14 @@ def pack(self, linear, scales, zeros, g_idx=None): self.qzeros = t.from_numpy(qzeros.astype(self.pack_np_dtype)) + # assert + # assert isinstance(self, TorchQuantLinear), f"type: {self.__class_}" + # wq = linear.weight.data + # wq_dequantized = self.dequantize_weight().T + # print(f"------ WQ -----") + # print(wq) + # print(f"------ WQ Dequantized -----") + # print(wq_dequantized) + # assert t.equal(wq, wq_dequantized) + # print("self qw", self.qweight, self.scales, self.qzeros) diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index 6cd701581..cffce514f 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -23,6 +23,7 @@ import numpy as np import torch import torch.nn as nn +from gptqmodel.adapter.adapter import Adapter, Lora from gptqmodel.nn_modules.qlinear import PackableQuantLinear from ...models._const import DEVICE, PLATFORM @@ -96,6 +97,7 @@ class BitBLASQuantLinear(PackableQuantLinear): SUPPORTS_DEVICES = [DEVICE.CUDA] SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512] zeros_mode = "quantized" # "original" or "rescale" or "quantized" @@ -121,6 +123,7 @@ def __init__( out_features: int, bias: bool = False, pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, enable_tuning: bool = True, fast_decoding: bool = True, propagate_b: bool = BITBLAS_PROPAGATE_WEIGHTS, @@ -137,6 +140,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, + adapter=adapter, register_buffers=False, **kwargs) @@ -287,7 +291,7 @@ def pack(self, linear, scales, zeros, g_idx=None): zeros = zeros.t().contiguous() scale_zeros = zeros * scales self.scales = scales.clone().half() - if linear.bias is not None: + if linear.bias: self.bias = linear.bias.clone().half() intweight = torch.round((W + scale_zeros[g_idx].T) / scales[g_idx].T).to(torch.int) @@ -395,6 +399,10 @@ def forward(self, A): self.bitblas_matmul.call_lib( ctypes.c_void_p(A.data_ptr()) , *self.q_params, ctypes.c_void_p(C.data_ptr()), m ) + + if self.adapter: + C = self.adapter.apply(x=A, out=C) + return C diff --git a/gptqmodel/nn_modules/qlinear/dynamic_cuda.py b/gptqmodel/nn_modules/qlinear/dynamic_cuda.py index b4acf9977..25fd81ff7 100644 --- a/gptqmodel/nn_modules/qlinear/dynamic_cuda.py +++ b/gptqmodel/nn_modules/qlinear/dynamic_cuda.py @@ -17,6 +17,7 @@ from typing import Optional, Tuple import torch +from gptqmodel.adapter.adapter import Adapter, Lora from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear from gptqmodel.utils.logger import setup_logger @@ -47,6 +48,7 @@ class DynamicCudaQuantLinear(TorchQuantLinear): SUPPORTS_DEVICES = [DEVICE.CUDA, DEVICE.ROCM] SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "cuda" @@ -61,6 +63,7 @@ def __init__( out_features: int, bias: bool = False, pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, kernel_switch_threshold=128, **kwargs, ): @@ -77,6 +80,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, + adapter=adapter, **kwargs) # assert in_features % 64 == 0 and out_features % 64 == 0 @@ -116,7 +120,7 @@ def forward(self, x: torch.Tensor): if x.shape[0] >= self.kernel_switch_threshold: # logger.warning_once( # f"Input shape `{x.shape[0]}` >= `{self.kernel_switch_threshold}` is not optimized for cuda kernel: dynamic switching to torch kernel.") - return self._forward(x, x.dtype).reshape(out_shape) + return self._forward(x, x.dtype, out_shape) out = torch.zeros((x.shape[0], self.out_features), device=x.device, dtype=torch.float32) self.qmatmul( @@ -128,10 +132,15 @@ def forward(self, x: torch.Tensor): self.g_idx, ) - out = out.to(x.dtype).reshape(out_shape) + out = out.reshape(out_shape) + + if self.adapter: + out = self.adapter.apply(x=x, out=out) + if self.bias is not None: out.add_(self.bias) - return out + + return out.to(x.dtype) __all__ = ["DynamicCudaQuantLinear"] diff --git a/gptqmodel/nn_modules/qlinear/exllama.py b/gptqmodel/nn_modules/qlinear/exllama.py index 4dcdd5762..29b6f5670 100644 --- a/gptqmodel/nn_modules/qlinear/exllama.py +++ b/gptqmodel/nn_modules/qlinear/exllama.py @@ -16,13 +16,12 @@ # Adapted from turboderp exllama: https://github.com/turboderp/exllama -import math from logging import getLogger from typing import Optional, Tuple import torch -import torch.nn.functional as F -from gptqmodel.nn_modules.qlinear import PackableQuantLinear +from gptqmodel.adapter.adapter import Adapter, Lora +from gptqmodel.nn_modules.qlinear import BaseQuantLinear from ...models._const import DEVICE, PLATFORM @@ -55,20 +54,21 @@ def ext_q4_matmul(x, q4, q4_width): return output.view(outshape) -class ExllamaQuantLinear(PackableQuantLinear): +class ExllamaQuantLinear(BaseQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] SUPPORTS_DESC_ACT = [True, False] SUPPORTS_SYM = [True, False] SUPPORTS_SHARDS = True SUPPORTS_TRAINING = False - SUPPORTS_AUTO_PADDING = True + SUPPORTS_AUTO_PADDING = False SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32] SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32] SUPPORTS_DEVICES = [DEVICE.CUDA, DEVICE.ROCM] SUPPORTS_PLATFORM = [PLATFORM.LINUX] SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "exllama" @@ -85,6 +85,7 @@ def __init__( out_features: int, bias: bool = False, pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, **kwargs, ): if exllama_import_exception is not None: @@ -93,15 +94,15 @@ def __init__( ) # backup original values - self.original_out_features = out_features - self.original_in_features = in_features - - # auto pad - group_size = group_size if group_size != -1 else in_features - out_features = out_features + (-out_features % 32) - in_features = in_features + (-in_features % group_size) - self.in_features_padding_size = in_features - self.original_in_features - self.in_features_padding_shape = (0, self.in_features_padding_size) + # self.original_out_features = out_features + # self.original_in_features = in_features + # + # # auto pad + # group_size = group_size if group_size != -1 else in_features + # out_features = out_features + (-out_features % 32) + # in_features = in_features + (-in_features % group_size) + # self.in_features_padding_size = in_features - self.original_in_features + # self.in_features_padding_shape = (0, self.in_features_padding_size) super().__init__( bits=bits, @@ -111,9 +112,10 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, + adapter=adapter, register_buffers=True, - register_buffers_in_features=self.original_in_features, - register_buffers_out_feature=self.original_out_features, + register_buffers_in_features=in_features, + register_buffers_out_feature=out_features, **kwargs) @classmethod @@ -124,16 +126,16 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: def post_init(self): # resize due to padding after model weights have been loaded - if self.out_features != self.original_out_features or self.in_features != self.original_in_features: - self.qweight.resize_(self.in_features // self.pack_dtype_bits * self.bits, self.out_features) - self.qzeros.resize_( - math.ceil(self.in_features / self.group_size), - self.out_features // self.pack_dtype_bits * self.bits - ) - self.scales.resize_((math.ceil(self.in_features / self.group_size), self.out_features), ) - self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32, device=self.g_idx.device) - if self.bias is not None: - self.bias.resize_(self.out_features) + # if self.out_features != self.original_out_features or self.in_features != self.original_in_features: + # self.qweight.resize_(self.in_features // self.pack_dtype_bits * self.bits, self.out_features) + # self.qzeros.resize_( + # math.ceil(self.in_features / self.group_size), + # self.out_features // self.pack_dtype_bits * self.bits + # ) + # self.scales.resize_((math.ceil(self.in_features / self.group_size), self.out_features), ) + # self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32, device=self.g_idx.device) + # if self.bias is not None: + # self.bias.resize_(self.out_features) self.width = self.qweight.shape[1] @@ -147,9 +149,12 @@ def post_init(self): self.qweight.device.index, ) + super().post_init() + def forward(self, x): - if x.dtype != torch.float16: + x_dtype = x.dtype + if x_dtype != torch.float16: logger.warning_once( f"Exllama kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model." ) @@ -158,12 +163,15 @@ def forward(self, x): # TODO: need to run checks to make sure there is no performance regression padding with F.pad # if in_features is padded, we need to pad the input as well - if x.size(-1) != self.in_features: - x = F.pad(x, self.in_features_padding_shape) + # if x.size(-1) != self.in_features: + # x = F.pad(x, self.in_features_padding_shape) out = ext_q4_matmul(x, self.q4, self.width) if self.bias is not None: out.add_(self.bias) - return out + if self.adapter: + out = self.adapter.apply(x=x, out=out) + + return out.to(x_dtype) diff --git a/gptqmodel/nn_modules/qlinear/exllama_eora.py b/gptqmodel/nn_modules/qlinear/exllama_eora.py new file mode 100644 index 000000000..6adce0c25 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/exllama_eora.py @@ -0,0 +1,192 @@ +# Copyright 2025 ModelCloud +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 + +from typing import Optional, Tuple + +import torch +from gptqmodel.adapter.adapter import Adapter, Lora +from gptqmodel.nn_modules.qlinear import BaseQuantLinear +from torch.nn import Parameter + +from ...models._const import DEVICE, PLATFORM +from ...utils.logger import setup_logger + +exllama_v2v_import_exception = None + +try: + import gptqmodel_exllama_eora +except ImportError as e: + exllama_v2v_import_exception = e + +logger = setup_logger() + + + +# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension +NONE_TENSOR = torch.empty((1, 1), device="meta") + + +# TODO remove this? +def _torch_device(idx): + if idx == -1: + return "cpu" + return f"cuda:{idx}" + +def gptq_gemm(x, qweight, qzeros, scales, g_idx, bit): + return gptqmodel_exllama_eora.gptq_gemm(x, qweight, qzeros, scales, g_idx, True, bit) + + +def gptq_gemm_lora(x, qweight, qzeros, scales, g_idx, bit, A, B): + return gptqmodel_exllama_eora.gptq_gemm_lora(x, qweight, qzeros, scales, g_idx, True, bit, A, B) + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + gptqmodel_exllama_eora.gptq_shuffle(q_weight, q_perm, bit) + + +class ExllamaEoraQuantLinear(BaseQuantLinear): + SUPPORTS_BITS = [4, 8] # TODO: validate 2/3 + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] + SUPPORTS_DESC_ACT = [True, False] + SUPPORTS_SYM = [True] # TODO: validate False + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = False + SUPPORTS_AUTO_PADDING = False # TODO: validate True + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32] + + SUPPORTS_DEVICES = [DEVICE.CUDA, DEVICE.ROCM] + SUPPORTS_PLATFORM = [PLATFORM.LINUX] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + # for transformers/optimum tests compat + QUANT_TYPE = "exllama_v2v" + + """Linear layer implementation with per-group 4-bit quantization of the weights""" + + def __init__(self, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features: int, + out_features: int, + pack_dtype: torch.dtype, + adapter: Adapter, + bias: bool, **kwargs, + ): + if exllama_v2v_import_exception is not None: + raise ValueError( + f"Trying to use the exllama v2 backend, but could not import the C++/CUDA dependencies with the following error: {exllama_v2v_import_exception}" + ) + + # # backup original values + # self.original_out_features = out_features + # self.original_in_features = in_features + # + # # auto pad + # group_size = group_size if group_size != -1 else in_features + # out_features = out_features + (-out_features % 32) + # in_features = in_features + (-in_features % group_size) + # self.in_features_padding_size = in_features - self.original_in_features + # self.in_features_padding_shape = (0, self.in_features_padding_size) + + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + adapter=adapter, + register_buffers=True, + register_buffers_in_features=in_features, # self.original_in_features + register_buffers_out_feature=out_features, # self.original_out_features + **kwargs) + + + @classmethod + def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: + if exllama_v2v_import_exception is not None: + return False, exllama_v2v_import_exception + return cls._validate(**args) + + def post_init(self): + # resize due to padding after model weights have been loaded + # if self.out_features != self.original_out_features or self.in_features != self.original_in_features: + # self.qweight.resize_(self.in_features // self.pack_dtype_bits * self.bits, self.out_features) + # self.qzeros.resize_( + # math.ceil(self.in_features / self.group_size), + # self.out_features // self.pack_dtype_bits * self.bits + # ) + # self.scales.resize_(math.ceil(self.in_features / self.group_size), self.out_features) + # self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32, device=self.g_idx.device) + # if self.bias is not None: + # self.bias.resize_(self.out_features) + + super().post_init() + + self.qzeros = Parameter(self.qzeros.data, requires_grad=False) + self.qweight = Parameter(self.qweight.data, requires_grad=False) + self.g_idx = Parameter(self.g_idx.data, requires_grad=False) + self.scales = Parameter(self.scales.data, requires_grad=False) + + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if self.desc_act: + self.g_idx.data = torch.argsort(self.g_idx).to(torch.int32) + else: + self.g_idx.data = torch.empty((0,), + dtype=torch.int32, + device=self.g_idx.device) + + gptq_shuffle(self.qweight, self.g_idx, self.bits) + + def forward(self, x): + x_dtype = x.dtype + if x_dtype != torch.float16: + logger.warning_once( + f"Exllama v2 kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model." + ) + + x = x.to(dtype=torch.float16) + + # sync with vllm + out_shape = x.shape[:-1] + (self.qweight.shape[-1],) + reshaped_x = x.reshape(-1, x.shape[-1]) + + # TODO: need to run checks to make sure there is no performance regression padding with F.pad + # if in_features is padded, we need to pad the input as well + # if x.size(-1) != self.in_features: + # x = F.pad(x, self.in_features_padding_shape) + + if self.adapter: + # output = gptq_gemm_lora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused + output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits).add_((reshaped_x @ self.adapter.lora_A) @ self.adapter.lora_B) # normal + else: + output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits) + + + if self.bias is not None: + output.add_(self.bias) + + # sync with vllm + output = output.reshape(out_shape) + + return output.to(dtype=x_dtype) diff --git a/gptqmodel/nn_modules/qlinear/exllamav2.py b/gptqmodel/nn_modules/qlinear/exllamav2.py index b4429d419..efd573edd 100644 --- a/gptqmodel/nn_modules/qlinear/exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/exllamav2.py @@ -16,11 +16,10 @@ # Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 -import math from typing import Optional, Tuple import torch -import torch.nn.functional as F +from gptqmodel.adapter.adapter import Adapter, Lora from gptqmodel.nn_modules.qlinear import BaseQuantLinear from ...models._const import DEVICE, PLATFORM @@ -126,14 +125,14 @@ class ExllamaV2QuantLinear(BaseQuantLinear): SUPPORTS_SYM = [True, False] SUPPORTS_SHARDS = True SUPPORTS_TRAINING = False - SUPPORTS_AUTO_PADDING = True + SUPPORTS_AUTO_PADDING = False SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32] SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32] SUPPORTS_DEVICES = [DEVICE.CUDA, DEVICE.ROCM] SUPPORTS_PLATFORM = [PLATFORM.LINUX] SUPPORTS_PACK_DTYPES = [torch.int32] - + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "exllamav2" @@ -149,6 +148,7 @@ def __init__( out_features: int, bias: bool = False, pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, **kwargs, ): if exllama_v2_import_exception is not None: @@ -157,15 +157,15 @@ def __init__( ) # backup original values - self.original_out_features = out_features - self.original_in_features = in_features - - # auto pad - group_size = group_size if group_size != -1 else in_features - out_features = out_features + (-out_features % 32) - in_features = in_features + (-in_features % group_size) - self.in_features_padding_size = in_features - self.original_in_features - self.in_features_padding_shape = (0, self.in_features_padding_size) + # self.original_out_features = out_features + # self.original_in_features = in_features + # + # # auto pad + # group_size = group_size if group_size != -1 else in_features + # out_features = out_features + (-out_features % 32) + # in_features = in_features + (-in_features % group_size) + # self.in_features_padding_size = in_features - self.original_in_features + # self.in_features_padding_shape = (0, self.in_features_padding_size) super().__init__( bits=bits, @@ -176,9 +176,10 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, + adapter=adapter, register_buffers=True, - register_buffers_in_features=self.original_in_features, - register_buffers_out_feature=self.original_out_features, + register_buffers_in_features=in_features, + register_buffers_out_feature=out_features, **kwargs) self.q_handle = None @@ -192,16 +193,16 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: def post_init(self, temp_dq): # resize due to padding after model weights have been loaded - if self.out_features != self.original_out_features or self.in_features != self.original_in_features: - self.qweight.resize_(self.in_features // self.pack_dtype_bits * self.bits, self.out_features) - self.qzeros.resize_( - math.ceil(self.in_features / self.group_size), - self.out_features // self.pack_dtype_bits * self.bits - ) - self.scales.resize_(math.ceil(self.in_features / self.group_size), self.out_features) - self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32, device=self.g_idx.device) - if self.bias is not None: - self.bias.resize_(self.out_features) + # if self.out_features != self.original_out_features or self.in_features != self.original_in_features: + # self.qweight.resize_(self.in_features // self.pack_dtype_bits * self.bits, self.out_features) + # self.qzeros.resize_( + # math.ceil(self.in_features / self.group_size), + # self.out_features // self.pack_dtype_bits * self.bits + # ) + # self.scales.resize_(math.ceil(self.in_features / self.group_size), self.out_features) + # self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32, device=self.g_idx.device) + # if self.bias is not None: + # self.bias.resize_(self.out_features) self.q_tensors = { "qweight": self.qweight, @@ -212,8 +213,11 @@ def post_init(self, temp_dq): temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) + super().post_init() + def forward(self, x, force_cuda=False): - if x.dtype != torch.float16: + x_dtype = x.dtype + if x_dtype != torch.float16: logger.warning_once( f"Exllama v2 kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model." ) @@ -222,15 +226,19 @@ def forward(self, x, force_cuda=False): # TODO: need to run checks to make sure there is no performance regression padding with F.pad # if in_features is padded, we need to pad the input as well - if x.size(-1) != self.in_features: - x = F.pad(x, self.in_features_padding_shape) + # if x.size(-1) != self.in_features: + # x = F.pad(x, self.in_features_padding_shape) + - output = ext_gemm_half_q_half(x, self.q_handle, self.out_features, force_cuda) + out = ext_gemm_half_q_half(x, self.q_handle, self.out_features, force_cuda) if self.bias is not None: - output.add_(self.bias) + out.add_(self.bias) + + if self.adapter: + out = self.adapter.apply(x=x, out=out) - return output + return out.to(dtype=x_dtype) def temp_dq_size(self): return self.in_features * self.out_features * 2 + 128 diff --git a/gptqmodel/nn_modules/qlinear/ipex.py b/gptqmodel/nn_modules/qlinear/ipex.py index b4c058d25..0769f7fdc 100644 --- a/gptqmodel/nn_modules/qlinear/ipex.py +++ b/gptqmodel/nn_modules/qlinear/ipex.py @@ -16,15 +16,13 @@ from typing import Optional, Tuple -import numpy as np import torch -import torch.nn as nn -import transformers +from gptqmodel.adapter.adapter import Adapter, Lora from gptqmodel.models._const import DEVICE, PLATFORM -from gptqmodel.nn_modules.qlinear import PackableQuantLinear from ...utils.logger import setup_logger -from ...utils.torch import HAS_XPU +from ...utils.torch import torch_compile +from . import PackableQuantLinear logger = setup_logger() @@ -47,7 +45,7 @@ def ipex_dtype() -> torch.dtype: raise ImportError("intel_extension_for_pytorch not installed. " "Please install via `pip install intel_extension_for_pytorch`") - return torch.float16 if HAS_XPU else torch.bfloat16 + return torch.float16 def convert_dtype_torch2str(dtype): @@ -93,7 +91,7 @@ class IPEXQuantLinear(PackableQuantLinear): SUPPORTS_DESC_ACT = [True, False] SUPPORTS_SYM = [True, False] SUPPORTS_SHARDS = True - SUPPORTS_TRAINING = True + SUPPORTS_TRAINING = False SUPPORTS_AUTO_PADDING = False SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] @@ -101,7 +99,7 @@ class IPEXQuantLinear(PackableQuantLinear): SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU] SUPPORTS_PLATFORM = [PLATFORM.LINUX] SUPPORTS_PACK_DTYPES = [torch.int32] - + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "ipex" @@ -115,8 +113,7 @@ def __init__( out_features: int, bias: bool = False, pack_dtype: torch.dtype = torch.int32, - kernel_switch_threshold=128, - training=False, + adapter: Adapter = None, **kwargs, ): super().__init__( @@ -128,192 +125,40 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, + adapter=adapter, register_buffers=True, **kwargs) - # FIX ME IPEX CPU has no float16 support - self.weight_dtype = torch.float16 if HAS_XPU else torch.bfloat16 - self.init_ipex = False - - self.kernel_switch_threshold = kernel_switch_threshold - - self.training = training - - # for training forward - self.wf = torch.tensor(list(range(0, self.pack_dtype_bits, self.bits)), dtype=torch.int32).unsqueeze(0) + self.weight_dtype = torch.float16 @classmethod - def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: + def validate(cls, bias: bool = False, adapter: Optional[Adapter] = None, **args) -> Tuple[bool, Optional[Exception]]: if not HAS_IPEX: return False, IPEX_ERROR_LOG return cls._validate(**args) def post_init(self): - pass - - def init_ipex_linear(self, x: torch.Tensor): - if not self.training and HAS_IPEX and not x.requires_grad: - self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, - self.in_features, self.out_features, None, self.bias, - self.group_size, self.g_idx, quant_method=QuantMethod.GPTQ_GEMM, dtype=QuantDtype.INT4) - - def pack(self, linear, scales, zeros, g_idx=None): - W = linear.weight.data.clone() - if isinstance(linear, nn.Conv2d): - W = W.flatten(1) - if isinstance(linear, transformers.pytorch_utils.Conv1D): - W = W.t() - - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().to(dtype=linear.weight.dtype) - if linear.bias is not None: - self.bias = linear.bias.clone().to(dtype=linear.weight.dtype) - - intweight = torch.round((W + scale_zeros[self.g_idx].T) / scales[self.g_idx].T).to(torch.int) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - - qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) - for row in range(qweight.shape[0]): - i = row * (32 // self.bits) - for j in range(32 // self.bits): - qweight[row] |= intweight[i + j] << (self.bits * j) - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) - for col in range(qzeros.shape[1]): - i = col * (32 // self.bits) - for j in range(32 // self.bits): - qzeros[:, col] |= zeros[:, i + j] << (self.bits * j) - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - + self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight( + self.qweight, + self.scales, + self.qzeros, + self.in_features, + self.out_features, + None, + self.bias, + self.group_size, + self.g_idx, + quant_method=QuantMethod.GPTQ_GEMM, + dtype=QuantDtype.INT4) + + @torch.no_grad() def forward(self, x: torch.Tensor): - if not self.init_ipex: - self.init_ipex_linear(x) - self.init_ipex = True - - if hasattr(self, "ipex_linear"): - with torch.no_grad(): - outputs = self.ipex_linear(x) - return outputs - - if self.wf.device != x.device: - self.wf = self.wf.to(x.device) - out_shape = x.shape[:-1] + (self.out_features,) - x = x.reshape(-1, x.shape[-1]) - x_dtype = x.dtype - zeros = torch.bitwise_right_shift( - torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), - self.wf.unsqueeze(0), - ).to(torch.int16) - zeros = torch.bitwise_and(zeros, (2**self.bits) - 1) - - zeros = zeros + 1 - zeros = zeros.reshape(self.scales.shape) - - weight = torch.bitwise_right_shift( - torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), - self.wf.unsqueeze(-1), - ).to(torch.int16) - weight = torch.bitwise_and(weight, (2**self.bits) - 1) - - weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) - num_itr = self.g_idx.shape[0] // x.shape[-1] - if num_itr == 1: - weights = self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()]) + if self.adapter: + return self.adapter(x=x, out=self.ipex_linear(x)) else: - num_dim = self.g_idx.shape[0] // num_itr - weights = [] - for i in range(num_itr): - scale_i = self.scales[:, i * num_dim : (i + 1) * num_dim] - weight_i = weight[:, i * num_dim : (i + 1) * num_dim] - zeros_i = zeros[:, i * num_dim : (i + 1) * num_dim] - g_idx_i = self.g_idx[i * num_dim : (i + 1) * num_dim] - weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) - weights = torch.cat(weights, dim=1) - out = torch.matmul(x, weights.to(x.dtype)) - out = out.to(x_dtype) - out = out.reshape(out_shape) - if self.bias is not None: - out.add_(self.bias) - - return out - - -@torch.no_grad() -def unpack_to_8bit_signed(qweight, qzeros, bits, g_idx=None): - wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0) - zeros = None - if not torch.all(torch.eq(qzeros, 2004318071 if bits == 4 else 0b01111111011111110111111101111111)): - zp_shape = list(qzeros.shape) - zp_shape[1] = zp_shape[1] * (32 // bits) - - zeros = torch.bitwise_right_shift( - torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0) - ).to(torch.int16 if bits == 8 else torch.int8) - torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) - if bits == 8: - zeros = zeros.to(torch.uint8) - zeros = zeros + 1 - try: - zeros = zeros.reshape(zp_shape) - except Exception: - # zeros and scales have different iteam numbers. - # remove 1 (due to 0 + 1 in line 252) - zeros = zeros[zeros != 1] - zeros = zeros.reshape(zp_shape) - - try: - r = torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1) - except BaseException as e: - print(e) - weight = torch.bitwise_right_shift( - r, wf.unsqueeze(-1) - ).to(torch.int16 if bits == 8 else torch.int8) - weight.bitwise_and_((2**bits) - 1) - weight = weight.view(-1, weight.shape[-1]) - - if g_idx is not None: - group_size = weight.shape[0] // qzeros.shape[0] - weight2 = weight.clone() - group_dict = {} - for i in range(len(g_idx)): - group_idx = g_idx[i].item() - if group_idx not in group_dict: - target_idx = group_idx * group_size - group_dict[group_idx] = 0 - else: - group_dict[group_idx] = group_dict[group_idx] + 1 - target_idx = group_idx * group_size + group_dict[group_idx] - weight2[target_idx] = weight[i] - weight = weight2 - - return weight, zeros - - -# Copied from marlin.py -@torch.no_grad() -def dequantize_weight(qweight, qzeros, scales, bits): - unpacked_qweight, unpacked_qzeros = unpack_to_8bit_signed(qweight, qzeros, bits) - group_size = unpacked_qweight.shape[0] // scales.shape[0] - scales = scales.repeat_interleave(group_size, dim=0) - if unpacked_qzeros is not None: - unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) - else: - unpacked_qzeros = torch.full_like(scales, 8 if bits == 4 else 128, dtype=torch.int32) - unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales - - return unpacked_qweight, unpacked_qzeros + return self.ipex_linear(x) + def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False): + self.forward = torch_compile(self.forward, backend=backend, mode=mode, fullgraph=fullgraph) -__all__ = ["IPEXQuantLinear", "dequantize_weight"] +__all__ = ["IPEXQuantLinear"] diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py index 6c82c5315..8bde9c56a 100644 --- a/gptqmodel/nn_modules/qlinear/marlin.py +++ b/gptqmodel/nn_modules/qlinear/marlin.py @@ -21,6 +21,7 @@ import numpy as np import torch +from gptqmodel.adapter.adapter import Adapter, Lora from gptqmodel.nn_modules.qlinear import BaseQuantLinear from torch.nn.parameter import Parameter @@ -170,7 +171,7 @@ class MarlinQuantLinear(BaseQuantLinear): SUPPORTS_DEVICES = [DEVICE.CUDA] SUPPORTS_PLATFORM = [PLATFORM.LINUX] SUPPORTS_PACK_DTYPES = [torch.int32] - + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "marlin" @@ -183,6 +184,7 @@ def __init__( out_features: int, bias: bool = False, pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, **kwargs): if marlin_import_exception is not None: raise ValueError( @@ -206,6 +208,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, + adapter=adapter, register_buffers=False, **kwargs) @@ -368,11 +371,13 @@ def post_init(self): group_size=self.group_size) replace_tensor(self, "scales", marlin_scales) + super().post_init() + def forward(self, A: torch.Tensor): if A.dtype != torch.float16: A = A.to(torch.float16) - return apply_gptq_marlin_linear( + out = apply_gptq_marlin_linear( input=A.contiguous() if self.is_lm_head else A, weight=self.qweight, weight_scale=self.scales, @@ -384,7 +389,13 @@ def forward(self, A: torch.Tensor): output_size_per_partition=self.out_features, input_size_per_partition=self.in_features, is_k_full=self.is_k_full, - bias=self.bias) + bias=self.bias, + ) + + if self.adapter: + out = self.adapter.apply(x=A, out=out) + + return out # Precompute permutations for Marlin weight and scale shuffling def _get_perms(): diff --git a/gptqmodel/nn_modules/qlinear/torch.py b/gptqmodel/nn_modules/qlinear/torch.py index b592de7d2..434d3e019 100644 --- a/gptqmodel/nn_modules/qlinear/torch.py +++ b/gptqmodel/nn_modules/qlinear/torch.py @@ -14,15 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import torch import torch.nn as nn -import torch.nn.functional as F +from gptqmodel.adapter.adapter import Adapter, Lora from gptqmodel.nn_modules.qlinear import BaseQuantLinear, PackableQuantLinear from gptqmodel.utils.logger import setup_logger +from transformers import PreTrainedModel from ...models._const import DEVICE, PLATFORM +from ...utils.torch import torch_compile logger = setup_logger() @@ -40,7 +41,7 @@ class TorchQuantLinear(PackableQuantLinear): SUPPORTS_DEVICES = [DEVICE.ALL] SUPPORTS_PLATFORM = [PLATFORM.ALL] SUPPORTS_PACK_DTYPES = [torch.int8, torch.int16, torch.int32] - + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "torch" @@ -54,6 +55,7 @@ def __init__( out_features: int, bias: bool = False, pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, **kwargs, ): super().__init__( @@ -65,67 +67,67 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, + adapter=adapter, register_buffers=True, **kwargs) self.dequant_dtype = torch.int16 if self.bits == 8 else torch.int8 - if self.group_size != self.in_features: - self.padded_infeatures = self.in_features + (-self.in_features % self.group_size) - else: - self.padded_infeatures = self.in_features + # if self.group_size != self.in_features: + # self.padded_infeatures = self.in_features + (-self.in_features % self.group_size) + # else: + # self.padded_infeatures = self.in_features def post_init(self): - if self.padded_infeatures != self.in_features: - self.qweight.resize_(self.padded_infeatures // self.pack_dtype_bits * self.bits, self.out_features) - self.qzeros.resize_( - math.ceil(self.padded_infeatures / self.group_size), - self.out_features // self.pack_dtype_bits * self.bits - ) - self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.out_features), ) - self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32, - device=self.g_idx.device) - - if self.bits in [2, 4, 8]: - self.register_buffer( - "wf", - torch.tensor(list(range(0, self.pack_dtype_bits, self.bits)), dtype=torch.int32).unsqueeze(0).to(device=self.g_idx.device), - ) - elif self.bits == 3: - self.register_buffer( - "wf", - torch.tensor( - [ - [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], - [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], - [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0], - ], - dtype=torch.int32, - ).reshape(1, 3, 12).to(device=self.g_idx.device) - ) + # if self.padded_infeatures != self.in_features: + # self.qweight.resize_(self.padded_infeatures // self.pack_dtype_bits * self.bits, self.out_features) + # self.qzeros.resize_( + # math.ceil(self.padded_infeatures / self.group_size), + # self.out_features // self.pack_dtype_bits * self.bits + # ) + # self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.out_features), ) + # self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32, + # device=self.g_idx.device) + + super().post_init() + + # torch benefits the most from torch.compile, enable it by default + self.optimize() + + def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False): + if self.optimized: + return - def compile(self): # compile dequantize - self.dequantize = torch.compile(self.dequantize) + self.dequantize_weight = torch_compile(self.dequantize_weight, backend=backend, mode=mode, fullgraph=fullgraph) + + if self.adapter: + self.adapter.optimize(backend=backend, mode=mode, fullgraph=fullgraph) + + super().optimize() def forward(self, x: torch.Tensor): - if x.size(-1) != self.padded_infeatures: - x = F.pad(x, (0, self.padded_infeatures - self.in_features)) + # if x.size(-1) != self.padded_infeatures: + # x = F.pad(x, (0, self.padded_infeatures - self.in_features)) out_shape = x.shape[:-1] + (self.out_features,) x = x.reshape(-1, x.shape[-1]) - out = self._forward(x, x.dtype) - out = out.reshape(out_shape) + out = self._forward(x, x.dtype, out_shape) return out - def _forward(self, x, x_dtype): + def _forward(self, x, x_dtype, out_shape): num_itr = self.g_idx.shape[0] // x.shape[-1] - weights = self.dequantize(num_itr=num_itr) + weights = self.dequantize_weight(num_itr=num_itr) + + out = torch.matmul(x, weights).reshape(out_shape) - out = torch.matmul(x, weights).to(x_dtype) if self.bias is not None: out.add_(self.bias) - return out + + if self.adapter: + out = self.adapter.apply(x=x, out=out) + + return out.to(x_dtype) # clear gptq only weights: useful in de-quantization def _empty_gptq_only_weights(self): @@ -134,61 +136,8 @@ def _empty_gptq_only_weights(self): self.g_idx = None self.scales = None - def dequantize(self, num_itr=1): - if self.bits in [4, 8, 2]: - zeros = torch.bitwise_right_shift( - torch.unsqueeze(self.qzeros, 2).expand(-1, -1, self.pack_factor), - self.wf.unsqueeze(0), - ).to(self.dequant_dtype) - zeros = torch.bitwise_and(zeros, self.maxq).reshape(self.scales.shape) - - weight = torch.bitwise_and( - torch.bitwise_right_shift( - torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1), - self.wf.unsqueeze(-1), - ).to(self.dequant_dtype), - self.maxq - ) - elif self.bits == 3: - zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1).expand( - -1, -1, -1, 12 - ) - zeros = zeros >> self.wf.unsqueeze(0) - zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4) - zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6) - zeros = zeros & 0x7 - zeros = torch.cat( - [zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], - dim=2, - ).reshape(self.scales.shape) - - weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand( - -1, -1, 12, -1 - ) - weight = (weight >> self.wf.unsqueeze(-1)) & 0x7 - weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4) - weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6) - weight = weight & 0x7 - weight = torch.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1) - weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) - - if num_itr == 1: - weights = self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()]) - else: - num_dim = self.g_idx.shape[0] // num_itr - weights = [] - for i in range(num_itr): - scale_i = self.scales[:, i * num_dim: (i + 1) * num_dim] - weight_i = weight[:, i * num_dim: (i + 1) * num_dim] - zeros_i = zeros[:, i * num_dim: (i + 1) * num_dim] - g_idx_i = self.g_idx[i * num_dim: (i + 1) * num_dim].long() - weights.append(scale_i[g_idx_i] * (weight_i - zeros_i[g_idx_i])) - weights = torch.cat(weights, dim=1) - - return weights - -def dequantize_model(model: nn.Module): - for name, module in model.model.named_modules(): +def dequantize_model(model: PreTrainedModel): + for name, module in model.named_modules(): if isinstance(module, BaseQuantLinear) and not isinstance(module, TorchQuantLinear): raise ValueError( "Only models loaded using TorchQuantLinear are supported for dequantization. " @@ -198,14 +147,14 @@ def dequantize_model(model: nn.Module): if isinstance(module, TorchQuantLinear): # Create a new Linear layer with dequantized weights new_module = nn.Linear(module.in_features, module.out_features) - new_module.weight = nn.Parameter(module.dequantize().T.detach().to("cpu", torch.float16)) - new_module.bias = module.bias + new_module.weight = nn.Parameter(module.dequantize_weight().T.detach().to("cpu", torch.float16)) + new_module.bias = torch.nn.Parameter(module.bias) # Replace the module in the model - parent = model.model + parent = model if '.' in name: parent_name, module_name = name.rsplit('.', 1) - parent = dict(model.model.named_modules())[parent_name] + parent = dict(model.named_modules())[parent_name] else: module_name = name diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index 5d1c39ea4..c48c43002 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -14,11 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Optional, Tuple import torch -import torch.nn.functional as F +from gptqmodel.adapter.adapter import Adapter, Lora from packaging import version from ...models._const import DEVICE, PLATFORM @@ -60,7 +59,7 @@ class TritonV2QuantLinear(PackableQuantLinear, TritonModuleMixin): SUPPORTS_DEVICES = [DEVICE.CUDA, DEVICE.XPU] SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] SUPPORTS_PACK_DTYPES = [torch.int32, torch.int16, torch.int8] - + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "tritonv2" @@ -82,6 +81,7 @@ def __init__( out_features, bias: bool = False, pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, **kwargs, ): if not TRITON_AVAILABLE: @@ -95,13 +95,14 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, + adapter=adapter, register_buffers=True, **kwargs) - if self.group_size != self.in_features: - self.padded_infeatures = self.in_features + (-self.in_features % self.group_size) - else: - self.padded_infeatures = self.in_features + # if self.group_size != self.in_features: + # self.padded_infeatures = self.in_features + (-self.in_features % self.group_size) + # else: + # self.padded_infeatures = self.in_features @classmethod def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: @@ -116,20 +117,21 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: return cls._validate(**args) def post_init(self): - if self.padded_infeatures != self.in_features: - self.qweight.resize_(self.padded_infeatures // self.pack_factor, self.out_features) - self.qzeros.resize_( - math.ceil(self.padded_infeatures / self.group_size), - self.out_features // self.pack_factor - ) - self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.out_features), ) - self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32, - device=self.g_idx.device) + # if self.padded_infeatures != self.in_features: + # self.qweight.resize_(self.padded_infeatures // self.pack_factor, self.out_features) + # self.qzeros.resize_( + # math.ceil(self.padded_infeatures / self.group_size), + # self.out_features // self.pack_factor + # ) + # self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.out_features), ) + # self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32, + # device=self.g_idx.device) + super().post_init() def forward(self, x): # if in_features is padded, we need to pad the input as well - if x.size(-1) != self.padded_infeatures: - x = F.pad(x, (0, self.padded_infeatures - self.in_features)) + # if x.size(-1) != self.padded_infeatures: + # x = F.pad(x, (0, self.padded_infeatures - self.in_features)) out_shape = x.shape[:-1] + (self.out_features,) @@ -142,11 +144,15 @@ def forward(self, x): self.bits, self.pack_dtype_bits, self.maxq, - ) - out = out.to(dtype=x.dtype).reshape(out_shape) + ).reshape(out_shape) + if self.bias is not None: out.add_(self.bias) - return out + + if self.adapter: + out = self.adapter.apply(x=x, out=out) + + return out.to(dtype=x.dtype) __all__ = ["TritonV2QuantLinear"] diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index 906263653..8299863d8 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -24,6 +24,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +from gptqmodel.adapter.adapter import Lora, normalize_adapter from packaging import version from ..utils.logger import setup_logger @@ -57,6 +58,7 @@ META_FIELD_MSE = "mse" +ADAPTER_FIELD = "adapter" # pkg names PKG_AUTO_ROUND = "auto-round" @@ -104,7 +106,6 @@ class QUANT_METHOD: FORMAT_FIELD_JSON: FORMAT_FIELD_CODE, } - def dict_scale_dtype_to_str(d: Dict[str, Any]) -> None: """ Checks whether the passed dictionary and its nested dicts have a *scale_dtype* key and if it's not None, @@ -119,6 +120,10 @@ def dict_scale_dtype_to_str(d: Dict[str, Any]) -> None: def dynamic_get(dynamic: Dict[str, Dict[str, Union[int, bool]]], module_name: str, key: str = None, default_value: Union[int, bool] = None) -> Union[Dict, int, bool]: + + if dynamic is None: + return default_value + for pattern, overrides in dynamic.items(): if pattern.startswith("-:"): if re.match(pattern.removeprefix("-:"), module_name): @@ -177,6 +182,9 @@ class QuantizeConfig(): # affects [`qweights`, `qzeros`] pack_dtype: Optional[Union[str, torch.dtype]] = field(default=torch.int32) + # pending used field + adapter: Optional[Union[Dict[str, Any], Lora]] = field(default=None) + def __post_init__(self): fields_info = fields(self) @@ -187,26 +195,26 @@ def __post_init__(self): if isinstance(self.pack_dtype, str): self.pack_dtype = self.pack_dtype.lower() if self.pack_dtype not in ["int64", "int32", "int16", "int8"]: - raise ValueError(f"Unsupported pack_dtype: {self.pack_dtype}") + raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {self.pack_dtype}") self.pack_dtype = getattr(torch, self.pack_dtype) elif isinstance(self.pack_dtype, torch.dtype): if self.pack_dtype not in [torch.int64, torch.int32, torch.int16, torch.int8]: - raise ValueError(f"Unsupported pack_dtype: {self.pack_dtype}") + raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {self.pack_dtype}") else: - raise ValueError(f"Unsupported pack_dtype: {self.pack_dtype}") + raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {self.pack_dtype}") # validate quant method and format is matched valid_formats = QUANT_METHOD_FORMAT_MAPPING.get(self.quant_method, None) if valid_formats is None: - raise ValueError(f"Unsupported quantization method: {self.quant_method}") + raise ValueError(f"QuantizeConfig: Unsupported `quant_method`: {self.quant_method}") if self.format not in valid_formats: raise ValueError( - f"The checkpoint format used is {self.format}, and the quantization method is {self.quant_method}. " + f"QuantizeConfig: checkpoint `format` used is {self.format}, and the quantization method is {self.quant_method}. " ) if self.bits not in fields_info[0].metadata["choices"]: - raise ValueError(f"only support quantize to {fields_info[0].metadata['choices']} bits.") + raise ValueError(f"QuantizeConfig: `bits` must be in the set of `{fields_info[0].metadata['choices']}`.") if self.dynamic is not None: self.dynamic = { @@ -217,29 +225,43 @@ def __post_init__(self): for layer, layer_dict in self.dynamic.items(): for key, value in layer_dict.items(): if key == "bits" and value not in fields_info[0].metadata["choices"]: - raise ValueError(f"Layer {layer}: only support quantize to {fields_info[0].metadata['choices']} bits.") + raise ValueError(f"QuantizeConfig: Layer `{layer}` only support quantization of `{fields_info[0].metadata['choices']}` bits.") elif key == "group_size" and value != -1 and value <= 0: - raise ValueError("unless equal to -1, group_size must greater then 0.") + raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.") if self.group_size != -1 and self.group_size <= 0: - raise ValueError("unless equal to -1, group_size must greater than 0.") + raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.") if not (0 < self.damp_percent < 1): - raise ValueError("damp_percent must between 0 and 1.") + raise ValueError("QuantizeConfig: `damp_percent` must between 0 and 1.") if self.damp_auto_increment < 0: - raise ValueError("damp_auto_increment must greater than 0.") + raise ValueError("QuantizeConfig:: `damp_auto_increment` must greater than 0.") # validate meta if self.meta is not None: if not isinstance(self.meta, dict): - raise ValueError("meta must be a dictionary") + raise ValueError("QuantizeConfig: `meta` must be a dictionary") for key, value in self.meta.items(): if not isinstance(key, str): - raise ValueError("Keys in the meta dictionary must be strings") + raise ValueError("QuantizeConfig: `meta` keys must be strings") else: self.meta = {} + # adapter normalize + self.adapter = normalize_adapter(self.adapter) + + #print(f"adapter: {self.adapter}") + + def extension_set(self, key: str, value: Any): + if self.adapter is None: + self.adapter = {} + + self.adapter[key.lower()] = value + + def extension_get(self, key: str) -> Any: + return self.adapter.get(key.lower()) if self.adapter else None + def meta_set(self, key: str, value: Any): self.meta[key] = value @@ -291,9 +313,9 @@ def from_quant_config(cls, quantize_cfg, format: str = None): # compat: format can be passed in via from_quantized() if field missing from json if format: if format not in valid_formats: - raise ValueError(f"Unknown quantization checkpoint format: {format}.") + raise ValueError(f"QuantizeConfig: Unknown quantization checkpoint format: {format}.") if quantize_cfg.get(FORMAT_FIELD_JSON): - raise ValueError("Conflict: quantization format is passed in and also exists in model config.") + raise ValueError("QuantizeConfig: Conflicting quantization format passed in manually and also exists in model config.") # compat: warn if checkpoint_format is missing elif quantize_cfg.get(FORMAT_FIELD_JSON) is None: format_auto_inferred = True @@ -318,7 +340,7 @@ def from_quant_config(cls, quantize_cfg, format: str = None): if val in {FORMAT.GPTQ, FORMAT.GPTQ_V2, FORMAT.MARLIN, FORMAT.BITBLAS}: normalized[key] = val else: - raise ValueError(f"Unknown quantization format: {val}.") + raise ValueError(f"QuantizeConfig: Unknown quantization format: `{val}`.") elif key == QUANT_METHOD_FIELD: val = val.lower() # compat: some hf models use quant_method=marlin or bitblas @@ -327,7 +349,7 @@ def from_quant_config(cls, quantize_cfg, format: str = None): elif val == FORMAT.BITBLAS: normalized[FORMAT_FIELD_CODE] = FORMAT.BITBLAS elif val not in {QUANT_METHOD.GPTQ, QUANT_METHOD.AUTO_ROUND}: - raise ValueError(f"Unknown quantization method: {val}.") + raise ValueError(f"QuantizeConfig: Unknown quantization method: `{val}`.") else: normalized[QUANT_METHOD_FIELD] = val elif key == FORMAT_FIELD_COMPAT_MARLIN and val: @@ -335,10 +357,10 @@ def from_quant_config(cls, quantize_cfg, format: str = None): elif key in field_names: normalized[key] = val else: - logger.info(f"Ignoring unknown parameter in the quantization configuration: {key}.") + logger.info(f"QuantizeConfig: Ignoring unknown parameter in the quantization configuration: {key}.") if format_auto_inferred: - logger.info(f"`{FORMAT_FIELD_JSON}` is missing from the quantization configuration and is automatically inferred to {normalized[FORMAT_FIELD_CODE]}") + logger.info(f"QuantizeConfig: `{FORMAT_FIELD_JSON}` is missing from the quantization configuration and is automatically inferred to {normalized[FORMAT_FIELD_CODE]}") if normalized[FORMAT_FIELD_CODE] in {FORMAT.BITBLAS}: # AWQ and Marlin do not reorder the rows. @@ -346,8 +368,7 @@ def from_quant_config(cls, quantize_cfg, format: str = None): if "sym" not in normalized: logger.warning( - "The quantization configuration does not contain an entry `sym` (symmetric quantization). " - "This may result in silent errors. Defaulting to `sym=True`." + "QuantizeConfig: config does not contain `sym` (symmetric quantization). This may result in silent errors. Defaulting to `sym=True`." ) return cls(**normalized) @@ -367,7 +388,7 @@ def from_pretrained(cls, save_dir: str, **kwargs): if resolved_config_file is None: raise ValueError( - "No quantize_config.json, quant_config.json or config.json file was found in the model repository." + "QuantizeConfig: No quantize_config.json, quant_config.json or config.json file was found in the model repository." ) with open(resolved_config_file, "r", encoding="utf-8") as f: @@ -388,12 +409,15 @@ def to_dict(self): "lm_head": self.lm_head, QUANT_METHOD_FIELD:self.quant_method, FORMAT_FIELD_JSON: self.format, + # torch.dtype convert to string PACK_DTYPE_FIELD: str(self.pack_dtype).split(".")[-1], META_FIELD: self.meta, + # DO NOT EXPORT Adapter to config/json since adapter can be swapped out/in + # ADAPTER_FIELD: self.adapter.to_dict() if self.adapter else None, } - # simplify: clean keys where the value is None - out = {k: v for k, v in out.items() if v is not None} + # simplify: clean keys where the value is None or empty [list, dict] + out = {k: v for k, v in out.items() if v is not None and (v not in [None, {}])} dict_scale_dtype_to_str(out) return out @@ -412,7 +436,12 @@ def calculate_bits_per_weight(self): # FIX ME: g_idx is I32, one per infeature per_group_bits += 4 # ESTIMATE for g_idx int32: one per features/group_size item bpw = per_group_bits / self.group_size + + # normally g_idx (int32 allocated one per in_feature) is allocated in device memory + # but each module may have different infeatures we don't have enouch ctx here, use estimated `0.1` for now + bpw += 0.1 else: + # there is only one scale int32 + one qzero int32 per entire module so overall it contributes to close to 0 bpw bpw = self.bits logger.info(f"Estimated Quantization BPW (bits per weight): {bpw} bpw, based on [bits: {self.bits}, group_size: {self.group_size}]") @@ -480,4 +509,4 @@ def to_dict(self): class BaseQuantizeConfig(QuantizeConfig): def __init__(self, **kwargs): super().__init__(**kwargs) - logger.warning("BaseQuantizeConfig is re-named and pending deprecation. Please use `QuantizeConfig` instead.") + logger.warning("QuantizeConfig: BaseQuantizeConfig is re-named and pending deprecation. Please use `QuantizeConfig` instead.") diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index faa830354..c829805a7 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -20,14 +20,17 @@ import os import sys import time +from typing import Optional import torch import torch.nn as nn import transformers +from gptqmodel.quantization import QuantizeConfig +from ..looper.named_module import NamedModule from ..utils.logger import setup_logger from ..utils.torch import torch_sync -from .quantizer import Quantizer +from .quantizer import HF_OPTIMUM, Quantizer logger = setup_logger() @@ -37,15 +40,23 @@ CPU = torch.device("cpu") class GPTQ: - def __init__(self, module: torch.nn.Module): - self.module = module + def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig]=None): + if isinstance(module, NamedModule): + self.module = module.module + name = module.name + else: + name = HF_OPTIMUM + self.module = module + + self.qcfg = qcfg if qcfg else QuantizeConfig() # HF compat will not pass qcfg self.device = self.module.weight.device self.module_copy = self._clone_module() self.rows, self.columns = self.module_copy.shape[0], self.module_copy.shape[1] # self.H = torch.zeros((self.columns, self.columns), device=self.device) self.nsamples = 0 - self.quantizer = Quantizer() + + self.quantizer = Quantizer(qcfg=self.qcfg, name=name) # fwd input buffer self.fwd_inputs_buffered = False @@ -112,8 +123,7 @@ def process_batch(self, inp): # self.H += 2 / self.nsamples * inp.matmul(inp.t()) self.H += inp.matmul(inp.t()) - # wrapper for backward compat with optimum - # TODO: mark for deprecation + # FIXME, optimum needs fasterquant, we need to remove it def fasterquant( self, blocksize=128, @@ -135,17 +145,19 @@ def hf_quantize( actorder=False, static_groups=False, ): - return self.quantize(blocksize, percdamp, damp_auto_increment, group_size, actorder, static_groups) + self.qcfg.group_size = group_size + self.qcfg.damp_percent = percdamp + self.qcfg.damp_auto_increment = damp_auto_increment + self.qcfg.desc_act = actorder + self.qcfg.static_groups = static_groups + (Q, scale, zero, g_idx, duration, avg_loss, damp_percent) = self.quantize(blocksize=blocksize) + self.module.weight.data = Q + return scale, zero, g_idx, duration, avg_loss, damp_percent @torch.inference_mode() def quantize( self, blocksize=128, - percdamp=0.01, - damp_auto_increment=0.0015, - group_size=-1, - actorder=False, - static_groups=False, ): start = time.time() @@ -156,8 +168,8 @@ def quantize( # release buffer del self.fwd_inputs_buffered_data - if self.device.type not in ["mps", "cpu"]: - self.module.weight.data = self.module.weight.data.cpu() + # if self.device.type not in ["mps", "cpu"]: + # self.module.weight.data = self.module.weight.data.cpu() # TODO: waiting for pytorch implementation of ops for MPS if sys.platform == "darwin" and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1": @@ -169,8 +181,7 @@ def quantize( W = self.module_copy self.module_copy = None - if not self.quantizer.ready(): - self.quantizer.find_params(W, weight=True) + self.quantizer.find_params(W, weight=True) H = self.H del self.H @@ -183,19 +194,19 @@ def quantize( zero = [] now_idx = 1 - if static_groups: + if self.qcfg.static_groups: import copy groups = [] - for i in range(0, self.columns, group_size): + for i in range(0, self.columns, self.qcfg.group_size): quantizer = copy.deepcopy(self.quantizer) - quantizer.find_params(W[:, i : (i + group_size)], weight=True) + quantizer.find_params(W[:, i : (i + self.qcfg.group_size)], weight=True) scale.append(quantizer.scale) zero.append(quantizer.zero) groups.append(quantizer) - if actorder: + if self.qcfg.desc_act: perm = torch.argsort(torch.diag(H), descending=True) W = W[:, perm] H = H[perm][:, perm] @@ -204,9 +215,10 @@ def quantize( Losses = torch.zeros_like(W) Q = torch.zeros_like(W) - while 1 > percdamp > 0: + damp_percent = self.qcfg.damp_percent + while 1 > damp_percent > 0: try: - damp = percdamp * torch.mean(torch.diag(H)) + damp = damp_percent * torch.mean(torch.diag(H)) diag = torch.arange(self.columns, device=self.device) H[diag, diag] += damp @@ -216,15 +228,15 @@ def quantize( Hinv = H break except torch._C._LinAlgError as e: - if damp_auto_increment != 0: - logger.warning(f"Current damp={percdamp:.5f} is too low, increased by {damp_auto_increment:.5f}") - percdamp += damp_auto_increment + if self.qcfg.damp_auto_increment != 0: + logger.warning(f"Quantization: Current `damp_percent = {damp_percent:.5f}` is too low, auto-incrementing by `{ self.qcfg.damp_auto_increment:.5f}`") + damp_percent += self.qcfg.damp_auto_increment else: - logger.warning("Please increase damp or nsamples for calibration data to avoid the following quant error. ") + logger.warning("Quantization: Please increase damp or nsamples for calibration data to avoid the following quant error: current damp_percent=`{damp_percent:.5f}`") raise e - if not (0 < percdamp < 1): - raise ValueError(f"damp_percent must between 0 and 1. current is {percdamp}") + if not (0 < damp_percent < 1): + raise ValueError(f"Quantization: `damp_percent` must between 0 and 1. current is {damp_percent}") for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -240,21 +252,21 @@ def quantize( w = W1[:, i] d = Hinv1[i, i] - if group_size != -1: - if not static_groups: - if (i1 + i) % group_size == 0: - self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + group_size)], weight=True) + if self.qcfg.group_size != -1: + if not self.qcfg.static_groups: + if (i1 + i) % self.qcfg.group_size == 0: + self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + self.qcfg.group_size)], weight=True) - if ((i1 + i) // group_size) - now_idx == -1: + if ((i1 + i) // self.qcfg.group_size) - now_idx == -1: scale.append(self.quantizer.scale) zero.append(self.quantizer.zero) now_idx += 1 else: idx = i1 + i - if actorder: + if self.qcfg.desc_act: idx = perm[idx] - self.quantizer = groups[idx // group_size] + self.quantizer = groups[idx // self.qcfg.group_size] q = self.quantizer.quantize(w.unsqueeze(1)).flatten() Q1[:, i] = q @@ -282,31 +294,38 @@ def quantize( if math.isnan(avg_loss): print("Losses sum item:", torch.sum(Losses).item()) - raise ValueError("Quantization failed due to NaN loss") + raise ValueError("Quantization: Failed due to `NaN` loss") - group_size = group_size if group_size != -1 else self.columns + group_size = self.qcfg.group_size if self.qcfg.group_size != -1 else self.columns - if static_groups and actorder: + if self.qcfg.static_groups and self.qcfg.desc_act: g_idx = [perm[i] // group_size for i in range(self.columns)] else: g_idx = [i // group_size for i in range(self.columns)] g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) - if actorder: + if self.qcfg.desc_act: Q = Q[:, invperm] g_idx = g_idx[invperm] if isinstance(self.module, transformers.Conv1D): Q = Q.t() + # if Q.shape != self.module.weight.shape: + # self.module.weight.data = Q.reshape(self.module.weight.shape).type_as(self.module.weight.data) + # else: + # self.module.weight.data = Q.type_as(self.module.weight.data) + # + # # move back to self.dev + # self.module.weight.data = self.module.weight.data.to(device=self.device) + if Q.shape != self.module.weight.shape: - self.module.weight.data = Q.reshape(self.module.weight.shape).type_as(self.module.weight.data) + Q = Q.reshape(self.module.weight.shape).type_as(self.module.weight.data) else: - self.module.weight.data = Q.type_as(self.module.weight.data) + Q = Q.type_as(self.module.weight.data) - # move back to self.dev - self.module.weight.data = self.module.weight.data.to(device=self.device) + Q = Q.to(device=self.device) # if os.environ.get("DEBUG"): # logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) @@ -319,7 +338,8 @@ def quantize( zero = torch.cat(zero, dim=1) duration = time.time() - start - return scale, zero, g_idx, duration, avg_loss, percdamp + + return Q, scale, zero, g_idx, duration, avg_loss, damp_percent def free(self): # if os.environ.get("DEBUG"): diff --git a/gptqmodel/quantization/quantizer.py b/gptqmodel/quantization/quantizer.py index eec510be1..df7738b5f 100644 --- a/gptqmodel/quantization/quantizer.py +++ b/gptqmodel/quantization/quantizer.py @@ -18,11 +18,13 @@ import torch import torch.nn as nn +from gptqmodel.quantization import QuantizeConfig from ..utils.logger import setup_logger logger = setup_logger() +HF_OPTIMUM = "hf_optimum" def quantize(x, scale, zero, maxq): if maxq < 0: @@ -32,26 +34,32 @@ def quantize(x, scale, zero, maxq): class Quantizer(nn.Module): - def __init__(self, shape=1): + def __init__(self, qcfg: QuantizeConfig, shape=1, name: str=None): super(Quantizer, self).__init__() + + self.qcfg = qcfg self.register_buffer("maxq", torch.tensor(0)) self.register_buffer("scale", torch.zeros(shape)) self.register_buffer("zero", torch.zeros(shape)) + self.name=name + + # FIXME, optimum shouldn't call this directly, it should call hf_configure def configure( self, - bits, perchannel=False, - sym=True, - mse=0.0, # 2.4 grid=100, maxshrink=0.8, trits=False, + bits:int=4, # for hf compat + sym:bool=False, # for hf compat ): - self.maxq = torch.tensor(2**bits - 1) + if self.name == HF_OPTIMUM: + self.qcfg.bits = bits + self.qcfg.sym = sym + + self.maxq = torch.tensor(2**self.qcfg.bits - 1) self.perchannel = perchannel - self.sym = sym - self.mse = mse self.grid = grid self.maxshrink = maxshrink if trits: @@ -80,7 +88,7 @@ def find_params(self, x, weight=False): xmin = torch.minimum(x.min(1)[0], tmp) xmax = torch.maximum(x.max(1)[0], tmp) - if self.sym: + if self.qcfg.sym: xmax = torch.maximum(torch.abs(xmin), xmax) tmp = xmin < 0 if torch.any(tmp): @@ -94,23 +102,23 @@ def find_params(self, x, weight=False): self.zero = xmin else: self.scale = (xmax - xmin) / self.maxq - if self.sym: + if self.qcfg.sym: self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) else: self.zero = torch.round(-xmin / self.scale) - if self.mse > 0.0: + if self.qcfg.mse > 0.0: best = torch.full([x.shape[0]], float("inf"), device=dev) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid xmin1 = p * xmin xmax1 = p * xmax scale1 = (xmax1 - xmin1) / self.maxq - zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + zero1 = torch.round(-xmin1 / scale1) if not self.qcfg.sym else self.zero q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) q -= x q.abs_() - q.pow_(self.mse) + q.pow_(self.qcfg.mse) err = torch.sum(q, 1) tmp = err < best if torch.any(tmp): @@ -141,15 +149,13 @@ def find_params(self, x, weight=False): self.zero = self.zero.unsqueeze(0) def quantize(self, x): - if self.ready(): - return quantize(x, self.scale, self.zero, self.maxq) - return x + return quantize(x, self.scale, self.zero, self.maxq) - def enabled(self): - return self.maxq > 0 + # def enabled(self): + # return self.maxq > 0 - def ready(self): - return torch.all(self.scale != 0) + # def ready(self): + # return torch.all(self.scale != 0) __all__ = ["Quantizer"] diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py index 69f7e3162..aa0b6f400 100644 --- a/gptqmodel/utils/backend.py +++ b/gptqmodel/utils/backend.py @@ -22,9 +22,11 @@ class BACKEND(str, Enum): AUTO_TRAINABLE = "auto_trainable" # choose the optimal trainable local kernel for post-quant training CUDA = "cuda" TORCH = "torch" + EORA_TORCH = "eora_torch" TRITON = "triton" EXLLAMA_V1 = "exllama_v1" EXLLAMA_V2 = "exllama_v2" + # EXLLAMA_EORA = "exllama_eora" MARLIN = "marlin" BITBLAS = "bitblas" IPEX = "ipex" diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py index 2d90f5968..5acf5f7e3 100644 --- a/gptqmodel/utils/bitblas.py +++ b/gptqmodel/utils/bitblas.py @@ -92,7 +92,7 @@ def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool # Note that due to tvm compilation of per layer modules shapes, the first layer loop is # relatively much slower if caching is not available. estimate time remaining is highly inaccurate - for name, module in ProgressBar(model.named_modules(), desc=message, total=len(list(model.named_modules()))): + for name, module in ProgressBar(model.named_modules(), info=message, total=len(list(model.named_modules()))): if not isinstance(module, model_quantlinear): continue @@ -111,7 +111,8 @@ def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool out_features=module.out_features, pack_dtype=qcfg.pack_dtype, bias=module.bias is not None, - enable_tuning=True + enable_tuning=True, + adapter=qcfg.adapter, ) # convert to bitblas format diff --git a/gptqmodel/utils/eval.py b/gptqmodel/utils/eval.py index 2aa080359..60c0eadad 100644 --- a/gptqmodel/utils/eval.py +++ b/gptqmodel/utils/eval.py @@ -21,15 +21,16 @@ from .evalplus import patch_evalplus + class EVAL: - class LM_EVAL(Enum): + class LM_EVAL(str, Enum): ARC_CHALLENGE = "arc_challenge" MMLU = "mmlu" HELLASWAG = "hellaswag" GSM8K_COT = "gsm8k_cot" GPQA = "gpqa" - class EVALPLUS(Enum): + class EVALPLUS(str, Enum): HUMAN = "humaneval" MBPP = "mbpp" @@ -55,7 +56,6 @@ def get_all_tasks_string(cls): full_names.extend(cls.get_full_name(member) for member in attr) return ', '.join(full_names) - def evalplus( model, dataset: str, diff --git a/gptqmodel/utils/evalplus.py b/gptqmodel/utils/evalplus.py index 368c91fa0..b632ee9a2 100644 --- a/gptqmodel/utils/evalplus.py +++ b/gptqmodel/utils/evalplus.py @@ -15,6 +15,7 @@ def patch_evalplus(model): if isinstance(model, BaseGPTQModel) or isinstance(model, PreTrainedModel): model.strip = types.MethodType(patch_strip, model) model.__str__ = types.MethodType(patch_tostring, model) + model.__repr__ = types.MethodType(patch_tostring, model) import torch from evalplus.provider.base import DecoderBase @@ -76,4 +77,16 @@ def __init__( else: # with chat template self.eos += ["\n```\n"] + def __str__(self): + if isinstance(self.model, str): + return self.model + elif isinstance(self.model, PreTrainedModel): + return self.model.config.name_or_path + elif isinstance(self.model, BaseGPTQModel): + return self.model.model_local_path + else: + return self.model.__class__.__name__ + + GPTQModelDecoder.__init__ = PatchedGPTQModelDecoder.__init__ + GPTQModelDecoder.__str__ = PatchedGPTQModelDecoder.__str__ diff --git a/gptqmodel/utils/hf.py b/gptqmodel/utils/hf.py new file mode 100644 index 000000000..d4dd5d34f --- /dev/null +++ b/gptqmodel/utils/hf.py @@ -0,0 +1,54 @@ +from gptqmodel.utils.logger import setup_logger +from transformers import GenerationConfig, PreTrainedModel + +logger = setup_logger() + +# TODO FIXME! Pre-quantized use AutoModelForCausalLM.from_pretrained() but post-quantized use AutoModelForCausalLM.from_config() +def autofix_hf_model_config(model: PreTrainedModel, path: str = None): + if model.can_generate(): + # sync config first + if path: + logger.info(f"Model: Loaded `generation_config`: {model.generation_config}") + try: + cfg = GenerationConfig.from_pretrained(pretrained_model_name=path) + if cfg != model.generation_config: + model.generation_config = cfg + logger.info( + "Model: Auto-fixed `generation_config` mismatch between model and `generation_config.json`.") + logger.info(f"Model: Updated `generation_config`: {model.generation_config}") + else: + pass + # logger.info(f"Model: loaded `generation_config` matching `generation_config.json`.") + except Exception: + logger.info("Model: `generation_config.json` not found. Skipped checking.") + + # print(f"Before autofix_hf_model_config: {model.generation_config}") + autofix_hf_generation_config(model.generation_config) + # print(f"After autofix_hf_model_config: {model.generation_config}") + +def autofix_hf_generation_config(cfg: GenerationConfig): + # HF has recently started to perform very strict validation model save which results in warnings on load() + # to become exceptions on save(). + if cfg.do_sample is False: + errors = 0 + if hasattr(cfg, "temperature") and cfg.temperature is not None and cfg.temperature != 1.0: + errors += 1 + if hasattr(cfg, "top_p") and cfg.top_p is not None and cfg.top_p != 1.0: + errors += 1 + if hasattr(cfg, "min_p") and cfg.min_p is not None: + errors += 1 + if hasattr(cfg, "typical_p") and cfg.typical_p is not None and cfg.typical_p != 1.0: + errors += 1 + # contrastive search uses top_k + if (hasattr(cfg, "top_k") and cfg.top_k is not None and cfg.top_k != 50) and (hasattr(cfg, "penalty_alpha") and cfg.penalty_alpha is None): + errors += 1 + if hasattr(cfg, "epsilon_cutoff") and cfg.epsilon_cutoff is not None and cfg.epsilon_cutoff != 0.0: + errors += 1 + if hasattr(cfg, "eta_cutoff") and cfg.eta_cutoff is not None and cfg.eta_cutoff != 0.0: + errors += 1 + + # fix wrong do_sample + if errors > 0: + cfg.do_sample = True + logger.info("Model: Auto-Fixed `generation_config` by setting `do_sample=True`.") + diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index f926816b7..27798549f 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -19,6 +19,7 @@ from typing import Dict, List, Optional, Type, Union import torch +from gptqmodel.adapter.adapter import Adapter from ..models._const import DEVICE, normalize_device from ..nn_modules.qlinear import BaseQuantLinear, PackableQuantLinear @@ -39,20 +40,20 @@ message_logged = False logger = setup_logger() -BACKEND_DICT = OrderedDict({ +AUTO_SELECT_BACKEND_ORDER = OrderedDict({ BACKEND.MARLIN: MarlinQuantLinear, # optimized for bs > 1 BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, # optimized for bs > 1 BACKEND.EXLLAMA_V1: ExllamaQuantLinear, # optimized for bs == 1 - BACKEND.TRITON: TritonV2QuantLinear, - BACKEND.CUDA: DynamicCudaQuantLinear, - BACKEND.BITBLAS: BitBLASQuantLinear, # super slow JIT compile but fastest for bs=1 - BACKEND.IPEX: IPEXQuantLinear, - BACKEND.TORCH: TorchQuantLinear, + BACKEND.TRITON: TritonV2QuantLinear, # good all around kernel that JIT compiles + # BACKEND.CUDA: DynamicCudaQuantLinear, + BACKEND.BITBLAS: BitBLASQuantLinear, # super slow AOT pre-compiler but fastest for bs=1 + BACKEND.IPEX: IPEXQuantLinear, # best kernel Intel XPU and CPU with amx/avx512/xmx + BACKEND.TORCH: TorchQuantLinear, # slightly slower than Triton but getting close in Torch 2.6.0+ }) FORMAT_DICT = { - FORMAT.GPTQ: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.IPEX, BACKEND.TORCH], - FORMAT.GPTQ_V2: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.TORCH], + FORMAT.GPTQ: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.IPEX, BACKEND.TORCH], # BACKEND.EXLLAMA_EORA + FORMAT.GPTQ_V2: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.TORCH], # , BACKEND.EXLLAMA_EORA FORMAT.MARLIN: [BACKEND.MARLIN], FORMAT.BITBLAS: [BACKEND.BITBLAS], FORMAT.IPEX: [BACKEND.IPEX], @@ -140,6 +141,7 @@ def hf_select_quant_linear( allow_marlin=True, # TODO: remove this after marlin padding is fixed dynamic=None, pack_dtype=torch.int32, + adapter=None, ) @@ -157,6 +159,7 @@ def select_quant_linear( dynamic=None, pack_dtype: torch.dtype = None, multi_select: bool = False, # return all valid kernels + adapter: Optional[Adapter] = None, ) -> Union[Type[BaseQuantLinear], List[Type[BaseQuantLinear]]]: if device is None: device = DEVICE.XPU if backend == BACKEND.IPEX else DEVICE.CUDA @@ -175,29 +178,40 @@ def select_quant_linear( validated_qlinears = [] # Handle the case where backend is AUTO. if backend in [BACKEND.AUTO, BACKEND.AUTO_TRAINABLE]: - allow_quant_linears = [(k, v) for k,v in BACKEND_DICT.items() if k in FORMAT_DICT[format]] + allow_quant_linears = [(k, v) for k,v in AUTO_SELECT_BACKEND_ORDER.items() if k in FORMAT_DICT[format]] err = None global message_logged # Suppose all quant linears in the model should have the same backend. for k, cls in allow_quant_linears: - validate, err = cls.validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, pack_dtype=pack_dtype, dynamic=dynamic, device=device, trainable=trainable) + validate, err = cls.validate( + bits=bits, + group_size=group_size, + desc_act=desc_act, + sym=sym, + pack_dtype=pack_dtype, + dynamic=dynamic, + device=device, + trainable=trainable, + adapter=adapter, + ) if os.environ.get("DEBUG") and not validate: logger.info(f"skip {k} for {str(err)}") if validate: if pack: check_pack_func = issubclass(cls, PackableQuantLinear) if check_pack_func: - if not message_logged: - logger.info(f"Auto pick kernel based on compatibility: {cls}") - message_logged = True + #if not message_logged: + # logger.info(f"Auto pick kernel based on compatibility: {cls}") + # message_logged = True + logger.info(f"Kernel: Auto-selection: adding candidate `{cls.__name__}`") validated_qlinears.append(cls) if not multi_select: return cls else: - if not message_logged: - logger.info(f"Auto pick kernel based on compatibility: {cls}") - message_logged = True - + #if not message_logged: + # logger.info(f"Auto pick kernel based on compatibility: {cls}") + # message_logged = True + logger.info(f"Kernel: Auto-selection: adding candidate `{cls.__name__}`") validated_qlinears.append(cls) if not multi_select: return cls @@ -216,6 +230,8 @@ def select_quant_linear( qlinear = BitBLASQuantLinear elif backend == BACKEND.MARLIN: qlinear = MarlinQuantLinear + # elif backend == BACKEND.EXLLAMA_EORA: + # qlinear = ExllamaEoraQuantLinear elif backend == BACKEND.EXLLAMA_V2: qlinear = ExllamaV2QuantLinear elif backend == BACKEND.EXLLAMA_V1: @@ -225,13 +241,13 @@ def select_quant_linear( elif backend == BACKEND.IPEX: from ..nn_modules.qlinear.ipex import HAS_IPEX if not HAS_IPEX: - raise ValueError("IPEX is not available.") + raise ValueError("Kernel: IPEX is not installed. Please install it via `pip install gptqmodel['ipex']`") from device_smi import Device cpu_vendor = Device("cpu").vendor if cpu_vendor != "intel": - logger.warning(f"Intel/IPEX cpu kernel is only validated and optimized for Intel cpu. Current cpu vendor: `{cpu_vendor}`.") + logger.warning(f"Kernel: IPEX on cpu is only validated and optimized for Intel cpu with AVX512, AMX, or XMX. Current cpu vendor: `{cpu_vendor}`.") qlinear = IPEXQuantLinear elif backend == BACKEND.TORCH: diff --git a/gptqmodel/utils/logger.py b/gptqmodel/utils/logger.py index 0b3f8e92b..bfde3a9bb 100644 --- a/gptqmodel/utils/logger.py +++ b/gptqmodel/utils/logger.py @@ -15,21 +15,75 @@ # limitations under the License. import logging +import sys +from typing import Callable + +from colorlog import ColoredFormatter # global static/shared logger instance logger = None +last_logging_src = 1 # one for logger, 2 for progressbar + +def update_logging_src(src: int): + global last_logging_src + last_logging_src = src def setup_logger(): global logger if logger is not None: return logger + class CustomLogger(logging.Logger): + def critical(self, msg, *args, **kwargs): + op = super().critical + self._process(op, msg, *args, **kwargs) + + def warning(self, msg, *args, **kwargs): + op = super().warning + self._process(op, msg, *args, **kwargs) + + def debug(self, msg, *args, **kwargs): + op = super().debug + self._process(op, msg, *args, **kwargs) + + def info(self, msg, *args, **kwargs): + op = super().info + self._process(op, msg, *args, **kwargs) + + def _process(self, op: Callable, msg, *args, **kwargs): + global last_logging_src + if last_logging_src == 2: + print(" ", flush=True) + last_logging_src = 1 + op(msg, *args, **kwargs) + + logging.setLoggerClass(CustomLogger) + logger = logging.getLogger(__name__) - handler = logging.StreamHandler() - formatter = logging.Formatter("%(levelname)s - %(message)s") - handler.setFormatter(formatter) logger.propagate = False - logger.addHandler(handler) logger.setLevel(logging.DEBUG) + # Create a colored formatter + formatter = ColoredFormatter( + "%(log_color)s%(levelname)-8s%(reset)s %(message)s", + datefmt=None, + reset=True, + log_colors={ + 'DEBUG': 'cyan', + 'INFO': 'green', + 'WARNING': 'yellow', + 'ERROR': 'red', + 'CRITICAL': 'red,bg_white', + }, + secondary_log_colors={}, + style='%' + ) + + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + handler.flush = sys.stdout.flush + logger.addHandler(handler) + return logger + + diff --git a/gptqmodel/utils/marlin.py b/gptqmodel/utils/marlin.py index 41a902629..42b1edb71 100644 --- a/gptqmodel/utils/marlin.py +++ b/gptqmodel/utils/marlin.py @@ -110,7 +110,7 @@ def convert_to_marlin( # TODO: load directly Marlin QuantLinear. message = "Overriding QuantLinear layers to use Marlin's QuantLinear" - for name, module in ProgressBar(model.named_modules(), desc=message, total=len(list(model.named_modules()))): + for name, module in ProgressBar(model.named_modules(), info=message, total=len(list(model.named_modules()))): if not isinstance(module, model_quantlinear): continue diff --git a/gptqmodel/utils/mlx.py b/gptqmodel/utils/mlx.py index 9fa642917..7f02eee60 100644 --- a/gptqmodel/utils/mlx.py +++ b/gptqmodel/utils/mlx.py @@ -20,7 +20,7 @@ logger = setup_logger() -def convert_gptq_to_mlx_weights(model_id_or_path: str, model: Union[PreTrainedModel, BaseGPTQModel], gptq_config: QuantizeConfig): +def convert_gptq_to_mlx_weights(model_id_or_path: str, model: Union[PreTrainedModel, BaseGPTQModel], gptq_config: QuantizeConfig, lm_head_name: str): if not MLX_AVAILABLE: raise ValueError("MLX not installed. Please install via `pip install gptqmodel[mlx] --no-build-isolation`.") @@ -49,9 +49,9 @@ def convert_gptq_to_mlx_weights(model_id_or_path: str, model: Union[PreTrainedMo # Convert weights weights = {} n = 1 - pb = ProgressBar(model.named_modules(), prefix="Converting to mlx:", total=len(list(model.named_modules()))) + pb = ProgressBar(model.named_modules(), prefix="Format: Converting to mlx ->", total=len(list(model.named_modules()))) for name, module in pb: - pb.set_description(f"{name}") + pb.info(f"{name}") if isinstance(module, TorchQuantLinear): weights[f"{name}.weight"] = mx.array( module.dequantize_weight().T.detach().to("cpu", torch.float16).numpy() @@ -65,8 +65,7 @@ def convert_gptq_to_mlx_weights(model_id_or_path: str, model: Union[PreTrainedMo n += 1 - elif hasattr(module, "weight") and ( - name != "lm_head" if config.get("tie_word_embeddings", False) else True): + elif hasattr(module, "weight") and (config.tie_word_embeddings or name != lm_head_name): weights[f"{name}.weight"] = mx.array( module.weight.detach().to("cpu", torch.float16).numpy() ) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 3d5cc9a24..ef1ad2607 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -26,19 +26,21 @@ import shutil from concurrent.futures import ThreadPoolExecutor from enum import Enum -from typing import Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import accelerate import threadpoolctl as tctl import torch import torch.nn as nn import transformers +from gptqmodel.adapter.adapter import Adapter from huggingface_hub import HfApi, hf_hub_download from packaging import version from transformers import AutoConfig, PretrainedConfig from transformers.pytorch_utils import id_tensor_storage from transformers.utils.hub import cached_file +from ..looper.named_module import NamedModule from ..models._const import (CPU, DEVICE, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS, SUPPORTS_MODULE_TYPES) from ..nn_modules.qlinear import BaseQuantLinear @@ -51,7 +53,7 @@ from .importer import select_quant_linear from .logger import setup_logger from .progress import ProgressBar -from .torch import torch_empty_cache +from .torch import torch_empty_cache, torch_new_stream_ctx logger = setup_logger() @@ -88,22 +90,46 @@ def get_device(obj: torch.Tensor | nn.Module): return next(obj.parameters()).device -def move_to(obj: torch.Tensor | nn.Module, device: torch.device): +def move_to(obj: torch.Tensor | nn.Module, device: torch.device, dtype: torch.dtype = None, stream: bool = False): if get_device(obj) != device: - obj = obj.to(device) + if stream: + # we cannot support changing dtype and stream at the same time + assert dtype is None, f"streaming does not support changing dtype: actual = `{dtype}" + if not isinstance(obj, torch.Tensor): + raise NotImplementedError( + f"Streaming `move_to` is not supported for non-Tensors: actual = `{obj.__class__.__name__}`") + + if device == CPU: + # print(f" streaming from non-CPU to CPU...nonblocking") + obj_copy = torch.zeros_like(obj, device=CPU, pin_memory=True) + streamCtx = torch_new_stream_ctx() + if streamCtx: + # use streaming context with pinned cpu memory + with streamCtx: + obj_copy.copy_(obj, non_blocking=True) + return obj_copy + else: + # does not support streaming context + obj = obj.to(device=device, non_blocking=True) + else: + # cpu to non-cpu or non-cpu to non-cpu uses normal .to() api + obj = obj.to(device=device, non_blocking=True) + else: + obj = obj.to(device=device, dtype=dtype, non_blocking=False) + return obj -def nested_move_to(v, device): +def nested_move_to(v, device, dtype: torch.dtype = None, stream: bool = False): if isinstance(v, torch.Tensor): - return move_to(v, device) + return move_to(v, device=device, dtype=dtype, stream=stream) elif isinstance(v, (list, tuple)): - return type(v)([nested_move_to(e, device) for e in v]) + return type(v)([nested_move_to(e, device=device, dtype=dtype, stream=stream) for e in v]) else: return v -def find_modules(module, layers=None, name=""): +def find_modules(module: nn.Module, layers=None, name: str="") -> Dict[str, nn.Module]: if not layers: layers = SUPPORTS_MODULE_TYPES @@ -140,23 +166,26 @@ def get_module(module, key): module = getattr(module, name, None) return module - def make_quant( module, - names, - bits: int, - group_size: int, + quant_result: Dict[str, Dict[str, Any]], + qcfg: QuantizeConfig, backend: BACKEND, - format: str | FORMAT, lm_head_name: str, - desc_act: bool = False, - sym: bool = True, pack: bool = False, - dynamic=None, device: DEVICE = None, from_quantized: bool = False, - pack_dtype: torch.dtype = None, ) -> Type[BaseQuantLinear]: + + bits = qcfg.bits + group_size =qcfg.group_size + extension = qcfg.adapter + format = qcfg.format + desc_act = qcfg.desc_act + sym = qcfg.sym + dynamic = qcfg.dynamic + pack_dtype = qcfg.pack_dtype + # returns multiple validated kernels quant_linear_candidates = select_quant_linear( bits=bits, @@ -170,9 +199,10 @@ def make_quant( device=device, pack_dtype=pack_dtype, multi_select=True, + adapter=extension, ) - logger.info(f"make_quant: Linear candidates: {quant_linear_candidates}") + logger.info(f"Kernel: candidates -> `[{', '.join(cls.__name__ for cls in quant_linear_candidates)}]`") # loop over actual QLinear init, catch errors and use fallbacks if applicable for cls in quant_linear_candidates: @@ -189,15 +219,18 @@ def make_quant( dynamic=dynamic, group_size=group_size, module=module, - names=names, + quant_result=quant_result, sym=sym, device=device, lm_head_name=lm_head_name, - pack_dtype=pack_dtype) - logger.info(f"make_quant: Selected linear: `{cls}`.") + pack_dtype=pack_dtype, + adapter=qcfg.adapter, + ) + logger.info(f"Kernel: selected -> `{linear_cls.__name__}`.") return linear_cls except NotImplementedError as e: - logger.info(f"make_quant: Skipped linear: `{cls}`. ") + logger.info(f"Kernel: skipped -> `{cls}`.") + # only fallback to other quant linears when backend is auto. if backend not in [BACKEND.AUTO, BACKEND.AUTO_TRAINABLE]: raise e @@ -212,86 +245,97 @@ def create_quant_layer( dynamic, group_size: int, module, - names, + quant_result: Dict[str, Dict[str, Any]], sym: bool, device: DEVICE, lm_head_name: str, - pack_dtype: torch.dtype) -> Type[BaseQuantLinear]: + pack_dtype: torch.dtype, + adapter: Optional[Adapter] = None, +) -> Type[BaseQuantLinear]: if isinstance(module, linear_cls): return linear_cls for name, submodule in module.named_modules(): - if name in names: - ori_layer_device = next(submodule.parameters()).device - if isinstance(submodule, nn.Linear): - in_features = submodule.in_features - out_features = submodule.out_features - elif isinstance(submodule, nn.Conv2d): - in_features = submodule.in_channels - out_features = submodule.out_channels - elif isinstance(submodule, transformers.pytorch_utils.Conv1D): - in_features = submodule.weight.shape[0] - out_features = submodule.weight.shape[1] - elif isinstance(submodule, BaseQuantLinear): - # if submodule is already a quant layer, we need to get in_features and out_features from the submodule - in_features = submodule.in_features - out_features = submodule.out_features - else: - raise NotImplementedError(f"Unsupported module {submodule}") - - bias = submodule.bias is not None - - # need copies as dynamic config may override these in for loop - tmp_bits = bits - tmp_group_size = group_size - tmp_desc_act = desc_act - tmp_sym = sym - tmp_pack_dtype = pack_dtype - - # dynamic bits, group_size, sym, pack_dtype for each layer/module - if dynamic is not None: - overrides = dynamic_get(dynamic=dynamic, module_name=name) - # negative module match, skip this module - if overrides == False: # noqa: E712 - continue - - # positive module match - if overrides: - # override base QuantizeConfig for every quant config key/value - tmp_bits = overrides.get("bits", bits) - tmp_group_size = overrides.get("group_size", group_size) - tmp_desc_act = overrides.get("desc_act", desc_act) - tmp_sym = overrides.get("sym", sym) - tmp_pack_dtype = overrides.get("pack_dtype", pack_dtype) - - # when loading a quantized model, device is target device passed in GPTQModel.load() - # check in_features and out_features validate - _, err = linear_cls.validate( - bits=tmp_bits, - group_size=tmp_group_size, - desc_act=tmp_desc_act, - sym=tmp_sym, - pack_dtype=tmp_pack_dtype, - in_features=in_features, - out_features=out_features, - device=device) - if err is not None: - raise err - - new_layer = linear_cls( - bits=tmp_bits, - group_size=tmp_group_size, - desc_act=tmp_desc_act, - sym=tmp_sym, - in_features=in_features, - out_features=out_features, - pack_dtype=tmp_pack_dtype, - bias=bias, - #weight_dtype=submodule.qweight.dtype if isinstance(submodule, BaseQuantLinear) else submodule.weight.dtype, - name=name, - lm_head_name=lm_head_name, - ) - new_layer.device = ori_layer_device - recurse_setattr(module, name, new_layer.to(ori_layer_device)) + # skip non-quantized modules + if name not in quant_result: + continue + + ori_layer_device = next(submodule.parameters()).device + if isinstance(submodule, NamedModule): + in_features = submodule.state.get("in_features") + out_features = submodule.state.get("out_features") + elif isinstance(submodule, nn.Linear): + in_features = submodule.in_features + out_features = submodule.out_features + elif isinstance(submodule, nn.Conv2d): + in_features = submodule.in_channels + out_features = submodule.out_channels + elif isinstance(submodule, transformers.pytorch_utils.Conv1D): + in_features = submodule.weight.shape[0] + out_features = submodule.weight.shape[1] + elif isinstance(submodule, BaseQuantLinear): + # if submodule is already a quant layer, we need to get in_features and out_features from the submodule + in_features = submodule.in_features + out_features = submodule.out_features + else: + raise NotImplementedError(f"Unsupported module {submodule}") + + bias = submodule.bias is not None + + # need copies as dynamic config may override these in for loop + tmp_bits = bits + tmp_group_size = group_size + tmp_desc_act = desc_act + tmp_sym = sym + tmp_pack_dtype = pack_dtype + + # dynamic bits, group_size, sym, pack_dtype for each layer/module + if dynamic is not None: + overrides = dynamic_get(dynamic=dynamic, module_name=name) + # negative module match, skip this module + if overrides == False: # noqa: E712 + continue + + # positive module match + if overrides: + # override base QuantizeConfig for every quant config key/value + tmp_bits = overrides.get("bits", bits) + tmp_group_size = overrides.get("group_size", group_size) + tmp_desc_act = overrides.get("desc_act", desc_act) + tmp_sym = overrides.get("sym", sym) + tmp_pack_dtype = overrides.get("pack_dtype", pack_dtype) + + # when loading a quantized model, device is target device passed in GPTQModel.load() + # check in_features and out_features validate + _, err = linear_cls.validate( + bits=tmp_bits, + group_size=tmp_group_size, + desc_act=tmp_desc_act, + sym=tmp_sym, + pack_dtype=tmp_pack_dtype, + in_features=in_features, + out_features=out_features, + device=device, + adapter=adapter, # TODO FIX ME..need to pass Eora if loaded + ) + if err is not None: + raise err + + new_layer = linear_cls( + bits=tmp_bits, + group_size=tmp_group_size, + desc_act=tmp_desc_act, + sym=tmp_sym, + in_features=in_features, + out_features=out_features, + pack_dtype=tmp_pack_dtype, + bias=bias, + #weight_dtype=submodule.qweight.dtype if isinstance(submodule, BaseQuantLinear) else submodule.weight.dtype, + name=name, + lm_head_name=lm_head_name, + adapter=adapter, + ) + new_layer.device = ori_layer_device + recurse_setattr(module, name, new_layer.to(ori_layer_device)) return linear_cls # public/stable api exposed to transformer/optimum @@ -454,12 +498,13 @@ def convert_gptq_v2_to_v1_format( return model -def pack_module(name, qModules, quantizers, layers, pbar=None): +def pack_module(name, qModules, quant_result, layers, pbar=None): # Limit pack() thread usage to avoid auto-parallizataion regression with tctl.threadpool_limits(limits=1): if pbar: - pbar.set_description(f"Packing {name}") - quantizers[name], scale, zero, g_idx = quantizers[name] + pbar.info(f"Packing {name}") + r = quant_result[name] + scale, zero, g_idx = r.get("scale"), r.get("zero"), r.get("g_idx") # TODO FIX ME: use const, not string for field names layer_device = qModules[name].device qModules[name].to(CPU) layers[name], scale, zero, g_idx = ( @@ -468,7 +513,7 @@ def pack_module(name, qModules, quantizers, layers, pbar=None): zero.to(CPU), g_idx.to(CPU) if g_idx is not None else None, ) - qModules[name].pack(layers[name], scale, zero, g_idx) + qModules[name].pack(linear=layers[name], scales=scale, zeros=zero, g_idx=g_idx) qModules[name].to(layer_device) if pbar: pbar.progress() @@ -476,7 +521,7 @@ def pack_module(name, qModules, quantizers, layers, pbar=None): def pack_model( model, - quantizers, + quant_result: Dict[str, Dict[str, Any]], bits, group_size, backend: BACKEND, @@ -488,25 +533,32 @@ def pack_model( parallel_packing: bool = True, pack_dtype: torch.dtype = None, ): + qcfg = QuantizeConfig( + bits=bits, + group_size=group_size, + format=format, + desc_act=desc_act, + sym=sym, + dynamic=dynamic, + pack_dtype=pack_dtype, + ) + model.to(CPU) logger.info("Packing model...") modules = find_modules(model) - modules = {n: modules[n] for n in quantizers} + + modules = {n: modules[n] for n in quant_result} quant_linear_cls = make_quant( model, - quantizers, - bits, - group_size, + quant_result=quant_result, + qcfg=qcfg, backend=backend, - format=format, lm_head_name=lm_head_name, - desc_act=desc_act, pack=True, - dynamic=dynamic, - pack_dtype=pack_dtype, ) + qModules = find_modules(model, [quant_linear_cls]) assert len(qModules) > 0, f"No quantizeed modules[{quant_linear_cls}] found in the model." @@ -521,7 +573,7 @@ def pack_model( with ThreadPoolExecutor(max_workers=max_workers) as executor: with ProgressBar(total=len(names)) as pbar: def wrapper(name): - pack_module(name, qModules, quantizers, modules, pbar) + pack_module(name, qModules, quant_result, modules, pbar) for _ in executor.map(wrapper, names): pass diff --git a/gptqmodel/utils/perplexity.py b/gptqmodel/utils/perplexity.py index f5073aee3..653adb776 100644 --- a/gptqmodel/utils/perplexity.py +++ b/gptqmodel/utils/perplexity.py @@ -149,7 +149,7 @@ def calculate(self, n_ctx=512, n_batch=512): curr_ppl = 0 all_perplexity = [] - with ProgressBar(range(len(tokens[0]) // n_ctx), desc="Perplexity: - ") as progress: + with ProgressBar(range(len(tokens[0]) // n_ctx), info="Perplexity: - ") as progress: for i in progress: # Process each batch of tokens nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count) @@ -157,7 +157,7 @@ def calculate(self, n_ctx=512, n_batch=512): # Calculate and display the current perplexity curr_ppl = np.exp(nll / count) all_perplexity.append(curr_ppl) - progress.set_description(f"Perplexity: {curr_ppl:.4f}") + progress.info(f"Perplexity: {curr_ppl:.4f}") return all_perplexity diff --git a/gptqmodel/utils/progress.py b/gptqmodel/utils/progress.py index 6bd63d6ca..19efeb9fc 100644 --- a/gptqmodel/utils/progress.py +++ b/gptqmodel/utils/progress.py @@ -15,9 +15,15 @@ # limitations under the License. import datetime +import os +import sys import time +from typing import Iterable from warnings import warn +from gptqmodel.utils.logger import setup_logger, update_logging_src + +logger = setup_logger() class ProgressBarWarning(Warning): def __init__(self, msg, fp_write=None, *a, **k): @@ -27,7 +33,17 @@ def __init__(self, msg, fp_write=None, *a, **k): super().__init__(msg, *a, **k) class ProgressBar: - def __init__(self, iterable=None, total=None, prefix='', bar_length=40, fill='█', desc=""): + def __init__(self, + iterable: Iterable=None, + total=None, + prefix:str = '', + bar_length:int =60, + fill:str = '█', + info:str = ""): + + # max info length over the life ot the pb + self.max_info_length = len(info) + if total is None and iterable is not None: try: total = len(iterable) @@ -45,20 +61,43 @@ def __init__(self, iterable=None, total=None, prefix='', bar_length=40, fill=' self.prefix = prefix self.bar_length = bar_length self.fill = fill - self.description = desc - self.current = 0 + self.info_text = info + self.current_iteration = 0 self.time = time.time() - def set_description(self, description): - self.description = description + def info(self, info:str): + if len(info) > self.max_info_length: + self.max_info_length = len(info) + + self.info_text = info - def progress(self, iteration = None): + def progress(self, iteration:int = None): if not iteration: - iteration = self.current - percent = ("{0:.1f}").format(100 * (iteration / float(len(self)))) - filled_length = int(self.bar_length * iteration // len(self)) - bar = self.fill * filled_length + '-' * (self.bar_length - filled_length) - self.log(bar, f"{self.calc_time(iteration)} [{iteration}/{len(self)}] {percent}%") + iteration = self.current_iteration + + columns, _ = terminal_size() + bar_length = columns + bar_length -= len(self.prefix) # +1 for space + bar_length -= len(self.info_text) + + percent_num = iteration / float(len(self)) + percent = ("{0:.1f}").format(100 * (percent_num)) + log = f"{self.calc_time(iteration)} [{iteration}/{len(self)}] {percent}%" + + bar_length -= len(log) + bar_length -= 5 # space + | chars + + # calculate padding + if len(self.info_text) < self.max_info_length: + padding = " " * (self.max_info_length - len(self.info_text)) + else: + padding = "" + + bar_length -= len(padding) + + filled_length = int(bar_length * iteration // len(self)) + bar = self.fill * filled_length + '-' * (bar_length - filled_length) + self.log(bar=bar, log=log, padding=padding, end='\n' if percent_num >= 1.0 else '') def calc_time(self, iteration): used_time = int(time.time() - self.time) @@ -66,8 +105,14 @@ def calc_time(self, iteration): remaining = str(datetime.timedelta(seconds=int((used_time / max(iteration, 1)) * len(self)))) return f"{formatted_time} / {remaining}" - def log(self, bar, log): - print(f'\r{self.prefix} {self.description} |{bar}| {log}', end='', flush=True) + def log(self, bar:str, log:str, padding:str = "", end: str = ""): + # print(f'\r{self.prefix} {self.info_text} |{bar}| {log}', end='', flush=True) + if self.prefix: + print(f'\r{self.prefix} {self.info_text}{padding} |{bar}| {log}', end=end, flush=True) + else: + print(f'\r{self.info_text}{padding} |{bar}| {log}', end=end, flush=True) + + update_logging_src(src=2) # let logger now we logged def __bool__(self): if self.total is not None: @@ -84,6 +129,7 @@ def __len__(self): else self.iterable.__length_hint__() if hasattr(self.iterable, "__length_hint__") else getattr(self, "total", None)) + # TODO FIXME: I have no cluse why the try/catch is catching nothing here def __reversed__(self): try: orig = self.iterable @@ -102,6 +148,7 @@ def __contains__(self, item): def __enter__(self): return self + # TODO FIXME: I don't understand the exception here. What are we catching? yield error? def __exit__(self, exc_type, exc_value, traceback): try: self.close() @@ -125,12 +172,60 @@ def __iter__(self): iterable = self.iterable for obj in iterable: - self.current+=1 + self.current_iteration+=1 self.progress() yield obj + + self.progress() return def close(self): - self.log(f"{'-' * self.bar_length}", "100.0%") - + pass + #self.log(f"{self.fill * self.bar_length}", "100.0%", end="\n") + +# copied from github.com/onsim/shutils +def terminal_size(fallback=(80, 24)): + """Get the size of the terminal window. + + For each of the two dimensions, the environment variable, COLUMNS + and LINES respectively, is checked. If the variable is defined and + the value is a positive integer, it is used. + + When COLUMNS or LINES is not defined, which is the common case, + the terminal connected to sys.__stdout__ is queried + by invoking os.get_terminal_size. + + If the terminal size cannot be successfully queried, either because + the system doesn't support querying, or because we are not + connected to a terminal, the value given in fallback parameter + is used. Fallback defaults to (80, 24) which is the default + size used by many terminal emulators. + + The value returned is a named tuple of type os.terminal_size. + """ + # columns, lines are the working values + try: + columns = int(os.environ['COLUMNS']) + except (KeyError, ValueError): + columns = 0 + + try: + lines = int(os.environ['LINES']) + except (KeyError, ValueError): + lines = 0 + + # only query if necessary + if columns <= 0 or lines <= 0: + try: + size = os.get_terminal_size(sys.__stdout__.fileno()) + except (AttributeError, ValueError, OSError): + # stdout is None, closed, detached, or not a terminal, or + # os.get_terminal_size() is unsupported + size = os.terminal_size(fallback) + if columns <= 0: + columns = size.columns or fallback[0] + if lines <= 0: + lines = size.lines or fallback[1] + + return (columns, lines) diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index db5dbba51..e83cfdb05 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -15,14 +15,28 @@ # limitations under the License. import gc as py_gc +from typing import Callable, Union import torch +from gptqmodel.utils.logger import setup_logger +from packaging.version import Version HAS_CUDA = False HAS_XPU = False HAS_MPS = False HAS_MLX = False +STREAM = None # cache + +logger = setup_logger() + +# reset dynamo cache on each model load since during ci loop model inference may exhuast cache +torch._dynamo.reset() + +# Increase the dynamo cache size limit, default of 8 is too low +if torch._dynamo.config.cache_size_limit < 128: + torch._dynamo.config.cache_size_limit = 128 + if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available(): HAS_CUDA = True @@ -39,6 +53,37 @@ except BaseException: pass +def torch_compile(module: Union[torch.nn.Module, Callable], backend:str ="inductor", mode: str = None, fullgraph=False): + from gptqmodel.models.base import PYTORCH_MIN_VERSION_WITH_COMPILE + + if Version(torch.__version__) < PYTORCH_MIN_VERSION_WITH_COMPILE: + return module + try: + return torch.compile(module, backend=backend, mode=mode, fullgraph=fullgraph) + except BaseException: + logger.warning(f"Failed to compile `{module}`") + return module + +def torch_new_stream(): + global STREAM + if STREAM is None: + return STREAM + + if HAS_CUDA: + STREAM = torch.cuda.Stream() + return STREAM + if HAS_XPU: + STREAM = torch.xpu.Stream() + return STREAM + return None + +def torch_new_stream_ctx(): + if HAS_CUDA: + return torch.cuda.stream(torch_new_stream()) + if HAS_XPU: + return torch.xpu.Stream(torch_new_stream()) + return None + def torch_sync(device: torch.device = None): # check all backends if device is None: diff --git a/requirements.txt b/requirements.txt index e50509f69..6fab58144 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,6 @@ protobuf>=5.29.3 pillow>=11.1.0 hf_transfer>=0.1.9 huggingface_hub>=0.28.1 -tokenicer>=0.0.2 +lm-eval==0.4.7 +colorlog>=6.9.0 +tokenicer>=0.0.2 \ No newline at end of file diff --git a/setup.py b/setup.py index d5589b253..fb47913ef 100644 --- a/setup.py +++ b/setup.py @@ -314,12 +314,12 @@ def run(self): install_requires=requirements, extras_require={ "test": ["pytest>=8.2.2", "parameterized"], - "quality": ["ruff==0.4.9", "isort==5.13.2"], - 'vllm': ["vllm>=0.6.4", "flashinfer-python>=0.2.1"], - 'sglang': ["sglang>=0.3.2", "flashinfer-python>=0.2.1"], + "quality": ["ruff==0.9.6", "isort==6.0.0"], + 'vllm': ["vllm>=0.6.4", "flashinfer-python>=0.2.1"], + 'sglang': ["sglang[srt]>=0.3.2", "flashinfer-python>=0.2.1"], 'bitblas': ["bitblas==0.0.1-dev13"], 'hf': ["optimum>=1.21.2"], - 'ipex': ["intel_extension_for_pytorch>=2.5.0"], + 'ipex': ["intel_extension_for_pytorch>=2.6.0"], 'auto_round': ["auto_round>=0.3"], 'logger': ["clearml", "random_word", "plotly"], 'eval': ["lm_eval>=0.4.7", "evalplus>=0.3.1"], diff --git a/test_prepare_dataset.py b/test_prepare_dataset.py new file mode 100644 index 000000000..425431546 --- /dev/null +++ b/test_prepare_dataset.py @@ -0,0 +1,66 @@ + +from datasets import load_dataset +from gptqmodel import GPTQModel, QuantizeConfig + + +def question_answering_format(question, answer): + + return f"Question: {question}\nAnswer: {answer}" + +## An example of using ARC for construting the EoRA calibration set + +def construct_c4(nsamples): + calibration_dataset = load_dataset( + "allenai/c4", + data_files="en/c4-train.00001-of-01024.json.gz", + split="train" + ).select(range(1024))["text"] + return calibration_dataset + +def construct_ARC(nsamples): + arc_easy_calibration_dataset = load_dataset('ai2_arc', 'ARC-Easy', split='train').select(range(nsamples)) + arc_challenge_calibration_dataset = load_dataset('ai2_arc', 'ARC-Challenge', split='train').select(range(nsamples)) + dataset = [] + + for example in arc_easy_calibration_dataset: + answer = example['choices']['text'][example['choices']['label'].index(example['answerKey'])] + question = example['question'] + dataset.append(question_answering_format(question=question,answer=answer)) + + for example in arc_challenge_calibration_dataset: + answer = example['choices']['text'][example['choices']['label'].index(example['answerKey'])] + question = example['question'] + dataset.append(question_answering_format(question=question,answer=answer)) + + ## we recommend also include some examples from C4 to avoid overfitting to the downstream data + c4_dataset = load_dataset( + "allenai/c4", + data_files="en/c4-train.00001-of-01024.json.gz", + split="train" + ).select(range(nsamples))["text"] + + return dataset + c4_dataset + + +# arc_calibration_dataset = construct_ARC(1024) +# print(len(arc_calibration_dataset)) +# print(arc_calibration_dataset[-1]) + +# c4_calibrarion_dataset = construct_c4(1024) + +# model_id = "meta-llama/Llama-3.2-1B" +# quant_config = QuantizeConfig(bits=4, group_size=128) +# model = GPTQModel.load(model_id, quant_config) + +# ## tokenizer for testing +# from transformers import AutoTokenizer + +# tokenizer = AutoTokenizer.from_pretrained(model_id) + +# prepare_dataset = model.prepare_dataset(c4_calibrarion_dataset) + + +# inputs = tokenizer(c4_calibrarion_dataset[0], return_tensors="pt") +# print(inputs['input_ids'].shape) + +# print(prepare_dataset[0]['input_ids'].shape) \ No newline at end of file diff --git a/tests/benchmark/benchmark.py b/tests/benchmark/benchmark.py index b23b5ca17..5aeb3f276 100644 --- a/tests/benchmark/benchmark.py +++ b/tests/benchmark/benchmark.py @@ -22,10 +22,10 @@ class TestInference(BenchmarkTest): @parameterized.expand( [ - (BACKEND.TORCH, 'cuda', 292.50), - (BACKEND.TORCH, 'cpu', 5.50), - (BACKEND.TORCH, 'xpu', 58.20), - (BACKEND.TORCH, 'mps', 3.40), + (BACKEND.TORCH, 'cuda', 210), + # (BACKEND.TORCH, 'cpu', 5.50), + # (BACKEND.TORCH, 'xpu', 58.20), + # (BACKEND.TORCH, 'mps', 3.40), ] ) def test_inference(self, backend, device, tokens_per_second): diff --git a/tests/benchmark/benchmark_test.py b/tests/benchmark/benchmark_test.py index 8ce94bada..ff84a693f 100644 --- a/tests/benchmark/benchmark_test.py +++ b/tests/benchmark/benchmark_test.py @@ -23,13 +23,13 @@ from gptqmodel import GPTQModel # noqa: E402 from gptqmodel.utils.progress import ProgressBar # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 class BenchmarkTest(unittest.TestCase): MODEL_id = "/monster/data/model/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1" - MIN_NEW_TOEKNS = 10 - NUM_RUNS = 10 + MIN_NEW_TOKENS = 10 + MAX_NEW_TOKENS = 20 + NUM_RUNS = 50 PROMPTS = [ "I am in Paris and I", "The capital of the United Kingdom is", @@ -45,31 +45,38 @@ class BenchmarkTest(unittest.TestCase): MAX_DELTA_FLOOR_PERCENT = 0.25 MAX_POSITIVE_DELTA_CEIL_PERCENT = 1.0 - def benchmark(self, backend, device, tokens_per_second): - model = GPTQModel.from_quantized( + def benchmark(self, backend, device, tokens_per_second: int, warmup_iter: int = 1): + model = GPTQModel.load( self.MODEL_id, device=device, backend=backend, + use_cache=False, ) - tokenizer = AutoTokenizer.from_pretrained(self.MODEL_id) - tokenizer.pad_token = tokenizer.eos_token - inp = tokenizer(self.PROMPTS, padding=True, truncation=True, return_tensors="pt", padding_side='left').to(device) + model.optimize() + + tokenizer = model.tokenizer + inp = tokenizer(self.PROMPTS, padding=True, padding_side="left", pad_to_multiple_of=16, truncation=True, return_tensors="pt",).to(device) + + print(f"Warming up: warmup_iter = `{warmup_iter}`") + for i in range(warmup_iter): + _ = model.generate(**inp, min_new_tokens=self.MIN_NEW_TOKENS, + max_new_tokens=self.MAX_NEW_TOKENS) times = [] pb = ProgressBar(range(self.NUM_RUNS)) for i in pb: - pb.set_description(f"run index {i} of {self.NUM_RUNS -1}") + pb.info(f"run index {i} of {self.NUM_RUNS - 1}") start_time = time.time() - _ = model.generate(**inp, num_beams=1, min_new_tokens=self.MIN_NEW_TOEKNS, - max_new_tokens=self.MIN_NEW_TOEKNS) + _ = model.generate(**inp,min_new_tokens=self.MIN_NEW_TOKENS, + max_new_tokens=self.MAX_NEW_TOKENS) end_time = time.time() elapsed_time = end_time - start_time times.append(elapsed_time) sum_time = sum(times) - sum_tokens = len(self.PROMPTS) * self.MIN_NEW_TOEKNS * self.NUM_RUNS + sum_tokens = len(self.PROMPTS) * self.MIN_NEW_TOKENS * self.NUM_RUNS avg_tokens_per_second = sum_tokens / sum_time print("**************** Benchmark Result Info****************") diff --git a/tests/cpu/test_progress_bar.py b/tests/cpu/test_progress_bar.py new file mode 100644 index 000000000..30cd73f88 --- /dev/null +++ b/tests/cpu/test_progress_bar.py @@ -0,0 +1,14 @@ +import unittest +from time import sleep + +from gptqmodel.utils.progress import ProgressBar + + +class TestBits(unittest.TestCase): + def test_progress_bar(self): + pb = ProgressBar(range(1,101)) + for i in pb: + pb.info(f"Test run index {i} of 100") + sleep(0.1) + + diff --git a/tests/inference_speed.py b/tests/inference_speed.py index 9714c51c2..7281aa41f 100644 --- a/tests/inference_speed.py +++ b/tests/inference_speed.py @@ -17,6 +17,8 @@ import os import time +from gptqmodel.utils.torch import torch_empty_cache + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -47,14 +49,14 @@ class InferenceSpeed(unittest.TestCase): MAX_DELTA_FLOOR_PERCENT = 0.25 MAX_POSITIVE_DELTA_CEIL_PERCENT = 0.25 - def inference(self, model_path, backend, tokens_per_second, assert_result=True, compile=False, warmup_runs=0): + def inference(self, model_path, backend, tokens_per_second, assert_result=True, optimize=False, fullgraph=False, warmup_runs=0): model = GPTQModel.from_quantized( model_path, backend=backend, ) - if compile: - model.compile() + if optimize: + model.optimize(fullgraph=fullgraph) tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token_id = tokenizer.eos_token_id @@ -68,7 +70,7 @@ def inference(self, model_path, backend, tokens_per_second, assert_result=True, if warmup_runs > 0: pb = ProgressBar(range(warmup_runs)) for i in pb: - pb.set_description(f"warmup run index {i} of {self.NUM_RUNS - 1}") + pb.info(f"warmup run index {i} of {self.NUM_RUNS - 1}") start_time = time.time() result = model.generate(**inp, max_new_tokens=self.MAX_NEW_TOEKNS, pad_token_id=tokenizer.pad_token_id) end_time = time.time() @@ -87,7 +89,7 @@ def inference(self, model_path, backend, tokens_per_second, assert_result=True, print(f"\n**************** {backend} Warm-up Result Info****************") print(f"Times: {times}") - print(f"New Tokens: {tokens}") + print(f"New Tokens (Size Per Batch Request): {tokens}") print(f"Sum Times: {sum_time}") print(f"Sum New Tokens: {sum_tokens}") print(f"New Token Per Second: {avg_tokens_per_second} token/s") @@ -95,7 +97,7 @@ def inference(self, model_path, backend, tokens_per_second, assert_result=True, pb = ProgressBar(range(self.NUM_RUNS)) for i in pb: - pb.set_description(f"run index {i} of {self.NUM_RUNS - 1}") + pb.info(f"run index {i} of {self.NUM_RUNS - 1}") start_time = time.time() result = model.generate(**inp, max_new_tokens=self.MAX_NEW_TOEKNS, pad_token_id=tokenizer.pad_token_id) end_time = time.time() @@ -129,3 +131,6 @@ def inference(self, model_path, backend, tokens_per_second, assert_result=True, self.assertTrue(negative_pct <= diff_pct <= positive_pct, f"Tokens Per Second: {avg_tokens_per_second} diff {diff_pct:.2f}% is out of the expected range [{negative_pct}-{positive_pct}%]") + + del model + torch_empty_cache() diff --git a/tests/models/model_test.py b/tests/models/model_test.py index b01339b7f..111ce21a2 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -19,8 +19,6 @@ import sys from typing import Dict, List -from gptqmodel.utils.eval import EVAL - if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -40,6 +38,7 @@ from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.quantization import FORMAT # noqa: E402 from gptqmodel.quantization.config import QuantizeConfig # noqa: E402 +from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.model import MODALITY # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 from ovis.image_to_test_dataset import get_calib_dataset # noqa: E402 @@ -63,6 +62,7 @@ class ModelTest(unittest.TestCase): USE_VLLM = False INPUTS_MAX_LENGTH = 2048 MODEL_MAX_LEN = 4096 + DATASET_SIZE = 256 DELETE_QUANTIZED_MODEL = True KERNEL_QUANT = {} # kernel sets @@ -131,7 +131,7 @@ def load_tokenizer(self, model_id_or_path, trust_remote_code=False): return tokenizer @classmethod - def load_dataset(self, tokenizer, rows: int = 128): + def load_dataset(self, tokenizer, rows: int = DATASET_SIZE): traindata = load_dataset("json", data_files="/monster/data/model/dataset/c4-train.00000-of-01024.json.gz", split="train") datas = [] @@ -246,7 +246,7 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa return model, tokenizer - def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, delete_quantized_model=False): + def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, delete_quantized_model=False, extra_args:dict=None): try: with tempfile.TemporaryDirectory() as tmp_dir: if self.USE_VLLM: @@ -260,11 +260,13 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del } else: model_args = {} + if extra_args: + model_args.update(extra_args) from lm_eval.tasks import TaskManager from lm_eval.utils import make_table results = GPTQModel.eval( model_or_id_or_path=model, - backend="vllm" if self.USE_VLLM else "gptqmodel", + llm_backend="vllm" if self.USE_VLLM else "gptqmodel", model_args=model_args, output_path=tmp_dir, framework=EVAL.LM_EVAL, diff --git a/tests/models/test_falcon.py b/tests/models/test_falcon.py index 3387721ff..b58b89392 100644 --- a/tests/models/test_falcon.py +++ b/tests/models/test_falcon.py @@ -23,7 +23,7 @@ class TestFalcon(ModelTest): NATIVE_ARC_CHALLENGE_ACC = 0.3993 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4292 APPLY_CHAT_TEMPLATE = True - TRUST_REMOTE_CODE = True + TRUST_REMOTE_CODE = False TORCH_DTYPE = torch.float16 QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.52 BATCH_SIZE = 6 diff --git a/tests/models/test_opt.py b/tests/models/test_opt.py index cdd3b84cb..92dc21b6a 100644 --- a/tests/models/test_opt.py +++ b/tests/models/test_opt.py @@ -15,7 +15,7 @@ # limitations under the License. from gptqmodel import BACKEND -from gptqmodel.utils.importer import BACKEND_DICT +from gptqmodel.utils.importer import AUTO_SELECT_BACKEND_ORDER from model_test import ModelTest @@ -24,8 +24,8 @@ class TestOpt(ModelTest): NATIVE_ARC_CHALLENGE_ACC = 0.1894 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2278 - KERNEL_QUANT = {BACKEND_DICT[BACKEND.EXLLAMA_V1]} - KERNEL_INFERENCE = {BACKEND_DICT[BACKEND.MARLIN]} + KERNEL_QUANT = {AUTO_SELECT_BACKEND_ORDER[BACKEND.TRITON]} + KERNEL_INFERENCE = {AUTO_SELECT_BACKEND_ORDER[BACKEND.MARLIN]} def test_opt(self): self.quant_lm_eval() diff --git a/tests/pytest.ini b/tests/pytest.ini index 603f470f8..6ecfee9ef 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,3 +1,4 @@ [pytest] addopts=-s -v log_cli=true +norecursedirs = tasks evalplus_results \ No newline at end of file diff --git a/tests/tasks/arc/arc_easy.yaml b/tests/tasks/arc/arc_easy.yaml index 5375ca035..1b2e369a4 100644 --- a/tests/tasks/arc/arc_easy.yaml +++ b/tests/tasks/arc/arc_easy.yaml @@ -1,7 +1,7 @@ tag: - ai2_arc task: arc_easy -dataset_path: /monster/data/model/dataset/allenai-ai2_arc +dataset_path: allenai/ai2_arc dataset_name: ARC-Easy output_type: multiple_choice training_split: train diff --git a/tests/test_adapter_config.py b/tests/test_adapter_config.py new file mode 100644 index 000000000..6c09017e4 --- /dev/null +++ b/tests/test_adapter_config.py @@ -0,0 +1,91 @@ +# Copyright 2025 ModelCloud +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -- do not touch +import os + +from gptqmodel import QuantizeConfig +from gptqmodel.adapter.adapter import Lora, normalize_adapter + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch + +import unittest # noqa: E402 + +lora = "lora" + +class TestExtensionConfig(unittest.TestCase): + @classmethod + def setUpClass(self): + pass + + def test_extension_parse(self): + ext = normalize_adapter(adapter={"name": lora, "rank": 128}) + + assert isinstance(ext, Lora) + assert ext.rank == 128 + print(f"{ext}") + + ext = normalize_adapter(adapter=Lora(rank=128)) + + assert isinstance(ext, Lora) + assert ext.rank == 128 + print(f"{ext}") + + try: + normalize_adapter(adapter={"name": lora, "rank": 128, "crash": 1}) + raise RuntimeError("Non supported extension.property should crash on decode") + except Exception: + pass + + try: + normalize_adapter(adapter={"CRASH": {"rank": 128}}) + raise RuntimeError("Non supported extension should crash on decode") + except Exception: + pass + + + def test_extension_config(self): + rank_field = "rank" + rank = 2 + lora_config = Lora(rank=rank) + + kv = lora_config.to_dict() + print(f"{lora} config: {kv}") + + assert lora_config.rank == rank + assert len(kv) == 3 + assert rank_field in kv.keys() + assert kv[rank_field] == rank + + def test_extension_embed(self): + bits = 4 + rank = 2 + + eora_config = Lora(rank=rank) + + qconfig = QuantizeConfig( + bits=bits, + adapter=eora_config, + ) + + print(f"qconfig: {qconfig}") + + assert qconfig.bits == bits + assert qconfig.adapter == eora_config + assert qconfig.adapter.rank == rank + + + diff --git a/tests/test_bits.py b/tests/test_bits.py index 2b9ce716c..6f2dc1843 100644 --- a/tests/test_bits.py +++ b/tests/test_bits.py @@ -17,14 +17,12 @@ # -- do not touch import os - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import logging # noqa: E402 import tempfile # noqa: E402 import traceback # noqa: E402 import unittest # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.nn_modules.qlinear.bitblas import BitBLASQuantLinear # noqa: E402 @@ -37,6 +35,7 @@ from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 from lm_eval.utils import make_table # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 logger = logging.getLogger(__name__) @@ -76,7 +75,7 @@ def check_results(self, bits: int, task_results): diff_pct = self.calculatorPer(filter=filter, value=value, base_value=base_value) negative_pct = 100 * (1 - self.QUANT_ARC_MAX_DELTA_FLOOR_PERCENT) positive_pct = 100 * (1 + self.QUANT_ARC_MAX_POSITIVE_DELTA_CEIL_PERCENT) - self.assertTrue(negative_pct <= diff_pct <= positive_pct, f"{filter}: {value} diff {diff_pct:.2f}% is out of the expected range [{negative_pct}-{positive_pct}%], expected: `{base_value}`") + self.assertTrue(negative_pct <= diff_pct <= positive_pct, f"{filter}: {value} diff {diff_pct:.2f}% is out of the expected range [{negative_pct}-{positive_pct}%], expected: {base_value}") @classmethod def setUpClass(cls): diff --git a/tests/test_bits_new.py b/tests/test_bits_new.py new file mode 100644 index 000000000..125169453 --- /dev/null +++ b/tests/test_bits_new.py @@ -0,0 +1,187 @@ +# Copyright 2025 ModelCloud +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -- do not touch +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch + +import tempfile # noqa: E402 +from typing import Optional # noqa: E402 + +from datasets import load_dataset # noqa: E402 +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 +from gptqmodel.adapter.adapter import Lora # noqa: E402 +from gptqmodel.utils.eval import EVAL # noqa: E402 +from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from tabulate import tabulate # noqa: E402 + + +def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): + # test post-quant inference + model = GPTQModel.load( + model_id_or_path=path, + backend=backend, + adapter=adapter, + ) + + # torch can benefit from optimization + if backend == BACKEND.TORCH: + model.optimize() + + tokens = model.generate("Capital of France is")[0] + result = model.tokenizer.decode(tokens) + print(f"BACKEND: {backend}, Result: {result}") + # assert "paris" in result.lower(), f"`paris` not found in `{result}`" + + bench_result = GPTQModel.eval( + model_or_id_or_path=model, + framework=EVAL.LM_EVAL, + tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU], + batch_size=16, + ) + + del model + torch_empty_cache() + + return bench_result + +class Test(ModelTest): + # NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/" + #NATIVE_MODEL_ID = "/monster/data/model/tinyllama-15M-stories" + # NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + # NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-3B-Instruct" + + + NATIVE_ARC_CHALLENGE_ACC = 0.3567 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805 + QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36 + + @classmethod + def setUpClass(cls): + pass +# clear && CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=2 BITS=2 NATIVE_MODEL_ID=/monster/data/model/Llama-3.2-1B-Instruct pytest tests/test_quant_and_eora.py +# clear && CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=1 BITS=3 NATIVE_MODEL_ID=/monster/data/model/Llama-3.2-1B-Instruct pytest tests/test_quant_and_eora.py +# clear && CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=2 BITS=4 NATIVE_MODEL_ID=/monster/data/model/Llama-3.2-1B-Instruct pytest tests/test_quant_and_eora.py +# clear && CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=3 BITS=8 NATIVE_MODEL_ID=/monster/data/model/Llama-3.2-1B-Instruct pytest tests/test_quant_and_eora.py +# +# clear && CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=4 BITS=2 NATIVE_MODEL_ID=/monster/data/model/Llama-3.2-3B-Instruct pytest tests/test_quant_and_eora.py +# clear && CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=5 BITS=3 NATIVE_MODEL_ID=/monster/data/model/Llama-3.2-3B-Instruct pytest tests/test_quant_and_eora.py +# clear && CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=7 BITS=4 NATIVE_MODEL_ID=/monster/data/model/Llama-3.2-3B-Instruct pytest tests/test_quant_and_eora.py +# clear && CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=1 BITS=8 NATIVE_MODEL_ID=/monster/data/model/Llama-3.2-3B-Instruct pytest tests/test_quant_and_eora.py + + + def test_quant_and_eora(self): + bits = int(os.environ["BITS"]) + self.NATIVE_MODEL_ID = os.environ["NATIVE_MODEL_ID"] + + print(f"eeeeee gpu: testing {bits}: bits, model: {self.NATIVE_MODEL_ID}") + group_size = 128 + desc_act = True + rank = 128 + batch_size = 1 + calibration_dataset_rows = 512 + calibration_dataset_concat_size = 0 # disable + auto_gc = False + adapter_file_name = "eora.safetensors" + dataset_id = "allenai/c4" + dataset_files = "en/c4-train.00001-of-01024.json.gz" + + config_dict = { + "model_id": self.NATIVE_MODEL_ID, + "dataset_id": dataset_id, + "dataset_files": dataset_files, + "bits": bits, + "group_size": group_size, + "desc_act": desc_act, + "rank": rank, + "batch_size": batch_size, + "calibration_dataset_rows": calibration_dataset_rows, + "calibration_dataset_concat_size": calibration_dataset_concat_size, + "auto_gc": auto_gc, + "adapter_file_name": adapter_file_name, + } + + calibration_dataset = load_dataset( + dataset_id, + data_files=dataset_files, + split="train" + ).select(range(calibration_dataset_rows))["text"] + + with tempfile.TemporaryDirectory(): + # eora = Lora( + # # for quant, path is save path. for load, it is loading path + # path=os.path.join(tmpdir, adapter_file_name), + # rank=rank, + # ) + + quant_config = QuantizeConfig( + bits=bits, + group_size=group_size, + desc_act=desc_act, # bitblas only supports DESC_ACT=False + # adapter=eora, + ) + + save_path=os.path.join(f"./{quant_config.bits}", self.NATIVE_MODEL_ID.removeprefix("/monster/data/model/")) + + if os.path.exists(save_path): + self.NATIVE_MODEL_ID=save_path + + model = GPTQModel.load( + model_id_or_path=self.NATIVE_MODEL_ID, + quantize_config=quant_config, + ) + + if not model.quantized: + model.quantize( + calibration_dataset=calibration_dataset, + batch_size=batch_size, + auto_gc=auto_gc, + calibration_dataset_concat_size=calibration_dataset_concat_size, + backend=BACKEND.TORCH, + ) # + + + # EoRA adapter is saved according to Lora.path property + # if Lora.path is not set, we will save the lora as "lora.safetensors" in the same path as quant model + # You can also pass `eora_path` to `model.save()` to override this save path + model.save(save_path) + + del model + torch_empty_cache() + + # BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, + for backend in [ BACKEND.TORCH ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN + base_bench = bench(path=save_path, backend=backend, adapter=None) # inference using qweights only + # eora_bench = bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora) + + print('--------GPTQModel + EoRA Config ---------') + + # Convert the dictionary to a list of lists for tabulate + table_data = [[key, value] for key, value in config_dict.items()] + print(tabulate(table_data, headers=["Key", "Value"], tablefmt="grid")) + + print('--------Eval GPTQ Result---------') + print(make_table(base_bench)) + if "groups" in base_bench: + print(make_table(base_bench, "groups")) + + # print('--------Eval GPTQ + EoRA Result---------') + # print(make_table(eora_bench)) + # if "groups" in eora_bench: + # print(make_table(eora_bench, "groups")) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index b47ae558a..3e5874507 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -17,7 +17,6 @@ # -- do not touch import os -from gptqmodel.nn_modules.qlinear.dynamic_cuda import DynamicCudaQuantLinear from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -111,13 +110,12 @@ def tearDownClass(cls): @parameterized.expand( [ # exllama v1/v2 only supports 4bit so does not support dynamic bits control - (BACKEND.TORCH, TorchQuantLinear, 15.7372), - (BACKEND.CUDA, DynamicCudaQuantLinear, 15.7372), - (BACKEND.TRITON, TritonV2QuantLinear, 15.7372), - (BACKEND.MARLIN, MarlinQuantLinear, 15.8582), # A100: 15.7545 + (BACKEND.TORCH, TorchQuantLinear, 15.793), + (BACKEND.TRITON, TritonV2QuantLinear, 15.793), + (BACKEND.MARLIN, MarlinQuantLinear, 15.829), ] ) - def test_dynamic_bits(self, backend, backendQLinear, ppl): + def test_dynamic_bits(self, backend, backendQLinear, expected_ppl): model = GPTQModel.load( self.tmp_quant_path.name, backend=backend, @@ -133,7 +131,7 @@ def test_dynamic_bits(self, backend, backendQLinear, ppl): del model print(f"Backend: {backend}, PPL: {dynamic_bits_ppl}") - assert dynamic_bits_ppl <= ppl + assert dynamic_bits_ppl <= expected_ppl, f"PPL expected: `{expected_ppl}`, actual = `{dynamic_bits_ppl}`" def test_skip_module(self): dynamic = { diff --git a/tests/test_eval.py b/tests/test_eval.py index 3bae072b8..06f76743c 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -15,10 +15,12 @@ # limitations under the License. import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import tempfile # noqa: E402 import unittest # noqa: E402 +from typing import Type # noqa: E402 from typing import Union # noqa: E402 from gptqmodel import GPTQModel # noqa: E402 @@ -31,17 +33,18 @@ class TestEval(unittest.TestCase): @classmethod def setUpClass(self): self.MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1" + self.model = GPTQModel.load(self.MODEL_ID) @parameterized.expand( [ (EVAL.LM_EVAL, EVAL.LM_EVAL.ARC_CHALLENGE, 'gptqmodel'), - (EVAL.EVALPLUS, EVAL.EVALPLUS.HUMAN, 'gptqmodel'), (EVAL.LM_EVAL, EVAL.LM_EVAL.ARC_CHALLENGE, 'vllm'), + (EVAL.EVALPLUS, EVAL.EVALPLUS.HUMAN, 'gptqmodel'), (EVAL.EVALPLUS, EVAL.EVALPLUS.HUMAN, 'vllm'), (EVAL.LM_EVAL, EVAL.LM_EVAL.GPQA, 'vllm'), ] ) - def test_eval_gptqmodel(self, eval_backend: EVAL, task: Union[EVAL.LM_EVAL, EVAL.EVALPLUS], backend: str): + def test_eval_gptqmodel(self, framework: Union[Type[EVAL.LM_EVAL],Type[EVAL.EVALPLUS]], task: Union[EVAL.LM_EVAL, EVAL.EVALPLUS], llm_backend: str): with tempfile.TemporaryDirectory() as tmp_dir: output_path = f"{tmp_dir}/result.json" model_args = {} @@ -49,16 +52,16 @@ def test_eval_gptqmodel(self, eval_backend: EVAL, task: Union[EVAL.LM_EVAL, EVAL model_args["gpu_memory_utilization"]=0.7 results = GPTQModel.eval(model_or_id_or_path=self.MODEL_ID, - framework=eval_backend, + framework=framework, tasks=[task], - batch_size=32, + batch_size=16, output_path=output_path, - backend=backend, + llm_backend=llm_backend, model_args=model_args, task_manager=TaskManager(include_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "tasks"), include_defaults=False) ) - if eval_backend == EVAL.LM_EVAL: + if llm_backend == EVAL.LM_EVAL: if task == EVAL.LM_EVAL.GPQA: gpqa_main_n_shot = results['results'].get('gpqa_main_n_shot', {}).get('acc,none') gpqa_main_zeroshot = results['results'].get('gpqa_main_zeroshot', {}).get('acc,none') @@ -71,7 +74,7 @@ def test_eval_gptqmodel(self, eval_backend: EVAL, task: Union[EVAL.LM_EVAL, EVAL self.assertGreaterEqual(acc_score, 0.28, "acc score does not match expected result") self.assertGreaterEqual(acc_norm_score, 0.32, "acc_norm score does not match expected result") - elif eval_backend == EVAL.EVALPLUS: + elif llm_backend == EVAL.EVALPLUS: result = results.get(task.value) base_formatted, plus_formatted, _ = float(result.get("base tests")), float( result.get("base + extra tests")), result.get("results_path") diff --git a/tests/test_evalplus.py b/tests/test_evalplus.py index 8fb0fb49e..13d7251b7 100644 --- a/tests/test_evalplus.py +++ b/tests/test_evalplus.py @@ -23,6 +23,7 @@ import tempfile # noqa: E402 import unittest # noqa: E402 +from gptqmodel import GPTQModel # noqa: E402 from gptqmodel.utils.eval import evalplus # noqa: E402 @@ -34,7 +35,10 @@ def setUpClass(self): def test_evalplus(self): with tempfile.TemporaryDirectory() as tmp_dir: output_file = f"{tmp_dir}/result.json" - base_formatted, plus_formatted, _ = evalplus(model=self.MODEL_ID, dataset='humaneval', output_file=output_file) + + model = GPTQModel.load(self.MODEL_ID) + + base_formatted, plus_formatted, _ = evalplus(model=model, dataset='humaneval', output_file=output_file) self.assertGreaterEqual(float(base_formatted), 0.26, "Base score does not match expected result") self.assertGreaterEqual(float(plus_formatted), 0.23, "Plus score does not match expected result") diff --git a/tests/test_group_size.py b/tests/test_group_size.py index ddf0f4326..719866080 100644 --- a/tests/test_group_size.py +++ b/tests/test_group_size.py @@ -33,14 +33,14 @@ from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 -from gptqmodel.utils.eval import lm_eval # noqa: E402 +from gptqmodel.utils.eval import EVAL # noqa: E402 from lm_eval.utils import make_table # noqa: E402 from transformers import AutoTokenizer # noqa: E402 logger = logging.getLogger(__name__) RAND_SEED = 42 -TASK_NAME = "arc_challenge" +TASK_NAME = EVAL.LM_EVAL.ARC_CHALLENGE class TestGroupSize(unittest.TestCase): QLINEAR_DICT = { @@ -117,15 +117,15 @@ def eval(self, inference_backend, quant_backend, quantize_config, tmp_dir): device_map="auto", backend=inference_backend, ) - results = lm_eval( - model, - model_name="hf", + results = GPTQModel.eval( + model_or_id_or_path=model, output_path=tmp_dir, tasks=TASK_NAME, apply_chat_template=False, trust_remote_code=False, batch_size=32, gen_kwargs="temperature=0.0,top_k=50", + random_seed=RAND_SEED, ) print('--------Eval Result---------') print(make_table(results)) diff --git a/tests/test_inference_speed.py b/tests/test_inference_speed.py index 94460e76b..ed9955b3f 100644 --- a/tests/test_inference_speed.py +++ b/tests/test_inference_speed.py @@ -17,8 +17,6 @@ # -- do not touch import os -import torch - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" from gptqmodel.utils import BACKEND # noqa: E402 # -- end do not touch @@ -44,21 +42,19 @@ class TestInferenceSpeed(InferenceSpeed): @parameterized.expand( [ - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.MARLIN, 286.74), - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.CUDA, 161.72), - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V1, 282.64), - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V2, 290.60), - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TRITON, 239.58), - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TORCH, 227.96), - (InferenceSpeed.BITBLAS_NATIVE_MODEL_ID, BACKEND.BITBLAS, 2167.38), # Second time running bitblas, there is cache + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.MARLIN, 286.74, False, False), + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.CUDA, 161.72, True, False), + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TORCH, 227.96, True, False), + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TORCH, 53, False, False), + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V1, 282.64, False, False), + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V2, 290.60, False, False), + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TRITON, 239.58, False, False), + (InferenceSpeed.BITBLAS_NATIVE_MODEL_ID, BACKEND.BITBLAS, 2167.38, False, False), # Second time running bitblas, there is cache ] ) - def test_inference_speed(self, model_path, backend, tokens_per_second): - # Start a fresh compile for each parameter of the test case - torch._dynamo.reset() - + def test_inference_speed(self, model_path, backend, tokens_per_second, optimize, fullgraph): # There are differences between the results of the first and second runs of bitblas # (there is a cache when running bitblas for the second time), # so only the results of the second run of bitblas are asserted. # The first run of bitblas only prints relevant information - self.inference(model_path=model_path, backend=backend, tokens_per_second=tokens_per_second, compile=True, warmup_runs=1) + self.inference(model_path=model_path, backend=backend, tokens_per_second=tokens_per_second, optimize=optimize, fullgraph=fullgraph, warmup_runs=1) diff --git a/tests/test_lm_eval.py b/tests/test_lm_eval.py index a6f903752..1ceaffaf1 100644 --- a/tests/test_lm_eval.py +++ b/tests/test_lm_eval.py @@ -17,16 +17,14 @@ # -- do not touch import os - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import tempfile # noqa: E402 import unittest # noqa: E402 -from lm_eval.utils import make_table # noqa: E402 - -from gptqmodel import GPTQModel # noqa: E402 +from gptqmodel import BACKEND, GPTQModel from gptqmodel.utils.eval import EVAL # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 class TestLmEval(unittest.TestCase): @@ -37,7 +35,11 @@ def setUpClass(self): self.random_seed = 1234 self.task = EVAL.LM_EVAL.ARC_CHALLENGE - def test_lm_eval(self): + # self.acc_score = 0.3183 + self.acc_norm_score = 0.3515 + + + def test_eval_direct(self): with tempfile.TemporaryDirectory() as tmp_dir: results = GPTQModel.eval( model_or_id_or_path=self.MODEL_ID, @@ -52,9 +54,29 @@ def test_lm_eval(self): print(make_table(results, "groups")) print('--------lm_eval Result End---------') - acc_score = results['results'].get(self.task.value, {}).get('acc,none') + results['results'].get(self.task.value, {}).get('acc,none') acc_norm_score = results['results'].get(self.task.value, {}).get('acc_norm,none') - self.assertGreaterEqual(acc_score, 0.28, "acc score does not match expected result") - self.assertGreaterEqual(acc_norm_score, 0.32, "acc_norm score does not match expected result") + # self.assertGreaterEqual(acc_score, self.acc_score, "acc score does not match expected result") + self.assertGreaterEqual(acc_norm_score, self.acc_norm_score, "acc_norm score does not match expected result") + + def test_eval_path(self): + with tempfile.TemporaryDirectory() as tmp_dir: + results = GPTQModel.eval( + model_or_id_or_path=self.MODEL_ID, + backend = BACKEND.EXLLAMA_V2, # for path loading, can override backend + output_path=tmp_dir, + tasks=[self.task], + ) + + print('--------lm_eval Eval Result---------') + print(make_table(results)) + if "groups" in results: + print(make_table(results, "groups")) + print('--------lm_eval Result End---------') + + # acc_score = results['results'].get(self.task, {}).get('acc,none') + acc_norm_score = results['results'].get(self.task, {}).get('acc_norm,none') + # self.assertGreaterEqual(acc_score, self.acc_score, "acc score does not match expected result") + self.assertGreaterEqual(acc_norm_score, self.acc_norm_score, "acc_norm score does not match expected result") diff --git a/tests/test_lm_head.py b/tests/test_lm_head.py index 00b01f048..c5d39bacf 100644 --- a/tests/test_lm_head.py +++ b/tests/test_lm_head.py @@ -46,7 +46,7 @@ def test_eval(self): class TestLmHeadQuant(ModelTest): APPLY_CHAT_TEMPLATE = True - EXPECT_LM_HEAD_LOSS = 31.11202 + EXPECT_LM_HEAD_LOSS = 23.84 sample_length = 1024 samples = 128 diff --git a/tests/test_lora.py b/tests/test_lora.py new file mode 100644 index 000000000..0e50794fb --- /dev/null +++ b/tests/test_lora.py @@ -0,0 +1,95 @@ +# Copyright 2025 ModelCloud +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -- do not touch +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch + +from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel.adapter.adapter import Lora # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from parameterized import parameterized # noqa: E402 + + +class Test(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/sliuau-llama3.2-1b-4bit-group128" + lora_path = "/monster/data/model/sliuau-llama3.2-1b-4bit-group128/llama3.2-1b-4bit-group128-eora-rank128-arc/adapter_model.safetensors" #"sliuau/llama3.2-1b-4bit-group128-eora_test-rank128-arc/blob/main/adapter_model.safetensors" #"sliuau/llama3.2-1b-4bit-group128-eora_test-rank128-arc" + + NATIVE_ARC_CHALLENGE_ACC = 0.3567 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805 + QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36 + + @classmethod + def setUpClass(cls): + cls.adapter = Lora(path=cls.lora_path, rank=128) + + @parameterized.expand([ + # BACKEND.EXLLAMA_V2V, + #BACKEND.TORCH, + # BACKEND.CUDA, + # BACKEND.TRITON, + # BACKEND.EXLLAMA_V1, + BACKEND.EXLLAMA_V2, + # BACKEND.MARLIN, + # # (BACKEND.IPEX), <-- not tested yet + # # (BACKEND.BITBLAS, <-- not tested yet + ]) + def test_load(self, backend: BACKEND): + model = GPTQModel.load( + self.NATIVE_MODEL_ID, + adapter=self.adapter, + backend=backend, + device_map="auto", + ) + + # print(model) + tokens = model.generate("Capital of France is")[0] + result = model.tokenizer.decode(tokens) + print(f"Result: {result}") + self.assertIn("paris", result.lower()) + + @parameterized.expand([ + BACKEND.EXLLAMA_V2, + ]) + def test_download(self, backend: BACKEND): + adapter = Lora(path="https://huggingface.co/sliuau/llama3.2-1b-4bit-group128-eora-rank128-arc/blob/main/adapter_model.safetensors", rank=128) + + model = GPTQModel.load( + self.NATIVE_MODEL_ID, + adapter=adapter, + backend=backend, + device_map="auto", + ) + + tokens = model.generate("Capital of France is")[0] + result = model.tokenizer.decode(tokens) + print(f"Result: {result}") + self.assertIn("paris", result.lower()) + + def test_lm_eval_from_path(self): + adapter = Lora(path=self.lora_path, rank=128) + task_results = self.lm_eval(self.NATIVE_MODEL_ID, extra_args={"adapter": adapter.to_dict()}) # "backend":"exllama_v2", + self.check_results(task_results) + + def test_lm_eval_from_model(self): + model = GPTQModel.load( + self.NATIVE_MODEL_ID, + adapter=self.adapter, + # backend=BACKEND.EXLLAMA_V2V, + ) + task_results = self.lm_eval(model) + self.check_results(task_results) diff --git a/tests/test_modelscope.py b/tests/test_modelscope.py index 95fc43bf9..22fcf2663 100644 --- a/tests/test_modelscope.py +++ b/tests/test_modelscope.py @@ -1,7 +1,8 @@ import os + os.environ["GPTQMODEL_USE_MODELSCOPE"] = "True" -from models.model_test import ModelTest # noqa: E402 from gptqmodel import GPTQModel # noqa: E402 +from models.model_test import ModelTest # noqa: E402 class TestLoadModelscope(ModelTest): @@ -17,4 +18,4 @@ def test_load_modelscope(self): str_output = model.tokenizer.decode(result) assert "beijing" in str_output.lower() or "bei-jing" in str_output.lower() - del model \ No newline at end of file + del model diff --git a/tests/test_packing_speed.py b/tests/test_packing_speed.py index 7b9594403..516c45b8a 100644 --- a/tests/test_packing_speed.py +++ b/tests/test_packing_speed.py @@ -106,34 +106,34 @@ def pack(self, qlinearCls): [ # [ExllamaQuantLinear, 9.63], # A100 Z3: 36.89 # 4090? 26.5349 # [TritonV2QuantLinear, 9.67], # A100 Z3: 35.04 # 4090? 26.5268 - [TorchQuantLinear, 13.819], # A100 Z3 33.56 # 4090? 27.0297 + [TorchQuantLinear, 16.63], # A100 Z3 33.56 # 4090? 27.0297 ] ) def test_pack_speed(self, qlinearCls, expect_time): + start = time.time() with threadpoolctl.threadpool_limits(limits=1): - now = time.time() for i in range(30): self.pack(qlinearCls) - time_usage = time.time() - now + time_usage = time.time() - start speed = self.k * self.k / time_usage print(f"{qlinearCls.__name__}, time={time_usage}, speed={speed:.4f}") - self.assertLess(abs(time_usage - expect_time) / expect_time, 0.025, msg=f"time: {time_usage}") + self.assertLess((time_usage - expect_time) / expect_time, 0.025, msg=f"time: {time_usage}") @parameterized.expand( [ # [ExllamaQuantLinear, 9.63], # A100 Z3: 36.89 # 4090? 26.5349 # [TritonV2QuantLinear, 9.67], # A100 Z3: 35.04 # 4090? 26.5268 - [TorchQuantLinear, 10.674], # A100 Z3 33.56 # 4090? 27.0297 + [TorchQuantLinear, 12.51], # A100 Z3 33.56 # 4090? 27.0297 ] ) def test_pack_speed_2_threads(self, qlinearCls, expect_time): + start = time.time() with threadpoolctl.threadpool_limits(limits=2): - now = time.time() for i in range(30): self.pack(qlinearCls) - time_usage = time.time() - now + time_usage = time.time() - start speed = self.k * self.k / time_usage print(f"{qlinearCls.__name__}, time={time_usage}, speed={speed:.4f}") - self.assertLess(abs(time_usage - expect_time) / expect_time, 0.025, msg=f"time: {time_usage}") + self.assertLess((time_usage - expect_time) / expect_time, 0.025, msg=f"time: {time_usage}") diff --git a/tests/test_perplexity.py b/tests/test_perplexity.py index 67cd8ce01..5518a3a1a 100644 --- a/tests/test_perplexity.py +++ b/tests/test_perplexity.py @@ -25,7 +25,7 @@ import unittest # noqa: E402 from datasets import load_dataset # noqa: E402 -from gptqmodel import GPTQModel # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.quantization.config import FORMAT, QUANT_METHOD, AutoRoundQuantizeConfig, QuantizeConfig # noqa: E402 from gptqmodel.utils import Perplexity # noqa: E402 from gptqmodel.utils.rocm import IS_ROCM # noqa: E402 @@ -129,12 +129,12 @@ def calculate_native_ppl(self, format): @parameterized.expand( [ - (QUANT_METHOD.GPTQ, FORMAT.GPTQ, 8, 32, True), # A100, 4889 max ram + # (QUANT_METHOD.GPTQ, FORMAT.GPTQ, 8, 32, True), # A100, 4889 max ram (QUANT_METHOD.GPTQ, FORMAT.GPTQ, 8, 32, False), # A100, 6571 max ram - (QUANT_METHOD.GPTQ, FORMAT.GPTQ_V2, 8, 32, False), - (QUANT_METHOD.GPTQ, FORMAT.GPTQ_V2, 4, 32, False), - (QUANT_METHOD.GPTQ, FORMAT.GPTQ, 4, 32, False), - (QUANT_METHOD.GPTQ, FORMAT.BITBLAS, 4, 32, False), + # (QUANT_METHOD.GPTQ, FORMAT.GPTQ_V2, 8, 32, False), + # (QUANT_METHOD.GPTQ, FORMAT.GPTQ_V2, 4, 32, False), + # (QUANT_METHOD.GPTQ, FORMAT.GPTQ, 4, 32, False), + # (QUANT_METHOD.GPTQ, FORMAT.BITBLAS, 4, 32, False), # (QUANT_METHOD.AUTO_ROUND, FORMAT.GPTQ, 4, 32, False), ] ) @@ -173,7 +173,7 @@ def test_quantized_perplexity(self, method: QUANT_METHOD, format: FORMAT, bits: model.quantize( dataset, batch_size=128 if IS_ROCM else 256, - buffered_fwd=buffered_fwd, + # buffered_fwd=buffered_fwd, TODO FIX ME auto_gc=False, # speed up quant ) quant_time = time.time() - start @@ -204,6 +204,7 @@ def test_quantized_perplexity(self, method: QUANT_METHOD, format: FORMAT, bits: model = GPTQModel.load( tmp_dir, + backend=BACKEND.EORA_TORCH, device_map="auto", ) diff --git a/tests/test_post_quant_eora.py b/tests/test_post_quant_eora.py new file mode 100644 index 000000000..1ded29448 --- /dev/null +++ b/tests/test_post_quant_eora.py @@ -0,0 +1,137 @@ +# Copyright 2025 ModelCloud +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -- do not touch +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch + +import tempfile # noqa: E402 +from typing import Optional # noqa: E402 + +from datasets import load_dataset # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel.adapter.adapter import Lora # noqa: E402 +from gptqmodel.utils.eval import EVAL # noqa: E402 +from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from tabulate import tabulate # noqa: E402 + + +def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): + # test post-quant inference + model = GPTQModel.load( + model_id_or_path=path, + backend=backend, + adapter=adapter, + ) + + # torch can benefit from optimization + if backend == BACKEND.TORCH: + model.optimize() + + tokens = model.generate("Capital of France is")[0] + result = model.tokenizer.decode(tokens) + print(f"BACKEND: {backend}, Result: {result}") + if "paris" not in result.lower(): + raise AssertionError(" `paris` not found in `result`") + + bench_result = GPTQModel.eval( + model_or_id_or_path=model, + framework=EVAL.LM_EVAL, + tasks=[EVAL.LM_EVAL.ARC_CHALLENGE] + ) + + del model + torch_empty_cache() + + return bench_result + + +class TestEoraPostQuant(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + QUANTIZED_MODEL_PATH = "/monster/data/model/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1/" + + @classmethod + def setUpClass(cls): + pass + + def test_post_quant_eora(self): + bits = 4 + group_size = 128 + desc_act = True + rank = 256 + batch_size = 1 + calibration_dataset_rows = 1024 + calibration_dataset_concat_size = 0 # disable + auto_gc = False + adapter_file_name = "eora.safetensors" + + config_dict = { + "bits": bits, + "group_size": group_size, + "desc_act": desc_act, + "rank": rank, + "batch_size": batch_size, + "calibration_dataset_rows": calibration_dataset_rows, + "calibration_dataset_concat_size": calibration_dataset_concat_size, + "auto_gc": auto_gc, + "adapter_file_name": adapter_file_name, + } + + calibration_dataset = load_dataset( + "allenai/c4", + data_files="en/c4-train.00001-of-01024.json.gz", + split="train" + ).select(range(calibration_dataset_rows))["text"] + + with tempfile.TemporaryDirectory() as tmpdir: + eora = Lora( + # for eora generation, path is adapter save path; for load, it is loading path + path=os.path.join(tmpdir, adapter_file_name), + rank=rank, + ) + + # eora generation and save in one step + GPTQModel.adapter.generate( + adapter=eora, + model_id_or_path=self.NATIVE_MODEL_ID, + quantized_model_id_or_path=self.QUANTIZED_MODEL_PATH, + calibration_dataset=calibration_dataset, + calibration_dataset_concat_size=calibration_dataset_concat_size, + auto_gc=auto_gc) + + # BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, + for backend in [BACKEND.TORCH]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN + base_bench = bench(path=self.QUANTIZED_MODEL_PATH, backend=backend, adapter=None) # inference using qweights only + eora_bench = bench(path=self.QUANTIZED_MODEL_PATH, backend=backend, adapter=eora) # inference using eora (lora) + + print('--------Quant/EoRA Config ---------') + + # Convert the dictionary to a list of lists for tabulate + table_data = [[key, value] for key, value in config_dict.items()] + print(tabulate(table_data, headers=["Key", "Value"], tablefmt="grid")) + + print('--------Eval Base Result---------') + print(make_table(base_bench)) + if "groups" in base_bench: + print(make_table(base_bench, "groups")) + + print('--------Eval EoRA Result---------') + print(make_table(eora_bench)) + if "groups" in eora_bench: + print(make_table(eora_bench, "groups")) diff --git a/tests/test_q4_cuda.py b/tests/test_q4_cuda.py index e42bc359b..51af7c270 100644 --- a/tests/test_q4_cuda.py +++ b/tests/test_q4_cuda.py @@ -16,16 +16,13 @@ # -- do not touch import os -import tempfile - -from gptqmodel.utils import Perplexity os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import torch # noqa: E402 -from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 from models.model_test import ModelTest # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoTokenizer # noqa: E402 diff --git a/tests/test_quant_and_eora.py b/tests/test_quant_and_eora.py new file mode 100644 index 000000000..f05220b02 --- /dev/null +++ b/tests/test_quant_and_eora.py @@ -0,0 +1,160 @@ +# Copyright 2025 ModelCloud +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -- do not touch +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch + +import tempfile # noqa: E402 +from typing import Optional # noqa: E402 + +from datasets import load_dataset # noqa: E402 +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 +from gptqmodel.adapter.adapter import Lora # noqa: E402 +from gptqmodel.utils.eval import EVAL # noqa: E402 +from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from tabulate import tabulate # noqa: E402 + + +def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): + # test post-quant inference + model = GPTQModel.load( + model_id_or_path=path, + backend=backend, + adapter=adapter, + ) + + tokens = model.generate("Capital of France is")[0] + result = model.tokenizer.decode(tokens) + print(f"BACKEND: {backend}, Result: {result}") + assert "paris" in result.lower(), f"`paris` not found in `{result}`" + + bench_result = GPTQModel.eval( + model_or_id_or_path=model, + framework=EVAL.LM_EVAL, + tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU], + batch_size=32, + ) + + del model + torch_empty_cache() + + return bench_result + +class Test(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/" + #NATIVE_MODEL_ID = "/monster/data/model/tinyllama-15M-stories" + #NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B" + + NATIVE_ARC_CHALLENGE_ACC = 0.3567 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805 + QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36 + + @classmethod + def setUpClass(cls): + pass + + def test_quant_and_eora(self): + bits = 4 + group_size = 128 + desc_act = True + rank = 128 + batch_size = 1 + calibration_dataset_rows = 512 + calibration_dataset_concat_size = 0 # disable + auto_gc = False + adapter_file_name = "eora.safetensors" + dataset_id = "allenai/c4" + dataset_files = "en/c4-train.00001-of-01024.json.gz" + + config_dict = { + "model_id": self.NATIVE_MODEL_ID, + "dataset_id": dataset_id, + "dataset_files": dataset_files, + "bits": bits, + "group_size": group_size, + "desc_act": desc_act, + "rank": rank, + "batch_size": batch_size, + "calibration_dataset_rows": calibration_dataset_rows, + "calibration_dataset_concat_size": calibration_dataset_concat_size, + "auto_gc": auto_gc, + "adapter_file_name": adapter_file_name, + } + + calibration_dataset = load_dataset( + dataset_id, + data_files=dataset_files, + split="train" + ).select(range(calibration_dataset_rows))["text"] + + with tempfile.TemporaryDirectory() as tmpdir: + eora = Lora( + # for quant, path is save path. for load, it is loading path + path=os.path.join(tmpdir, adapter_file_name), + rank=rank, + ) + + quant_config = QuantizeConfig( + bits=bits, + group_size=group_size, + desc_act=desc_act, # bitblas only supports DESC_ACT=False + adapter=eora, + ) + + model = GPTQModel.load( + model_id_or_path=self.NATIVE_MODEL_ID, + quantize_config=quant_config, + ) + + model.quantize( + calibration_dataset=calibration_dataset, + batch_size=batch_size, + auto_gc=auto_gc, + calibration_dataset_concat_size=calibration_dataset_concat_size, + ) # + + # EoRA adapter is saved according to Lora.path property + # if Lora.path is not set, we will save the lora as "lora.safetensors" in the same path as quant model + # You can also pass `eora_path` to `model.save()` to override this save path + model.save(tmpdir) + + del model + torch_empty_cache() + + # BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, + for backend in [ BACKEND.TORCH ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN + base_bench = bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only + eora_bench = bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora) + + print('--------GPTQModel + EoRA Config ---------') + + # Convert the dictionary to a list of lists for tabulate + table_data = [[key, value] for key, value in config_dict.items()] + print(tabulate(table_data, headers=["Key", "Value"], tablefmt="grid")) + + print('--------Eval GPTQ Result---------') + print(make_table(base_bench)) + if "groups" in base_bench: + print(make_table(base_bench, "groups")) + + print('--------Eval GPTQ + EoRA Result---------') + print(make_table(eora_bench)) + if "groups" in eora_bench: + print(make_table(eora_bench, "groups")) diff --git a/tests/test_quant_formats.py b/tests/test_quant_formats.py index 2ce433759..59f23308c 100644 --- a/tests/test_quant_formats.py +++ b/tests/test_quant_formats.py @@ -50,7 +50,6 @@ def setUpClass(self): @parameterized.expand( [ (QUANT_METHOD.GPTQ, BACKEND.AUTO, False, FORMAT.GPTQ, 8), - (QUANT_METHOD.GPTQ, BACKEND.IPEX, False, FORMAT.GPTQ, 4), (QUANT_METHOD.GPTQ, BACKEND.EXLLAMA_V2, True, FORMAT.GPTQ_V2, 4), (QUANT_METHOD.GPTQ, BACKEND.EXLLAMA_V2, False, FORMAT.GPTQ, 4), ] @@ -99,6 +98,8 @@ def test_quantize(self, method: QUANT_METHOD, backend: BACKEND, sym: bool, forma backend=backend, ) + self.assertInference(model) + logging.info(f"Loaded config: {model.quantize_config}") versionable = model.quantize_config.meta_get_versionable(META_FIELD_QUANTIZER) diff --git a/tests/test_quant_formats_ipex.py b/tests/test_quant_formats_ipex.py new file mode 100644 index 000000000..a2774d8ad --- /dev/null +++ b/tests/test_quant_formats_ipex.py @@ -0,0 +1,110 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -- do not touch +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch + +import json # noqa: E402 +import logging # noqa: E402 +import tempfile # noqa: E402 + +from datasets import load_dataset # noqa: E402 +from gptqmodel import BACKEND, GPTQModel, __version__, get_best_device # noqa: E402 +from gptqmodel.quantization import FORMAT, QUANT_CONFIG_FILENAME, QUANT_METHOD # noqa: E402 +from gptqmodel.quantization.config import (META_FIELD_QUANTIZER, META_QUANTIZER_GPTQMODEL, # noqa: E402 + AutoRoundQuantizeConfig, QuantizeConfig) +from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from parameterized import parameterized # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + + +class TestQuantization(ModelTest): + + @classmethod + def setUpClass(self): + self.pretrained_model_id = "/monster/data/model/Qwen2.5-0.5B-Instruct/" #"/monster/data/model/TinyLlama-1.1B-intermediate-step-1431k-3T" + + self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_id, use_fast=True) + + traindata = load_dataset("json", data_files="/monster/data/model/dataset/c4-train.00000-of-01024.json.gz", split="train") + self.calibration_dataset = [self.tokenizer(example["text"]) for example in traindata.select(range(32))] + + + @parameterized.expand( + [ + (QUANT_METHOD.GPTQ, BACKEND.IPEX, False, FORMAT.GPTQ, 4), + ] + ) + def test_quantize(self, method: QUANT_METHOD, backend: BACKEND, sym: bool, format: FORMAT, bits: int): + if method == QUANT_METHOD.GPTQ: + quantize_config = QuantizeConfig( + bits=bits, + group_size=128, + desc_act=False if format == FORMAT.MARLIN else True, + sym=sym, + format=format, + damp_percent=0.05 + ) + elif method == QUANT_METHOD.AUTO_ROUND: + quantize_config = AutoRoundQuantizeConfig( + bits=bits, + group_size=128, + sym=sym, + format=format, + ) + else: + raise ValueError(f"Invalid quantization method: {method}") + + model = GPTQModel.load( + self.pretrained_model_id, + quantize_config=quantize_config, + ) + model.quantize(self.calibration_dataset, batch_size=32) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save(tmpdirname) + + logging.info(f"Saved config mem: {model.quantize_config}") + + with open(tmpdirname + "/" + QUANT_CONFIG_FILENAME, "r") as f: + file_dict = json.loads(f.read()) + + # make sure the json dict saved to file matches config in memory + assert model.quantize_config.to_dict() == file_dict + logging.info(f"Saved config file: {file_dict}") + + model = GPTQModel.load( + tmpdirname, + device=get_best_device(backend), + backend=backend, + ) + + self.assertInference(model) + + logging.info(f"Loaded config: {model.quantize_config}") + + versionable = model.quantize_config.meta_get_versionable(META_FIELD_QUANTIZER) + assert META_QUANTIZER_GPTQMODEL in [v[0] for v in versionable] + for producer, _version in versionable: + if producer == META_QUANTIZER_GPTQMODEL: + assert _version == __version__ + + del model + torch_empty_cache() diff --git a/tests/test_quant_time.py b/tests/test_quant_time.py index acc82674b..b925a9c0b 100644 --- a/tests/test_quant_time.py +++ b/tests/test_quant_time.py @@ -27,15 +27,15 @@ class TestQuantTime(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" - INPUTS_MAX_LENGTH = 2048 DATASETS_MAX_COUNT = 128 - QUANT_TIME = 136 + QUANT_TIME = 116 MAX_DELTA_PERCENT = 5 # % def test_quant_time(self): quantize_config = QuantizeConfig( bits=4, group_size=128, + desc_act=True, ) model = GPTQModel.load( @@ -44,13 +44,18 @@ def test_quant_time(self): ) tokenizer = model.tokenizer - datasets = self.load_dataset(tokenizer) + datasets = self.load_dataset(tokenizer, self.DATASETS_MAX_COUNT) - start_time = time.time() - model.quantize(datasets, batch_size=4) + start = time.time() + model.quantize( + calibration_dataset=datasets, + # calibration_dataset_concat_size=2048, + batch_size=4, + auto_gc=False, + ) end_time = time.time() - quant_time = end_time - start_time + quant_time = end_time - start diff_pct = (quant_time / self.QUANT_TIME) print("**************** Quant Time Result Info****************") diff --git a/tests/test_save_loaded_quantized_model.py b/tests/test_save_loaded_quantized_model.py index cf540b4a5..6f85bd14f 100644 --- a/tests/test_save_loaded_quantized_model.py +++ b/tests/test_save_loaded_quantized_model.py @@ -37,7 +37,6 @@ class TestSave(unittest.TestCase): (BACKEND.TRITON), (BACKEND.BITBLAS), (BACKEND.MARLIN), - (BACKEND.IPEX), ] ) def test_save(self, backend: BACKEND): diff --git a/tests/test_save_loaded_quantized_model_ipex.py b/tests/test_save_loaded_quantized_model_ipex.py new file mode 100644 index 000000000..70a6e526a --- /dev/null +++ b/tests/test_save_loaded_quantized_model_ipex.py @@ -0,0 +1,60 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -- do not touch +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch +import tempfile # noqa: E402 +import unittest # noqa: E402 + +from gptqmodel import BACKEND, GPTQModel, get_best_device # noqa: E402 +from parameterized import parameterized # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" + +class TestSave(unittest.TestCase): + @parameterized.expand( + [ + (BACKEND.IPEX), + ] + ) + def test_save(self, backend: BACKEND): + prompt = "I am in Paris and" + device = get_best_device(backend) + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + inp = tokenizer(prompt, return_tensors="pt").to(device) + + # origin model produce correct output + origin_model = GPTQModel.load(MODEL_ID, backend=backend) + origin_model_res = origin_model.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60) + origin_model_predicted_text = tokenizer.decode(origin_model_res[0]) + + with tempfile.TemporaryDirectory() as tmpdir: + origin_model.save(tmpdir) + + # saved model produce wrong output + new_model = GPTQModel.load(tmpdir, backend=backend) + + new_model_res = new_model.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60) + new_model_predicted_text = tokenizer.decode(new_model_res[0]) + + print("origin_model_predicted_text",origin_model_predicted_text) + print("new_model_predicted_text",new_model_predicted_text) + + self.assertEqual(origin_model_predicted_text[:20], new_model_predicted_text[:20]) diff --git a/tests/test_sglang.py b/tests/test_sglang.py index 7fc4aa22f..cbc8e6344 100644 --- a/tests/test_sglang.py +++ b/tests/test_sglang.py @@ -20,10 +20,7 @@ # -- end do not touch import importlib.util # noqa: E402 -import subprocess # noqa: E402 -import sys # noqa: E402 -import torch # noqa: E402 from gptqmodel import BACKEND, GPTQModel # noqa: E402 from models.model_test import ModelTest # noqa: E402 @@ -33,10 +30,8 @@ class TestLoadSglang(ModelTest): @classmethod def setUpClass(self): # sglang set disable_flashinfer=True still import flashinfer - if importlib.util.find_spec("flashinfer") is None: - subprocess.check_call([sys.executable, "-m", "pip", "install", "flashinfer", "-i", f"https://flashinfer.ai/whl/cu{torch.version.cuda.replace('.', '')}/torch{'.'.join(torch.__version__.split('.')[:2])}"]) - if importlib.util.find_spec("sglang") is None: - subprocess.check_call([sys.executable, "-m", "pip", "install", "sglang[srt]>=0.3.2"]) + if importlib.util.find_spec("flashinfer") is None or importlib.util.find_spec("sglang") is None: + raise RuntimeError("flashinfer and sglang are required by this test. you can install them by `pip install gptqmodel['sglang']`") self.MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 4e2fad487..65ad31d3e 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -15,12 +15,15 @@ # limitations under the License. import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import tempfile # noqa: E402 import unittest # noqa: E402 + +import transformers # noqa: E402 from packaging.version import Version # noqa: E402 from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig # noqa: E402 -import transformers # noqa: E402 +from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 class TestTransformersIntegration(unittest.TestCase): @@ -39,6 +42,9 @@ def _test_load_quantized_model_gptq_v1(self, device_map): self.assertInference(model=model, tokenizer=tokenizer) + del model + torch_empty_cache() + def _test_load_quantized_model_gptq_v2(self, device_map): model_id_or_path = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" model = AutoModelForCausalLM.from_pretrained(model_id_or_path, device_map=device_map) @@ -47,6 +53,9 @@ def _test_load_quantized_model_gptq_v2(self, device_map): self.assertInference(model=model, tokenizer=tokenizer) + del model + torch_empty_cache() + def _test_quantize(self, device_map): model_id = "/monster/data/model/opt-125m" tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -65,6 +74,9 @@ def _test_quantize(self, device_map): self.assertIn("is a good", generate_str.lower()) + del model + torch_empty_cache() + def test_load_quantized_model_gptq_v1_ipex(self): self._test_load_quantized_model_gptq_v1(device_map="cpu") @@ -104,4 +116,4 @@ def generate(self, model, tokenizer, prompt=None): res = model.generate(**inp, num_beams=1, do_sample=False, min_new_tokens=10, max_new_tokens=30) output = tokenizer.decode(res[0]) print(f"Result is: >>\n{output}\n<<") - return output \ No newline at end of file + return output diff --git a/tests/test_vllm.py b/tests/test_vllm.py index d5e9c7cd3..16534b9cb 100644 --- a/tests/test_vllm.py +++ b/tests/test_vllm.py @@ -21,11 +21,8 @@ # -- end do not touch import importlib.util # noqa: E402 -import subprocess # noqa: E402 -import sys # noqa: E402 import tempfile # noqa: E402 -import torch # noqa: E402 from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402