|
40 | 40 | device_woqlinear_mapping = {"cpu": INCWeightOnlyLinear, "hpu": HPUWeightOnlyLinear}
|
41 | 41 |
|
42 | 42 |
|
43 |
| -def save(model, output_dir="./saved_results"): |
| 43 | +def save(model, output_dir="./saved_results", format=LoadFormat.DEFAULT, **kwargs): |
44 | 44 | """Save the quantized model and config to the output path.
|
45 | 45 |
|
46 | 46 | Args:
|
47 | 47 | model (torch.nn.module): raw fp32 model or prepared model.
|
48 | 48 | output_dir (str, optional): output path to save.
|
| 49 | + format (str, optional): The format in which to save the model. Options include "default" and "huggingface". Defaults to "default". |
| 50 | + kwargs: Additional arguments for specific formats. For example: |
| 51 | + - safe_serialization (bool): Whether to use safe serialization when saving (only applicable for 'huggingface' format). Defaults to True. |
| 52 | + - tokenizer (Tokenizer, optional): The tokenizer to be saved along with the model (only applicable for 'huggingface' format). |
| 53 | + - max_shard_size (str, optional): The maximum size for each shard (only applicable for 'huggingface' format). Defaults to "5GB". |
49 | 54 | """
|
50 | 55 | os.makedirs(output_dir, exist_ok=True)
|
| 56 | + if format == LoadFormat.HUGGINGFACE: # pragma: no cover |
| 57 | + config = model.config |
| 58 | + quantization_config = config.quantization_config if hasattr(config, "quantization_config") else None |
| 59 | + if "backend" in quantization_config and "auto_round" in quantization_config["backend"]: |
| 60 | + safe_serialization = kwargs.get("safe_serialization", True) |
| 61 | + tokenizer = kwargs.get("tokenizer", None) |
| 62 | + max_shard_size = kwargs.get("max_shard_size", "5GB") |
| 63 | + if tokenizer is not None: |
| 64 | + tokenizer.save_pretrained(output_dir) |
| 65 | + del model.save |
| 66 | + model.save_pretrained(output_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) |
| 67 | + return |
| 68 | + |
51 | 69 | qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
|
52 | 70 | qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
|
53 | 71 | # saving process
|
@@ -203,8 +221,15 @@ def load_hf_format_woq_model(self):
|
203 | 221 |
|
204 | 222 | # get model class and config
|
205 | 223 | model_class, config = self._get_model_class_and_config()
|
206 |
| - self.quantization_config = config.quantization_config |
207 |
| - |
| 224 | + self.quantization_config = config.quantization_config if hasattr(config, "quantization_config") else None |
| 225 | + if ( |
| 226 | + "backend" in self.quantization_config and "auto_round" in self.quantization_config["backend"] |
| 227 | + ): # # pragma: no cover |
| 228 | + # load autoround format quantized model |
| 229 | + from auto_round import AutoRoundConfig |
| 230 | + |
| 231 | + model = model_class.from_pretrained(self.model_name_or_path) |
| 232 | + return model |
208 | 233 | # get loaded state_dict
|
209 | 234 | self.loaded_state_dict = self._get_loaded_state_dict(config)
|
210 | 235 | self.loaded_state_dict_keys = list(set(self.loaded_state_dict.keys()))
|
@@ -400,7 +425,7 @@ def _get_model_class_and_config(self):
|
400 | 425 | trust_remote_code = self.kwargs.pop("trust_remote_code", None)
|
401 | 426 | kwarg_attn_imp = self.kwargs.pop("attn_implementation", None)
|
402 | 427 |
|
403 |
| - config = AutoConfig.from_pretrained(self.model_name_or_path) |
| 428 | + config = AutoConfig.from_pretrained(self.model_name_or_path, trust_remote_code=trust_remote_code) |
404 | 429 | # quantization_config = config.quantization_config
|
405 | 430 |
|
406 | 431 | if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: # pragma: no cover
|
|
0 commit comments