diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 7244b6f7a..ea523f6f1 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -61,8 +61,7 @@ env: PYTORCH_CUDA_ALLOC_CONF: 'expandable_segments:True' MAX_JOBS: 8 RUNNER: 10.0.13.31 - TRANSFORMERS_DIFF_TESTS: "models/test_internlm.py,models/test_internlm2_5.py,models/test_xverse.py" - TORCH_2_5_TESTS: "test_evalplus.py,test_perplexity.py,test_q4_ipex.py,test_ipex_xpu.py,test_save_loaded_quantized_model.py,test_quant_formats.py,models/test_hymba.py" + LEGACY_TESTS: "models/test_internlm.py,models/test_internlm2_5.py,models/test_xverse.py" IGNORED_TEST_FILES: "test_tgi.py,test_gptneox.py,models/test_mixtral.py,models/test_phi_3_moe.py" GPTQMODEL_FORCE_BUILD: 1 repo: ${{ github.event.inputs.repo || github.repository }} @@ -139,7 +138,7 @@ jobs: import os import re - TRANSFORMERS_DIFF_TESTS = '${TRANSFORMERS_DIFF_TESTS}' + LEGACY_TESTS = '${LEGACY_TESTS}' IGNORED_TEST_FILES = '${IGNORED_TEST_FILES}' TEST_NAMES='${{ github.event.inputs.test_names }}' @@ -147,7 +146,7 @@ jobs: input_test_files_list = [f.strip().removesuffix('.py') for f in TEST_NAMES.split(',') if f.strip()] - transformers_test_files = [f.strip().removesuffix('.py') for f in f'{TRANSFORMERS_DIFF_TESTS}'.split(',') if f.strip()] + transformers_test_files = [f.strip().removesuffix('.py') for f in f'{LEGACY_TESTS}'.split(',') if f.strip()] transformers_test_files = [f for f in transformers_test_files if not input_test_files_list or f in input_test_files_list] all_tests = [f.removesuffix('.py') for f in os.listdir('tests/') if f.startswith('test_') and f.endswith('.py') and f.strip().removesuffix('py') not in f'{IGNORED_TEST_FILES}'] @@ -190,8 +189,8 @@ jobs: echo "Conditions:" echo "will build run: ${{ github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.torch-files != '[]' && needs.list-test-files.outputs.transformers-files != '[]' && !(needs.list-test-files.outputs.m4-files == '[]' && needs.list-test-files.outputs.m4-files == '[]') }}" - echo "will transformers_diff run: ${{ (needs.build.result == 'success' || github.event.inputs.artifact_id != '') && github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.transformers-files != '[]' }}" - echo "will torch2_5 run: ${{ (needs.build.result == 'success' || github.event.inputs.artifact_id != '') && github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.torch-files != '[]' }}" + echo "will legacy run: ${{ (needs.build.result == 'success' || github.event.inputs.artifact_id != '') && github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.transformers-files != '[]' }}" + echo "will torch run: ${{ (needs.build.result == 'success' || github.event.inputs.artifact_id != '') && github.event.inputs.m4-only != 'true' && needs.list-test-files.outputs.torch-files != '[]' }}" echo "will m4 run: ${{ (github.event.inputs.test_names == '' || contains(github.event.inputs.test_names, 'apple') || contains(github.event.inputs.test_names, 'mlx') ) && (needs.list-test-files.outputs.m4-files != '' || needs.list-test-files.outputs.m4-files != '[]') }}" build: @@ -202,6 +201,12 @@ jobs: if: github.event.inputs.m4-only != 'true' && (needs.list-test-files.outputs.torch-files != '[]' || needs.list-test-files.outputs.transformers-files != '[]') container: image: ${{ needs.check-vm.outputs.ip }}:5000/modelcloud/gptqmodel:github-ci-v5 + options: --device /dev/dri --ipc=host --runtime=nvidia --gpus all + volumes: + - /dev/dri/by-path:/dev/dri/by-path + - /home/ci/models:/monster/data/model + - /home/ci/models/huggingface:/github/home/.cache/huggingface + steps: - name: Checkout Codes uses: actions/checkout@v4 @@ -286,7 +291,7 @@ jobs: if: always() run: pip cache purge && uv cache clean && rm -rf ./* ./.* - transformers_diff: + legacy: needs: - build - list-test-files @@ -383,6 +388,7 @@ jobs: - name: Install wheel run: | + uv pip install colorlog uv pip install git+https://github.com/ModelCloud/Tokenicer -U echo "===== install optimum bitblas parameterized uvicorn =====" uv pip install optimum bitblas==0.0.1.dev13 parameterized uvicorn -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple @@ -441,7 +447,7 @@ jobs: if: always() run: pip cache purge && uv cache clean && rm -rf ./* ./.* - torch2_5: + torch: needs: - build - list-test-files @@ -541,22 +547,26 @@ jobs: - name: Install wheel run: | - if [ "${{ matrix.test_script }}" == "test_quant_formats" ] || [ "${{ matrix.test_script }}" == "test_perplexity" ]; then - echo "===== install auto_round =====" - uv pip install auto_round -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple - fi - if [ "${{ matrix.test_script }}" == "models/test_cohere2" ] || [ "${{ matrix.test_script }}" == "models/test_gemma" ]; then - echo "===== install transformers from git =====" - uv pip install -U git+https://github.com/huggingface/transformers.git -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple + uv pip install colorlog + echo "===== updateing latest transformers =====" + uv pip install -U transformers + + if [ "${{ matrix.test_script }}" == "test_quant_formats" ] || [ "${{ matrix.test_script }}" == "test_perplexity" ] || [ "${{ matrix.test_script }}" == "test_q4_bitblas" ]; then + echo "===== install auto_round bitblas==0.0.1.dev13 =====" + uv pip install auto_round bitblas==0.0.1.dev13 -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple fi + if [[ "${{ matrix.test_script }}" == *xpu* ]]; then source /etc/profile.d/pyenv.sh && pyenv activate xpu + uv pip install colorlog fi if [[ "${{ matrix.test_script }}" == *"mlx"* ]]; then uv pip install mlx_lm --no-build-isolation -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple fi + if [[ "${{ matrix.test_script }}" == "test_modelscope" ]]; then + echo "===== installing modelscope =====" uv pip install modelscope --no-build-isolation -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }} --extra-index-url https://pypi.org/simple fi @@ -622,7 +632,9 @@ jobs: - name: Clean cache if: always() - run: pip cache purge && uv cache clean && rm -rf ./* ./.* + run: | + rm ~/.cache/evalplus/*pkl || true + pip cache purge && uv cache clean && rm -rf ./* ./.* show-statistics: runs-on: [ self-hosted, xeon5 ] @@ -630,8 +642,8 @@ jobs: container: image: modelcloud/gptqmodel:alpine-ci-v1 needs: - - transformers_diff - - torch2_5 + - legacy + - torch steps: - name: Print statistics run: curl "http://10.0.14.248/gpu/get_vram_logs?id=${{ github.run_id }}" diff --git a/examples/benchmark/generation_speed.py b/examples/benchmark/generation_speed.py index add850be4..ad7eaea4c 100644 --- a/examples/benchmark/generation_speed.py +++ b/examples/benchmark/generation_speed.py @@ -195,8 +195,8 @@ def load_model_tokenizer( def benchmark_generation_speed(model, tokenizer, examples, generation_config): generation_time_list = [] num_generated_tokens_list = [] - progress_bar = ProgressBar(examples) - for example in progress_bar: + pb = ProgressBar(examples) + for example in pb: input_ids = example["input_ids"].to(model.device) start = time.time() @@ -217,7 +217,7 @@ def benchmark_generation_speed(model, tokenizer, examples, generation_config): ) num_generated_tokens_list.append(num_generated_tokens) - progress_bar.set_postfix( + pb.set_postfix( num_tokens=num_generated_tokens_list[-1], time=generation_time_list[-1], speed=f"{num_generated_tokens_list[-1] / generation_time_list[-1]:.3f} tokens/s", diff --git a/examples/quantization/basic_usage_wikitext2.py b/examples/quantization/basic_usage_wikitext2.py index 7c87a6b6f..ac1ba63d9 100644 --- a/examples/quantization/basic_usage_wikitext2.py +++ b/examples/quantization/basic_usage_wikitext2.py @@ -68,9 +68,6 @@ def main(): # with value under torch.LongTensor type. model.quantize(traindataset) - # save quantized model - model.save(quantized_model_id) - # save quantized model using safetensors model.save(quantized_model_id) diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index f015202a9..4a13698b4 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -14,13 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + from .models import GPTQModel, get_best_device from .quantization import BaseQuantizeConfig, QuantizeConfig from .utils import BACKEND from .utils.exllama import exllama_set_max_input_length from .version import __version__ -import os if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']: try: from modelscope.utils.hf_util.patcher import patch_hub diff --git a/gptqmodel/adapter/adapter.py b/gptqmodel/adapter/adapter.py index 7717a2326..5791c6948 100644 --- a/gptqmodel/adapter/adapter.py +++ b/gptqmodel/adapter/adapter.py @@ -28,7 +28,7 @@ def validate_path(self, local_only=False): raise ValueError(f"Adapter: `path` str in this context must be a local os path: actual = `{self.path}`.") # override me - def apply(self, x: torch.Tensor, out: torch.Tensor): + def apply(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor: pass # override me @@ -67,15 +67,18 @@ def parameter_keys(cls) -> List[str]: return ["lora_A", "lora_B"] def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False): - print("Lora compile") - self.apply = torch_compile(self.apply, backend=backend, mode=mode, fullgraph=fullgraph) + pass + #logger.info("Adapter: optimize (compile)") + #self.apply = torch_compile(self.apply, backend=backend, mode=mode, fullgraph=fullgraph) - def apply(self, x: torch.Tensor, out: torch.Tensor): + def apply(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor: # original code # out = out + ((x @ self.lora_A) @ self.lora_B) # fix batch for lora - if out.shape[0] > 1: + # Some kernels do not reshape x, such as marlin / exllama / exllamav2. + # out.dim() > x.dim() is used to exclude these kernels without additional processing + if out.dim() > x.dim() and out.shape[0] > 1: out_orgi_shape = out.shape out = out.view(-1, out.shape[-1]) out.add_((x @ self.lora_A) @ self.lora_B) diff --git a/gptqmodel/eora/eora.py b/gptqmodel/eora/eora.py index 140905c92..22c43c9a3 100644 --- a/gptqmodel/eora/eora.py +++ b/gptqmodel/eora/eora.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 NVIDIA +# Copyright 2024-2025 NVIDIA CORPORATION # EoRA arXiv: https://arxiv.org/abs/2410.21271 # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,7 +22,7 @@ logger = setup_logger() -def eora_process_input(input: Tensor, name: str, eigen_scaling_diag_matrix: Dict[str, torch.float32], sample_size: int): +def eora_process_input(input: Tensor, name: str, eigen_scaling_diag_matrix: Dict[str, torch.dtype], sample_size: int): inp = input[0].to(dtype=torch.float32) if inp.dim() == 2: inp = inp.unsqueeze(0) @@ -38,9 +38,9 @@ def eora_process_input(input: Tensor, name: str, eigen_scaling_diag_matrix: Dict def eora_compute_lora( device: torch.device, - w_wq_delta: Tensor, # need the w (original weight) and wq (quantized qeight) delta in float32 + w_wq_delta: Tensor, # need the w (original weight) and wq (quantized qweight) delta in float32 module: NamedModule, - eigen_scaling_diag_matrix: torch.float32, + eigen_scaling_diag_matrix: torch.dtype, rank: int) -> Tuple[Tensor, Tensor]: assert w_wq_delta.dtype == torch.float32 diff --git a/gptqmodel/looper/dequantize_processor.py b/gptqmodel/looper/dequantize_processor.py index 66d2e4637..9540627b5 100644 --- a/gptqmodel/looper/dequantize_processor.py +++ b/gptqmodel/looper/dequantize_processor.py @@ -26,7 +26,8 @@ class DequantizeProcessor(LoopProcessor): def __init__(self, quantized_modules: Dict[str, TorchQuantLinear]): - super().__init__(tokenizer=None, qcfg=None, calibration_dataset=None, calibration_dataset_concat_size=None, batch_size=1, + super().__init__(tokenizer=None, qcfg=None, calibration_dataset=None, calibration_dataset_concat_size=None, + prepare_dataset_func=None, batch_size=1, logger_board="", require_fwd=True) self.quantized_modules = quantized_modules diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py index bfe578d76..337a4adec 100644 --- a/gptqmodel/looper/eora_processor.py +++ b/gptqmodel/looper/eora_processor.py @@ -30,18 +30,20 @@ from gptqmodel.quantization.gptq import CPU from gptqmodel.utils.logger import setup_logger from gptqmodel.utils.model import move_to -from gptqmodel.utils.torch import torch_sync, torch_compile +from gptqmodel.utils.torch import torch_compile, torch_sync from torch.nn import Module logger = setup_logger() class EoraProcessor(LoopProcessor): - def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration_dataset, + def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration_dataset, prepare_dataset_func, calibration_dataset_concat_size: Optional[int], batch_size: int, logger_board: str = "", require_fwd: bool = True, ): - super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration_dataset=calibration_dataset, calibration_dataset_concat_size=calibration_dataset_concat_size, batch_size=batch_size, + super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration_dataset=calibration_dataset, + calibration_dataset_concat_size=calibration_dataset_concat_size, + prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, logger_board=logger_board, require_fwd=require_fwd) # dict: key is module name, value is the accumulated eigen_scaling_diag_matrix @@ -113,7 +115,7 @@ def tmp(_, input: Tuple[torch.Tensor, ...], output: torch.Tensor): def process(self, module: NamedModule): assert isinstance(module.adapter_cfg, Lora) - self.pb.set_description(f"EoRA gen: {module.name} in layer {module.layer_index} of {self.layer_count - 1}") + self.pb.info(f"EoRA gen: {module.name} in layer {module.layer_index} of {self.layer_count - 1}") start = time.time() diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 8fa23a3d9..dc5bca773 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -34,11 +34,13 @@ logger = setup_logger() class GPTQProcessor(LoopProcessor): - def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration_dataset, + def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration_dataset, prepare_dataset_func, calibration_dataset_concat_size: Optional[int], batch_size: int, logger_board: str = "", require_fwd: bool = True, retain_w: bool = False): - super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration_dataset=calibration_dataset, calibration_dataset_concat_size=calibration_dataset_concat_size, batch_size=batch_size, + super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration_dataset=calibration_dataset, + calibration_dataset_concat_size=calibration_dataset_concat_size, + prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, logger_board=logger_board, require_fwd=require_fwd) self.retain_w = retain_w @@ -111,7 +113,7 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): return tmp def process(self, module: NamedModule): - self.pb.set_description(f"Quantizing {module.name} in layer {module.layer_index} of {self.layer_count - 1}") + self.pb.info(f"Quantizing {module.name} in layer {module.layer_index} of {self.layer_count - 1}") gptq = self.tasks # logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}") diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index 9b01a7760..fc4a0e860 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -33,7 +33,7 @@ # LoopProcessor is a singleton(), not per module instance class LoopProcessor: - def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration_dataset, + def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration_dataset, prepare_dataset_func, calibration_dataset_concat_size: Optional[int], batch_size: int, logger_board: str = "", require_fwd: bool = True): @@ -95,7 +95,7 @@ def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration_dataset, logger.warning(f"Calibration dataset size should be more than {min_calibration_dataset_size}. " f"Current: {len(calibration_dataset)}.") - calibration_dataset = self.prepare_dataset(calibration_dataset=calibration_dataset, + calibration_dataset = prepare_dataset_func(calibration_dataset=calibration_dataset, calibration_dataset_concat_size=calibration_dataset_concat_size, batch_size=batch_size) @@ -137,131 +137,6 @@ def result_get(self, key: str, default: Any = None) -> Any: def results(self): return self._results - def prepare_dataset( - self, - calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[List[int]]], - # Setting a fixed calibration_dataset_concat_size may improve the performance of the quantized model. - calibration_dataset_concat_size: Optional[int] = None, - batch_size: int = 1, - ): - if isinstance(calibration_dataset[0], (str, list)) or ( - isinstance(calibration_dataset[0], list) and all(isinstance(x, int) for x in calibration_dataset[0])): - if self.tokenizer is None: - raise ValueError( - f"tokenizer must be provided when calibration_dataset is List[str] or List[int], type: {type(calibration_dataset[0])}") - - # Convert strings/ints to tokenized format - new_calibration_dataset = [] - for data in calibration_dataset: - # convert to tensor directly if already in token ids format (ints) - if isinstance(data, list) and all(isinstance(x, int) for x in data): - input_ids = torch.tensor([data], dtype=torch.long) - attention_mask = torch.ones_like(input_ids) - new_calibration_dataset.append({ - "input_ids": input_ids, - "attention_mask": attention_mask - }) - # call tokenizer if dataset still string format (str) - else: - tokenized = self.tokenizer(data, return_tensors="pt") - new_calibration_dataset.append({ - "input_ids": tokenized["input_ids"], - "attention_mask": tokenized["attention_mask"] - }) - calibration_dataset = new_calibration_dataset - - def _convert_tensor_to_list(tensor): - if isinstance(tensor, torch.Tensor): - if len(tensor.shape) == 1: - tensor = tensor.unsqueeze(0) - tensor = tensor.long() - return tensor.cpu().numpy().tolist() - return [tensor] - - new_calibration_dataset = [] - for example in calibration_dataset: - input_ids = _convert_tensor_to_list(example["input_ids"]) - attention_mask = _convert_tensor_to_list(example["attention_mask"]) - - new_calibration_dataset.append( - { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - ) - - if calibration_dataset_concat_size: - concatenated_data = [] - input_ids_buff = [] - attention_mask_buff = [] - current_length = 0 - - new_line = self.tokenizer(CALIBRATION_DATASET_CONCAT_CHAR, return_tensors="pt") - new_line_input_ids = _convert_tensor_to_list(new_line["input_ids"])[0] - new_line_attention_mask = _convert_tensor_to_list(new_line["attention_mask"])[0] - new_line_input_ids_len = len(new_line_input_ids) - - for example in new_calibration_dataset: - input_ids = example["input_ids"][0] - attention_mask = example["attention_mask"][0] - - if current_length + len(input_ids) + new_line_input_ids_len >= calibration_dataset_concat_size: - if len(input_ids_buff) > 0: - remaining_space = calibration_dataset_concat_size - current_length - # if there is remaining space, add the remaining input to the current block - if remaining_space > 0: - input_ids_buff.extend(new_line_input_ids) - input_ids_buff.extend(input_ids[:remaining_space - new_line_input_ids_len]) - attention_mask_buff.extend(new_line_attention_mask) - attention_mask_buff.extend(attention_mask[:remaining_space - new_line_input_ids_len]) - - concatenated_data.append({ - "input_ids": [input_ids_buff], - "attention_mask": [attention_mask_buff] - }) - else: - # if there is no remaining space, add the current block to the concatenated data - concatenated_data.append({ - "input_ids": [input_ids_buff], - "attention_mask": [attention_mask_buff] - }) - - input_ids_buff = input_ids[:calibration_dataset_concat_size] - attention_mask_buff = attention_mask[:calibration_dataset_concat_size] - current_length = len(input_ids_buff) - else: - input_ids_buff = input_ids[:calibration_dataset_concat_size] - attention_mask_buff = attention_mask[:calibration_dataset_concat_size] - current_length = len(input_ids_buff) - else: - if len(input_ids_buff) > 0: - input_ids_buff.extend(new_line_input_ids) - attention_mask_buff.extend(new_line_attention_mask) - current_length += new_line_input_ids_len - - input_ids_buff.extend(input_ids) - attention_mask_buff.extend(attention_mask) - current_length += len(input_ids) - - if input_ids_buff: - padding_length = calibration_dataset_concat_size - len(input_ids_buff) - if padding_length > 0: - input_ids_buff.extend([self.tokenizer.pad_token_id] * padding_length) - attention_mask_buff.extend([0] * padding_length) - concatenated_data.append({ - "input_ids": [input_ids_buff], - "attention_mask": [attention_mask_buff] - }) - - new_calibration_dataset = concatenated_data - - new_calibration_dataset_batched = [ - collate_data(new_calibration_dataset[start: start + batch_size], self.tokenizer.pad_token_id) - for start in range(0, len(new_calibration_dataset), batch_size) - ] - - return new_calibration_dataset_batched - def collect_memory_info(self, layer_index: int): if self.logger_task is not None: gpu_memory = get_gpu_usage_memory() diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 528d48760..47dd8cc9e 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -207,11 +207,11 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal is_lm_head_module = layer_index >= layer_count if is_lm_head_module: - quant_modules_pb.set_description("Quantizing lm_head") + quant_modules_pb.info("Quantizing lm_head") module = get_module(self.gptq_model.model, key=self.gptq_model.lm_head) layer_inputs = self.gptq_model.lm_head_pre_quantize_generate_hook(layer_inputs) else: - quant_modules_pb.set_description(f"Quantizing layer {layer_index} of {layer_count - 1}") + quant_modules_pb.info(f"Quantizing layer {layer_index} of {layer_count - 1}") module = layers[layer_index] if module.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower(): diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index e3fbf0d5c..b2937adef 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -18,12 +18,10 @@ import os +from gptqmodel.adapter.adapter import Adapter, Lora, normalize_adapter from lm_eval.utils import make_table from tokenicer import Tokenicer - -from gptqmodel.adapter.adapter import Adapter, Lora, normalize_adapter - from ..nn_modules.qlinear.torch import TorchQuantLinear from ..quantization.gptq import CPU from ..utils.torch import torch_empty_cache @@ -308,17 +306,16 @@ def from_quantized( def eval( cls, model_or_id_or_path: str=None, - tokenizer: PreTrainedTokenizerBase=None, + tokenizer: Union[PreTrainedTokenizerBase, Tokenicer]=None, tasks: Union[EVAL.LM_EVAL, EVAL.EVALPLUS, List[EVAL.LM_EVAL], List[EVAL.EVALPLUS]] = None, # set to None to fix mutable warning - framework: EVAL = EVAL.LM_EVAL, - batch_size: int = 1, + framework: Union[Type[EVAL.LM_EVAL],Type[EVAL.EVALPLUS]] = EVAL.LM_EVAL, + batch_size: Union[int, str] = 1, trust_remote_code: bool = False, output_path: Optional[str] = None, llm_backend: str = 'gptqmodel', backend: BACKEND = BACKEND.AUTO, # gptqmodel arg only random_seed: int = 1234, # only for framework=EVAL.LM_EVAL backend=vllm model_args: Dict[str, Any] = None, # only for framework=EVAL.LM_EVAL backend=vllm - **args ): if model_args is None: @@ -354,34 +351,17 @@ def eval( if isinstance(model, BaseGPTQModel): tokenizer = model.tokenizer elif isinstance(model, PreTrainedModel) or model_id_or_path.strip(): - tokenizer = Tokenicer.load(model_id_or_path).tokenizer # lm-eval checks if tokenizer's type is PretrainedTokenizer + tokenizer = Tokenicer.load(model_id_or_path) if tokenizer is None: raise ValueError("Tokenizer: Auto-loading of tokenizer failed with `model_or_id_or_path`. Please pass in `tokenizer` as argument.") - if llm_backend=="gptqmodel": # vllm loads tokenizer - model_args["tokenizer"] = tokenizer - - if isinstance(model_or_id_or_path, str): - model = None - model_id_or_path = model_or_id_or_path - elif isinstance(model_or_id_or_path, BaseGPTQModel) or isinstance(model_or_id_or_path, PreTrainedModel): - model = model_or_id_or_path - model_id_or_path = model.config.name_or_path # - else: - raise ValueError(f"`model_or_id_or_path` is invalid. expected: `model instance or str` actual: `{model_or_id_or_path}`") - - if tokenizer is None: - if isinstance(model, BaseGPTQModel): - tokenizer = model.tokenizer - elif isinstance(model, PreTrainedModel) or model_id_or_path.strip(): - tokenizer = Tokenicer.load(model_id_or_path).tokenizer # lm-eval checks if tokenizer's type is PretrainedTokenizer - - if tokenizer is None: - raise ValueError("Tokenizer: Auto-loading of tokenizer failed with `model_or_id_or_path`. Please pass in `tokenizer` as argument.") if backend=="gptqmodel": # vllm loads tokenizer - model_args["tokenizer"] = tokenizer + if isinstance(tokenizer, Tokenicer): + model_args["tokenizer"] = tokenizer.tokenizer # lm-eval checks if tokenizer's type is PretrainedTokenizer + else: + model_args["tokenizer"] = tokenizer if framework == EVAL.LM_EVAL: for task in tasks: @@ -396,9 +376,7 @@ def eval( try: from lm_eval import simple_evaluate - from lm_eval.loggers import EvaluationTracker, WandbLogger from lm_eval.models.huggingface import HFLM - from lm_eval.utils import handle_non_serializable except BaseException: raise ValueError("lm_eval is not installed. Please install via `pip install gptqmodel[eval]`.") diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 14ae4547c..dbb631e47 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -19,9 +19,8 @@ import copy import json import os -import shutil import time -from typing import Any, Dict, List, Optional, Tuple, Union, Type +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch._dynamo @@ -29,7 +28,8 @@ from packaging import version from packaging.version import Version from tokenicer import Tokenicer -from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase, modeling_utils +from transformers import (AutoModelForCausalLM, AutoProcessor, PreTrainedModel, + PreTrainedTokenizerBase, ProcessorMixin, modeling_utils) from ..adapter.adapter import Adapter from ..nn_modules.hooked_linear import replace_linear_with_hooked_linear @@ -45,7 +45,7 @@ from ..utils.model import (MODALITY, check_to_quantized, find_modules, get_device, get_module, get_module_by_name_prefix, get_moe_layer_modules, move_to, nested_move_to, pack_model) from ..utils.progress import ProgressBar -from ..utils.torch import torch_empty_cache, torch_compile +from ..utils.torch import torch_compile, torch_empty_cache from ._const import CALIBRATION_DATASET_CONCAT_CHAR, CPU, DEFAULT_MAX_SHARD_SIZE, DEVICE, SUPPORTS_MODULE_TYPES from .loader import ModelLoader from .writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, @@ -91,6 +91,9 @@ class BaseGPTQModel(nn.Module): require_dtype: Optional[str|torch.dtype] = None require_fast_init: bool = True + # some models require Processor? For example, Qwen2VLImageProcessor. + require_load_processor = False + # TODO: use a better name and what if the value is not at the config root? # allow dynamic expert n-count layer extraction # so moe model defs do not need to write out 64 layers if expert size is 64 (Qwen2Moe) @@ -152,6 +155,10 @@ def __init__( # stores all per-layer quant stats such as avg loss and processing time self.quant_log = [] + self.processor: ProcessorMixin = None + if self.require_load_processor: + self.processor = AutoProcessor.from_pretrained(model_local_path) + # apply patching of broken trust_remote_code models here if self.require_monkeypatch: self.monkey_patch() @@ -167,7 +174,7 @@ def __init__( if all(hasattr(m.adapter, name) for name in Lora.parameter_keys()): loaded_loras += 1 - logger.info(f"Adapter: `{loaded_loras}` EoRA/Lora adapters loaded.") + logger.info(f"Adapter: `{loaded_loras}` EoRA/Lora adapters loaded for `{len(qmodules)}` modules.") # print kernel info: loaded_kernels = self.kernels() @@ -378,6 +385,7 @@ def quantize( tokenizer=self.tokenizer, qcfg=self.quantize_config, calibration_dataset=calibration_dataset, + prepare_dataset_func=self.prepare_dataset, calibration_dataset_concat_size=calibration_dataset_concat_size, batch_size=batch_size, logger_board=logger_board, @@ -392,6 +400,7 @@ def quantize( tokenizer=self.tokenizer, qcfg=self.quantize_config, calibration_dataset=adapter_calibration_dataset, + prepare_dataset_func=self.prepare_dataset, calibration_dataset_concat_size=calibration_dataset_concat_size, batch_size=batch_size, logger_board=logger_board, @@ -454,6 +463,7 @@ def _eora_generate( tokenizer=self.tokenizer, qcfg=self.quantize_config, calibration_dataset=calibration_dataset, + prepare_dataset_func=self.prepare_dataset, calibration_dataset_concat_size=calibration_dataset_concat_size, batch_size=batch_size, logger_board=logger_board, @@ -816,11 +826,11 @@ def store_input_hook(_, args, kwargs): for module_index in quant_modules_pb: is_lm_head_module = module_index >= layer_count if is_lm_head_module: - quant_modules_pb.set_description("Quantizing lm_head") + quant_modules_pb.info("Quantizing lm_head") module = get_module(self.model, key=self.lm_head) layer_inputs = self.lm_head_pre_quantize_generate_hook(layer_inputs) else: - quant_modules_pb.set_description(f"Quantizing layer {module_index} of {layer_count - 1}") + quant_modules_pb.info(f"Quantizing layer {module_index} of {layer_count - 1}") module = layers[module_index] if module.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower(): @@ -962,7 +972,7 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): for name_index, name in enumerate(subset): layer_name = self.lm_head if is_lm_head_module else f"{self.layers_node}.{module_index}.{name}" - quant_modules_pb.set_description(f"Quantizing {name} in layer {module_index} of {layer_count - 1}") + quant_modules_pb.info(f"Quantizing {name} in layer {module_index} of {layer_count - 1}") # logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}") ## Need to return the quantized_weight for offloading @@ -1147,14 +1157,6 @@ def save( eora_path: Optional[str] = None, **kwargs, ): - extra_json_file_names = ["preprocessor_config.json", "chat_template.json"] - for name in extra_json_file_names: - json_path = os.path.join(self.model_local_path, name) - if os.path.exists(json_path): - os.makedirs(save_dir, exist_ok=True) - - shutil.copyfile(json_path, os.path.join(save_dir, name)) - if self.quantized: # Safetensors is unable to save tied weights, so we untie them here. Reference: https://github.com/huggingface/safetensors/issues/202 #untie_weights(self.model) diff --git a/gptqmodel/models/definitions/minicpm.py b/gptqmodel/models/definitions/minicpm.py index 092389fbc..00df27e63 100644 --- a/gptqmodel/models/definitions/minicpm.py +++ b/gptqmodel/models/definitions/minicpm.py @@ -29,5 +29,4 @@ class MiniCPMGPTQ(BaseGPTQModel): ["self_attn.v_proj"], ["self_attn.o_proj"], ["mlp.gate_proj", "mlp.up_proj","mlp.down_proj"], - ["mlp.c_proj"], ] diff --git a/gptqmodel/models/definitions/qwen2_vl.py b/gptqmodel/models/definitions/qwen2_vl.py index 3e2d0928f..14c58dc18 100644 --- a/gptqmodel/models/definitions/qwen2_vl.py +++ b/gptqmodel/models/definitions/qwen2_vl.py @@ -45,6 +45,8 @@ class Qwen2VLGPTQ(BaseGPTQModel): modality = [MODALITY.TEXT, MODALITY.IMAGE_TO_TEXT] + require_load_processor = True + quant_override_files = { "preprocessor_config.json": { "do_convert_rgb": True, diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 42dd73929..b153a8b78 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -23,6 +23,7 @@ import torch import transformers + if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']: try: from modelscope import snapshot_download @@ -33,7 +34,6 @@ from gptqmodel.adapter.adapter import Adapter from huggingface_hub import snapshot_download - from packaging.version import InvalidVersion, Version from transformers import AutoConfig, AutoTokenizer, PretrainedConfig from transformers.modeling_utils import no_init_weights @@ -412,8 +412,17 @@ def skip(*args, **kwargs): init_contexts = [no_init_weights()] with ContextManagers(init_contexts): + if config.architectures: + model_class = getattr(transformers, config.architectures[0], None) + if model_class is not None and hasattr(model_class, "_supports_flash_attn_2"): + supports_flash_attn = model_class._supports_flash_attn_2 + else: + supports_flash_attn = None + else: + supports_flash_attn = None + args = {} - if device in [DEVICE.CUDA, DEVICE.ROCM]: + if supports_flash_attn and device in [DEVICE.CUDA, DEVICE.ROCM]: if ATTN_IMPLEMENTATION in kwargs: args[ATTN_IMPLEMENTATION] = kwargs.pop(ATTN_IMPLEMENTATION, None) if USE_FLASH_ATTENTION_2 in kwargs: diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index b5c8c869b..5709ab44e 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -30,7 +30,7 @@ from huggingface_hub.constants import SAFETENSORS_WEIGHTS_FILE_PATTERN from safetensors.torch import save_file from safetensors.torch import save_file as safe_save -from transformers import AutoConfig, PreTrainedTokenizerFast +from transformers import AutoConfig, GenerationConfig, PreTrainedTokenizerFast, ProcessorMixin from transformers.modeling_utils import no_init_weights from transformers.models.auto.tokenization_auto import get_tokenizer_config from transformers.utils.generic import ContextManagers @@ -212,6 +212,41 @@ def save_quantized( model_id_or_path=self.model_local_path, ) + # --- start config save block --- + # Save quantized config + config.quantization_config = quantize_config.to_dict() + self.model.config = config + + # Hack validator so it skips validation on save + original_validator = None + if hasattr(self, "generation_config") and isinstance(self.generation_config, GenerationConfig): + try: + self.generation_config.validate() + except Exception as e: + logger.warning(f"Model `generation_config` validation failed. We will allow model save to continue but please fix discrepancies: {e}") + + original_validator = self.generation_config.validate + def dummy_validate(**kwargs): + pass + + self.generation_config.validate = dummy_validate + + # Save model config, including generation_config + # Use empty state_dict hack to bypass saving weights + self.model.save_pretrained(save_dir, state_dict={}) + + # Restore validator + if original_validator is not None: + self.generation_config.validate = original_validator + + # Save `quantize_config.json` + quantize_config.save_pretrained(save_dir) + + # Save processor related config files. For example: preprocessor_config.json, chat_template.json + if hasattr(self,"processor") and isinstance(self.processor, ProcessorMixin): + self.processor.save_pretrained(save_dir) + # --- end config save block --- + model.to(CPU) state_dict = get_state_dict_for_save(model) @@ -345,11 +380,6 @@ def save_quantized( logger.info(f"Quantized model size: {total_size_mb:.2f}MB, {total_size_gb:.2f}GB") logger.info(f"Size difference: {size_diff_mb:.2f}MB, {size_diff_gb:.2f}GB - {percent_diff:.2f}%") - config.quantization_config = quantize_config.to_dict() - config.save_pretrained(save_dir) - - quantize_config.save_pretrained(save_dir) - # need to copy .py files for model/tokenizers not yet merged to HF transformers if self.trust_remote_code: copy_py_files(save_dir, model_id_or_path=self.model_local_path) diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 7034eb2f0..96fbd1735 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -39,7 +39,7 @@ class BaseQuantLinear(nn.Module): SUPPORTS_OUT_FEATURES_DIVISIBLE_BY: List[int] = None SUPPORTS_PACK_DTYPES: List[t.dtype] = None - SUPORTS_ADAPTERS: List[Adapter] = None + SUPPORTS_ADAPTERS: List[Adapter] = None SUPPORTS_DEVICES: List[DEVICE] = None SUPPORTS_PLATFORM: List[PLATFORM] = None @@ -238,7 +238,7 @@ def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: out_features:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None, adapter:Optional[Adapter]=None) -> Tuple[bool, Optional[Exception]]: cls.verify_supports_params() - if adapter is not None and adapter.__class__ not in cls.SUPORTS_ADAPTERS: + if adapter is not None and adapter.__class__ not in cls.SUPPORTS_ADAPTERS: err = f"{cls} does not support adapter: {adapter}" return False, NotImplementedError(err) @@ -264,7 +264,8 @@ def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: if bits not in cls.SUPPORTS_BITS: err = f"{cls} only supports `{cls.SUPPORTS_BITS}` bits: actual bits = `{bits}`" return False, NotImplementedError(err) - if group_size not in cls.SUPPORTS_GROUP_SIZE: + # valid group size is set of cls.SUPPORTS_GROUP_SIZE + in_features; group_size = -1 is alias for group_size == in_features + if group_size not in cls.SUPPORTS_GROUP_SIZE and group_size != in_features: err = f"{cls} only supports `{cls.SUPPORTS_GROUP_SIZE}` group_size: actual group_size = `{group_size}`" return False, NotImplementedError(err) if sym not in cls.SUPPORTS_SYM: @@ -340,8 +341,8 @@ def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool pass class PackableQuantLinear(BaseQuantLinear): - def __init__(self, **kwargs): - super().__init__(**kwargs) + def post_init(self, **kwargs): + super().post_init(**kwargs) if self.bits in [2, 4, 8]: wf = t.tensor(list(range(0, self.pack_dtype_bits, self.bits)), dtype=t.int32).unsqueeze(0).to( @@ -412,7 +413,7 @@ def dequantize_weight(self, num_itr: int = 1): return weights - def pack(self, linear, scales, zeros, g_idx=None): + def pack(self, linear: nn.Module, scales: t.Tensor, zeros: t.Tensor, g_idx: t.Tensor=None): W = linear.weight.data.clone() if isinstance(linear, nn.Conv2d): W = W.flatten(1) diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index 12e34e0d3..8ea70a505 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -97,7 +97,7 @@ class BitBLASQuantLinear(PackableQuantLinear): SUPPORTS_DEVICES = [DEVICE.CUDA] SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] SUPPORTS_PACK_DTYPES = [torch.int32] - SUPORTS_ADAPTERS = [Lora] + SUPPORTS_ADAPTERS = [Lora] OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512] zeros_mode = "quantized" # "original" or "rescale" or "quantized" diff --git a/gptqmodel/nn_modules/qlinear/dynamic_cuda.py b/gptqmodel/nn_modules/qlinear/dynamic_cuda.py index 744b2d0b0..25fd81ff7 100644 --- a/gptqmodel/nn_modules/qlinear/dynamic_cuda.py +++ b/gptqmodel/nn_modules/qlinear/dynamic_cuda.py @@ -48,7 +48,7 @@ class DynamicCudaQuantLinear(TorchQuantLinear): SUPPORTS_DEVICES = [DEVICE.CUDA, DEVICE.ROCM] SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] SUPPORTS_PACK_DTYPES = [torch.int32] - SUPORTS_ADAPTERS = [Lora] + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "cuda" diff --git a/gptqmodel/nn_modules/qlinear/exllama.py b/gptqmodel/nn_modules/qlinear/exllama.py index 55a81cad6..5169edf40 100644 --- a/gptqmodel/nn_modules/qlinear/exllama.py +++ b/gptqmodel/nn_modules/qlinear/exllama.py @@ -70,7 +70,7 @@ class ExllamaQuantLinear(PackableQuantLinear): SUPPORTS_DEVICES = [DEVICE.CUDA, DEVICE.ROCM] SUPPORTS_PLATFORM = [PLATFORM.LINUX] SUPPORTS_PACK_DTYPES = [torch.int32] - SUPORTS_ADAPTERS = [Lora] + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "exllama" @@ -168,12 +168,15 @@ def forward(self, x): if x.size(-1) != self.in_features: x = F.pad(x, self.in_features_padding_shape) - out = ext_q4_matmul(x, self.q4, self.width) - if self.adapter: - out = self.adapter.apply(x=x, out=out) - - if self.bias is not None: - out.add_(self.bias) + if self.bias: + out = self.adapter.apply(x=x, out=ext_q4_matmul(x, self.q4, self.width)).add_(self.bias) + else: + out = self.adapter.apply(x=x, out=ext_q4_matmul(x, self.q4, self.width)) + else: + if self.bias: + out = ext_q4_matmul(x, self.q4, self.width).add_(self.bias) + else: + out = ext_q4_matmul(x, self.q4, self.width) return out.to(x_dtype) diff --git a/gptqmodel/nn_modules/qlinear/exllama_eora.py b/gptqmodel/nn_modules/qlinear/exllama_eora.py index aad56a867..6adce0c25 100644 --- a/gptqmodel/nn_modules/qlinear/exllama_eora.py +++ b/gptqmodel/nn_modules/qlinear/exllama_eora.py @@ -72,7 +72,7 @@ class ExllamaEoraQuantLinear(BaseQuantLinear): SUPPORTS_DEVICES = [DEVICE.CUDA, DEVICE.ROCM] SUPPORTS_PLATFORM = [PLATFORM.LINUX] SUPPORTS_PACK_DTYPES = [torch.int32] - SUPORTS_ADAPTERS = [Lora] + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "exllama_v2v" diff --git a/gptqmodel/nn_modules/qlinear/exllamav2.py b/gptqmodel/nn_modules/qlinear/exllamav2.py index e4853d159..2998342b3 100644 --- a/gptqmodel/nn_modules/qlinear/exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/exllamav2.py @@ -134,7 +134,7 @@ class ExllamaV2QuantLinear(BaseQuantLinear): SUPPORTS_DEVICES = [DEVICE.CUDA, DEVICE.ROCM] SUPPORTS_PLATFORM = [PLATFORM.LINUX] SUPPORTS_PACK_DTYPES = [torch.int32] - SUPORTS_ADAPTERS = [Lora] + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "exllamav2" @@ -231,13 +231,16 @@ def forward(self, x, force_cuda=False): if x.size(-1) != self.in_features: x = F.pad(x, self.in_features_padding_shape) - output = ext_gemm_half_q_half(x, self.q_handle, self.out_features, force_cuda) - if self.adapter: - output = self.adapter.apply(x=x, out=output) - - if self.bias is not None: - output.add_(self.bias) + if self.bias: + output = self.adapter.apply(x=x, out=ext_gemm_half_q_half(x, self.q_handle, self.out_features, force_cuda)).add_(self.bias) + else: + output = self.adapter.apply(x=x, out=ext_gemm_half_q_half(x, self.q_handle, self.out_features, force_cuda)) + else: + if self.bias: + output = ext_gemm_half_q_half(x, self.q_handle, self.out_features, force_cuda).add_(self.bias) + else: + output = ext_gemm_half_q_half(x, self.q_handle, self.out_features, force_cuda) return output.to(dtype=x_dtype) diff --git a/gptqmodel/nn_modules/qlinear/ipex.py b/gptqmodel/nn_modules/qlinear/ipex.py index 9121e90e7..40939c1bc 100644 --- a/gptqmodel/nn_modules/qlinear/ipex.py +++ b/gptqmodel/nn_modules/qlinear/ipex.py @@ -19,10 +19,10 @@ import torch from gptqmodel.adapter.adapter import Adapter, Lora from gptqmodel.models._const import DEVICE, PLATFORM -from .torch import TorchQuantLinear from ...utils.logger import setup_logger -from ...utils.torch import HAS_XPU +from ...utils.torch import torch_compile +from . import PackableQuantLinear logger = setup_logger() @@ -45,7 +45,7 @@ def ipex_dtype() -> torch.dtype: raise ImportError("intel_extension_for_pytorch not installed. " "Please install via `pip install intel_extension_for_pytorch`") - return torch.float16 if HAS_XPU else torch.bfloat16 + return torch.float16 def convert_dtype_torch2str(dtype): @@ -85,13 +85,13 @@ def convert_idx(self, g_idx, k): # if import GPTQShuffle failed, do nothing pass -class IPEXQuantLinear(TorchQuantLinear): +class IPEXQuantLinear(PackableQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [16, 32, 64, 128] SUPPORTS_DESC_ACT = [True, False] SUPPORTS_SYM = [True, False] SUPPORTS_SHARDS = True - SUPPORTS_TRAINING = True + SUPPORTS_TRAINING = False SUPPORTS_AUTO_PADDING = False SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] @@ -99,7 +99,7 @@ class IPEXQuantLinear(TorchQuantLinear): SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU] SUPPORTS_PLATFORM = [PLATFORM.LINUX] SUPPORTS_PACK_DTYPES = [torch.int32] - SUPORTS_ADAPTERS = [Lora] + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "ipex" @@ -114,7 +114,6 @@ def __init__( bias: bool = False, pack_dtype: torch.dtype = torch.int32, adapter: Adapter = None, - training=False, **kwargs, ): super().__init__( @@ -130,105 +129,40 @@ def __init__( register_buffers=True, **kwargs) - # FIX ME IPEX CPU has no float16 support - self.weight_dtype = torch.float16 if HAS_XPU else torch.bfloat16 - self.training = training - self.ipex_linear = None # None means not init, False means no ipex, else is good + self.weight_dtype = torch.float16 @classmethod - def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: + def validate(cls, bias: bool = False, adapter: Optional[Adapter] = None, **args) -> Tuple[bool, Optional[Exception]]: if not HAS_IPEX: return False, IPEX_ERROR_LOG return cls._validate(**args) def post_init(self): - pass - - def init_ipex_linear(self, x: torch.Tensor): - if not self.training and HAS_IPEX and not x.requires_grad: - self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, - self.in_features, self.out_features, None, self.bias, - self.group_size, self.g_idx, quant_method=QuantMethod.GPTQ_GEMM, dtype=QuantDtype.INT4) - assert self.ipex_linear is not None - else: - self.ipex_linear = False - + self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight( + self.qweight, + self.scales, + self.qzeros, + self.in_features, + self.out_features, + None, + # bias: if adapter, do not let ipex do apply bias, do it after adapter.apply + self.bias if not self.adapter else None, + self.group_size, + self.g_idx, + quant_method=QuantMethod.GPTQ_GEMM, + dtype=QuantDtype.INT4) + + @torch.no_grad() def forward(self, x: torch.Tensor): - if self.ipex_linear is None: # None is special value meaning ipex_linear init is not called yet - self.init_ipex_linear(x) - - if self.ipex_linear: - with torch.no_grad(): - outputs = self.ipex_linear(x) - return outputs - - return super().forward(x) - - -# @torch.no_grad() -# def unpack_to_8bit_signed(qweight, qzeros, bits, g_idx=None): -# wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0) -# zeros = None -# if not torch.all(torch.eq(qzeros, 2004318071 if bits == 4 else 0b01111111011111110111111101111111)): -# zp_shape = list(qzeros.shape) -# zp_shape[1] = zp_shape[1] * (32 // bits) -# -# zeros = torch.bitwise_right_shift( -# torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0) -# ).to(torch.int16 if bits == 8 else torch.int8) -# torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) -# if bits == 8: -# zeros = zeros.to(torch.uint8) -# zeros = zeros + 1 -# try: -# zeros = zeros.reshape(zp_shape) -# except Exception: -# # zeros and scales have different iteam numbers. -# # remove 1 (due to 0 + 1 in line 252) -# zeros = zeros[zeros != 1] -# zeros = zeros.reshape(zp_shape) -# -# try: -# r = torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1) -# except BaseException as e: -# print(e) -# weight = torch.bitwise_right_shift( -# r, wf.unsqueeze(-1) -# ).to(torch.int16 if bits == 8 else torch.int8) -# weight.bitwise_and_((2**bits) - 1) -# weight = weight.view(-1, weight.shape[-1]) -# -# if g_idx is not None: -# group_size = weight.shape[0] // qzeros.shape[0] -# weight2 = weight.clone() -# group_dict = {} -# for i in range(len(g_idx)): -# group_idx = g_idx[i].item() -# if group_idx not in group_dict: -# target_idx = group_idx * group_size -# group_dict[group_idx] = 0 -# else: -# group_dict[group_idx] = group_dict[group_idx] + 1 -# target_idx = group_idx * group_size + group_dict[group_idx] -# weight2[target_idx] = weight[i] -# weight = weight2 -# -# return weight, zeros -# -# -# # Copied from marlin.py -# @torch.no_grad() -# def dequantize_weight(qweight, qzeros, scales, bits): -# unpacked_qweight, unpacked_qzeros = unpack_to_8bit_signed(qweight, qzeros, bits) -# group_size = unpacked_qweight.shape[0] // scales.shape[0] -# scales = scales.repeat_interleave(group_size, dim=0) -# if unpacked_qzeros is not None: -# unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) -# else: -# unpacked_qzeros = torch.full_like(scales, 8 if bits == 4 else 128, dtype=torch.int32) -# unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales -# -# return unpacked_qweight, unpacked_qzeros + if self.adapter: + if self.bias: + return self.adapter(x=x, out=self.ipex_linear(x)).add_(self.bias) + else: + return self.adapter(x=x, out=self.ipex_linear(x)) + else: + return self.ipex_linear(x) + def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False): + self.forward = torch_compile(self.forward, backend=backend, mode=mode, fullgraph=fullgraph) __all__ = ["IPEXQuantLinear"] diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py index 015225f64..b2faa0366 100644 --- a/gptqmodel/nn_modules/qlinear/marlin.py +++ b/gptqmodel/nn_modules/qlinear/marlin.py @@ -171,7 +171,7 @@ class MarlinQuantLinear(BaseQuantLinear): SUPPORTS_DEVICES = [DEVICE.CUDA] SUPPORTS_PLATFORM = [PLATFORM.LINUX] SUPPORTS_PACK_DTYPES = [torch.int32] - SUPORTS_ADAPTERS = [Lora] + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "marlin" @@ -389,10 +389,13 @@ def forward(self, A: torch.Tensor): output_size_per_partition=self.out_features, input_size_per_partition=self.in_features, is_k_full=self.is_k_full, - bias=self.bias) + bias=self.bias if not self.adapter else None) if self.adapter: - output = self.adapter.apply(x=A, out=output) + if self.bias: + output = self.adapter.apply(x=A, out=output).add_(self.bias) + else: + output = self.adapter.apply(x=A, out=output) return output diff --git a/gptqmodel/nn_modules/qlinear/torch.py b/gptqmodel/nn_modules/qlinear/torch.py index e8c4654c2..632243763 100644 --- a/gptqmodel/nn_modules/qlinear/torch.py +++ b/gptqmodel/nn_modules/qlinear/torch.py @@ -43,7 +43,7 @@ class TorchQuantLinear(PackableQuantLinear): SUPPORTS_DEVICES = [DEVICE.ALL] SUPPORTS_PLATFORM = [PLATFORM.ALL] SUPPORTS_PACK_DTYPES = [torch.int8, torch.int16, torch.int32] - SUPORTS_ADAPTERS = [Lora] + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "torch" @@ -97,8 +97,8 @@ def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool # compile dequantize self.dequantize_weight = torch_compile(self.dequantize_weight, backend=backend, mode=mode, fullgraph=fullgraph) - #if self.adapter: - # self.adapter.g_compile(backend=backend, mode=mode, fullgraph=fullgraph) + if self.adapter: + self.adapter.optimize(backend=backend, mode=mode, fullgraph=fullgraph) def forward(self, x: torch.Tensor): if x.size(-1) != self.padded_infeatures: diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index 086dca620..7b49aca8d 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -61,7 +61,7 @@ class TritonV2QuantLinear(PackableQuantLinear, TritonModuleMixin): SUPPORTS_DEVICES = [DEVICE.CUDA, DEVICE.XPU] SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] SUPPORTS_PACK_DTYPES = [torch.int32, torch.int16, torch.int8] - SUPORTS_ADAPTERS = [Lora] + SUPPORTS_ADAPTERS = [Lora] # for transformers/optimum tests compat QUANT_TYPE = "tritonv2" diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index fb003329a..8299863d8 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -195,26 +195,26 @@ def __post_init__(self): if isinstance(self.pack_dtype, str): self.pack_dtype = self.pack_dtype.lower() if self.pack_dtype not in ["int64", "int32", "int16", "int8"]: - raise ValueError(f"Unsupported pack_dtype: {self.pack_dtype}") + raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {self.pack_dtype}") self.pack_dtype = getattr(torch, self.pack_dtype) elif isinstance(self.pack_dtype, torch.dtype): if self.pack_dtype not in [torch.int64, torch.int32, torch.int16, torch.int8]: - raise ValueError(f"Unsupported pack_dtype: {self.pack_dtype}") + raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {self.pack_dtype}") else: - raise ValueError(f"Unsupported pack_dtype: {self.pack_dtype}") + raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {self.pack_dtype}") # validate quant method and format is matched valid_formats = QUANT_METHOD_FORMAT_MAPPING.get(self.quant_method, None) if valid_formats is None: - raise ValueError(f"Unsupported quantization method: {self.quant_method}") + raise ValueError(f"QuantizeConfig: Unsupported `quant_method`: {self.quant_method}") if self.format not in valid_formats: raise ValueError( - f"The checkpoint format used is {self.format}, and the quantization method is {self.quant_method}. " + f"QuantizeConfig: checkpoint `format` used is {self.format}, and the quantization method is {self.quant_method}. " ) if self.bits not in fields_info[0].metadata["choices"]: - raise ValueError(f"only support quantize to {fields_info[0].metadata['choices']} bits.") + raise ValueError(f"QuantizeConfig: `bits` must be in the set of `{fields_info[0].metadata['choices']}`.") if self.dynamic is not None: self.dynamic = { @@ -225,33 +225,33 @@ def __post_init__(self): for layer, layer_dict in self.dynamic.items(): for key, value in layer_dict.items(): if key == "bits" and value not in fields_info[0].metadata["choices"]: - raise ValueError(f"Layer {layer}: only support quantize to {fields_info[0].metadata['choices']} bits.") + raise ValueError(f"QuantizeConfig: Layer `{layer}` only support quantization of `{fields_info[0].metadata['choices']}` bits.") elif key == "group_size" and value != -1 and value <= 0: - raise ValueError("unless equal to -1, group_size must greater then 0.") + raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.") if self.group_size != -1 and self.group_size <= 0: - raise ValueError("unless equal to -1, group_size must greater than 0.") + raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.") if not (0 < self.damp_percent < 1): - raise ValueError("damp_percent must between 0 and 1.") + raise ValueError("QuantizeConfig: `damp_percent` must between 0 and 1.") if self.damp_auto_increment < 0: - raise ValueError("damp_auto_increment must greater than 0.") + raise ValueError("QuantizeConfig:: `damp_auto_increment` must greater than 0.") # validate meta if self.meta is not None: if not isinstance(self.meta, dict): - raise ValueError("meta must be a dictionary") + raise ValueError("QuantizeConfig: `meta` must be a dictionary") for key, value in self.meta.items(): if not isinstance(key, str): - raise ValueError("Keys in the meta dictionary must be strings") + raise ValueError("QuantizeConfig: `meta` keys must be strings") else: self.meta = {} # adapter normalize self.adapter = normalize_adapter(self.adapter) - print(f"adapter: {self.adapter}") + #print(f"adapter: {self.adapter}") def extension_set(self, key: str, value: Any): if self.adapter is None: @@ -313,9 +313,9 @@ def from_quant_config(cls, quantize_cfg, format: str = None): # compat: format can be passed in via from_quantized() if field missing from json if format: if format not in valid_formats: - raise ValueError(f"Unknown quantization checkpoint format: {format}.") + raise ValueError(f"QuantizeConfig: Unknown quantization checkpoint format: {format}.") if quantize_cfg.get(FORMAT_FIELD_JSON): - raise ValueError("Conflict: quantization format is passed in and also exists in model config.") + raise ValueError("QuantizeConfig: Conflicting quantization format passed in manually and also exists in model config.") # compat: warn if checkpoint_format is missing elif quantize_cfg.get(FORMAT_FIELD_JSON) is None: format_auto_inferred = True @@ -340,7 +340,7 @@ def from_quant_config(cls, quantize_cfg, format: str = None): if val in {FORMAT.GPTQ, FORMAT.GPTQ_V2, FORMAT.MARLIN, FORMAT.BITBLAS}: normalized[key] = val else: - raise ValueError(f"Unknown quantization format: {val}.") + raise ValueError(f"QuantizeConfig: Unknown quantization format: `{val}`.") elif key == QUANT_METHOD_FIELD: val = val.lower() # compat: some hf models use quant_method=marlin or bitblas @@ -349,7 +349,7 @@ def from_quant_config(cls, quantize_cfg, format: str = None): elif val == FORMAT.BITBLAS: normalized[FORMAT_FIELD_CODE] = FORMAT.BITBLAS elif val not in {QUANT_METHOD.GPTQ, QUANT_METHOD.AUTO_ROUND}: - raise ValueError(f"Unknown quantization method: {val}.") + raise ValueError(f"QuantizeConfig: Unknown quantization method: `{val}`.") else: normalized[QUANT_METHOD_FIELD] = val elif key == FORMAT_FIELD_COMPAT_MARLIN and val: @@ -357,10 +357,10 @@ def from_quant_config(cls, quantize_cfg, format: str = None): elif key in field_names: normalized[key] = val else: - logger.info(f"Ignoring unknown parameter in the quantization configuration: {key}.") + logger.info(f"QuantizeConfig: Ignoring unknown parameter in the quantization configuration: {key}.") if format_auto_inferred: - logger.info(f"`{FORMAT_FIELD_JSON}` is missing from the quantization configuration and is automatically inferred to {normalized[FORMAT_FIELD_CODE]}") + logger.info(f"QuantizeConfig: `{FORMAT_FIELD_JSON}` is missing from the quantization configuration and is automatically inferred to {normalized[FORMAT_FIELD_CODE]}") if normalized[FORMAT_FIELD_CODE] in {FORMAT.BITBLAS}: # AWQ and Marlin do not reorder the rows. @@ -368,8 +368,7 @@ def from_quant_config(cls, quantize_cfg, format: str = None): if "sym" not in normalized: logger.warning( - "The quantization configuration does not contain an entry `sym` (symmetric quantization). " - "This may result in silent errors. Defaulting to `sym=True`." + "QuantizeConfig: config does not contain `sym` (symmetric quantization). This may result in silent errors. Defaulting to `sym=True`." ) return cls(**normalized) @@ -389,7 +388,7 @@ def from_pretrained(cls, save_dir: str, **kwargs): if resolved_config_file is None: raise ValueError( - "No quantize_config.json, quant_config.json or config.json file was found in the model repository." + "QuantizeConfig: No quantize_config.json, quant_config.json or config.json file was found in the model repository." ) with open(resolved_config_file, "r", encoding="utf-8") as f: @@ -510,4 +509,4 @@ def to_dict(self): class BaseQuantizeConfig(QuantizeConfig): def __init__(self, **kwargs): super().__init__(**kwargs) - logger.warning("BaseQuantizeConfig is re-named and pending deprecation. Please use `QuantizeConfig` instead.") + logger.warning("QuantizeConfig: BaseQuantizeConfig is re-named and pending deprecation. Please use `QuantizeConfig` instead.") diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py index 6d9367e53..aa0b6f400 100644 --- a/gptqmodel/utils/backend.py +++ b/gptqmodel/utils/backend.py @@ -26,7 +26,7 @@ class BACKEND(str, Enum): TRITON = "triton" EXLLAMA_V1 = "exllama_v1" EXLLAMA_V2 = "exllama_v2" - EXLLAMA_V2V = "exllama_v2v" + # EXLLAMA_EORA = "exllama_eora" MARLIN = "marlin" BITBLAS = "bitblas" IPEX = "ipex" diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py index cf562a262..5acf5f7e3 100644 --- a/gptqmodel/utils/bitblas.py +++ b/gptqmodel/utils/bitblas.py @@ -92,7 +92,7 @@ def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool # Note that due to tvm compilation of per layer modules shapes, the first layer loop is # relatively much slower if caching is not available. estimate time remaining is highly inaccurate - for name, module in ProgressBar(model.named_modules(), desc=message, total=len(list(model.named_modules()))): + for name, module in ProgressBar(model.named_modules(), info=message, total=len(list(model.named_modules()))): if not isinstance(module, model_quantlinear): continue diff --git a/gptqmodel/utils/eval.py b/gptqmodel/utils/eval.py index 75e50b6ec..60c0eadad 100644 --- a/gptqmodel/utils/eval.py +++ b/gptqmodel/utils/eval.py @@ -21,6 +21,7 @@ from .evalplus import patch_evalplus + class EVAL: class LM_EVAL(str, Enum): ARC_CHALLENGE = "arc_challenge" diff --git a/gptqmodel/utils/evalplus.py b/gptqmodel/utils/evalplus.py index 368c91fa0..c873e831b 100644 --- a/gptqmodel/utils/evalplus.py +++ b/gptqmodel/utils/evalplus.py @@ -15,6 +15,7 @@ def patch_evalplus(model): if isinstance(model, BaseGPTQModel) or isinstance(model, PreTrainedModel): model.strip = types.MethodType(patch_strip, model) model.__str__ = types.MethodType(patch_tostring, model) + model.__repr__ = types.MethodType(patch_tostring, model) import torch from evalplus.provider.base import DecoderBase diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index ce79a638f..da7a5a83a 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -26,7 +26,6 @@ from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear from ..nn_modules.qlinear.dynamic_cuda import DynamicCudaQuantLinear from ..nn_modules.qlinear.exllama import ExllamaQuantLinear -from ..nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear from ..nn_modules.qlinear.ipex import IPEXQuantLinear from ..nn_modules.qlinear.marlin import MarlinQuantLinear @@ -53,8 +52,8 @@ }) FORMAT_DICT = { - FORMAT.GPTQ: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2V, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.IPEX, BACKEND.TORCH], - FORMAT.GPTQ_V2: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2V, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.TORCH], + FORMAT.GPTQ: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.IPEX, BACKEND.TORCH], # BACKEND.EXLLAMA_EORA + FORMAT.GPTQ_V2: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.TORCH], # , BACKEND.EXLLAMA_EORA FORMAT.MARLIN: [BACKEND.MARLIN], FORMAT.BITBLAS: [BACKEND.BITBLAS], FORMAT.IPEX: [BACKEND.IPEX], @@ -231,8 +230,8 @@ def select_quant_linear( qlinear = BitBLASQuantLinear elif backend == BACKEND.MARLIN: qlinear = MarlinQuantLinear - elif backend == BACKEND.EXLLAMA_V2V: - qlinear = ExllamaEoraQuantLinear + # elif backend == BACKEND.EXLLAMA_EORA: + # qlinear = ExllamaEoraQuantLinear elif backend == BACKEND.EXLLAMA_V2: qlinear = ExllamaV2QuantLinear elif backend == BACKEND.EXLLAMA_V1: @@ -242,7 +241,7 @@ def select_quant_linear( elif backend == BACKEND.IPEX: from ..nn_modules.qlinear.ipex import HAS_IPEX if not HAS_IPEX: - raise ValueError("IPEX is not available. please install it with `pip install gptqmodel['ipex']`") + raise ValueError("IPEX is not available. Please install it by `pip install gptqmodel['ipex']`") from device_smi import Device diff --git a/gptqmodel/utils/logger.py b/gptqmodel/utils/logger.py index 0b3f8e92b..bfde3a9bb 100644 --- a/gptqmodel/utils/logger.py +++ b/gptqmodel/utils/logger.py @@ -15,21 +15,75 @@ # limitations under the License. import logging +import sys +from typing import Callable + +from colorlog import ColoredFormatter # global static/shared logger instance logger = None +last_logging_src = 1 # one for logger, 2 for progressbar + +def update_logging_src(src: int): + global last_logging_src + last_logging_src = src def setup_logger(): global logger if logger is not None: return logger + class CustomLogger(logging.Logger): + def critical(self, msg, *args, **kwargs): + op = super().critical + self._process(op, msg, *args, **kwargs) + + def warning(self, msg, *args, **kwargs): + op = super().warning + self._process(op, msg, *args, **kwargs) + + def debug(self, msg, *args, **kwargs): + op = super().debug + self._process(op, msg, *args, **kwargs) + + def info(self, msg, *args, **kwargs): + op = super().info + self._process(op, msg, *args, **kwargs) + + def _process(self, op: Callable, msg, *args, **kwargs): + global last_logging_src + if last_logging_src == 2: + print(" ", flush=True) + last_logging_src = 1 + op(msg, *args, **kwargs) + + logging.setLoggerClass(CustomLogger) + logger = logging.getLogger(__name__) - handler = logging.StreamHandler() - formatter = logging.Formatter("%(levelname)s - %(message)s") - handler.setFormatter(formatter) logger.propagate = False - logger.addHandler(handler) logger.setLevel(logging.DEBUG) + # Create a colored formatter + formatter = ColoredFormatter( + "%(log_color)s%(levelname)-8s%(reset)s %(message)s", + datefmt=None, + reset=True, + log_colors={ + 'DEBUG': 'cyan', + 'INFO': 'green', + 'WARNING': 'yellow', + 'ERROR': 'red', + 'CRITICAL': 'red,bg_white', + }, + secondary_log_colors={}, + style='%' + ) + + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + handler.flush = sys.stdout.flush + logger.addHandler(handler) + return logger + + diff --git a/gptqmodel/utils/marlin.py b/gptqmodel/utils/marlin.py index 41a902629..42b1edb71 100644 --- a/gptqmodel/utils/marlin.py +++ b/gptqmodel/utils/marlin.py @@ -110,7 +110,7 @@ def convert_to_marlin( # TODO: load directly Marlin QuantLinear. message = "Overriding QuantLinear layers to use Marlin's QuantLinear" - for name, module in ProgressBar(model.named_modules(), desc=message, total=len(list(model.named_modules()))): + for name, module in ProgressBar(model.named_modules(), info=message, total=len(list(model.named_modules()))): if not isinstance(module, model_quantlinear): continue diff --git a/gptqmodel/utils/mlx.py b/gptqmodel/utils/mlx.py index 83fa43374..8d790de19 100644 --- a/gptqmodel/utils/mlx.py +++ b/gptqmodel/utils/mlx.py @@ -51,7 +51,7 @@ def convert_gptq_to_mlx_weights(model_id_or_path: str, model: Union[PreTrainedMo n = 1 pb = ProgressBar(model.named_modules(), prefix="Converting to mlx:", total=len(list(model.named_modules()))) for name, module in pb: - pb.set_description(f"{name}") + pb.info(f"{name}") if isinstance(module, TorchQuantLinear): weights[f"{name}.weight"] = mx.array( module.dequantize_weight().T.detach().to("cpu", torch.float16).numpy() diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index ec59fbcc1..b2571575e 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -26,7 +26,7 @@ import shutil from concurrent.futures import ThreadPoolExecutor from enum import Enum -from typing import Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import accelerate import threadpoolctl as tctl @@ -175,7 +175,7 @@ def make_quant( pack: bool = False, device: DEVICE = None, from_quantized: bool = False, -) -> BaseQuantLinear: +) -> Type[BaseQuantLinear]: bits = qcfg.bits group_size =qcfg.group_size @@ -205,15 +205,15 @@ def make_quant( logger.info(f"Kernel: candidates -> `{quant_linear_candidates}`") # loop over actual QLinear init, catch errors and use fallbacks if applicable - for linear in quant_linear_candidates: + for cls in quant_linear_candidates: try: # if linear is not selectedQLinear: # logger.info(f"make_quant: Faild linear: `{selectedQLinear}` failed, trying to use fallback: `{linear}`") # else: # logger.info("make_quant: Testing linear: {linear}") - linear_instance = create_quant_layer( - linear=linear, + linear_cls = create_quant_layer( + linear_cls=cls, bits=bits, desc_act=desc_act, dynamic=dynamic, @@ -226,10 +226,11 @@ def make_quant( pack_dtype=pack_dtype, adapter=qcfg.adapter, ) - logger.info(f"Kernel: selected -> `{linear}`.") - return linear_instance + logger.info(f"Kernel: selected -> `{linear_cls}`.") + return linear_cls except NotImplementedError as e: - logger.info(f"Kernel: skipped -> `{linear}`.") + logger.info(f"Kernel: skipped -> `{linear_cls}`.") + # only fallback to other quant linears when backend is auto. if backend not in [BACKEND.AUTO, BACKEND.AUTO_TRAINABLE]: raise e @@ -238,7 +239,7 @@ def make_quant( def create_quant_layer( - linear: nn.Module, + linear_cls: Type[BaseQuantLinear], bits: int, desc_act: bool, dynamic, @@ -250,10 +251,9 @@ def create_quant_layer( lm_head_name: str, pack_dtype: torch.dtype, adapter: Optional[Adapter] = None, - - ) -> BaseQuantLinear: - if isinstance(module, linear): - return linear +) -> Type[BaseQuantLinear]: + if isinstance(module, linear_cls): + return linear_cls for name, submodule in module.named_modules(): # skip non-quantized modules if name not in quant_result: @@ -306,7 +306,7 @@ def create_quant_layer( # when loading a quantized model, device is target device passed in GPTQModel.load() # check in_features and out_features validate - _, err = linear.validate( + _, err = linear_cls.validate( bits=tmp_bits, group_size=tmp_group_size, desc_act=tmp_desc_act, @@ -320,7 +320,7 @@ def create_quant_layer( if err is not None: raise err - new_layer = linear( + new_layer = linear_cls( bits=tmp_bits, group_size=tmp_group_size, desc_act=tmp_desc_act, @@ -336,7 +336,7 @@ def create_quant_layer( ) new_layer.device = ori_layer_device recurse_setattr(module, name, new_layer.to(ori_layer_device)) - return linear + return linear_cls # public/stable api exposed to transformer/optimum def hf_convert_gptq_v1_to_v2_format( @@ -502,7 +502,7 @@ def pack_module(name, qModules, quant_result, layers, pbar=None): # Limit pack() thread usage to avoid auto-parallizataion regression with tctl.threadpool_limits(limits=1): if pbar: - pbar.set_description(f"Packing {name}") + pbar.info(f"Packing {name}") r = quant_result[name] scale, zero, g_idx = r.get("scale"), r.get("zero"), r.get("g_idx") # TODO FIX ME: use const, not string for field names layer_device = qModules[name].device @@ -542,25 +542,15 @@ def pack_model( dynamic=dynamic, pack_dtype=pack_dtype, ) - quantLinear = select_quant_linear( - bits=bits, - dynamic=dynamic, - group_size=group_size, - desc_act=desc_act, - sym=sym, - backend=backend, - format=format, - pack=True, - pack_dtype=pack_dtype, - ) model.to(CPU) logger.info("Packing model...") modules = find_modules(model) + modules = {n: modules[n] for n in quant_result} - make_quant( + quant_linear_cls = make_quant( model, quant_result=quant_result, qcfg=qcfg, @@ -568,7 +558,11 @@ def pack_model( lm_head_name=lm_head_name, pack=True, ) - qModules = find_modules(model, [quantLinear]) + + qModules = find_modules(model, [quant_linear_cls]) + + assert len(qModules) > 0, f"No quantizeed modules[{quant_linear_cls}] found in the model." + names = list(qModules.keys()) if parallel_packing: @@ -585,7 +579,7 @@ def wrapper(name): pass logger.info("Model packed.") - return quantLinear + return quant_linear_cls def verify_model_hash(file_path: str, verify_hash: str): diff --git a/gptqmodel/utils/perplexity.py b/gptqmodel/utils/perplexity.py index f5073aee3..653adb776 100644 --- a/gptqmodel/utils/perplexity.py +++ b/gptqmodel/utils/perplexity.py @@ -149,7 +149,7 @@ def calculate(self, n_ctx=512, n_batch=512): curr_ppl = 0 all_perplexity = [] - with ProgressBar(range(len(tokens[0]) // n_ctx), desc="Perplexity: - ") as progress: + with ProgressBar(range(len(tokens[0]) // n_ctx), info="Perplexity: - ") as progress: for i in progress: # Process each batch of tokens nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count) @@ -157,7 +157,7 @@ def calculate(self, n_ctx=512, n_batch=512): # Calculate and display the current perplexity curr_ppl = np.exp(nll / count) all_perplexity.append(curr_ppl) - progress.set_description(f"Perplexity: {curr_ppl:.4f}") + progress.info(f"Perplexity: {curr_ppl:.4f}") return all_perplexity diff --git a/gptqmodel/utils/progress.py b/gptqmodel/utils/progress.py index 6bd63d6ca..19efeb9fc 100644 --- a/gptqmodel/utils/progress.py +++ b/gptqmodel/utils/progress.py @@ -15,9 +15,15 @@ # limitations under the License. import datetime +import os +import sys import time +from typing import Iterable from warnings import warn +from gptqmodel.utils.logger import setup_logger, update_logging_src + +logger = setup_logger() class ProgressBarWarning(Warning): def __init__(self, msg, fp_write=None, *a, **k): @@ -27,7 +33,17 @@ def __init__(self, msg, fp_write=None, *a, **k): super().__init__(msg, *a, **k) class ProgressBar: - def __init__(self, iterable=None, total=None, prefix='', bar_length=40, fill='█', desc=""): + def __init__(self, + iterable: Iterable=None, + total=None, + prefix:str = '', + bar_length:int =60, + fill:str = '█', + info:str = ""): + + # max info length over the life ot the pb + self.max_info_length = len(info) + if total is None and iterable is not None: try: total = len(iterable) @@ -45,20 +61,43 @@ def __init__(self, iterable=None, total=None, prefix='', bar_length=40, fill=' self.prefix = prefix self.bar_length = bar_length self.fill = fill - self.description = desc - self.current = 0 + self.info_text = info + self.current_iteration = 0 self.time = time.time() - def set_description(self, description): - self.description = description + def info(self, info:str): + if len(info) > self.max_info_length: + self.max_info_length = len(info) + + self.info_text = info - def progress(self, iteration = None): + def progress(self, iteration:int = None): if not iteration: - iteration = self.current - percent = ("{0:.1f}").format(100 * (iteration / float(len(self)))) - filled_length = int(self.bar_length * iteration // len(self)) - bar = self.fill * filled_length + '-' * (self.bar_length - filled_length) - self.log(bar, f"{self.calc_time(iteration)} [{iteration}/{len(self)}] {percent}%") + iteration = self.current_iteration + + columns, _ = terminal_size() + bar_length = columns + bar_length -= len(self.prefix) # +1 for space + bar_length -= len(self.info_text) + + percent_num = iteration / float(len(self)) + percent = ("{0:.1f}").format(100 * (percent_num)) + log = f"{self.calc_time(iteration)} [{iteration}/{len(self)}] {percent}%" + + bar_length -= len(log) + bar_length -= 5 # space + | chars + + # calculate padding + if len(self.info_text) < self.max_info_length: + padding = " " * (self.max_info_length - len(self.info_text)) + else: + padding = "" + + bar_length -= len(padding) + + filled_length = int(bar_length * iteration // len(self)) + bar = self.fill * filled_length + '-' * (bar_length - filled_length) + self.log(bar=bar, log=log, padding=padding, end='\n' if percent_num >= 1.0 else '') def calc_time(self, iteration): used_time = int(time.time() - self.time) @@ -66,8 +105,14 @@ def calc_time(self, iteration): remaining = str(datetime.timedelta(seconds=int((used_time / max(iteration, 1)) * len(self)))) return f"{formatted_time} / {remaining}" - def log(self, bar, log): - print(f'\r{self.prefix} {self.description} |{bar}| {log}', end='', flush=True) + def log(self, bar:str, log:str, padding:str = "", end: str = ""): + # print(f'\r{self.prefix} {self.info_text} |{bar}| {log}', end='', flush=True) + if self.prefix: + print(f'\r{self.prefix} {self.info_text}{padding} |{bar}| {log}', end=end, flush=True) + else: + print(f'\r{self.info_text}{padding} |{bar}| {log}', end=end, flush=True) + + update_logging_src(src=2) # let logger now we logged def __bool__(self): if self.total is not None: @@ -84,6 +129,7 @@ def __len__(self): else self.iterable.__length_hint__() if hasattr(self.iterable, "__length_hint__") else getattr(self, "total", None)) + # TODO FIXME: I have no cluse why the try/catch is catching nothing here def __reversed__(self): try: orig = self.iterable @@ -102,6 +148,7 @@ def __contains__(self, item): def __enter__(self): return self + # TODO FIXME: I don't understand the exception here. What are we catching? yield error? def __exit__(self, exc_type, exc_value, traceback): try: self.close() @@ -125,12 +172,60 @@ def __iter__(self): iterable = self.iterable for obj in iterable: - self.current+=1 + self.current_iteration+=1 self.progress() yield obj + + self.progress() return def close(self): - self.log(f"{'-' * self.bar_length}", "100.0%") - + pass + #self.log(f"{self.fill * self.bar_length}", "100.0%", end="\n") + +# copied from github.com/onsim/shutils +def terminal_size(fallback=(80, 24)): + """Get the size of the terminal window. + + For each of the two dimensions, the environment variable, COLUMNS + and LINES respectively, is checked. If the variable is defined and + the value is a positive integer, it is used. + + When COLUMNS or LINES is not defined, which is the common case, + the terminal connected to sys.__stdout__ is queried + by invoking os.get_terminal_size. + + If the terminal size cannot be successfully queried, either because + the system doesn't support querying, or because we are not + connected to a terminal, the value given in fallback parameter + is used. Fallback defaults to (80, 24) which is the default + size used by many terminal emulators. + + The value returned is a named tuple of type os.terminal_size. + """ + # columns, lines are the working values + try: + columns = int(os.environ['COLUMNS']) + except (KeyError, ValueError): + columns = 0 + + try: + lines = int(os.environ['LINES']) + except (KeyError, ValueError): + lines = 0 + + # only query if necessary + if columns <= 0 or lines <= 0: + try: + size = os.get_terminal_size(sys.__stdout__.fileno()) + except (AttributeError, ValueError, OSError): + # stdout is None, closed, detached, or not a terminal, or + # os.get_terminal_size() is unsupported + size = os.terminal_size(fallback) + if columns <= 0: + columns = size.columns or fallback[0] + if lines <= 0: + lines = size.lines or fallback[1] + + return (columns, lines) diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 9fd988181..dbe8c69bb 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -18,9 +18,8 @@ from typing import Callable, Union import torch -from packaging.version import Version - from gptqmodel.utils.logger import setup_logger +from packaging.version import Version HAS_CUDA = False HAS_XPU = False diff --git a/requirements.txt b/requirements.txt index 56ab58ea9..6fab58144 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ datasets>=3.2.0 numpy>=2.2.2 torch>=2.2.0 safetensors>=0.5.2 -transformers>=4.48.3 +transformers>=4.49.0 threadpoolctl>=3.5.0 packaging>=24.2 device-smi==0.3.3 @@ -12,4 +12,5 @@ pillow>=11.1.0 hf_transfer>=0.1.9 huggingface_hub>=0.28.1 lm-eval==0.4.7 -tokenicer>=0.0.2 +colorlog>=6.9.0 +tokenicer>=0.0.2 \ No newline at end of file diff --git a/setup.py b/setup.py index 5b3d2a947..e9bd9084e 100644 --- a/setup.py +++ b/setup.py @@ -211,20 +211,20 @@ def get_version_tag() -> str: ] extensions = [ - cpp_ext.CUDAExtension( - 'gptqmodel_exllama_eora', - [ - "gptqmodel_ext/exllama_eora/q_gemm.cu", - "gptqmodel_ext/exllama_eora/pybind.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - #include_dirs=[os.path.abspath("."), os.path.abspath("eora_test")], - # extra_compile_args={ - # 'cxx': ['-std=c++20'], - # 'nvcc': ['-std=c++20'], - # } - ), + # cpp_ext.CUDAExtension( + # 'gptqmodel_exllama_eora', + # [ + # "gptqmodel_ext/exllama_eora/q_gemm.cu", + # "gptqmodel_ext/exllama_eora/pybind.cu", + # ], + # extra_link_args=extra_link_args, + # extra_compile_args=extra_compile_args, + # #include_dirs=[os.path.abspath("."), os.path.abspath("eora_test")], + # # extra_compile_args={ + # # 'cxx': ['-std=c++20'], + # # 'nvcc': ['-std=c++20'], + # # } + # ), cpp_ext.CUDAExtension( "gptqmodel_cuda_64", [ diff --git a/tests/benchmark/benchmark_test.py b/tests/benchmark/benchmark_test.py index b995bd698..ff84a693f 100644 --- a/tests/benchmark/benchmark_test.py +++ b/tests/benchmark/benchmark_test.py @@ -66,7 +66,7 @@ def benchmark(self, backend, device, tokens_per_second: int, warmup_iter: int = times = [] pb = ProgressBar(range(self.NUM_RUNS)) for i in pb: - pb.set_description(f"run index {i} of {self.NUM_RUNS -1}") + pb.info(f"run index {i} of {self.NUM_RUNS - 1}") start_time = time.time() _ = model.generate(**inp,min_new_tokens=self.MIN_NEW_TOKENS, max_new_tokens=self.MAX_NEW_TOKENS) diff --git a/tests/cpu/test_progress_bar.py b/tests/cpu/test_progress_bar.py new file mode 100644 index 000000000..30cd73f88 --- /dev/null +++ b/tests/cpu/test_progress_bar.py @@ -0,0 +1,14 @@ +import unittest +from time import sleep + +from gptqmodel.utils.progress import ProgressBar + + +class TestBits(unittest.TestCase): + def test_progress_bar(self): + pb = ProgressBar(range(1,101)) + for i in pb: + pb.info(f"Test run index {i} of 100") + sleep(0.1) + + diff --git a/tests/inference_speed.py b/tests/inference_speed.py index 08e073308..7281aa41f 100644 --- a/tests/inference_speed.py +++ b/tests/inference_speed.py @@ -70,7 +70,7 @@ def inference(self, model_path, backend, tokens_per_second, assert_result=True, if warmup_runs > 0: pb = ProgressBar(range(warmup_runs)) for i in pb: - pb.set_description(f"warmup run index {i} of {self.NUM_RUNS - 1}") + pb.info(f"warmup run index {i} of {self.NUM_RUNS - 1}") start_time = time.time() result = model.generate(**inp, max_new_tokens=self.MAX_NEW_TOEKNS, pad_token_id=tokenizer.pad_token_id) end_time = time.time() @@ -97,7 +97,7 @@ def inference(self, model_path, backend, tokens_per_second, assert_result=True, pb = ProgressBar(range(self.NUM_RUNS)) for i in pb: - pb.set_description(f"run index {i} of {self.NUM_RUNS - 1}") + pb.info(f"run index {i} of {self.NUM_RUNS - 1}") start_time = time.time() result = model.generate(**inp, max_new_tokens=self.MAX_NEW_TOEKNS, pad_token_id=tokenizer.pad_token_id) end_time = time.time() diff --git a/tests/models/model_test.py b/tests/models/model_test.py index d0645e439..e643fd371 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -19,8 +19,6 @@ import sys from typing import Dict, List -from gptqmodel.utils.eval import EVAL - if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -40,6 +38,7 @@ from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.quantization import FORMAT # noqa: E402 from gptqmodel.quantization.config import QuantizeConfig # noqa: E402 +from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.model import MODALITY # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 from ovis.image_to_test_dataset import get_calib_dataset # noqa: E402 @@ -260,6 +259,8 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del } else: model_args = {} + if extra_args: + model_args.update(extra_args) from lm_eval.tasks import TaskManager from lm_eval.utils import make_table results = GPTQModel.eval( diff --git a/tests/pytest.ini b/tests/pytest.ini index 603f470f8..6ecfee9ef 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,3 +1,4 @@ [pytest] addopts=-s -v log_cli=true +norecursedirs = tasks evalplus_results \ No newline at end of file diff --git a/tests/test_bits.py b/tests/test_bits.py index a927fb7aa..64d5c8a9a 100644 --- a/tests/test_bits.py +++ b/tests/test_bits.py @@ -17,14 +17,12 @@ # -- do not touch import os - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import logging # noqa: E402 import tempfile # noqa: E402 import traceback # noqa: E402 import unittest # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.nn_modules.qlinear.bitblas import BitBLASQuantLinear # noqa: E402 @@ -37,6 +35,7 @@ from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 from lm_eval.utils import make_table # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 logger = logging.getLogger(__name__) diff --git a/tests/test_eval.py b/tests/test_eval.py index a6a991476..9232f4f0f 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -20,15 +20,7 @@ import tempfile # noqa: E402 import unittest # noqa: E402 -from typing import Union # noqa: E402 - -from gptqmodel import GPTQModel # noqa: E402 -from gptqmodel.utils.eval import EVAL # noqa: E402 -from lm_eval.tasks import TaskManager # noqa: E402 -from parameterized import parameterized # noqa: E402 - -import tempfile # noqa: E402 -import unittest # noqa: E402 +from typing import Type # noqa: E402 from typing import Union # noqa: E402 from gptqmodel import GPTQModel # noqa: E402 @@ -52,7 +44,7 @@ def setUpClass(self): (EVAL.LM_EVAL, EVAL.LM_EVAL.GPQA, 'vllm'), ] ) - def test_eval_gptqmodel(self, framework: EVAL, task: Union[EVAL.LM_EVAL, EVAL.EVALPLUS], llm_backend: str): + def test_eval_gptqmodel(self, framework: Union[Type[EVAL.LM_EVAL],Type[EVAL.EVALPLUS]], task: Union[EVAL.LM_EVAL, EVAL.EVALPLUS], llm_backend: str): with tempfile.TemporaryDirectory() as tmp_dir: output_path = f"{tmp_dir}/result.json" model_args = {} diff --git a/tests/test_evalplus.py b/tests/test_evalplus.py index ff4f29b68..13d7251b7 100644 --- a/tests/test_evalplus.py +++ b/tests/test_evalplus.py @@ -25,7 +25,6 @@ from gptqmodel import GPTQModel # noqa: E402 from gptqmodel.utils.eval import evalplus # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 class TestEvalplus(unittest.TestCase): @@ -37,7 +36,7 @@ def test_evalplus(self): with tempfile.TemporaryDirectory() as tmp_dir: output_file = f"{tmp_dir}/result.json" - model = GPTQModel.load(self.MODEL_ID, tokenizer=AutoTokenizer.from_pretrained(self.MODEL_ID)) + model = GPTQModel.load(self.MODEL_ID) base_formatted, plus_formatted, _ = evalplus(model=model, dataset='humaneval', output_file=output_file) self.assertGreaterEqual(float(base_formatted), 0.26, "Base score does not match expected result") diff --git a/tests/test_group_size.py b/tests/test_group_size.py index 26b45e4c1..719866080 100644 --- a/tests/test_group_size.py +++ b/tests/test_group_size.py @@ -17,7 +17,6 @@ # -- do not touch import os - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import logging # noqa: E402 @@ -25,9 +24,7 @@ import traceback # noqa: E402 import unittest # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 -from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.nn_modules.qlinear.bitblas import BitBLASQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.dynamic_cuda import DynamicCudaQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.exllama import ExllamaQuantLinear # noqa: E402 @@ -36,7 +33,9 @@ from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 +from gptqmodel.utils.eval import EVAL # noqa: E402 from lm_eval.utils import make_table # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 logger = logging.getLogger(__name__) diff --git a/tests/test_lm_eval.py b/tests/test_lm_eval.py index eef80e3af..1ceaffaf1 100644 --- a/tests/test_lm_eval.py +++ b/tests/test_lm_eval.py @@ -17,19 +17,14 @@ # -- do not touch import os - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import tempfile # noqa: E402 import unittest # noqa: E402 - from gptqmodel import BACKEND, GPTQModel - -from lm_eval.utils import make_table # noqa: E402 - -from gptqmodel import GPTQModel # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 class TestLmEval(unittest.TestCase): @@ -59,7 +54,7 @@ def test_eval_direct(self): print(make_table(results, "groups")) print('--------lm_eval Result End---------') - acc_score = results['results'].get(self.task.value, {}).get('acc,none') + results['results'].get(self.task.value, {}).get('acc,none') acc_norm_score = results['results'].get(self.task.value, {}).get('acc_norm,none') # self.assertGreaterEqual(acc_score, self.acc_score, "acc score does not match expected result") diff --git a/tests/test_lm_head.py b/tests/test_lm_head.py index 00b01f048..c5d39bacf 100644 --- a/tests/test_lm_head.py +++ b/tests/test_lm_head.py @@ -46,7 +46,7 @@ def test_eval(self): class TestLmHeadQuant(ModelTest): APPLY_CHAT_TEMPLATE = True - EXPECT_LM_HEAD_LOSS = 31.11202 + EXPECT_LM_HEAD_LOSS = 23.84 sample_length = 1024 samples = 128 diff --git a/tests/test_modelscope.py b/tests/test_modelscope.py index 95fc43bf9..22fcf2663 100644 --- a/tests/test_modelscope.py +++ b/tests/test_modelscope.py @@ -1,7 +1,8 @@ import os + os.environ["GPTQMODEL_USE_MODELSCOPE"] = "True" -from models.model_test import ModelTest # noqa: E402 from gptqmodel import GPTQModel # noqa: E402 +from models.model_test import ModelTest # noqa: E402 class TestLoadModelscope(ModelTest): @@ -17,4 +18,4 @@ def test_load_modelscope(self): str_output = model.tokenizer.decode(result) assert "beijing" in str_output.lower() or "bei-jing" in str_output.lower() - del model \ No newline at end of file + del model diff --git a/tests/test_post_quant_eora.py b/tests/test_post_quant_eora.py index 631f808ae..1ded29448 100644 --- a/tests/test_post_quant_eora.py +++ b/tests/test_post_quant_eora.py @@ -51,7 +51,7 @@ def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): raise AssertionError(" `paris` not found in `result`") bench_result = GPTQModel.eval( - model_or_path=model, + model_or_id_or_path=model, framework=EVAL.LM_EVAL, tasks=[EVAL.LM_EVAL.ARC_CHALLENGE] ) diff --git a/tests/test_q4_cuda.py b/tests/test_q4_cuda.py index e42bc359b..51af7c270 100644 --- a/tests/test_q4_cuda.py +++ b/tests/test_q4_cuda.py @@ -16,16 +16,13 @@ # -- do not touch import os -import tempfile - -from gptqmodel.utils import Perplexity os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import torch # noqa: E402 -from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 from models.model_test import ModelTest # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoTokenizer # noqa: E402 diff --git a/tests/test_quant_and_eora.py b/tests/test_quant_and_eora.py index 8f4c31f10..5e9d5a20e 100644 --- a/tests/test_quant_and_eora.py +++ b/tests/test_quant_and_eora.py @@ -50,9 +50,10 @@ def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): assert "paris" in result.lower(), f"`paris` not found in `{result}`" bench_result = GPTQModel.eval( - model_or_path=model, + model_or_id_or_path=model, framework=EVAL.LM_EVAL, - tasks=[EVAL.LM_EVAL.ARC_CHALLENGE] + tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.GSM8K_COT], + batch_size=32, ) del model @@ -84,8 +85,13 @@ def test_quant_and_eora(self): calibration_dataset_concat_size = 0 # disable auto_gc = False adapter_file_name = "eora.safetensors" + dataset_id = "allenai/c4" + dataset_files = "en/c4-train.00001-of-01024.json.gz" config_dict = { + "model_id": self.NATIVE_MODEL_ID, + "dataset_id": dataset_id, + "dataset_files": dataset_files, "bits": bits, "group_size": group_size, "desc_act": desc_act, @@ -98,8 +104,8 @@ def test_quant_and_eora(self): } calibration_dataset = load_dataset( - "allenai/c4", - data_files="en/c4-train.00001-of-01024.json.gz", + dataset_id, + data_files=dataset_files, split="train" ).select(range(calibration_dataset_rows))["text"] @@ -143,18 +149,18 @@ def test_quant_and_eora(self): base_bench = bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only eora_bench = bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora) - print('--------Quant/EoRA Config ---------') + print('--------GPTQModel + EoRA Config ---------') # Convert the dictionary to a list of lists for tabulate table_data = [[key, value] for key, value in config_dict.items()] print(tabulate(table_data, headers=["Key", "Value"], tablefmt="grid")) - print('--------Eval Base Result---------') + print('--------Eval GPTQ Result---------') print(make_table(base_bench)) if "groups" in base_bench: print(make_table(base_bench, "groups")) - print('--------Eval EoRA Result---------') + print('--------Eval GPTQ + EoRA Result---------') print(make_table(eora_bench)) if "groups" in eora_bench: print(make_table(eora_bench, "groups")) diff --git a/tests/test_vllm.py b/tests/test_vllm.py index d5e9c7cd3..16534b9cb 100644 --- a/tests/test_vllm.py +++ b/tests/test_vllm.py @@ -21,11 +21,8 @@ # -- end do not touch import importlib.util # noqa: E402 -import subprocess # noqa: E402 -import sys # noqa: E402 import tempfile # noqa: E402 -import torch # noqa: E402 from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402