Skip to content

Commit 892dfaf

Browse files
committed
Enable init from QEffAutoPeftModelForCausalLM with finite_adapters flag
Signed-off-by: Jou-An Chen <[email protected]>
1 parent 3f275c3 commit 892dfaf

File tree

3 files changed

+29
-11
lines changed

3 files changed

+29
-11
lines changed

QEfficient/peft/auto.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList
1818
from transformers.generation.streamers import BaseStreamer
1919

20+
from QEfficient import QEffAutoLoraModelForCausalLM
2021
from QEfficient.base.modeling_qeff import QEFFBaseModel
2122
from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform
2223
from QEfficient.base.pytorch_transforms import PytorchTransform
@@ -152,7 +153,13 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
152153
if kwargs.get("use_cache") is False:
153154
warnings.warn("Overriding to use_cache=True")
154155
kwargs["use_cache"] = True
155-
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
156+
157+
if kwargs.pop("finite_adapters", False): # initialize through finite_adapters class
158+
obj = QEffAutoLoraModelForCausalLM.from_pretrained(
159+
pretrained_model_name_or_path=pretrained_name_or_path, **kwargs
160+
)
161+
else:
162+
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
156163
return obj
157164

158165
def export(self, export_dir: Optional[str] = None) -> str:

examples/lora_models.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
## This example works on continuous batching with different lora adapters in the same batch ##
99

10-
from QEfficient import QEffAutoLoraModelForCausalLM
10+
from QEfficient import QEffAutoPeftModelForCausalLM
1111
from QEfficient.utils import load_hf_tokenizer
1212

1313
base_model_name = "mistralai/Mistral-7B-v0.1"
@@ -18,17 +18,14 @@
1818

1919
## STEP 1 -- init base model
2020

21-
# **Option1**: Download model weights from hugging face & Init it with QEffAuto model to apply QEff transforms
22-
# model_hf = AutoModelForCausalLM.from_pretrained(base_model_name)
23-
# qeff_model = QEffAutoLoraModelForCausalLM(model_hf, continuous_batching=True)
24-
25-
# **Option2**: Initialize the model using from_pretrained() method
26-
qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(
27-
pretrained_model_name_or_path=base_model_name, continuous_batching=True
21+
qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained(
22+
pretrained_name_or_path=base_model_name, continuous_batching=True, finite_adapters=True
2823
)
2924

3025
# (alternative) non-cb initialization
31-
# qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(pretrained_model_name_or_path=base_model_name, continuous_batching=False)
26+
# qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained(
27+
# pretrained_name_or_path=base_model_name, continuous_batching=False, finite_adapters=True
28+
# )
3229

3330
## STEP 2 -- load adapter adapter
3431
adapter_id_gsm8k = qeff_model.load_adapter("predibase/gsm8k", "gsm8k")

tests/lora/test_lora_model.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from peft import LoraConfig
1313
from transformers import AutoConfig, AutoModelForCausalLM
1414

15-
from QEfficient import QEffAutoLoraModelForCausalLM
15+
from QEfficient import QEffAutoLoraModelForCausalLM, QEffAutoPeftModelForCausalLM
1616
from QEfficient.utils import load_hf_tokenizer
1717

1818
configs = [
@@ -74,6 +74,20 @@ def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_
7474
assert len(qeff_model.active_adapter_to_id) == 0
7575

7676

77+
# test peft model initialization using from_pretrained approach
78+
@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples)
79+
def test_auto_peft_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1):
80+
qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained(
81+
pretrained_name_or_path=base_model_name, finite_adapters=True
82+
)
83+
84+
assert qeff_model.base_model_name == base_model_name
85+
assert len(qeff_model.adapter_weights) == 0
86+
assert len(qeff_model.adapter_configs) == 0
87+
assert qeff_model.max_num_adapters == 0
88+
assert len(qeff_model.active_adapter_to_id) == 0
89+
90+
7791
# test the init assertion for models that are not supported
7892
@pytest.mark.parametrize("base_model_name", ["distilbert/distilgpt2"])
7993
def test_auto_lora_model_for_causal_lm_init_from_unsupported_model(base_model_name):

0 commit comments

Comments
 (0)