Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,23 +344,24 @@ def pack(self, linear, scales, zeros, g_idx=None):
elif self.bits == 3:
i = 0
col = 0
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
i += 10
qzeros[:, col] |= zeros[:, i] << 30
col += 1
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
i += 10
qzeros[:, col] |= zeros[:, i] << 31
col += 1
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
i += 10
col += 1
while col < qzeros.shape[1]:
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
i += 10
qzeros[:, col] |= zeros[:, i] << 30
col += 1
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
i += 10
qzeros[:, col] |= zeros[:, i] << 31
col += 1
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
i += 10
col += 1

self.qzeros = t.from_numpy(qzeros.astype(self.pack_np_dtype))
7 changes: 5 additions & 2 deletions tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ModelTest(unittest.TestCase):
TORCH_DTYPE = "auto"
BATCH_SIZE = "auto"
LOAD_BACKEND = BACKEND.AUTO
QUANT_BACKEND = BACKEND.AUTO
USE_VLLM = False
INPUTS_MAX_LENGTH = 2048
MODEL_MAX_LEN = 4096
Expand All @@ -83,6 +84,8 @@ class ModelTest(unittest.TestCase):
LM_HEAD_LOSS_MAX_DELTA_PERCENT = 0.1 # ±10%
EXPECT_LM_HEAD_LOSS = None

QUANTIZE_CONFIG_BITS = 4

def assertInference(self, model, tokenizer=None, keywords=None, prompt=INFERENCE_PROMPT):
# gptqmodel can auto init tokenizer internally
if keywords is None:
Expand Down Expand Up @@ -148,7 +151,7 @@ def check_kernel(self, model, expected_kernels):

def quantModel(self, model_id_or_path, trust_remote_code=False, torch_dtype="auto", need_eval=True, batch_size: int = 4, **kwargs):
quantize_config = QuantizeConfig(
bits=4,
bits=self.QUANTIZE_CONFIG_BITS,
group_size=128,
format=self.QUANT_FORMAT,
desc_act=self.DESC_ACT,
Expand Down Expand Up @@ -189,7 +192,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, torch_dtype="aut
is_ovis_model = model.__class__.__name__ == "OvisGPTQ"
need_create_processor = is_image_to_text_model and not is_ovis_model
if not is_quantized:
model.quantize(calibration_dataset, batch_size=batch_size)
model.quantize(calibration_dataset, backend=self.QUANT_BACKEND, batch_size=batch_size)

self.check_kernel(model, self.KERNEL_QUANT)

Expand Down
36 changes: 19 additions & 17 deletions tests/test_bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ class TestBits(unittest.TestCase):
BACKEND.MARLIN: MarlinQuantLinear,
}

QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.025 # -2.5%
QUANT_ARC_MAX_POSITIVE_DELTA_CEIL_PERCENT = 0.025 # +2.5%
QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.1
QUANT_ARC_MAX_POSITIVE_DELTA_CEIL_PERCENT = 0.1

CUDA_QLINEAR_QUANTIZED_MODEL_ARC_CHALLENGE_EXPECTS = {
2: {'acc,none': 0.22610921501706485, 'acc_norm,none': 0.2909556313993174},
3: {'acc,none': 0.21245733788395904, 'acc_norm,none': 0.24744027303754265},
4: {'acc,none': 0.2738907849829352, 'acc_norm,none': 0.3122866894197952},
8: {'acc,none': 0.2841296928327645, 'acc_norm,none': 0.302901023890785},
2: {'acc,none': 0.2175767918088737, 'acc_norm,none': 0.26535836177474403},
3: {'acc,none': 0.22696245733788395, 'acc_norm,none': 0.2627986348122867},
4: {'acc,none': 0.26621160409556316, 'acc_norm,none': 0.3148464163822526},
8: {'acc,none': 0.29948805460750855, 'acc_norm,none': 0.3293515358361775},
}

def calculatorPer(self, filter, value, base_value):
Expand Down Expand Up @@ -92,22 +92,29 @@ def test_bits(self):
# quantize
model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = [
"gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
dataset = ["gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
calibration_dataset = [tokenizer(example) for example in dataset]

errors = []
for quant_backend in self.pack_backends:
supports_bits = self.QLINEAR_DICT[quant_backend].SUPPORTS_BITS
for bits in supports_bits:
print("-----------------------quant-----------------------")
print(f"-----------------------quant backend: {quant_backend}-- bits: {bits} ---------------------")
quantize_config = QuantizeConfig(bits=bits, group_size=128, sym=True, desc_act=False)
print(f"bits: {quantize_config.bits}, quant_backend: {quant_backend} start quant")
try:
self.quant_and_eval(calibration_dataset, model_id, quant_backend, quantize_config, tokenizer)
except Exception:
print(f"bits: {quantize_config.bits}, quant_backend: {quant_backend} An error occurred")
error_log=f"bits: {quantize_config.bits}, quant_backend: {quant_backend} An error occurred"
print(error_log)
errors.append(error_log)

traceback.print_exc()

continue

self.assertTrue(len(errors) == 0, '\n'.join(errors))

def quant_and_eval(self, calibration_dataset, model_id, quant_backend, quantize_config, tokenizer):
model = GPTQModel.load(
model_id,
Expand All @@ -127,11 +134,7 @@ def quant_and_eval(self, calibration_dataset, model_id, quant_backend, quantize_
# Skip inference_backend that does not support the current bits
continue

try:
self.eval(inference_backend, quant_backend, quantize_config, tmp_dir)
except Exception:
traceback.print_exc()
continue
self.eval(inference_backend, quant_backend, quantize_config, tmp_dir)

def eval(self, inference_backend, quant_backend, quantize_config, tmp_dir):
print("-----------------------eval-----------------------")
Expand Down Expand Up @@ -165,8 +168,7 @@ def eval(self, inference_backend, quant_backend, quantize_config, tmp_dir):
metric: value for metric, value in results['results'].get(TASK_NAME, {}).items()
if metric != 'alias' and 'stderr' not in metric
}
print(
f"bits is: {quantize_config.bits}, quant_backend: {quant_backend}, inference_backend: {inference_backend} -> task_results: {task_results}")
print(f"bits is: {quantize_config.bits}, quant_backend: {quant_backend}, inference_backend: {inference_backend} -> task_results: {task_results}")
del model

self.check_results(quantize_config.bits, task_results)
6 changes: 4 additions & 2 deletions tests/test_q4_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

# -- do not touch
import os
import tempfile

from gptqmodel.utils import Perplexity

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch


import torch # noqa: E402
from gptqmodel import BACKEND, GPTQModel # noqa: E402
from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402
from models.model_test import ModelTest # noqa: E402
from parameterized import parameterized # noqa: E402
from transformers import AutoTokenizer # noqa: E402
Expand Down Expand Up @@ -74,4 +77,3 @@ def test_generation_desc_act_false(self, torch_dtype, device):
self.assertInference(model=model_q,tokenizer=self.tokenizer)
# This one does not.
self.assertInference(model=model_q.model,tokenizer=self.tokenizer)