Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a2dc904
save eora to hf format
Qubitium Feb 27, 2025
9306a02
test needs to store self.x in cls to stay consistent
Qubitium Feb 27, 2025
6a2765a
temp disable torch kernel auto compile that is causing dynamo errors
Qubitium Feb 27, 2025
d28df61
fix shape error
ZX-ModelCloud Feb 27, 2025
b927f07
cleanup debug logs
Qubitium Feb 27, 2025
e0906ec
re-enable auto torch compile code
Qubitium Feb 27, 2025
e5fd40e
add lora config validation
Qubitium Feb 27, 2025
3bfe530
refractor loading cache into AdapterCache cls
Qubitium Feb 27, 2025
d87ad46
add lora rank override code from LoraConfig
Qubitium Feb 27, 2025
f81b077
remove `peft` dependency
Qubitium Feb 27, 2025
e042f70
comment on original HF repo path for test files
Qubitium Feb 27, 2025
31c8d1f
clean up HF download logic
Qubitium Feb 27, 2025
b22d094
save to PEFT compatible format
Qubitium Feb 27, 2025
964e5fe
add test_quant_and_eora_transformers.py
ZX-ModelCloud Feb 27, 2025
ab9519e
fix missing task_type in adapter_config.json
ZX-ModelCloud Feb 27, 2025
7ce0e8d
fix regex rule prefix not stripped
Qubitium Feb 27, 2025
45935ff
push peft compat changes
Qubitium Feb 28, 2025
5ed4472
prevent preft doing alpha / r scaling. set alpha eq r so math is just…
Qubitium Feb 28, 2025
32a468d
fix lora load with transformers
ZX-ModelCloud Feb 28, 2025
e590454
fix device
ZX-ModelCloud Feb 28, 2025
95f3e15
format
Qubitium Feb 28, 2025
a316cd2
Merge branch 'main' into lora-format
Qubitium Feb 28, 2025
86a242c
assert lora weight
ZX-ModelCloud Feb 28, 2025
a1803ab
fix empty base_model_name_or_path
ZX-ModelCloud Feb 28, 2025
21dbdf6
assert dynamic rank
ZX-ModelCloud Feb 28, 2025
39e0c96
remove dynamic adapter config when save quantize_config
ZX-ModelCloud Feb 28, 2025
0859d30
fix dynamic is none
CSY-ModelCloud Mar 1, 2025
3fceedf
[CI] install bitblas for test_inference_speed
CSY-ModelCloud Mar 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ jobs:
uv pip install -U transformers
uv pip install -U logbar==0.0.3
if [ "${{ matrix.test_script }}" == "test_perplexity" ] || \
[ "${{ matrix.test_script }}" == "test_inference_speed" ] || \
[ "${{ matrix.test_script }}" == "test_q4_bitblas" ] || \
[ "${{ matrix.test_script }}" == "test_save_loaded_quantized_model" ]; then
echo "===== install bitblas==0.0.1.dev13 ====="
Expand Down
172 changes: 112 additions & 60 deletions gptqmodel/adapter/adapter.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,66 @@
import os
from dataclasses import dataclass, field
from typing import Dict, List, Union
from urllib.parse import urlparse

import re
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

import safetensors
import torch

from ..utils.logger import setup_logger
from .peft import LoraConfig
from .remote import resolve_path

logger = setup_logger()
LORA_MERGED_WEIGHT_PATHS = [None, ""]
HF_ADAPTER_FILE_NAME = "adapter_model.safetensors"
HF_ADAPTER_CONFIG_FILE_NAME = "adapter_config.json"
HF_ADAPTER_WEIGHT_KEY_PREFIX = "base_model.model."


class AdapterCache():
cache: Dict[str, Dict[str, Union[LoraConfig, torch.Tensor]]] = {} # first level key is `path`, second level keys [ `config` = LoraConfig, `weights` = Dict[str, Tensors]

@classmethod
def get(cls, path: str) -> Optional[Tuple[LoraConfig, Dict[str, torch.Tensor]]]:
data = cls.cache.get(path)
if not data:
return None
else:
return data["config"], data["weights"]

@classmethod
def reset(cls):
logger.info("Adapter Cache: Resetting cache")
cls.cache = {}

@classmethod
def add(cls, path: str, config: LoraConfig, weights: Dict[str, torch.Tensor]):
cls.cache[path] = {"config": config, "weights": weights}

@classmethod
def remove(cls, path):
cls.cache.pop(path, None)

# TODO FIX ME: cache of adapter tensors loaded from disk
adapter_load_cache = None

class Adapter():
def __init__(self, rank: int, path: str = None):
self.rank = rank
def __init__(self, rank: int = None, path: str = None):
self.rank = rank # rank may be zero, when loading, and rank will be re-populated by loading saved LoraConfig file
self.path = path.lower().strip() if isinstance(path, str) else path

def validate_path(self, local_only=False):
def validate_path(self, local=False):
if not self.path or not isinstance(self.path, str):
raise ValueError("Adapter: `path` str is required.")

if local_only:
# path should not be a file but a directory
if self.path.endswith(".safetensors"):
raise ValueError(
f"Adapter: `path` must be a directory path or repo depending if you are saving (directory path) or loading (repo): actual = `{self.path}`")

if local:
if self.path.startswith("http"):
raise ValueError(f"Adapter: `path` str in this context must be a local os path: actual = `{self.path}`.")


# override me
def apply(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
pass
Expand Down Expand Up @@ -97,52 +131,69 @@ def post_init(self, weight_key: str, device:torch.device, lora_A: torch.Tensor=N
self.lora_A, self.lora_B = lora_A, lora_B
return

global adapter_load_cache
if adapter_load_cache is None:
if os.path.isfile(self.path):
lora_path = self.path
logger.info(f"Adapter: Loading `{self.path}` tensors from disk") # {adapter_load_cache}
elif self.path.startswith("http"):
from huggingface_hub import hf_hub_download
result = self.parse_url(self.path)
if len(result) == 3:
logger.info(f"Adapter: Downloading adapter weights from hf repo: `{result[0]}` revision: `{result[1]}` file: `{result[2]}`")
lora_path = hf_hub_download(repo_id=result[0], revision=result[1], filename=result[2])
elif len(result) == 1:
logger.info(f"Adapter: Downloading adapter weights from uri = `{self.path}`")
import requests
response = requests.get(self.path, stream=True)
lora_path = "lora.safetensors"
with open(lora_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
else:
raise Exception(f"Adapter: Lora path is invalid: `{self.path}`")
lora_cache = AdapterCache.get(self.path)
if lora_cache is None:
# get lora config
lora_cfg = LoraConfig.from_pretrained(path=self.path, filename=HF_ADAPTER_CONFIG_FILE_NAME)
lora_cfg.gptqmodel_path = self.path # hack: save this

if not isinstance(lora_cfg, LoraConfig):
raise ValueError(f"Adapter: Expected `LoraConfig` in `{self.path}`, actual = `{lora_cfg}`")

if self.rank is None:
self.rank = lora_cfg.r
else:
from huggingface_hub import HfApi, hf_hub_download
files = [f for f in HfApi().list_repo_files(self.path) if f in ["lora.safetensors", "eora_test.safetensors"]]
if self.rank != lora_cfg.r:
raise ValueError(f"Adapter: `rank` must match `LoraConfig.r`, expected `{self.rank}`, actual = `{lora_cfg.r}`")

lora_path = resolve_path(self.path, HF_ADAPTER_FILE_NAME)

# save to adapter cache
AdapterCache.add(self.path, lora_cfg, safetensors.torch.load_file(lora_path))

if files:
lora_path = hf_hub_download(repo_id=self.path, filename=files[0])
# print(f"Adapter tensors loaded from `{self.path}`")
else:
raise Exception(f"Adapter: There's no lora.safetensors or eora_test.safetensors on repo `{self.path}`")
lora_cache = AdapterCache.get(self.path)
assert lora_cache is not None

adapter_load_cache = safetensors.torch.load_file(lora_path)
# lora_cache result is a tuple
lora_cfg, lora_weights = lora_cache

weight_key = weight_key.lower()

# hack for HF Auto compat
if not f"{weight_key}.lora_A.weight" in adapter_load_cache:
weight_key = weight_key.removeprefix("model.")
lora_A_weight_key = f"{weight_key}.lora_A.weight"
lora_B_weight_key = f"{weight_key}.lora_B.weight"

#print(f"loaded lora weight keys: {adapter_load_cache.keys()}")
lora_A = adapter_load_cache.pop(f"{weight_key}.lora_A.weight").T
lora_B = adapter_load_cache.pop(f"{weight_key}.lora_B.weight").T
# print(f"lora_A_weight_key = {lora_A_weight_key}, lora_B_weight_key = {lora_B_weight_key}")
pop_keys = []
for k, v in lora_weights.items():
if k.endswith(lora_A_weight_key):
lora_A = v.T
pop_keys.append(k)
elif k.endswith(lora_B_weight_key):
lora_B = v.T
pop_keys.append(k)

# since loder cache is singleton, we need to reset to None to ci loop tests can pass
if len(adapter_load_cache) == 0:
adapter_load_cache = None

if pop_keys:
for k in pop_keys:
lora_weights.pop(k) # releasee lora weights from cache memory

# we have consumed all modules
if len(lora_weights) == 0:
AdapterCache.remove(self.path)
logger.info("Adapter: Consumed all Lora weights")

else:
logger.warn(f"Adapter: Lora weights not found for `{weight_key}`")

assert lora_A is not None and lora_B is not None, f"Adapter: `lora_A` and `lora_B` must both be present in the weights: actual = `{lora_A}` and `{lora_B}`"

# check for rank override from base config
self.dynamic_rank_override(lora_cfg=lora_cfg, weight_key=weight_key)

# # since loder cache is singleton, we need to reset to None to ci loop tests can pass
# if len(lora_weights) == 0:
# adapter_load_cache = None

# print(f"Adapter: {self.name()}, loaded lora_A shape: {lora_A.shape}")
# print(f"Adapter: {self.name()}, loaded lora_B shape: {lora_B.shape}")
Expand All @@ -155,21 +206,22 @@ def post_init(self, weight_key: str, device:torch.device, lora_A: torch.Tensor=N
#print(f"Adapter: lora_A {lora_A.shape}: `{lora_B}`")
#print(f"Adapter: lora_B {lora_B.shape}: `{lora_B}`")

def parse_url(self, url: str):
parsed_url = urlparse(url)
def dynamic_rank_override(self, lora_cfg: LoraConfig, weight_key: str) -> bool:
assert lora_cfg.rank_pattern is not None and weight_key is not None
if lora_cfg.rank_pattern:
for k, v in lora_cfg.rank_pattern.items():
assert isinstance(k, str) and isinstance(v, int)
k = k.lower()
assert v > 0 # check for invalid rank range
# first do string full match, then suffix match, then regex match
if weight_key == k or k.endswith(weight_key) or re.match(k, weight_key):
self.rank = v
logger.info(f"Adapter: Base Lora `rank` = `{self.rank}` has been overridden by `{k}` due to dynamic `LoraConfig.rank_pattern` control.")
return True

return False

if parsed_url.netloc.endswith("huggingface.co") or parsed_url.netloc.endswith("hf.co"):
parts = parsed_url.path.strip("/").split("/")

if "blob" in parts:
idx = parts.index("blob")
repo_id = "/".join(parts[:idx])
rev = parts[idx + 1]
filename = parts[idx + 2].split("?")[0] # remove ?download=true
return [repo_id, rev, filename]
else:
return [url]
return []

def to_dict(self):
return {
Expand Down
Loading