You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
:pretrained_name_or_path (str): Model card name from huggingface or local path to model directory.
150
+
:finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification.
150
151
:args, kwargs: Additional arguments to pass to peft.AutoPeftModelForCausalLM.
Copy file name to clipboardExpand all lines: QEfficient/peft/lora/auto.py
+52-34Lines changed: 52 additions & 34 deletions
Original file line number
Diff line number
Diff line change
@@ -29,17 +29,11 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM):
29
29
30
30
Args:
31
31
:model (nn.Module): PyTorch model
32
-
:base_model_name (str): Model card name for base model
33
-
:adapter_weights (Dict): A dictionary contains lora_name to lora_weight mapping
34
-
:adapter_configs (Dict): A dictionary contains lora_name to lora_configs mapping
35
-
:max_num_adapters (int): Total number of active adapters that to be exported and compiled
36
-
:active_adapter_to_id (Dict): A dictionary contains active adapter's lora_name to lora_id mapping
37
-
:lora_rank (int): The consistent lora rank across all active adapters
38
-
:target_modules_for_all_adapters (List[str]): The consistent set of target modules across all active adapters
32
+
:continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
39
33
40
34
.. code-block:: python
41
35
42
-
from QEfficient.lora import QEffAutoLoraModelForCausalLM
36
+
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
43
37
44
38
m = QEffAutoPeftModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
45
39
m.load_adapter("predibase/gsm8k", "gsm8k")
@@ -53,14 +47,13 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM):
assertself.max_num_adapters, "Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage"
293
+
iflen(self.active_adapter_to_id) ==0:
294
+
raiseValueError(
295
+
"Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage"
If the number of prompts cannot be divided by the ``batch_size``, the last unfulfilled batch will be dropped.
343
351
344
352
``Mandatory`` Args:
353
+
:tokenizer (PreTrainedTokenizerFast or PreTrainedTokenizer): The tokenizer used in the inference
345
354
:prompts (List[str]): List of prompts to run the execution.
346
355
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
356
+
:prompt_to_adapter_mapping (List[str]): The sequence of the adapter names will be matched with sequence of prompts and corresponding adapters will be used for the prompts."base" for base model (no adapter).
347
357
``optional`` Args:
348
358
:runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100".
349
-
:prompt_to_adapter_mapping (List[str]): A list of adapter names that maps to the prompts, specifying which adapter the prompt wants to apply. "base" for base model (no adapter).
359
+
350
360
"""
351
361
ifruntime!="AI_100":
352
362
raiseValueError("Only AI_100 runtime is supported right now via generate API")
f"Number of prompts should match number of prompt_to_adapter_mapping, got len(prompts) = {len(prompts)}, len(prompt_to_adapter_mapping) = {len(prompt_to_adapter_mapping)}"
0 commit comments