From 7fb14e67680b2823fac524a340ac3c7217e7a43f Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Fri, 21 Feb 2025 18:00:42 +0800 Subject: [PATCH] revert is_marlin_format check --- gptqmodel/quantization/config.py | 5 +++++ tests/test_quant_formats.py | 7 ++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index 2245a437f..d670b64c7 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -33,6 +33,7 @@ FORMAT_FIELD_CODE = "format" FORMAT_FIELD_JSON = "checkpoint_format" +FORMAT_FIELD_COMPAT_MARLIN = "is_marlin_format" QUANT_METHOD_FIELD = "quant_method" PACK_DTYPE_FIELD = "pack_dtype" QUANT_CONFIG_FILENAME = "quantize_config.json" @@ -184,6 +185,8 @@ class QuantizeConfig(): # pending used field adapter: Optional[Union[Dict[str, Any], Lora]] = field(default=None) + is_marlin_format: bool = False + def __post_init__(self): fields_info = fields(self) @@ -351,6 +354,8 @@ def from_quant_config(cls, quantize_cfg, format: str = None): raise ValueError(f"QuantizeConfig: Unknown quantization method: `{val}`.") else: normalized[QUANT_METHOD_FIELD] = val + elif key == FORMAT_FIELD_COMPAT_MARLIN and val: + normalized[FORMAT_FIELD_CODE] = FORMAT.MARLIN elif key in field_names: normalized[key] = val else: diff --git a/tests/test_quant_formats.py b/tests/test_quant_formats.py index 704d398e7..59f23308c 100644 --- a/tests/test_quant_formats.py +++ b/tests/test_quant_formats.py @@ -49,9 +49,9 @@ def setUpClass(self): @parameterized.expand( [ - # (QUANT_METHOD.GPTQ, BACKEND.AUTO, False, FORMAT.GPTQ, 8), + (QUANT_METHOD.GPTQ, BACKEND.AUTO, False, FORMAT.GPTQ, 8), (QUANT_METHOD.GPTQ, BACKEND.EXLLAMA_V2, True, FORMAT.GPTQ_V2, 4), - # (QUANT_METHOD.GPTQ, BACKEND.EXLLAMA_V2, False, FORMAT.GPTQ, 4), + (QUANT_METHOD.GPTQ, BACKEND.EXLLAMA_V2, False, FORMAT.GPTQ, 4), ] ) def test_quantize(self, method: QUANT_METHOD, backend: BACKEND, sym: bool, format: FORMAT, bits: int): @@ -115,12 +115,13 @@ def test_quantize(self, method: QUANT_METHOD, backend: BACKEND, sym: bool, forma if not sym and format == FORMAT.GPTQ or format == FORMAT.IPEX: return - # test compat: 1) with simple dict type + # test compat: 1) with simple dict type 2) is_marlin_format compat_quantize_config = { "bits": bits, "group_size": 128, "sym": sym, "desc_act": False if format == FORMAT.MARLIN else True, + "is_marlin_format": backend == BACKEND.MARLIN, } model = GPTQModel.load(