Skip to content

Commit 8d48e39

Browse files
committed
Address review comments
Signed-off-by: Jou-An Chen <[email protected]>
1 parent 96ce832 commit 8d48e39

File tree

8 files changed

+69
-51
lines changed

8 files changed

+69
-51
lines changed

QEfficient/peft/auto.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform
2222
from QEfficient.base.pytorch_transforms import PytorchTransform
2323
from QEfficient.generation.cloud_infer import QAICInferenceSession
24-
from QEfficient.lora import QEffAutoLoraModelForCausalLM
24+
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
2525
from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform
2626
from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform
2727
from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform
@@ -147,6 +147,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
147147
"""
148148
Args:
149149
:pretrained_name_or_path (str): Model card name from huggingface or local path to model directory.
150+
:finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification.
150151
:args, kwargs: Additional arguments to pass to peft.AutoPeftModelForCausalLM.
151152
"""
152153
if kwargs.get("full_batch_size"):

QEfficient/lora/__init__.py renamed to QEfficient/peft/lora/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# ----------------------------------------------------------------------------
77

8-
from QEfficient.lora.auto import QEffAutoLoraModelForCausalLM
8+
from QEfficient.peft.lora.auto import QEffAutoLoraModelForCausalLM
99

1010
__all__ = [
1111
"QEffAutoLoraModelForCausalLM",

QEfficient/lora/auto.py renamed to QEfficient/peft/lora/auto.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,11 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM):
2929
3030
Args:
3131
:model (nn.Module): PyTorch model
32-
:base_model_name (str): Model card name for base model
33-
:adapter_weights (Dict): A dictionary contains lora_name to lora_weight mapping
34-
:adapter_configs (Dict): A dictionary contains lora_name to lora_configs mapping
35-
:max_num_adapters (int): Total number of active adapters that to be exported and compiled
36-
:active_adapter_to_id (Dict): A dictionary contains active adapter's lora_name to lora_id mapping
37-
:lora_rank (int): The consistent lora rank across all active adapters
38-
:target_modules_for_all_adapters (List[str]): The consistent set of target modules across all active adapters
32+
:continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
3933
4034
.. code-block:: python
4135
42-
from QEfficient.lora import QEffAutoLoraModelForCausalLM
36+
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
4337
4438
m = QEffAutoPeftModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
4539
m.load_adapter("predibase/gsm8k", "gsm8k")
@@ -53,14 +47,13 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM):
5347

5448
def __init__(self, model: nn.Module, continuous_batching: bool = False, **kwargs) -> None:
5549
super().__init__(model, continuous_batching)
56-
assert (
57-
type(self.model).__name__ == "QEffMistralForCausalLM" or type(self.model).__name__ == "QEffLlamaForCausalLM"
58-
), f"Only QEffMistralForCausalLM and QEffLlamaForCausalLM model are supported but get {type(self.model).__name__}"
50+
if self.model.__class__.__name__ not in ["QEffMistralForCausalLM", "QEffLlamaForCausalLM"]:
51+
raise NotImplementedError(
52+
f"Only QEffMistralForCausalLM and QEffLlamaForCausalLM model are supported but get {self.model.__class__.__name__}"
53+
)
5954

60-
self.base_model_name = self.model.model.config._name_or_path
6155
self.adapter_weights = {}
6256
self.adapter_configs = {}
63-
self.max_num_adapters = 0
6457
self.active_adapter_to_id = {}
6558

6659
self.lora_rank = 0
@@ -101,11 +94,15 @@ def download_adapter(
10194
adapter_weight: Optional[dict] = None,
10295
adapter_config: Optional[PeftConfig] = None,
10396
):
104-
"""Loads a new adapter from huggingface hub or local path into CPU cache
97+
"""
98+
Loads a new adapter from huggingface hub or local path into CPU cache
10599
106-
Args:
100+
``Mandatory`` Args:
107101
:adapter_model_id (str): Adapter model ID from huggingface hub or local path
108-
:adapter_name (str): Adapter name to be used to set this adapter as current
102+
:adapter_name (str): Adapter name to be used to downloaded this adapter
103+
``Optional`` Args:
104+
:adapter_weight (dict): Adapter weight tensors in dictionary format
105+
:adapter_config (PeftConfig): Adapter config in the format of PeftConfig
109106
"""
110107

111108
# check if adapter name already loaded
@@ -128,7 +125,16 @@ def load_adapter(
128125
adapter_weight: Optional[dict] = None,
129126
adapter_config: Optional[PeftConfig] = None,
130127
):
131-
"Load adapter into CPU cache and Sets active adapter from one of the loaded adapters"
128+
"""
129+
Load adapter into CPU cache and set it as active
130+
131+
``Mandatory`` Args:
132+
:adapter_model_id (str): Adapter model ID from huggingface hub or local path
133+
:adapter_name (str): Adapter name to be used to load this adapter
134+
``Optional`` Args:
135+
:adapter_weight (dict): Adapter weight tensors in dictionary format
136+
:adapter_config (PeftConfig): Adapter config in the format of PeftConfig
137+
"""
132138

133139
# check if adapter name already exist and activated
134140
if adapter_name in self.active_adapter_to_id.keys():
@@ -151,22 +157,23 @@ def load_adapter(
151157

152158
# set active adapter id to current max if adapter_name is new
153159
if adapter_name not in self.active_adapter_to_id.keys():
154-
self.active_adapter_to_id[adapter_name] = self.max_num_adapters + 1 # reserve 0 for base
155-
156-
# add active adapter to set
157-
self.max_num_adapters = len(self.active_adapter_to_id)
160+
self.active_adapter_to_id[adapter_name] = len(self.active_adapter_to_id) + 1 # reserve 0 for base
158161

159162
return self.active_adapter_to_id[adapter_name]
160163

161164
def unload_adapter(self, adapter_name: str):
162-
"Deactivate adpater and remove it from CPU cache"
165+
"""
166+
Deactivate adpater and remove it from CPU cache
167+
168+
``Mandatory`` Args:
169+
:adapter_name (str): Adapter name to be unloaded
170+
"""
163171

164172
# step1: remove from active list if it's there
165173
if adapter_name not in self.active_adapter_to_id.keys():
166174
logger.info(f"Adapter name {adapter_name} is not set active yet")
167175
return False
168176

169-
self.max_num_adapters -= 1
170177
self.active_adapter_to_id.pop(adapter_name)
171178

172179
# renumbering of active adapter id
@@ -197,9 +204,9 @@ def _load_adapter_weights_to_model(self):
197204
for i in range(num_hidden_layers):
198205
for target_module in self.target_modules_for_all_adapters:
199206
# stack all adapters weights
200-
a_tensor_list = list(range(self.max_num_adapters + 1))
201-
b_tensor_list = list(range(self.max_num_adapters + 1))
202-
s_tensor_list = list(range(self.max_num_adapters + 1))
207+
a_tensor_list = list(range(len(self.active_adapter_to_id) + 1))
208+
b_tensor_list = list(range(len(self.active_adapter_to_id) + 1))
209+
s_tensor_list = list(range(len(self.active_adapter_to_id) + 1))
203210

204211
for lora_name, lora_id in self.active_adapter_to_id.items():
205212
if target_module in ["q_proj", "k_proj", "v_proj", "o_proj"]:
@@ -256,10 +263,6 @@ def _load_adapter_weights_to_model(self):
256263
def _init_adapter_model(self):
257264
"Initialize the fixed lora model with multiple adapter weigths standby"
258265

259-
# assume all adapters have same target_modules and ranks
260-
if self.max_num_adapters != len(self.active_adapter_to_id):
261-
raise ValueError("Inconsistent max_num_adapters and active adapters")
262-
263266
# set lora rank
264267
self.lora_rank = list(self.adapter_configs.values())[0].r
265268

@@ -268,7 +271,7 @@ def _init_adapter_model(self):
268271

269272
self.target_modules_for_all_adapters = list(self.adapter_configs.values())[0].target_modules
270273
_, transformed = TargetModulesTransform.apply(
271-
self.model, self.target_modules_for_all_adapters, self.lora_rank, self.max_num_adapters
274+
self.model, self.target_modules_for_all_adapters, self.lora_rank, len(self.active_adapter_to_id)
272275
)
273276

274277
# load_weight to model
@@ -287,7 +290,11 @@ def export(self, export_dir: Optional[str] = None) -> str:
287290
"""
288291

289292
# initialize the adapter model
290-
assert self.max_num_adapters, "Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage"
293+
if len(self.active_adapter_to_id) == 0:
294+
raise ValueError(
295+
"Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage"
296+
)
297+
291298
self._init_adapter_model()
292299

293300
bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
@@ -333,6 +340,7 @@ def generate(
333340
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
334341
prompts: List[str],
335342
device_id: List[int] = None,
343+
prompt_to_adapter_mapping: List[str] = None,
336344
runtime: str = "AI_100",
337345
**kwargs,
338346
):
@@ -342,18 +350,28 @@ def generate(
342350
If the number of prompts cannot be divided by the ``batch_size``, the last unfulfilled batch will be dropped.
343351
344352
``Mandatory`` Args:
353+
:tokenizer (PreTrainedTokenizerFast or PreTrainedTokenizer): The tokenizer used in the inference
345354
:prompts (List[str]): List of prompts to run the execution.
346355
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
356+
:prompt_to_adapter_mapping (List[str]): The sequence of the adapter names will be matched with sequence of prompts and corresponding adapters will be used for the prompts."base" for base model (no adapter).
347357
``optional`` Args:
348358
:runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100".
349-
:prompt_to_adapter_mapping (List[str]): A list of adapter names that maps to the prompts, specifying which adapter the prompt wants to apply. "base" for base model (no adapter).
359+
350360
"""
351361
if runtime != "AI_100":
352362
raise ValueError("Only AI_100 runtime is supported right now via generate API")
353363
if not isinstance(self.qpc_path, Path):
354364
raise TypeError("Please run compile API first!")
355365
generation_len = kwargs.pop("generation_len", None)
356-
prompt_to_adapter_mapping = kwargs.pop("prompt_to_adapter_mapping", ["base" for _ in range(len(prompts))])
366+
367+
if not prompt_to_adapter_mapping:
368+
prompt_to_adapter_mapping = ["base" for _ in range(len(prompts))]
369+
370+
if len(prompt_to_adapter_mapping) != len(prompts):
371+
raise RuntimeError(
372+
f"Number of prompts should match number of prompt_to_adapter_mapping, got len(prompts) = {len(prompts)}, len(prompt_to_adapter_mapping) = {len(prompt_to_adapter_mapping)}"
373+
)
374+
357375
return QEfficient.cloud_ai_100_exec_kv(
358376
tokenizer,
359377
self.qpc_path,
File renamed without changes.
File renamed without changes.
File renamed without changes.

examples/lora_models.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
)
2323

2424
# (alternative) non-cb compilation
25-
# qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained("predibase/gsm8k", "gsm8k", continuous_batching=False, finite_adapters=True)
25+
# qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained(
26+
# "predibase/gsm8k", "gsm8k", continuous_batching=False, finite_adapters=True
27+
# )
2628

2729
## STEP 2 -- load adapter adapter
2830
qeff_model.load_adapter("predibase/tldr_content_gen", "tldr_content_gen")
@@ -47,13 +49,15 @@
4749
)
4850

4951
# (alternative) non-cb compilation
50-
# qpc_path = qeff_model.compile(batch_size=2,
51-
# prefill_seq_len=seq_len,
52-
# ctx_len=ctx_len,
53-
# num_devices=len(device_group),
54-
# num_cores=16,
55-
# mxfp6_matmul=True,
56-
# mxint8_kv_cache=True)
52+
# qpc_path = qeff_model.compile(
53+
# batch_size=2,
54+
# prefill_seq_len=seq_len,
55+
# ctx_len=ctx_len,
56+
# num_devices=len(device_group),
57+
# num_cores=16,
58+
# mxfp6_matmul=True,
59+
# mxint8_kv_cache=True,
60+
# )
5761

5862
## STEP 4 -- run inference on the generate function
5963
prompts = [

tests/lora/test_lora_model.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from transformers import AutoConfig, AutoModelForCausalLM
1414

1515
from QEfficient import QEffAutoPeftModelForCausalLM
16-
from QEfficient.lora import QEffAutoLoraModelForCausalLM
16+
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
1717
from QEfficient.utils import load_hf_tokenizer
1818

1919
configs = [
@@ -56,10 +56,8 @@ def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapt
5656
model_hf = AutoModelForCausalLM.from_pretrained(base_model_name)
5757
qeff_model = QEffAutoLoraModelForCausalLM(model_hf)
5858

59-
assert qeff_model.base_model_name == base_model_name
6059
assert len(qeff_model.adapter_weights) == 0
6160
assert len(qeff_model.adapter_configs) == 0
62-
assert qeff_model.max_num_adapters == 0
6361
assert len(qeff_model.active_adapter_to_id) == 0
6462

6563

@@ -68,10 +66,8 @@ def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapt
6866
def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1):
6967
qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(pretrained_model_name_or_path=base_model_name)
7068

71-
assert qeff_model.base_model_name == base_model_name
7269
assert len(qeff_model.adapter_weights) == 0
7370
assert len(qeff_model.adapter_configs) == 0
74-
assert qeff_model.max_num_adapters == 0
7571
assert len(qeff_model.active_adapter_to_id) == 0
7672

7773

@@ -80,10 +76,9 @@ def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_
8076
def test_auto_peft_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1):
8177
qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained(adapter_id_0, "id_0", finite_adapters=True)
8278

83-
assert qeff_model.base_model_name == base_model_name
79+
assert isinstance(qeff_model, QEffAutoLoraModelForCausalLM)
8480
assert len(qeff_model.adapter_weights) == 1
8581
assert len(qeff_model.adapter_configs) == 1
86-
assert qeff_model.max_num_adapters == 1
8782
assert len(qeff_model.active_adapter_to_id) == 1
8883

8984

0 commit comments

Comments
 (0)