diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml
index 8d5098ef9..3dfbd2e46 100644
--- a/.github/workflows/unit_tests.yml
+++ b/.github/workflows/unit_tests.yml
@@ -553,6 +553,7 @@ jobs:
uv pip install -U transformers
uv pip install -U logbar==0.0.3
if [ "${{ matrix.test_script }}" == "test_perplexity" ] || \
+ [ "${{ matrix.test_script }}" == "test_inference_speed" ] || \
[ "${{ matrix.test_script }}" == "test_q4_bitblas" ] || \
[ "${{ matrix.test_script }}" == "test_save_loaded_quantized_model" ]; then
echo "===== install bitblas==0.0.1.dev13 ====="
diff --git a/gptqmodel/adapter/adapter.py b/gptqmodel/adapter/adapter.py
index ac5f6790a..5d4a1cf77 100644
--- a/gptqmodel/adapter/adapter.py
+++ b/gptqmodel/adapter/adapter.py
@@ -1,32 +1,66 @@
-import os
-from dataclasses import dataclass, field
-from typing import Dict, List, Union
-from urllib.parse import urlparse
+
+import re
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple, Union
import safetensors
import torch
from ..utils.logger import setup_logger
+from .peft import LoraConfig
+from .remote import resolve_path
logger = setup_logger()
LORA_MERGED_WEIGHT_PATHS = [None, ""]
+HF_ADAPTER_FILE_NAME = "adapter_model.safetensors"
+HF_ADAPTER_CONFIG_FILE_NAME = "adapter_config.json"
+HF_ADAPTER_WEIGHT_KEY_PREFIX = "base_model.model."
+
+
+class AdapterCache():
+ cache: Dict[str, Dict[str, Union[LoraConfig, torch.Tensor]]] = {} # first level key is `path`, second level keys [ `config` = LoraConfig, `weights` = Dict[str, Tensors]
+
+ @classmethod
+ def get(cls, path: str) -> Optional[Tuple[LoraConfig, Dict[str, torch.Tensor]]]:
+ data = cls.cache.get(path)
+ if not data:
+ return None
+ else:
+ return data["config"], data["weights"]
+
+ @classmethod
+ def reset(cls):
+ logger.info("Adapter Cache: Resetting cache")
+ cls.cache = {}
+
+ @classmethod
+ def add(cls, path: str, config: LoraConfig, weights: Dict[str, torch.Tensor]):
+ cls.cache[path] = {"config": config, "weights": weights}
+
+ @classmethod
+ def remove(cls, path):
+ cls.cache.pop(path, None)
-# TODO FIX ME: cache of adapter tensors loaded from disk
-adapter_load_cache = None
class Adapter():
- def __init__(self, rank: int, path: str = None):
- self.rank = rank
+ def __init__(self, rank: int = None, path: str = None):
+ self.rank = rank # rank may be zero, when loading, and rank will be re-populated by loading saved LoraConfig file
self.path = path.lower().strip() if isinstance(path, str) else path
- def validate_path(self, local_only=False):
+ def validate_path(self, local=False):
if not self.path or not isinstance(self.path, str):
raise ValueError("Adapter: `path` str is required.")
- if local_only:
+ # path should not be a file but a directory
+ if self.path.endswith(".safetensors"):
+ raise ValueError(
+ f"Adapter: `path` must be a directory path or repo depending if you are saving (directory path) or loading (repo): actual = `{self.path}`")
+
+ if local:
if self.path.startswith("http"):
raise ValueError(f"Adapter: `path` str in this context must be a local os path: actual = `{self.path}`.")
+
# override me
def apply(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
pass
@@ -97,52 +131,69 @@ def post_init(self, weight_key: str, device:torch.device, lora_A: torch.Tensor=N
self.lora_A, self.lora_B = lora_A, lora_B
return
- global adapter_load_cache
- if adapter_load_cache is None:
- if os.path.isfile(self.path):
- lora_path = self.path
- logger.info(f"Adapter: Loading `{self.path}` tensors from disk") # {adapter_load_cache}
- elif self.path.startswith("http"):
- from huggingface_hub import hf_hub_download
- result = self.parse_url(self.path)
- if len(result) == 3:
- logger.info(f"Adapter: Downloading adapter weights from hf repo: `{result[0]}` revision: `{result[1]}` file: `{result[2]}`")
- lora_path = hf_hub_download(repo_id=result[0], revision=result[1], filename=result[2])
- elif len(result) == 1:
- logger.info(f"Adapter: Downloading adapter weights from uri = `{self.path}`")
- import requests
- response = requests.get(self.path, stream=True)
- lora_path = "lora.safetensors"
- with open(lora_path, "wb") as f:
- for chunk in response.iter_content(chunk_size=8192):
- f.write(chunk)
- else:
- raise Exception(f"Adapter: Lora path is invalid: `{self.path}`")
+ lora_cache = AdapterCache.get(self.path)
+ if lora_cache is None:
+ # get lora config
+ lora_cfg = LoraConfig.from_pretrained(path=self.path, filename=HF_ADAPTER_CONFIG_FILE_NAME)
+ lora_cfg.gptqmodel_path = self.path # hack: save this
+
+ if not isinstance(lora_cfg, LoraConfig):
+ raise ValueError(f"Adapter: Expected `LoraConfig` in `{self.path}`, actual = `{lora_cfg}`")
+
+ if self.rank is None:
+ self.rank = lora_cfg.r
else:
- from huggingface_hub import HfApi, hf_hub_download
- files = [f for f in HfApi().list_repo_files(self.path) if f in ["lora.safetensors", "eora_test.safetensors"]]
+ if self.rank != lora_cfg.r:
+ raise ValueError(f"Adapter: `rank` must match `LoraConfig.r`, expected `{self.rank}`, actual = `{lora_cfg.r}`")
+
+ lora_path = resolve_path(self.path, HF_ADAPTER_FILE_NAME)
+
+ # save to adapter cache
+ AdapterCache.add(self.path, lora_cfg, safetensors.torch.load_file(lora_path))
- if files:
- lora_path = hf_hub_download(repo_id=self.path, filename=files[0])
- # print(f"Adapter tensors loaded from `{self.path}`")
- else:
- raise Exception(f"Adapter: There's no lora.safetensors or eora_test.safetensors on repo `{self.path}`")
+ lora_cache = AdapterCache.get(self.path)
+ assert lora_cache is not None
- adapter_load_cache = safetensors.torch.load_file(lora_path)
+ # lora_cache result is a tuple
+ lora_cfg, lora_weights = lora_cache
weight_key = weight_key.lower()
# hack for HF Auto compat
- if not f"{weight_key}.lora_A.weight" in adapter_load_cache:
- weight_key = weight_key.removeprefix("model.")
+ lora_A_weight_key = f"{weight_key}.lora_A.weight"
+ lora_B_weight_key = f"{weight_key}.lora_B.weight"
- #print(f"loaded lora weight keys: {adapter_load_cache.keys()}")
- lora_A = adapter_load_cache.pop(f"{weight_key}.lora_A.weight").T
- lora_B = adapter_load_cache.pop(f"{weight_key}.lora_B.weight").T
+ # print(f"lora_A_weight_key = {lora_A_weight_key}, lora_B_weight_key = {lora_B_weight_key}")
+ pop_keys = []
+ for k, v in lora_weights.items():
+ if k.endswith(lora_A_weight_key):
+ lora_A = v.T
+ pop_keys.append(k)
+ elif k.endswith(lora_B_weight_key):
+ lora_B = v.T
+ pop_keys.append(k)
- # since loder cache is singleton, we need to reset to None to ci loop tests can pass
- if len(adapter_load_cache) == 0:
- adapter_load_cache = None
+
+ if pop_keys:
+ for k in pop_keys:
+ lora_weights.pop(k) # releasee lora weights from cache memory
+
+ # we have consumed all modules
+ if len(lora_weights) == 0:
+ AdapterCache.remove(self.path)
+ logger.info("Adapter: Consumed all Lora weights")
+
+ else:
+ logger.warn(f"Adapter: Lora weights not found for `{weight_key}`")
+
+ assert lora_A is not None and lora_B is not None, f"Adapter: `lora_A` and `lora_B` must both be present in the weights: actual = `{lora_A}` and `{lora_B}`"
+
+ # check for rank override from base config
+ self.dynamic_rank_override(lora_cfg=lora_cfg, weight_key=weight_key)
+
+ # # since loder cache is singleton, we need to reset to None to ci loop tests can pass
+ # if len(lora_weights) == 0:
+ # adapter_load_cache = None
# print(f"Adapter: {self.name()}, loaded lora_A shape: {lora_A.shape}")
# print(f"Adapter: {self.name()}, loaded lora_B shape: {lora_B.shape}")
@@ -155,21 +206,22 @@ def post_init(self, weight_key: str, device:torch.device, lora_A: torch.Tensor=N
#print(f"Adapter: lora_A {lora_A.shape}: `{lora_B}`")
#print(f"Adapter: lora_B {lora_B.shape}: `{lora_B}`")
- def parse_url(self, url: str):
- parsed_url = urlparse(url)
+ def dynamic_rank_override(self, lora_cfg: LoraConfig, weight_key: str) -> bool:
+ assert lora_cfg.rank_pattern is not None and weight_key is not None
+ if lora_cfg.rank_pattern:
+ for k, v in lora_cfg.rank_pattern.items():
+ assert isinstance(k, str) and isinstance(v, int)
+ k = k.lower()
+ assert v > 0 # check for invalid rank range
+ # first do string full match, then suffix match, then regex match
+ if weight_key == k or k.endswith(weight_key) or re.match(k, weight_key):
+ self.rank = v
+ logger.info(f"Adapter: Base Lora `rank` = `{self.rank}` has been overridden by `{k}` due to dynamic `LoraConfig.rank_pattern` control.")
+ return True
+
+ return False
- if parsed_url.netloc.endswith("huggingface.co") or parsed_url.netloc.endswith("hf.co"):
- parts = parsed_url.path.strip("/").split("/")
- if "blob" in parts:
- idx = parts.index("blob")
- repo_id = "/".join(parts[:idx])
- rev = parts[idx + 1]
- filename = parts[idx + 2].split("?")[0] # remove ?download=true
- return [repo_id, rev, filename]
- else:
- return [url]
- return []
def to_dict(self):
return {
diff --git a/gptqmodel/adapter/peft.py b/gptqmodel/adapter/peft.py
new file mode 100644
index 000000000..ad0f0e620
--- /dev/null
+++ b/gptqmodel/adapter/peft.py
@@ -0,0 +1,396 @@
+import json
+import os
+from dataclasses import asdict, dataclass, field, fields
+from typing import Any, Literal, Optional, Set, Union
+
+from ..adapter.remote import resolve_path
+from ..utils.logger import setup_logger
+
+log = setup_logger()
+
+@dataclass
+class LoraConfig():
+ """
+ This is the configuration class to store the configuration of a [`LoraModel`].
+
+ Args:
+ r (`int`):
+ Lora attention dimension (the "rank").
+ target_modules (`Optional[Union[List[str], str]]`):
+ The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
+ names will be replaced. When passing a string, a regex match will be performed. When passing a list of
+ strings, either an exact match will be performed or it is checked if the name of the module ends with any
+ of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen,
+ excluding the output layer. If this is not specified, modules will be chosen according to the model
+ architecture. If the architecture is not known, an error will be raised -- in this case, you should specify
+ the target modules manually.
+ exclude_modules (`Optional[Union[List[str], str]]`):
+ The names of the modules to not apply the adapter. When passing a string, a regex match will be performed.
+ When passing a list of strings, either an exact match will be performed or it is checked if the name of the
+ module ends with any of the passed strings.
+ lora_alpha (`int`):
+ The alpha parameter for Lora scaling.
+ lora_dropout (`float`):
+ The dropout probability for Lora layers.
+ fan_in_fan_out (`bool`):
+ Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses
+ `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.
+ bias (`str`):
+ Bias type for LoRA. Can be 'none', 'all' or 'lora_only'. If 'all' or 'lora_only', the corresponding biases
+ will be updated during training. Be aware that this means that, even when disabling the adapters, the model
+ will not produce the same output as the base model would have without adaptation.
+ use_rslora (`bool`):
+ When set to True, uses Rank-Stabilized LoRA which
+ sets the adapter scaling factor to `lora_alpha/math.sqrt(r)`, since it was proven to work better.
+ Otherwise, it will use the original default value of `lora_alpha/r`.
+ modules_to_save (`List[str]`):
+ List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
+ init_lora_weights (`bool` | `Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]`):
+ How to initialize the weights of the adapter layers. Passing True (default) results in the default
+ initialization from the reference implementation from Microsoft. Passing 'gaussian' results in Gaussian
+ initialization scaled by the LoRA rank for linear and layers. Setting the initialization to False leads to
+ completely random initialization and is discouraged. Pass `'loftq'` to use LoftQ initialization. Passing
+ `'eva'` results in a data-driven initialization of Explained
+ Variance Adaptation. EVA initalizes LoRA based on the SVD of layer input activations and achieves SOTA
+ performance due to its ability to adapt to the finetuning data. Pass `'olora'` to use OLoRA initialization.
+ Passing `'pissa'` results in the initialization of Principal
+ Singular values and Singular vectors Adaptation (PiSSA), which converges more rapidly than LoRA and
+ ultimately achieves superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA,
+ leading to further enhancements. Passing `'pissa_niter_[number of iters]'` initiates Fast-SVD-based PiSSA
+ initialization, where `[number of iters]` indicates the number of subspace iterations to perform FSVD, and
+ must be a nonnegative integer. When `[number of iters]` is set to 16, it can complete the initialization of
+ a 7B model within seconds, and the training effect is approximately equivalent to using SVD.
+ layers_to_transform (`Union[List[int], int]`):
+ The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
+ that are specified in this list. If a single integer is passed, it will apply the transformations on the
+ layer at this index.
+ layers_pattern (`Optional[Union[List[str], str]]`):
+ The layer pattern name, used only if `layers_to_transform` is different from `None`. This should target the
+ `nn.ModuleList` of the model, which is often called `'layers'` or `'h'`.
+ rank_pattern (`dict`):
+ The mapping from layer names or regexp expression to ranks which are different from the default rank
+ specified by `r`.
+ alpha_pattern (`dict`):
+ The mapping from layer names or regexp expression to alphas which are different from the default alpha
+ specified by `lora_alpha`.
+ megatron_config (`Optional[dict]`):
+ The TransformerConfig arguments for Megatron. It is used to create LoRA's parallel linear layer. You can
+ get it like this, `core_transformer_config_from_args(get_args())`, these two functions being from Megatron.
+ The arguments will be used to initialize the TransformerConfig of Megatron. You need to specify this
+ parameter when you want to apply LoRA to the ColumnParallelLinear and RowParallelLinear layers of megatron.
+ megatron_core (`Optional[str]`):
+ The core module from Megatron to use, defaults to `"megatron.core"`.
+ loftq_config (`Optional[LoftQConfig]`):
+ The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights
+ and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a
+ quantized model in this case, as LoftQ will quantize the model itself.
+ eva_config (`Optional[EvaConfig]`):
+ The configuration of EVA. At a minimum the dataset argument needs to be set (use the same dataset as for
+ finetuning).
+ use_dora (`bool`):
+ Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights
+ into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is
+ handled by a separate learnable parameter. This can improve the performance of LoRA especially at low
+ ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger overhead than pure
+ LoRA, so it is recommended to merge weights for inference. For more information, see
+ https://arxiv.org/abs/2402.09353.
+ layer_replication (`List[Tuple[int, int]]`):
+ Build a new stack of layers by stacking the original model layers according to the ranges specified. This
+ allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will
+ all have separate LoRA adapters attached to them.
+ runtime_config (`LoraRuntimeConfig`):
+ Runtime configurations (which are not saved or restored).
+ lora_bias (`bool`):
+ Defaults to `False`. Whether to enable the bias term for the LoRA B parameter. Typically, this should be
+ disabled. The main use case for this is when the LoRA weights were extracted from fully fine-tuned
+ parameters so the bias of those parameters can be taken into account.
+ """
+ base_model_name_or_path: str = field(default="", metadata={"help": "base_model_name_or_path"}) # required by hf auto api
+
+ r: int = field(default=8, metadata={"help": "Lora attention dimension"})
+ target_modules: Optional[Union[list[str], str]] = field(
+ default=None,
+ metadata={
+ "help": (
+ "List of module names or regex expression of the module names to replace with LoRA."
+ "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'."
+ "This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer."
+ "If not specified, modules will be chosen according to the model architecture, If the architecture is "
+ "not known, an error will be raised -- in this case, you should specify the target modules manually."
+ ),
+ },
+ )
+ exclude_modules: Optional[Union[list[str], str]] = field(
+ default=None,
+ metadata={"help": "List of module names or regex expression of the module names to exclude from Lora."},
+ )
+ lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"})
+ lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
+ fan_in_fan_out: bool = field(
+ default=False,
+ metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
+ )
+ bias: Literal["none", "all", "lora_only"] = field(
+ default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}
+ )
+ use_rslora: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "When set to True, uses Rank-Stabilized LoRA"
+ " which sets the adapter scaling factor to `lora_alpha/math.sqrt(r)`, since it"
+ " was proven to work better. Otherwise, it will use the original default"
+ " value of `lora_alpha/r`."
+ )
+ },
+ )
+ modules_to_save: Optional[list[str]] = field(
+ default=None,
+ metadata={
+ "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. "
+ "For example, in Sequence Classification or Token Classification tasks, "
+ "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
+ },
+ )
+ init_lora_weights: (
+ bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]
+ ) = field(
+ default=True,
+ metadata={
+ "help": (
+ "How to initialize the weights of the LoRA layers. Passing `'True'` (default) results in the default "
+ "initialization from the reference implementation from Microsoft. Passing `'gaussian'` results "
+ "in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization "
+ "to `'False'` leads to completely random initialization and *is discouraged.*"
+ "Pass `'eva'` results in a data-driven initialization of Explained Variance Adaptation."
+ "Passing `'olora'` results in OLoRA initialization."
+ "Passing `'pissa'` results in PiSSA initialization."
+ "Passing `'pissa_niter_[number of iters]'` initiates Fast-SVD-based PiSSA initialization, "
+ "where [number of iters] indicates the number of subspace iterations to perform fsvd, and must be a nonnegative integer."
+ "Pass `'loftq'` to use LoftQ initialization"
+ ),
+ },
+ )
+ layers_to_transform: Optional[Union[list[int], int]] = field(
+ default=None,
+ metadata={
+ "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index. "
+ "This only works when target_modules is a list of str."
+ },
+ )
+ layers_pattern: Optional[Union[list[str], str]] = field(
+ default=None,
+ metadata={
+ "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern."
+ "This only works when target_modules is a list of str. This should target the `nn.ModuleList` of the "
+ "model, which is often called `'layers'` or `'h'`."
+ },
+ )
+ rank_pattern: Optional[dict] = field(
+ default_factory=dict,
+ metadata={
+ "help": (
+ "The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. "
+ "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}"
+ )
+ },
+ )
+ alpha_pattern: Optional[dict] = field(
+ default_factory=dict,
+ metadata={
+ "help": (
+ "The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. "
+ "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}"
+ )
+ },
+ )
+ megatron_config: Optional[dict] = field(
+ default=None,
+ metadata={
+ "help": (
+ "The TransformerConfig from Megatron. It is used to create LoRA's parallel linear layer."
+ "You can get it like this, `core_transformer_config_from_args(get_args())`, "
+ "these two functions being from Megatron."
+ "You need to specify this parameter when you want to apply LoRA to the ColumnParallelLinear and "
+ "RowParallelLinear layers of megatron."
+ "It should be noted that we may not be able to use the `save_pretrained` and `from_pretrained` "
+ "functions, because TransformerConfig may not necessarily be serialized."
+ "But when using megatron, we can use `get_peft_model_state_dict` function and "
+ "megatron's framework, they can also save and load models and configurations."
+ )
+ },
+ )
+ megatron_core: Optional[str] = field(
+ default="megatron.core",
+ metadata={
+ "help": (
+ "The core module from Megatron, it is used to create LoRA's parallel linear layer. "
+ "It only needs to be passed in when you need to use your own modified megatron core module. "
+ "Otherwise, it will use the default value `megatron.core`. "
+ )
+ },
+ )
+ # dict type is used when loading config.json
+ loftq_config: Union[Any, dict] = field(
+ default_factory=dict,
+ metadata={
+ "help": (
+ "The configuration of LoftQ. If this is passed, then LoftQ will be used to quantize the backbone "
+ "weights and initialize Lora layers. Also set `init_lora_weights='loftq'` in this case."
+ )
+ },
+ )
+ eva_config: Optional[Any] = field(
+ default=None,
+ metadata={
+ "help": (
+ "The configuration of EVA. If this is passed, then EVA will be used to intialize the LoRA layers. "
+ "Also set `init_lora_weights='eva'` in this case. "
+ )
+ },
+ )
+ use_dora: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the "
+ "weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the "
+ "magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, "
+ "especially at low ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger"
+ "overhead than pure LoRA, so it is recommended to merge weights for inference."
+ )
+ },
+ )
+ # Enables replicating layers in a model to expand it to a larger model.
+ layer_replication: Optional[list[tuple[int, int]]] = field(
+ default=None,
+ metadata={
+ "help": (
+ "This enables using LoRA to effectively expand a transformer model to a larger size by repeating some layers. "
+ "The transformation handles models (currently Llama, Bert or Falcon compatible architectures) with "
+ "a module list in the model which it modifies to expand the number of modules. "
+ "Base weights are shared so the memory usage is close to the original model. The intended use is these base weights "
+ "remain fixed during finetuning but each layer has a separate LoRA adapter so the layers can be specialed via "
+ "the adapter layers fit during fine tuning."
+ "The format is a list of [start, end) pairs which specify the layer ranges to stack. For example:\n"
+ " Original model has 5 layers labelled by their position in the model: `[0, 1, 2, 3, 4]`\n"
+ " layer_replication: `[[0, 4], [2, 5]]`\n"
+ " Final model will have this arrangement of original layers: `[0, 1, 2, 3, 2, 3, 4]`\n"
+ "This format is based on what is used for pass-through merges in mergekit. It makes it simple to select sequential "
+ "ranges of a model and stack them while reusing layers at either end of each sequence."
+ )
+ },
+ )
+ lora_bias: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to enable the bias term for the LoRA B parameter. Typically, this should be disabled. The "
+ "main use case for this is when the LoRA weights were extracted from fully fine-tuned parameters so "
+ "the bias of those parameters can be taken into account."
+ )
+ },
+ )
+
+ # from PeftConfigMixin
+ task_type: Optional[str] = field(default="CAUSAL_LM", metadata={"help": "The type of task."})
+ peft_type: Optional[str] = field(default="LORA", metadata={"help": "The type of PEFT model."})
+ auto_mapping: Optional[dict] = field(
+ default=None, metadata={"help": "An auto mapping dict to help retrieve the base model class if needed."}
+ )
+
+ def to_dict(self):
+ """
+ Returns the configuration for your adapter model as a dictionary. Removes runtime configurations.
+ """
+
+ kv = asdict(self)
+
+ # remove all None valued Keys and values that are empty (str, list, dict, set
+ removed_keys = []
+ for k, v in kv.items():
+ if v in [None, {}, []]:
+ removed_keys.append(k)
+ # FIX set is not serializable by json
+ elif isinstance(v, Set):
+ kv[k] = list(v)
+
+ for k in removed_keys:
+ kv.pop(k)
+
+ return kv
+
+ def save_pretrained(self, save_dir: str):
+ from ..adapter.adapter import HF_ADAPTER_CONFIG_FILE_NAME
+
+ log.info(f"Adapter: Saving EoRA/Lora config to -> `{save_dir}`")
+
+ os.makedirs(save_dir, exist_ok=True)
+ with open(os.path.join(save_dir, HF_ADAPTER_CONFIG_FILE_NAME), "w", encoding="utf-8") as f:
+ d = self.to_dict()
+ json_str = json.dumps(d, indent=2)
+ log.info(f"Saved Adapter Config: \n{json_str}")
+ f.write(json_str)
+
+ def __post_init__(self):
+ self.peft_type = "LORA"
+ self.target_modules = (
+ set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
+ )
+ self.exclude_modules = (
+ set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules
+ )
+
+ # if target_modules is a regex expression, then layers_to_transform should be None
+ if isinstance(self.target_modules, str) and self.layers_to_transform is not None:
+ raise ValueError("`Adapter: layers_to_transform` cannot be used when `target_modules` is a str.")
+
+ # if target_modules is a regex expression, then layers_pattern should be None
+ if isinstance(self.target_modules, str) and self.layers_pattern is not None:
+ raise ValueError("Adapter: `layers_pattern` cannot be used when `target_modules` is a str.")
+
+ # check for layers_to_transform and layers_pattern
+ if self.layers_pattern and not self.layers_to_transform:
+ raise ValueError("Adapter: When `layers_pattern` is specified, `layers_to_transform` must also be specified. ")
+
+ if self.use_dora and self.megatron_config:
+ raise ValueError("Adapter: DoRA is not supported")
+
+ # handle init_lora_weights and loftq_config
+ if self.init_lora_weights == "loftq":
+ raise ValueError("Adapter: LoftQ is not supported")
+ elif self.loftq_config:
+ raise ValueError("Adapter: LoftQ is not supported")
+ elif self.init_lora_weights == "eva" and self.eva_config is None:
+ raise ValueError("Adapter: Eva is not supported")
+ elif self.init_lora_weights != "eva" and self.eva_config is not None:
+ raise ValueError("Adapter: Eva is not supported")
+
+ if self.lora_bias:
+ raise ValueError("Adapter: Lora bias is not supported")
+
+ # Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot
+ # be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
+ # this when they'll eventually call save_pretrained (i.e. if they'll pass
+ # path_initial_model_for_weight_conversionl). Therefore, we only warn but don't raise an error here.
+ if (
+ self.use_rslora
+ and (self.rank_pattern or self.alpha_pattern)
+ and (
+ (isinstance(self.init_lora_weights, str) and (self.init_lora_weights.startswith("pissa")))
+ or (self.init_lora_weights == "olora")
+ )
+ ):
+ raise ValueError("Adapter: RSLora is not supported")
+
+ @classmethod
+ def from_pretrained(cls, path: str, filename: str):
+ resolved_path = resolve_path(path=path, filename=filename)
+ with open(resolved_path, "r") as file:
+ config_dict = json.load(file)
+
+ # Remove any keys that are not part of the LoraConfig dataclass fields or empty
+ valid_fields = {field.name for field in fields(cls)}
+ config_dict = {k: v for k, v in config_dict.items() if k in valid_fields and v not in [None, "", [], {}, ()]}
+
+ return cls(**config_dict)
\ No newline at end of file
diff --git a/gptqmodel/adapter/remote.py b/gptqmodel/adapter/remote.py
new file mode 100644
index 000000000..44c007e02
--- /dev/null
+++ b/gptqmodel/adapter/remote.py
@@ -0,0 +1,69 @@
+import os
+from urllib.parse import urlparse
+
+from ..utils.logger import setup_logger
+
+log = setup_logger()
+
+def parse_url(url: str):
+ parsed_url = urlparse(url)
+
+ if parsed_url.netloc.endswith("huggingface.co") or parsed_url.netloc.endswith("hf.co"):
+ parts = parsed_url.path.strip("/").split("/")
+
+ if "blob" in parts:
+ idx = parts.index("blob")
+ repo_id = "/".join(parts[:idx])
+ rev = parts[idx + 1]
+ filename = parts[idx + 2].split("?")[0] # remove ?download=true
+ return [repo_id, rev, filename]
+ else:
+ return [url]
+ return []
+
+def resolve_path(path: str, filename: str) -> str: # return a valid file path to read
+ if os.path.isdir(path):
+ resolved_path = f"{path.removesuffix('/')}/{filename}"
+ log.info(f"Resolver: Local path: `{resolved_path}`")
+ if not os.path.isfile(resolved_path):
+ raise ValueError(f"Resolver: Cannot find file in path: `{resolved_path}`")
+
+ return resolved_path
+ elif path.startswith("http"):
+ from huggingface_hub import hf_hub_download
+
+ result = parse_url(path)
+ if len(result) == 3:
+ log.info(
+ f"Resolver: Downloading file from HF repo: `{result[0]}` revision: `{result[1]}` file: `{result[2]}`")
+ resolved_path = hf_hub_download(repo_id=result[0], revision=result[1], filename=result[2])
+ return resolved_path
+ else:
+ raise ValueError(f"Resolver: We only support local file path or HF repo id; actual = path: `{path}`, filename = `{filename}`")
+ # logger.info(f"Adapter: Downloading adapter weights from uri = `{self.path}`")
+ # import requests
+ # response = requests.get(self.path, stream=True)
+ # lora_path = HF_ADAPTER_FILE_NAME
+ # with open(lora_path, "wb") as f:
+ # for chunk in response.iter_content(chunk_size=8192):
+ # f.write(chunk)
+ elif not path.startswith("/"):
+ path = path.rstrip("/")
+ subfolder = None
+
+ # fix HF subfoler path like: sliuau/llama3.2-1b-4bit-group128/llama3.2-1b-4bit-group128-eora-rank128-arc
+ if path.count("/") > 1:
+ path_split = path.split("/")
+ path = f"{path_split[0]}/{path_split[1]}"
+ subfolder = "/".join(path_split[2:])
+
+ from huggingface_hub import HfApi, hf_hub_download
+
+ # _ = HfApi().list_repo_files(path)
+
+ resolved_path = hf_hub_download(repo_id=path, filename=filename, subfolder=subfolder)
+ return resolved_path
+ # print(f"Adapter tensors loaded from `{self.path}`")
+ else:
+ raise ValueError(
+ f"Resolver: We only support local file path or HF repo id; actual = path: `{path}`, filename = `{filename}`")
\ No newline at end of file
diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py
index e5499325b..5da732acc 100644
--- a/gptqmodel/looper/eora_processor.py
+++ b/gptqmodel/looper/eora_processor.py
@@ -88,9 +88,12 @@ def preprocess(self, module: NamedModule, **kwargs):
adapter_cfg = copy.deepcopy(self.qcfg.adapter)
- # dynamic overrides
- if self.qcfg.dynamic is not None:
- adapter_cfg.adapter = self.qcfg.dynamic_get(module.full_name, "adapter", adapter_cfg)
+ # dynamic override of adapter.rank
+ adapter_cfg.rank = self.qcfg.dynamic_get(
+ module.full_name,
+ key="adapter",
+ sub_key="rank",
+ default=adapter_cfg.rank)
# hack store property inside module
module.adapter_cfg = adapter_cfg
@@ -183,6 +186,7 @@ def process(self, module: NamedModule):
# logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}")
self.result_save(module.full_name, {
+ "rank": module.adapter_cfg.rank,
"lora_A.weight": move_to(A.to(dtype=torch.float16), device=CPU, stream=self.stream),
"lora_B.weight": move_to(B.to(dtype=torch.float16), device=CPU, stream=self.stream),
})
diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py
index d60287350..d897517b9 100644
--- a/gptqmodel/looper/module_looper.py
+++ b/gptqmodel/looper/module_looper.py
@@ -149,7 +149,7 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal
lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": True, "desc_act": False, "mse": 2.4}
if self.gptq_model.quantize_config.dynamic is None:
self.gptq_model.quantize_config.dynamic = {self.gptq_model.lm_head: lm_head_quant_config}
- elif self.gptq_model.quantize_config.dynamic_get(self.gptq_model.lm_head, default_value=None) is None:
+ elif self.gptq_model.quantize_config.dynamic_get(self.gptq_model.lm_head, default=None) is None:
self.gptq_model.quantize_config.dynamic[self.gptq_model.lm_head] = lm_head_quant_config
forward_pass_use_cache = self.gptq_model.model.config.use_cache if hasattr(self.gptq_model.model.config, "use_cache") else False
diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py
index 1a5bd65c5..fbb4e457d 100644
--- a/gptqmodel/models/auto.py
+++ b/gptqmodel/models/auto.py
@@ -44,7 +44,6 @@
import numpy # noqa: E402
import torch # noqa: E402
from huggingface_hub import list_repo_files # noqa: E402
-from lm_eval.utils import make_table # noqa: E402
from tokenicer import Tokenicer # noqa: E402
from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizerBase # noqa: E402
@@ -313,6 +312,7 @@ def eval(
model_args: Dict[str, Any] = None, # only for framework=EVAL.LM_EVAL backend=vllm
**args
):
+ from peft import PeftModel
if model_args is None:
model_args = {}
if tasks is None:
@@ -336,7 +336,7 @@ def eval(
if isinstance(model_or_id_or_path, str):
model = GPTQModel.load(model_id_or_path=model_or_id_or_path, backend=backend)
model_id_or_path = model_or_id_or_path
- elif isinstance(model_or_id_or_path, BaseGPTQModel) or isinstance(model_or_id_or_path, PreTrainedModel):
+ elif isinstance(model_or_id_or_path, BaseGPTQModel) or isinstance(model_or_id_or_path, (PreTrainedModel, PeftModel)):
model = model_or_id_or_path
model_id_or_path = model.config.name_or_path #
else:
@@ -359,6 +359,8 @@ def eval(
model_args["tokenizer"] = tokenizer
if framework == EVAL.LM_EVAL:
+ from lm_eval.utils import make_table # hack: circular import
+
for task in tasks:
if task not in EVAL.get_task_enums():
raise ValueError(f"Eval.lm_eval supported `tasks`: `{EVAL.get_all_tasks_string()}`, actual = `{task}`")
@@ -521,7 +523,7 @@ def generate(
if not adapter or not isinstance(adapter, Lora):
raise ValueError(f"Adapter: expected `adapter` type to be `Lora`: actual = `{adapter}`.")
- adapter.validate_path(local_only=True)
+ adapter.validate_path(local=True)
quantized_model = GPTQModel.load(
model_id_or_path=quantized_model_id_or_path,
diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py
index c5534ab64..590b851d7 100644
--- a/gptqmodel/models/base.py
+++ b/gptqmodel/models/base.py
@@ -707,7 +707,7 @@ def collate_batch(batch):
lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": True, "desc_act": False, "mse": 2.4}
if self.quantize_config.dynamic is None:
self.quantize_config.dynamic = {self.lm_head: lm_head_quant_config}
- elif self.quantize_config.dynamic_get(self.lm_head, default_value=None) is None:
+ elif self.quantize_config.dynamic_get(self.lm_head, default=None) is None:
self.quantize_config.dynamic[self.lm_head] = lm_head_quant_config
forward_pass_use_cache = self.model.config.use_cache if hasattr(self.model.config, "use_cache") else False
diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py
index 4f44e30ea..7100812d8 100644
--- a/gptqmodel/models/writer.py
+++ b/gptqmodel/models/writer.py
@@ -35,6 +35,8 @@
from transformers.models.auto.tokenization_auto import get_tokenizer_config
from transformers.utils.generic import ContextManagers
+from ..adapter.adapter import HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX, Lora
+from ..adapter.peft import LoraConfig
from ..quantization.config import (FORMAT, META_FIELD_DAMP_AUTO_INCREMENT, META_FIELD_DAMP_PERCENT, META_FIELD_MSE,
META_FIELD_QUANTIZER, META_FIELD_STATIC_GROUPS, META_FIELD_TRUE_SEQUENTIAL,
META_FIELD_URI, META_QUANTIZER_GPTQMODEL, META_VALUE_URI, MIN_VERSION_WITH_V2)
@@ -70,38 +72,53 @@ def save_pretrained(
cls.save_pretrained = save_pretrained
- def eora_save(self, eora_path: str):
+ def _eora_save(self, save_dir: str, model_save_dir: str):
+ assert isinstance(self.quantize_config.adapter, Lora)
+
+ assert hasattr(self, 'lora_results')
+
# save lora tensors
- if hasattr(self, 'lora_results'): # hack: TODO
+ if self.lora_results: # TODO REFRACTOR
weights = {}
-
+ target_modules = set()
# convert the dict into safetensors compatible dict
for key, d in self.lora_results.items():
- # must normalize key since HF can load weights as `model.` or not based on what AutoModel is used
- key = key.lower().removeprefix("model.")
- for lora_key, lora_weight in d.items():
- if isinstance(lora_weight, torch.Tensor):
- weights[f"{key}.{lora_key}"] = lora_weight
- logger.info(f"lora weight: `{key}.{lora_key}`")
+ key = key.lower()
+ simple_module_name = key.split(".")[-1] # mlp.gate_proj => gate_proj
+ target_modules.add(simple_module_name)
- # then lora_path from `save()` then lora.path
- eora_path = eora_path if eora_path else self.quantize_config.adapter.path
+ # while key.startswith('model.'):
+ # key = key.removeprefix('model.') # some HF models use model. or model.model.
- if not eora_path:
- raise ValueError(f"Invalid EoRA lora path: actual = `{eora_path}`")
+ # must normalize key since HF can load weights as `model.` or not based on what AutoModel is used
+ key = f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}{key}"
+ lora_rank = d.pop("rank")
+ for lora_key, lora_weight in d.items():
+ assert isinstance(lora_weight, torch.Tensor)
+ weights[f"{key}.{lora_key}"] = lora_weight
+ logger.info(f"Adapter: EoRA weights found -> `{key}.{lora_key}`, rank = `{lora_rank}`")
- is_file = eora_path.endswith(".safetensors")
+ weight_file_path = f"{save_dir.removesuffix('/')}/{HF_ADAPTER_FILE_NAME}"
- if not is_file:
- eora_path = f"{eora_path}/eora.safetensors"
+ # dynamic rank
+ rank_pattern = {}
+ if self.quantize_config.dynamic:
+ rank_pattern = self.quantize_config.extract_adapter_rank_patterns()
- logger.info(f"Found EoRA lora weights: saving to {eora_path}")
+ lora_cfg = LoraConfig(base_model_name_or_path=model_save_dir,
+ r=self.quantize_config.adapter.rank,
+ lora_alpha=self.quantize_config.adapter.rank,
+ target_modules=list(target_modules),
+ rank_pattern=rank_pattern)
+ lora_cfg.save_pretrained(save_dir=save_dir)
- os.makedirs(os.path.dirname(eora_path), exist_ok=True)
+ logger.info(f"Adapter: Saving EoRA weights to -> `{save_dir}`")
+ os.makedirs(os.path.dirname(save_dir), exist_ok=True)
+ save_file(tensors=weights, filename=weight_file_path, metadata={"format": "pt"})
- save_file(tensors=weights, filename=eora_path, metadata={"format": "pt"})
+ del self.lora_results # TODO REFRACTOR
- cls.eora_save = eora_save
+ cls.eora_save = _eora_save
def save_quantized(
self,
@@ -368,7 +385,8 @@ def debug_saved_config(path):
f.write(content)
# save lora
- eora_save(self, eora_path=eora_path)
+ if self.quantize_config.adapter:
+ _eora_save(self, save_dir=eora_path if eora_path else self.quantize_config.adapter.path, model_save_dir=save_dir)
# If the saved model is a loaded quantized model, do not calculate the size diff.
if not self.load_quantized_model:
diff --git a/gptqmodel/nn_modules/qlinear/exllama.py b/gptqmodel/nn_modules/qlinear/exllama.py
index 9a80b170e..3c0a046cf 100644
--- a/gptqmodel/nn_modules/qlinear/exllama.py
+++ b/gptqmodel/nn_modules/qlinear/exllama.py
@@ -23,7 +23,7 @@
from ...adapter.adapter import Adapter, Lora
from ...models._const import DEVICE, PLATFORM
-from ...nn_modules.qlinear import BaseQuantLinear, PackableQuantLinear
+from ...nn_modules.qlinear import PackableQuantLinear
from ...utils.backend import BACKEND
exllama_import_exception = None
diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py
index ae2c85749..a4a717b36 100644
--- a/gptqmodel/quantization/config.py
+++ b/gptqmodel/quantization/config.py
@@ -119,10 +119,10 @@ def dict_scale_dtype_to_str(d: Dict[str, Any]) -> None:
dict_scale_dtype_to_str(value)
def dynamic_get(dynamic: Dict[str, Dict[str, Union[int, bool]]], module_name: str, key: str = None,
- default_value: Union[int, bool] = None) -> Union[Dict, int, bool]:
+ default: Union[int, bool] = None, sub_key: str = None) -> Union[Dict, int, bool]:
if dynamic is None:
- return default_value
+ return default
for pattern, overrides in dynamic.items():
if pattern.startswith("-:"):
@@ -132,8 +132,16 @@ def dynamic_get(dynamic: Dict[str, Dict[str, Union[int, bool]]], module_name: st
if key is None:
return overrides
else:
- return overrides.get(key, default_value)
- return default_value
+ # subkey example: Lora override format: `{ "adapter": { "rank": 512 } }`
+ if sub_key:
+ sub_value = overrides.get(key, None)
+ if isinstance(sub_value, Dict):
+ return sub_value.get(sub_key, default)
+ else:
+ logger.info(f"QuantConfig: Dynamic `sub_key`: `{sub_key}` failed extraction from `sub_value`: `{sub_value}`")
+ else:
+ return overrides.get(key, default)
+ return default
@dataclass
class QuantizeConfig():
@@ -270,9 +278,9 @@ def meta_set(self, key: str, value: Any):
def meta_get(self, key: str) -> Any:
return self.meta.get(key)
- def dynamic_get(self, layer_name: str, key: str = None, default_value: Union[int, bool, float] = None
+ def dynamic_get(self, layer_name: str, key: str = None, default: Union[int, bool, float] = None, sub_key: str = None
) -> Union[Dict, int, bool, float]:
- return dynamic_get(self.dynamic, layer_name, key, default_value)
+ return dynamic_get(self.dynamic, layer_name, key, default, sub_key)
# versionable is a meta.property that pairs value with version i.e "value:1.0.0"
def meta_set_versionable(self, key: str, value: List[str]):
@@ -303,9 +311,30 @@ def is_quantized_by_v2(self) -> bool:
return False
+ def extract_adapter_rank_patterns(self) -> Optional[Dict[str, int]]:
+ adapter_rank_patterns = {}
+
+ # no rank can be had if there is no dynamic or adapter
+ if not self.dynamic or not self.adapter:
+ return adapter_rank_patterns
+
+ # override format: `{ "adapter": { "rank": 512 } }`
+ for k, v in self.dynamic.items():
+ adapter_override = v.get("adapter", None) # TODO use const, not str
+ if adapter_override and isinstance(adapter_override, Dict):
+ rank = adapter_override.get("rank", None)
+ if rank and isinstance(rank, int):
+ # need to strip `+:` positive prefix
+ adapter_rank_patterns[k.lstrip("+:")] = rank # TODO use const, not str
+
+ return adapter_rank_patterns
+
def save_pretrained(self, save_dir: str, **kwargs):
with open(join(save_dir, QUANT_CONFIG_FILENAME), "w", encoding="utf-8") as f:
- json.dump(self.to_dict(), f, indent=2)
+ d = self.to_dict()
+ json_str = json.dumps(d, indent=2)
+ logger.info(f"Saved Quantize Config: \n{json_str}")
+ f.write(json_str)
@classmethod
# normalize quant config for compat and also performs validation
@@ -418,6 +447,17 @@ def to_dict(self):
# ADAPTER_FIELD: self.adapter.to_dict() if self.adapter else None,
}
+ dynamic = out["dynamic"]
+ if dynamic:
+ # dynamic adapter config is only used in the quantize phase and is deleted when saving.
+ for _, v in dynamic.items():
+ v.pop("adapter", None)
+
+ # clear empty dynamic value
+ keys_to_delete = [key for key, value in dynamic.items() if not value]
+ for key in keys_to_delete:
+ del dynamic[key]
+
# simplify: clean keys where the value is None or empty [list, dict]
out = {k: v for k, v in out.items() if v is not None and (v not in [None, {}])}
diff --git a/requirements.txt b/requirements.txt
index d75a86c3c..5d36c1fcc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,4 +12,4 @@ pillow>=11.1.0
hf_transfer>=0.1.9
huggingface_hub>=0.28.1
tokenicer==0.0.4
-logbar==0.0.3
+logbar==0.0.3
\ No newline at end of file
diff --git a/tests/test_kernel_output.py b/tests/test_kernel_output.py
index 7a0d1d802..c294e7f1e 100644
--- a/tests/test_kernel_output.py
+++ b/tests/test_kernel_output.py
@@ -2,56 +2,50 @@
import torch
from gptqmodel import BACKEND, GPTQModel
-from gptqmodel.adapter.adapter import Lora
-from gptqmodel.nn_modules.qlinear.exllama import ExllamaQuantLinear
-from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear
-from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear
+from gptqmodel.adapter.adapter import AdapterCache, Lora
from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear
-from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear
from gptqmodel.utils.model import find_modules
from parameterized import parameterized
-from safetensors import safe_open
from torch import Tensor
+CUDA = torch.device("cuda:0")
class TestKernelOutput(unittest.TestCase):
- model_path = "/monster/data/model/sliuau-llama3.2-1b-4bit-group128/"
- lora_path = "/monster/data/model/sliuau-llama3.2-1b-4bit-group128/llama3.2-1b-4bit-group128-eora-rank128-arc/adapter_model.safetensors"
+ model_path = "sliuau/llama3.2-1b-4bit-group128" # hf "sliuau/llama3.2-1b-4bit-group128"
+
target_qliner_map = {
- BACKEND.EXLLAMA_V1: ExllamaQuantLinear,
- # BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear,
- BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear,
- BACKEND.TRITON: TritonV2QuantLinear,
+ # BACKEND.EXLLAMA_V1: ExllamaQuantLinear,
+ # # BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear,
+ # BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear,
+ # BACKEND.TRITON: TritonV2QuantLinear,
# BACKEND.CUDA: DynamicCudaQuantLinear,
BACKEND.TORCH: TorchQuantLinear,
# BACKEND.BITBLAS: BitBLASQuantLinear,
# BACKEND.IPEX: IPEXQuantLinear,
- BACKEND.MARLIN: MarlinQuantLinear,
- BACKEND.MARLIN_FP16: MarlinQuantLinear,
+ # BACKEND.MARLIN: MarlinQuantLinear,
+ # BACKEND.MARLIN_FP16: MarlinQuantLinear,
}
+ target = 'model.layers.6.self_attn.v_proj'
+
@classmethod
def setUpClass(cls):
- cls.target = 'model.layers.6.self_attn.v_proj'
- eora_tensors = {}
- # with safe_open("/home/shihyangl/llama3.2-1b-4bit-group128-eora-rank128-arc-v2/adapter_model.safetensors",
- with safe_open(
- cls.lora_path,
- framework="pt", device=0) as f:
- for k in f.keys():
- if cls.target in k:
- eora_tensors[k] = f.get_tensor(k)
+ lora_path = "sliuau/llama3.2-1b-4bit-group128-eora-rank128-arc" # adapter_model.safetensors
+ # hf "sliuau-llama3.2-1b-4bit-group128/llama3.2-1b-4bit-group128-eora-rank128-arc/"
- m = 1
- k = eora_tensors[f'{cls.target}.lora_A.weight'].shape[1]
- eora_tensors[f'{cls.target}.lora_B.weight'].shape[0]
+ cls.m = 1
+ cls.k = -1
+ cls.x = None # random X input of shape (m, k)
+ cls.adapter = Lora(
+ rank=128,
+ path=lora_path)
- cls.x = torch.rand((m, k), device='cuda', dtype=torch.float16)
- cls.eora_a = eora_tensors[f'{cls.target}.lora_A.weight'].to('cuda:0').T
- cls.eora_b = eora_tensors[f'{cls.target}.lora_B.weight'].to('cuda:0').T
+ cls.adapter.post_init(cls.target, device=CUDA) # trigger adapter weight load from disk
+ cls.k = cls.adapter.lora_A.shape[0]
- cls.adapter = Lora(path=cls.lora_path, rank=128)
+ cls.x = torch.rand((cls.m, cls.k), device=CUDA, dtype=torch.float16)
+ AdapterCache.reset() # allow next load to complete since we are hacking to get consume only 1 lora module
# TORCH as reference output
cls.torch_kernel_out = cls.forward(cls, backend=BACKEND.TORCH)
@@ -82,11 +76,11 @@ def assert_on_mismatch(self, a: Tensor, b: Tensor, rtol=0.00005, atol=0.00005):
@parameterized.expand([
(BACKEND.TORCH, 0.0000, 0.0000),
- (BACKEND.TRITON, 0.00001, 0.00001),
- (BACKEND.EXLLAMA_V1, 0.09, 0.0001),
- (BACKEND.EXLLAMA_V2, 0.136, 0.0001),
- (BACKEND.MARLIN, 0.00005, 0.00005),
- (BACKEND.MARLIN_FP16, 0.0001, 0.0035),
+ # (BACKEND.TRITON, 0.00001, 0.00001),
+ # (BACKEND.EXLLAMA_V1, 0.09, 0.0001),
+ # (BACKEND.EXLLAMA_V2, 0.136, 0.0001),
+ # (BACKEND.MARLIN, 0.00005, 0.00005),
+ # (BACKEND.MARLIN_FP16, 0.0001, 0.0035),
# (BACKEND.EXLLAMA_EORA)
])
def test_kernel_output(self, backend: BACKEND, r_tolerance: float, a_tolerance: float):
@@ -100,11 +94,11 @@ def test_kernel_output(self, backend: BACKEND, r_tolerance: float, a_tolerance:
@parameterized.expand([
(BACKEND.TORCH, 0.0000, 0.0000),
- (BACKEND.TRITON, 0.00001, 0.00001),
- (BACKEND.EXLLAMA_V1, 0.015, 0.0008),
- (BACKEND.EXLLAMA_V2, 0.16, 0.0003),
- (BACKEND.MARLIN, 0.00001, 0.00003),
- (BACKEND.MARLIN_FP16, 0.0001, 0.0035),
+ # (BACKEND.TRITON, 0.00001, 0.00001),
+ # (BACKEND.EXLLAMA_V1, 0.015, 0.0008),
+ # (BACKEND.EXLLAMA_V2, 0.16, 0.0003),
+ # (BACKEND.MARLIN, 0.00001, 0.00003),
+ # (BACKEND.MARLIN_FP16, 0.0001, 0.0035),
# (BACKEND.EXLLAMA_EORA)
])
def test_kernel_output_with_lora(self, backend: BACKEND, r_tolerance: float, a_tolerance: float):
diff --git a/tests/test_packable.py b/tests/test_packable.py
index 412310aaf..1aae581ea 100644
--- a/tests/test_packable.py
+++ b/tests/test_packable.py
@@ -3,9 +3,6 @@
from typing import Dict
import torch
-from parameterized import parameterized
-from safetensors.torch import load_file
-
from gptqmodel import BACKEND, GPTQModel
from gptqmodel.utils.model import find_modules, convert_gptq_v2_to_v1_format
@@ -16,6 +13,9 @@
from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402
from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402
from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402
+from gptqmodel.utils.model import find_modules
+from parameterized import parameterized
+from safetensors.torch import load_file
class TestPackable(unittest.TestCase):
diff --git a/tests/test_packing.py b/tests/test_packing.py
index 011730830..7b08099a4 100644
--- a/tests/test_packing.py
+++ b/tests/test_packing.py
@@ -17,12 +17,11 @@
# -- do not touch
import os
-from parameterized import parameterized
-
from gptqmodel import BACKEND
from gptqmodel.nn_modules.qlinear.dynamic_cuda import DynamicCudaQuantLinear
from gptqmodel.nn_modules.qlinear.exllama import ExllamaQuantLinear
from gptqmodel.nn_modules.qlinear.ipex import IPEXQuantLinear
+from parameterized import parameterized
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
diff --git a/tests/test_quant_and_eora.py b/tests/test_quant_and_eora.py
index 902ff4b1f..6a907d4df 100644
--- a/tests/test_quant_and_eora.py
+++ b/tests/test_quant_and_eora.py
@@ -33,9 +33,9 @@
class Test(ModelTest):
- NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/"
+ #NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/"
#NATIVE_MODEL_ID = "/monster/data/model/tinyllama-15M-stories"
- #NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B"
+ NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B"
NATIVE_ARC_CHALLENGE_ACC = 0.3567
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805
@@ -54,7 +54,7 @@ def test_quant_and_eora(self):
calibration_dataset_rows = 512
calibration_dataset_concat_size = 0 # disable
auto_gc = False
- adapter_file_name = "eora.safetensors"
+ adapter_path = "eora"
dataset_id = "allenai/c4"
dataset_files = "en/c4-train.00001-of-01024.json.gz"
@@ -70,7 +70,7 @@ def test_quant_and_eora(self):
"calibration_dataset_rows": calibration_dataset_rows,
"calibration_dataset_concat_size": calibration_dataset_concat_size,
"auto_gc": auto_gc,
- "adapter_file_name": adapter_file_name,
+ "adapter_path": adapter_path,
}
calibration_dataset = load_dataset(
@@ -82,7 +82,7 @@ def test_quant_and_eora(self):
with tempfile.TemporaryDirectory() as tmpdir:
eora = Lora(
# for quant, path is save path. for load, it is loading path
- path=os.path.join(tmpdir, adapter_file_name),
+ path=os.path.join(tmpdir, adapter_path),
rank=rank,
)
@@ -114,7 +114,7 @@ def test_quant_and_eora(self):
torch_empty_cache()
# BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA,
- for backend in [ BACKEND.MARLIN ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN
+ for backend in [ BACKEND.AUTO ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN
base_bench = self.bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only
eora_bench = self.bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora)
@@ -145,7 +145,7 @@ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]):
tokens = model.generate("Capital of France is")[0]
result = model.tokenizer.decode(tokens)
print(f"BACKEND: {backend}, Result: {result}")
- assert "paris" in result.lower(), f"`paris` not found in `{result}`"
+ #assert "paris" in result.lower(), f"`paris` not found in `{result}`"
bench_result = GPTQModel.eval(
model_or_id_or_path=model,
diff --git a/tests/test_quant_and_eora_transformers.py b/tests/test_quant_and_eora_transformers.py
new file mode 100644
index 000000000..9f3b1f86e
--- /dev/null
+++ b/tests/test_quant_and_eora_transformers.py
@@ -0,0 +1,211 @@
+# Copyright 2025 ModelCloud
+# Contact: qubitium@modelcloud.ai, x.com/qubitium
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# -- do not touch
+import os
+
+import torch
+from safetensors.torch import load_file
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from peft.tuners.lora.gptq import GPTQLoraLinear
+
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+# -- end do not touch
+
+import tempfile # noqa: E402
+from typing import Optional # noqa: E402
+
+from datasets import load_dataset # noqa: E402
+from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402
+from gptqmodel.adapter.adapter import Lora, HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX # noqa: E402
+from gptqmodel.utils.eval import EVAL # noqa: E402
+from gptqmodel.utils.torch import torch_empty_cache # noqa: E402
+from lm_eval.utils import make_table # noqa: E402
+from logbar import LogBar
+from models.model_test import ModelTest # noqa: E402
+from tabulate import tabulate # noqa: E402
+
+log = LogBar.shared()
+
+
+class Test(ModelTest):
+ # NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/"
+ # NATIVE_MODEL_ID = "/monster/data/model/tinyllama-15M-stories"
+ NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B"
+
+ NATIVE_ARC_CHALLENGE_ACC = 0.3567
+ NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805
+ QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36
+
+ @classmethod
+ def setUpClass(cls):
+ pass
+
+ def test_quant_and_eora(self):
+ bits = 4
+ group_size = 128
+ desc_act = True
+ rank = 128
+ batch_size = 1
+ calibration_dataset_rows = 512
+ calibration_dataset_concat_size = 0 # disable
+ auto_gc = False
+ adapter_path = "eora"
+ dataset_id = "allenai/c4"
+ dataset_files = "en/c4-train.00001-of-01024.json.gz"
+
+ config_dict = {
+ "model_id": self.NATIVE_MODEL_ID,
+ "dataset_id": dataset_id,
+ "dataset_files": dataset_files,
+ "bits": bits,
+ "group_size": group_size,
+ "desc_act": desc_act,
+ "rank": rank,
+ "batch_size": batch_size,
+ "calibration_dataset_rows": calibration_dataset_rows,
+ "calibration_dataset_concat_size": calibration_dataset_concat_size,
+ "auto_gc": auto_gc,
+ "adapter_path": adapter_path,
+ }
+
+ calibration_dataset = load_dataset(
+ dataset_id,
+ data_files=dataset_files,
+ split="train"
+ ).select(range(calibration_dataset_rows))["text"]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ eora = Lora(
+ # for quant, path is save path. for load, it is loading path
+ path=os.path.join(tmpdir, adapter_path),
+ rank=rank,
+ )
+
+ quant_config = QuantizeConfig(
+ bits=bits,
+ group_size=group_size,
+ desc_act=desc_act, # bitblas only supports DESC_ACT=False
+ adapter=eora,
+ dynamic={
+ ".*\\.gate_proj.*": {
+ "adapter": {
+ "rank": 256
+ }
+ }
+ },
+ )
+
+ model = GPTQModel.load(
+ model_id_or_path=self.NATIVE_MODEL_ID,
+ quantize_config=quant_config,
+ )
+
+ model.quantize(
+ calibration_dataset=calibration_dataset,
+ batch_size=batch_size,
+ auto_gc=auto_gc,
+ calibration_dataset_concat_size=calibration_dataset_concat_size,
+ ) #
+
+ # EoRA adapter is saved according to Lora.path property
+ # if Lora.path is not set, we will save the lora as "lora.safetensors" in the same path as quant model
+ # You can also pass `eora_path` to `model.save()` to override this save path
+ model.save(tmpdir)
+
+ del model
+ torch_empty_cache()
+
+ # BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA,
+ for backend in [BACKEND.MARLIN]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN
+ eora_bench = self.bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora)
+ base_bench = self.bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only
+
+ print('--------GPTQModel + EoRA Config ---------')
+
+ # Convert the dictionary to a list of lists for tabulate
+ table_data = [[key, value] for key, value in config_dict.items()]
+ print(tabulate(table_data, headers=["Key", "Value"], tablefmt="grid"))
+
+ print('--------Eval GPTQ Result---------')
+ print(make_table(base_bench))
+ if "groups" in base_bench:
+ print(make_table(base_bench, "groups"))
+
+ print('--------Eval GPTQ + EoRA Result---------')
+ print(make_table(eora_bench))
+ if "groups" in eora_bench:
+ print(make_table(eora_bench, "groups"))
+
+ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]):
+ # test post-quant inference
+ if adapter:
+ adapter_weights = load_file(os.path.join(adapter.path, HF_ADAPTER_FILE_NAME))
+ origin_lora_a_weight = adapter_weights[
+ f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}model.layers.5.self_attn.v_proj.lora_A.weight"]
+ origin_lora_b_weight = adapter_weights[
+ f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}model.layers.5.self_attn.v_proj.lora_B.weight"]
+
+ model = AutoModelForCausalLM.from_pretrained(path, device_map="cuda")
+ log.info("PEFT: converting model to lora model")
+ model.load_adapter(adapter.path)
+
+ self.assert_adapter_load(model, origin_lora_a_weight, origin_lora_b_weight)
+ del model
+
+ model = AutoModelForCausalLM.from_pretrained(adapter.path, device_map="cuda")
+ log.info("PEFT: load model by adapter.path")
+
+ self.assert_adapter_load(model, origin_lora_a_weight, origin_lora_b_weight)
+ print("peft model", model)
+
+ # assert dynamic rank
+ v_proj_module = model.model.layers[5].self_attn.v_proj
+ assert v_proj_module.lora_A["default"].weight.data.shape[0] == 128
+ assert v_proj_module.lora_B["default"].weight.data.shape[1] == 128
+ gate_proj_module = model.model.layers[5].mlp.gate_proj
+ assert gate_proj_module.lora_A["default"].weight.data.shape[0] == 256
+ assert gate_proj_module.lora_B["default"].weight.data.shape[1] == 256
+
+ del origin_lora_a_weight, origin_lora_b_weight, adapter_weights
+ else:
+ model = AutoModelForCausalLM.from_pretrained(path, device_map="cuda")
+ print("model", model)
+
+ tokenizer = AutoTokenizer.from_pretrained(path)
+ inp = tokenizer("Capital of France is", return_tensors="pt").to(model.device)
+ tokens = model.generate(**inp)[0]
+ result = tokenizer.decode(tokens)
+ print(f"BACKEND: {backend}, Result: {result}")
+ # assert "paris" in result.lower(), f"`paris` not found in `{result}`"
+
+ bench_result = GPTQModel.eval(
+ model_or_id_or_path=model,
+ framework=EVAL.LM_EVAL,
+ tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU],
+ batch_size=self.get_batch_size(),
+ )
+
+ del model
+ torch_empty_cache()
+
+ return bench_result
+
+ def assert_adapter_load(self, model, origin_lora_a_weight, origin_lora_b_weight):
+ module = model.model.layers[5].self_attn.v_proj
+ assert isinstance(module, GPTQLoraLinear)
+ assert torch.equal(origin_lora_a_weight.to(model.device), module.lora_A["default"].weight.data)
+ assert torch.equal(origin_lora_b_weight.to(model.device), module.lora_B["default"].weight.data)