diff --git a/examples/quantization/basic_usage_wikitext2.py b/examples/quantization/basic_usage_wikitext2.py index 7c87a6b6f..ac1ba63d9 100644 --- a/examples/quantization/basic_usage_wikitext2.py +++ b/examples/quantization/basic_usage_wikitext2.py @@ -68,9 +68,6 @@ def main(): # with value under torch.LongTensor type. model.quantize(traindataset) - # save quantized model - model.save(quantized_model_id) - # save quantized model using safetensors model.save(quantized_model_id) diff --git a/gptqmodel/models/definitions/minicpm.py b/gptqmodel/models/definitions/minicpm.py index 092389fbc..00df27e63 100644 --- a/gptqmodel/models/definitions/minicpm.py +++ b/gptqmodel/models/definitions/minicpm.py @@ -29,5 +29,4 @@ class MiniCPMGPTQ(BaseGPTQModel): ["self_attn.v_proj"], ["self_attn.o_proj"], ["mlp.gate_proj", "mlp.up_proj","mlp.down_proj"], - ["mlp.c_proj"], ] diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 26d86ad02..3cffb7a0a 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -205,7 +205,8 @@ def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: if bits not in cls.SUPPORTS_BITS: err = f"{cls} only supports `{cls.SUPPORTS_BITS}` bits: actual bits = `{bits}`" return False, NotImplementedError(err) - if group_size not in cls.SUPPORTS_GROUP_SIZE: + # valid group size is set of cls.SUPPORTS_GROUP_SIZE + in_features; group_size = -1 is alias for group_size == in_features + if group_size not in cls.SUPPORTS_GROUP_SIZE and group_size != in_features: err = f"{cls} only supports `{cls.SUPPORTS_GROUP_SIZE}` group_size: actual group_size = `{group_size}`" return False, NotImplementedError(err) if sym not in cls.SUPPORTS_SYM: @@ -365,3 +366,5 @@ def pack(self, linear, scales, zeros, g_idx=None): col += 1 self.qzeros = t.from_numpy(qzeros.astype(self.pack_np_dtype)) + + # print("self qw", self.qweight, self.scales, self.qzeros) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 16b00e88d..3d5cc9a24 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -156,7 +156,7 @@ def make_quant( device: DEVICE = None, from_quantized: bool = False, pack_dtype: torch.dtype = None, -) -> BaseQuantLinear: +) -> Type[BaseQuantLinear]: # returns multiple validated kernels quant_linear_candidates = select_quant_linear( bits=bits, @@ -175,15 +175,15 @@ def make_quant( logger.info(f"make_quant: Linear candidates: {quant_linear_candidates}") # loop over actual QLinear init, catch errors and use fallbacks if applicable - for linear in quant_linear_candidates: + for cls in quant_linear_candidates: try: # if linear is not selectedQLinear: # logger.info(f"make_quant: Faild linear: `{selectedQLinear}` failed, trying to use fallback: `{linear}`") # else: # logger.info("make_quant: Testing linear: {linear}") - linear_instance = create_quant_layer( - linear=linear, + linear_cls = create_quant_layer( + linear_cls=cls, bits=bits, desc_act=desc_act, dynamic=dynamic, @@ -194,10 +194,10 @@ def make_quant( device=device, lm_head_name=lm_head_name, pack_dtype=pack_dtype) - logger.info(f"make_quant: Selected linear: `{linear}`.") - return linear_instance + logger.info(f"make_quant: Selected linear: `{cls}`.") + return linear_cls except NotImplementedError as e: - logger.info(f"make_quant: Skipped linear: `{linear}`.") + logger.info(f"make_quant: Skipped linear: `{cls}`. ") # only fallback to other quant linears when backend is auto. if backend not in [BACKEND.AUTO, BACKEND.AUTO_TRAINABLE]: raise e @@ -206,7 +206,7 @@ def make_quant( def create_quant_layer( - linear: nn.Module, + linear_cls: Type[BaseQuantLinear], bits: int, desc_act: bool, dynamic, @@ -216,10 +216,9 @@ def create_quant_layer( sym: bool, device: DEVICE, lm_head_name: str, - pack_dtype: torch.dtype, - ) -> BaseQuantLinear: - if isinstance(module, linear): - return linear + pack_dtype: torch.dtype) -> Type[BaseQuantLinear]: + if isinstance(module, linear_cls): + return linear_cls for name, submodule in module.named_modules(): if name in names: ori_layer_device = next(submodule.parameters()).device @@ -266,7 +265,7 @@ def create_quant_layer( # when loading a quantized model, device is target device passed in GPTQModel.load() # check in_features and out_features validate - _, err = linear.validate( + _, err = linear_cls.validate( bits=tmp_bits, group_size=tmp_group_size, desc_act=tmp_desc_act, @@ -278,7 +277,7 @@ def create_quant_layer( if err is not None: raise err - new_layer = linear( + new_layer = linear_cls( bits=tmp_bits, group_size=tmp_group_size, desc_act=tmp_desc_act, @@ -293,7 +292,7 @@ def create_quant_layer( ) new_layer.device = ori_layer_device recurse_setattr(module, name, new_layer.to(ori_layer_device)) - return linear + return linear_cls # public/stable api exposed to transformer/optimum def hf_convert_gptq_v1_to_v2_format( @@ -489,25 +488,13 @@ def pack_model( parallel_packing: bool = True, pack_dtype: torch.dtype = None, ): - quantLinear = select_quant_linear( - bits=bits, - dynamic=dynamic, - group_size=group_size, - desc_act=desc_act, - sym=sym, - backend=backend, - format=format, - pack=True, - pack_dtype=pack_dtype, - ) - model.to(CPU) logger.info("Packing model...") modules = find_modules(model) modules = {n: modules[n] for n in quantizers} - make_quant( + quant_linear_cls = make_quant( model, quantizers, bits, @@ -520,7 +507,10 @@ def pack_model( dynamic=dynamic, pack_dtype=pack_dtype, ) - qModules = find_modules(model, [quantLinear]) + qModules = find_modules(model, [quant_linear_cls]) + + assert len(qModules) > 0, f"No quantizeed modules[{quant_linear_cls}] found in the model." + names = list(qModules.keys()) if parallel_packing: @@ -537,7 +527,7 @@ def wrapper(name): pass logger.info("Model packed.") - return quantLinear + return quant_linear_cls def verify_model_hash(file_path: str, verify_hash: str):