Skip to content

Commit 96ce832

Browse files
committed
Enable init from QEffAutoPeftModelForCausalLM with finite_adapters flag
Signed-off-by: Jou-An Chen <[email protected]>
1 parent 522355a commit 96ce832

File tree

5 files changed

+76
-73
lines changed

5 files changed

+76
-73
lines changed

QEfficient/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from QEfficient.compile.compile_helper import compile
1010
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
1111
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
12-
from QEfficient.lora import QEffAutoLoraModelForCausalLM
1312
from QEfficient.peft import QEffAutoPeftModelForCausalLM
1413
from QEfficient.transformers.transform import transform
1514

@@ -25,6 +24,5 @@
2524
"QEffAutoModel",
2625
"QEFFAutoModelForCausalLM",
2726
"QEffAutoPeftModelForCausalLM",
28-
"QEffAutoLoraModelForCausalLM",
2927
"QEFFCommonLoader",
3028
]

QEfficient/lora/auto.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525
class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM):
2626
"""
27-
QEff class for loading models with multiple LoRA adapters.
28-
Once exported and compiled, the qpc can perform mixed batch inference with provided prompt_to_lora_id_mapping.
27+
QEff class for loading models with multiple LoRA adapters. Currently only Mistral and Llama model are supported.
28+
Once exported and compiled, the qpc can perform mixed batch inference with provided `prompt_to_adapter_mapping`.
2929
3030
Args:
3131
:model (nn.Module): PyTorch model
@@ -34,21 +34,20 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM):
3434
:adapter_configs (Dict): A dictionary contains lora_name to lora_configs mapping
3535
:max_num_adapters (int): Total number of active adapters that to be exported and compiled
3636
:active_adapter_to_id (Dict): A dictionary contains active adapter's lora_name to lora_id mapping
37+
:lora_rank (int): The consistent lora rank across all active adapters
38+
:target_modules_for_all_adapters (List[str]): The consistent set of target modules across all active adapters
3739
3840
.. code-block:: python
3941
40-
from QEfficient import QEffAutoLoraModelForCausalLM
42+
from QEfficient.lora import QEffAutoLoraModelForCausalLM
4143
4244
m = QEffAutoPeftModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
4345
m.load_adapter("predibase/gsm8k", "gsm8k")
4446
m.load_adapter("predibase/magicoder", "magicoder")
45-
gsm8k_id = m.set_adapter("gsm8k")
46-
magicoder_id = m.set_adapter("magicoder")
47-
m.export(full_batch_size=3)
4847
m.compile(num_cores=16, device_group=[0])
4948
5049
prompts=["code prompt", "math prompt", "generic"]
51-
m.generate(prompts, device_group=[0], prompt_to_lora_id_mapping=[magicoder_id,gsm8k_id,0])
50+
m.generate(prompts, device_group=[0], prompt_to_adapter_mapping=["magicoder","gsm8k_id","base"])
5251
5352
"""
5453

@@ -188,12 +187,10 @@ def unload_adapter(self, adapter_name: str):
188187

189188
return True
190189

191-
def get_adapter_id(self, adapter_name):
192-
"get the adapter_id that maps to the adapter_name"
190+
def set_adapter(self, adapter_name: str):
191+
raise NotImplementedError("Set adapter is not supported in finite_adapters mode")
193192

194-
return self.active_adapter_to_id[adapter_name]
195-
196-
def load_adapter_weights_to_model(self):
193+
def _load_adapter_weights_to_model(self):
197194
"Loads adapter weights to the model's multilora layer in a stacked format"
198195

199196
num_hidden_layers = len(self.model.model.layers)
@@ -256,7 +253,7 @@ def load_adapter_weights_to_model(self):
256253
module.lora_b_weights.copy_(stacked_lora_b)
257254
module.lora_scalings.copy_(stacked_lora_s)
258255

259-
def init_adapter_model(self):
256+
def _init_adapter_model(self):
260257
"Initialize the fixed lora model with multiple adapter weigths standby"
261258

262259
# assume all adapters have same target_modules and ranks
@@ -275,12 +272,23 @@ def init_adapter_model(self):
275272
)
276273

277274
# load_weight to model
278-
self.load_adapter_weights_to_model()
275+
self._load_adapter_weights_to_model()
279276

280277
def export(self, export_dir: Optional[str] = None) -> str:
278+
"""
279+
Exports the model to ``ONNX`` format using ``torch.onnx.export``.
280+
We currently don't support exporting non-transformed models. Please refer to the ``convert_to_cloud_bertstyle`` function in the **Low-Level API** for a legacy function that supports this."
281+
282+
``Optional`` Args:
283+
does not any arguments.
284+
285+
Returns:
286+
:str: Path of the generated ``ONNX`` graph.
287+
"""
288+
281289
# initialize the adapter model
282290
assert self.max_num_adapters, "Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage"
283-
self.init_adapter_model()
291+
self._init_adapter_model()
284292

285293
bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
286294
seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
@@ -338,18 +346,21 @@ def generate(
338346
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
339347
``optional`` Args:
340348
:runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100".
349+
:prompt_to_adapter_mapping (List[str]): A list of adapter names that maps to the prompts, specifying which adapter the prompt wants to apply. "base" for base model (no adapter).
341350
"""
342351
if runtime != "AI_100":
343352
raise ValueError("Only AI_100 runtime is supported right now via generate API")
344353
if not isinstance(self.qpc_path, Path):
345354
raise TypeError("Please run compile API first!")
346355
generation_len = kwargs.pop("generation_len", None)
347-
prompt_to_lora_id_mapping = kwargs.pop("prompt_to_lora_id_mapping", [0 for _ in range(len(prompts))])
356+
prompt_to_adapter_mapping = kwargs.pop("prompt_to_adapter_mapping", ["base" for _ in range(len(prompts))])
348357
return QEfficient.cloud_ai_100_exec_kv(
349358
tokenizer,
350359
self.qpc_path,
351360
prompt=prompts,
352361
device_id=device_id,
353362
generation_len=generation_len,
354-
prompt_to_lora_id_mapping=prompt_to_lora_id_mapping,
363+
prompt_to_lora_id_mapping=[
364+
self.active_adapter_to_id[name] if name != "base" else 0 for name in prompt_to_adapter_mapping
365+
],
355366
)

QEfficient/peft/auto.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import numpy as np
1414
import torch
15-
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM, load_peft_weights
15+
from peft import AutoPeftModelForCausalLM, PeftConfig, PeftModelForCausalLM, load_peft_weights
1616
from torch import nn
1717
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList
1818
from transformers.generation.streamers import BaseStreamer
@@ -21,6 +21,7 @@
2121
from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform
2222
from QEfficient.base.pytorch_transforms import PytorchTransform
2323
from QEfficient.generation.cloud_infer import QAICInferenceSession
24+
from QEfficient.lora import QEffAutoLoraModelForCausalLM
2425
from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform
2526
from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform
2627
from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform
@@ -38,6 +39,7 @@ class QEffAutoPeftModelForCausalLM(QEFFBaseModel):
3839
3940
Args:
4041
:model (nn.Module): PyTorch model
42+
:finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification.
4143
4244
.. code-block:: python
4345
@@ -152,7 +154,17 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
152154
if kwargs.get("use_cache") is False:
153155
warnings.warn("Overriding to use_cache=True")
154156
kwargs["use_cache"] = True
155-
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
157+
158+
if kwargs.pop("finite_adapters", False): # initialize through finite_adapters class
159+
obj = QEffAutoLoraModelForCausalLM.from_pretrained(
160+
pretrained_model_name_or_path=PeftConfig.from_pretrained(
161+
pretrained_name_or_path
162+
).base_model_name_or_path,
163+
**kwargs,
164+
)
165+
obj.load_adapter(pretrained_name_or_path, list(args)[0])
166+
else:
167+
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
156168
return obj
157169

158170
def export(self, export_dir: Optional[str] = None) -> str:

examples/lora_models.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
## This example works on continuous batching with different lora adapters in the same batch ##
99

10-
from QEfficient import QEffAutoLoraModelForCausalLM
10+
from QEfficient import QEffAutoPeftModelForCausalLM
1111
from QEfficient.utils import load_hf_tokenizer
1212

1313
base_model_name = "mistralai/Mistral-7B-v0.1"
@@ -17,37 +17,22 @@
1717
device_group = [0]
1818

1919
## STEP 1 -- init base model
20-
21-
# **Option1**: Download model weights from hugging face & Init it with QEffAuto model to apply QEff transforms
22-
# model_hf = AutoModelForCausalLM.from_pretrained(base_model_name)
23-
# qeff_model = QEffAutoLoraModelForCausalLM(model_hf, continuous_batching=True)
24-
25-
# **Option2**: Initialize the model using from_pretrained() method
26-
qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(
27-
pretrained_model_name_or_path=base_model_name, continuous_batching=True
20+
qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained(
21+
"predibase/gsm8k", "gsm8k", continuous_batching=True, finite_adapters=True
2822
)
2923

30-
# (alternative) non-cb initialization
31-
# qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(pretrained_model_name_or_path=base_model_name, continuous_batching=False)
24+
# (alternative) non-cb compilation
25+
# qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained("predibase/gsm8k", "gsm8k", continuous_batching=False, finite_adapters=True)
3226

3327
## STEP 2 -- load adapter adapter
34-
adapter_id_gsm8k = qeff_model.load_adapter("predibase/gsm8k", "gsm8k")
35-
print(f"Activating gsm8k as adapter_id {adapter_id_gsm8k}")
36-
37-
adapter_id_tldr = qeff_model.load_adapter("predibase/tldr_content_gen", "tldr_content_gen")
38-
print(f"Activating tldr_content_gen as adapter_id {adapter_id_tldr}")
28+
qeff_model.load_adapter("predibase/tldr_content_gen", "tldr_content_gen")
3929

40-
adapter_id_dbpedia = qeff_model.load_adapter("predibase/dbpedia", "dbpedia")
41-
print(f"Activating dbpedia as adapter_id {adapter_id_dbpedia}")
30+
qeff_model.load_adapter("predibase/dbpedia", "dbpedia")
4231

4332
# STEP 2 (optional) -- unload adapter
4433
unload_status = qeff_model.unload_adapter("dbpedia")
4534
print(f"Unloading dbpedia success: {unload_status}")
4635

47-
# get adapter id
48-
# NOTE: should rely on get_adapter_id in case the id obtained at set_adpater() get updated
49-
gsm8k_id = qeff_model.get_adapter_id("gsm8k")
50-
tldr_id = qeff_model.get_adapter_id("tldr_content_gen")
5136

5237
## STEP 3 -- export & compile qeff model
5338
qpc_path = qeff_model.compile(
@@ -71,10 +56,6 @@
7156
# mxint8_kv_cache=True)
7257

7358
## STEP 4 -- run inference on the generate function
74-
# prompt_to_lora_id_mapping is a list of lora_id of which the size matches num of prompts
75-
# and is a one-on-one mapping for the prompt-to-loraid
76-
# e.g., prompt_to_lora_id_mapping = [{adapter_id_0}, {adapter_id_1}, {adapter_id_0}, {adapter_id_1}, ...]
77-
# setting 0 means using base model
7859
prompts = [
7960
"""Please answer the following question: James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?\n\nAnswer:""",
8061
"""The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: Harvard shrank its insect-inspired microrobot to the size of a penny\n\nContent:""",
@@ -90,7 +71,16 @@
9071
tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name),
9172
prompts=prompts,
9273
device_id=device_group,
93-
prompt_to_lora_id_mapping=[gsm8k_id, tldr_id, gsm8k_id, 0, gsm8k_id, tldr_id, gsm8k_id, tldr_id],
74+
prompt_to_adapter_mapping=[
75+
"gsm8k",
76+
"tldr_content_gen",
77+
"gsm8k",
78+
"base",
79+
"gsm8k",
80+
"tldr_content_gen",
81+
"gsm8k",
82+
"tldr_content_gen",
83+
],
9484
)
9585

9686

tests/lora/test_lora_model.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from peft import LoraConfig
1313
from transformers import AutoConfig, AutoModelForCausalLM
1414

15-
from QEfficient import QEffAutoLoraModelForCausalLM
15+
from QEfficient import QEffAutoPeftModelForCausalLM
16+
from QEfficient.lora import QEffAutoLoraModelForCausalLM
1617
from QEfficient.utils import load_hf_tokenizer
1718

1819
configs = [
@@ -74,6 +75,18 @@ def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_
7475
assert len(qeff_model.active_adapter_to_id) == 0
7576

7677

78+
# test peft model initialization using from_pretrained approach
79+
@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples)
80+
def test_auto_peft_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1):
81+
qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained(adapter_id_0, "id_0", finite_adapters=True)
82+
83+
assert qeff_model.base_model_name == base_model_name
84+
assert len(qeff_model.adapter_weights) == 1
85+
assert len(qeff_model.adapter_configs) == 1
86+
assert qeff_model.max_num_adapters == 1
87+
assert len(qeff_model.active_adapter_to_id) == 1
88+
89+
7790
# test the init assertion for models that are not supported
7891
@pytest.mark.parametrize("base_model_name", ["distilbert/distilgpt2"])
7992
def test_auto_lora_model_for_causal_lm_init_from_unsupported_model(base_model_name):
@@ -156,27 +169,6 @@ def test_auto_lora_model_for_causal_lm_hash():
156169
assert model_hash_0_1 != model_hash_0_0
157170

158171

159-
# test load_adapter() and get_adapter_id()
160-
@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[:1])
161-
def test_auto_lora_model_for_causal_lm_load_get_adapter_id_check(base_model_name, adapter_id_0, adapter_id_1):
162-
qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1)
163-
164-
set_id_0 = qeff_model.load_adapter(adapter_id_0, "adapter_0")
165-
set_id_1 = qeff_model.load_adapter(adapter_id_1, "adapter_1")
166-
assert set_id_1 == set_id_0 + 1
167-
168-
qeff_model.load_adapter(adapter_id_1, "adapter_2")
169-
qeff_model.unload_adapter("adapter_1")
170-
171-
update_id_0 = qeff_model.get_adapter_id("adapter_0")
172-
update_id_2 = qeff_model.get_adapter_id("adapter_2")
173-
assert set_id_0 == update_id_0
174-
assert set_id_1 == update_id_2
175-
176-
with pytest.raises(KeyError):
177-
qeff_model.get_adapter_id("adapter_1")
178-
179-
180172
# test download_adapter(), load_adapter() and unload_adapter()
181173
@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[1:])
182174
def test_auto_lora_model_for_causal_lm_load_unload_adapter(base_model_name, adapter_id_0, adapter_id_1):
@@ -196,8 +188,8 @@ def test_auto_lora_model_for_causal_lm_load_unload_adapter(base_model_name, adap
196188
def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, adapter_id_0, adapter_id_1, tmp_path):
197189
qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1)
198190

199-
id_0 = qeff_model.load_adapter(adapter_id_0, "adapter_0")
200-
id_1 = qeff_model.load_adapter(adapter_id_1, "adapter_1")
191+
qeff_model.load_adapter(adapter_id_0, "adapter_0")
192+
qeff_model.load_adapter(adapter_id_1, "adapter_1")
201193

202194
# export
203195
start = perf_counter()
@@ -225,5 +217,5 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name,
225217
tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name),
226218
prompts=prompts,
227219
device_id=[0],
228-
prompt_to_lora_id_mapping=[id_0, id_1, id_0, 0],
220+
prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"],
229221
)

0 commit comments

Comments
 (0)