diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 04094c22f..26d86ad02 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -344,23 +344,24 @@ def pack(self, linear, scales, zeros, g_idx=None): elif self.bits == 3: i = 0 col = 0 - for j in range(i, i + 10): - qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) - i += 10 - qzeros[:, col] |= zeros[:, i] << 30 - col += 1 - qzeros[:, col] |= (zeros[:, i] >> 2) & 1 - i += 1 - for j in range(i, i + 10): - qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) - i += 10 - qzeros[:, col] |= zeros[:, i] << 31 - col += 1 - qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 - i += 1 - for j in range(i, i + 10): - qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) - i += 10 - col += 1 + while col < qzeros.shape[1]: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 self.qzeros = t.from_numpy(qzeros.astype(self.pack_np_dtype)) diff --git a/tests/models/model_test.py b/tests/models/model_test.py index f5b071af6..6fb6ccfc8 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -58,6 +58,7 @@ class ModelTest(unittest.TestCase): TORCH_DTYPE = "auto" BATCH_SIZE = "auto" LOAD_BACKEND = BACKEND.AUTO + QUANT_BACKEND = BACKEND.AUTO USE_VLLM = False INPUTS_MAX_LENGTH = 2048 MODEL_MAX_LEN = 4096 @@ -83,6 +84,8 @@ class ModelTest(unittest.TestCase): LM_HEAD_LOSS_MAX_DELTA_PERCENT = 0.1 # ±10% EXPECT_LM_HEAD_LOSS = None + QUANTIZE_CONFIG_BITS = 4 + def assertInference(self, model, tokenizer=None, keywords=None, prompt=INFERENCE_PROMPT): # gptqmodel can auto init tokenizer internally if keywords is None: @@ -148,7 +151,7 @@ def check_kernel(self, model, expected_kernels): def quantModel(self, model_id_or_path, trust_remote_code=False, torch_dtype="auto", need_eval=True, batch_size: int = 4, **kwargs): quantize_config = QuantizeConfig( - bits=4, + bits=self.QUANTIZE_CONFIG_BITS, group_size=128, format=self.QUANT_FORMAT, desc_act=self.DESC_ACT, @@ -189,7 +192,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, torch_dtype="aut is_ovis_model = model.__class__.__name__ == "OvisGPTQ" need_create_processor = is_image_to_text_model and not is_ovis_model if not is_quantized: - model.quantize(calibration_dataset, batch_size=batch_size) + model.quantize(calibration_dataset, backend=self.QUANT_BACKEND, batch_size=batch_size) self.check_kernel(model, self.KERNEL_QUANT) diff --git a/tests/test_bits.py b/tests/test_bits.py index b50e11ae5..0c22b8a19 100644 --- a/tests/test_bits.py +++ b/tests/test_bits.py @@ -54,14 +54,14 @@ class TestBits(unittest.TestCase): BACKEND.MARLIN: MarlinQuantLinear, } - QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.025 # -2.5% - QUANT_ARC_MAX_POSITIVE_DELTA_CEIL_PERCENT = 0.025 # +2.5% + QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.1 + QUANT_ARC_MAX_POSITIVE_DELTA_CEIL_PERCENT = 0.1 CUDA_QLINEAR_QUANTIZED_MODEL_ARC_CHALLENGE_EXPECTS = { - 2: {'acc,none': 0.22610921501706485, 'acc_norm,none': 0.2909556313993174}, - 3: {'acc,none': 0.21245733788395904, 'acc_norm,none': 0.24744027303754265}, - 4: {'acc,none': 0.2738907849829352, 'acc_norm,none': 0.3122866894197952}, - 8: {'acc,none': 0.2841296928327645, 'acc_norm,none': 0.302901023890785}, + 2: {'acc,none': 0.2175767918088737, 'acc_norm,none': 0.26535836177474403}, + 3: {'acc,none': 0.22696245733788395, 'acc_norm,none': 0.2627986348122867}, + 4: {'acc,none': 0.26621160409556316, 'acc_norm,none': 0.3148464163822526}, + 8: {'acc,none': 0.29948805460750855, 'acc_norm,none': 0.3293515358361775}, } def calculatorPer(self, filter, value, base_value): @@ -92,22 +92,29 @@ def test_bits(self): # quantize model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" tokenizer = AutoTokenizer.from_pretrained(model_id) - dataset = [ - "gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."] + dataset = ["gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."] calibration_dataset = [tokenizer(example) for example in dataset] + + errors = [] for quant_backend in self.pack_backends: supports_bits = self.QLINEAR_DICT[quant_backend].SUPPORTS_BITS for bits in supports_bits: - print("-----------------------quant-----------------------") + print(f"-----------------------quant backend: {quant_backend}-- bits: {bits} ---------------------") quantize_config = QuantizeConfig(bits=bits, group_size=128, sym=True, desc_act=False) print(f"bits: {quantize_config.bits}, quant_backend: {quant_backend} start quant") try: self.quant_and_eval(calibration_dataset, model_id, quant_backend, quantize_config, tokenizer) except Exception: - print(f"bits: {quantize_config.bits}, quant_backend: {quant_backend} An error occurred") + error_log=f"bits: {quantize_config.bits}, quant_backend: {quant_backend} An error occurred" + print(error_log) + errors.append(error_log) + traceback.print_exc() + continue + self.assertTrue(len(errors) == 0, '\n'.join(errors)) + def quant_and_eval(self, calibration_dataset, model_id, quant_backend, quantize_config, tokenizer): model = GPTQModel.load( model_id, @@ -127,11 +134,7 @@ def quant_and_eval(self, calibration_dataset, model_id, quant_backend, quantize_ # Skip inference_backend that does not support the current bits continue - try: - self.eval(inference_backend, quant_backend, quantize_config, tmp_dir) - except Exception: - traceback.print_exc() - continue + self.eval(inference_backend, quant_backend, quantize_config, tmp_dir) def eval(self, inference_backend, quant_backend, quantize_config, tmp_dir): print("-----------------------eval-----------------------") @@ -165,8 +168,7 @@ def eval(self, inference_backend, quant_backend, quantize_config, tmp_dir): metric: value for metric, value in results['results'].get(TASK_NAME, {}).items() if metric != 'alias' and 'stderr' not in metric } - print( - f"bits is: {quantize_config.bits}, quant_backend: {quant_backend}, inference_backend: {inference_backend} -> task_results: {task_results}") + print(f"bits is: {quantize_config.bits}, quant_backend: {quant_backend}, inference_backend: {inference_backend} -> task_results: {task_results}") del model self.check_results(quantize_config.bits, task_results) diff --git a/tests/test_q4_cuda.py b/tests/test_q4_cuda.py index de6c6ca5a..e42bc359b 100644 --- a/tests/test_q4_cuda.py +++ b/tests/test_q4_cuda.py @@ -16,13 +16,16 @@ # -- do not touch import os +import tempfile + +from gptqmodel.utils import Perplexity os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import torch # noqa: E402 -from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from models.model_test import ModelTest # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoTokenizer # noqa: E402 @@ -74,4 +77,3 @@ def test_generation_desc_act_false(self, torch_dtype, device): self.assertInference(model=model_q,tokenizer=self.tokenizer) # This one does not. self.assertInference(model=model_q.model,tokenizer=self.tokenizer) -