You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: QEfficient/lora/auto.py
+47-29Lines changed: 47 additions & 29 deletions
Original file line number
Diff line number
Diff line change
@@ -29,13 +29,7 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM):
29
29
30
30
Args:
31
31
:model (nn.Module): PyTorch model
32
-
:base_model_name (str): Model card name for base model
33
-
:adapter_weights (Dict): A dictionary contains lora_name to lora_weight mapping
34
-
:adapter_configs (Dict): A dictionary contains lora_name to lora_configs mapping
35
-
:max_num_adapters (int): Total number of active adapters that to be exported and compiled
36
-
:active_adapter_to_id (Dict): A dictionary contains active adapter's lora_name to lora_id mapping
37
-
:lora_rank (int): The consistent lora rank across all active adapters
38
-
:target_modules_for_all_adapters (List[str]): The consistent set of target modules across all active adapters
32
+
:continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
assertself.max_num_adapters, "Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage"
293
+
iflen(self.active_adapter_to_id) ==0:
294
+
raiseValueError(
295
+
"Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage"
If the number of prompts cannot be divided by the ``batch_size``, the last unfulfilled batch will be dropped.
343
351
344
352
``Mandatory`` Args:
353
+
:tokenizer (PreTrainedTokenizerFast or PreTrainedTokenizer): The tokenizer used in the inference
345
354
:prompts (List[str]): List of prompts to run the execution.
346
355
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
356
+
:prompt_to_adapter_mapping (List[str]): The sequence of the adapter names will be matched with sequence of prompts and corresponding adapters will be used for the prompts."base" for base model (no adapter).
347
357
``optional`` Args:
348
358
:runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100".
349
-
:prompt_to_adapter_mapping (List[str]): A list of adapter names that maps to the prompts, specifying which adapter the prompt wants to apply. "base" for base model (no adapter).
359
+
350
360
"""
351
361
ifruntime!="AI_100":
352
362
raiseValueError("Only AI_100 runtime is supported right now via generate API")
f"Number of prompts should match number of prompt_to_adapter_mapping, got len(prompts) = {len(prompts)}, len(prompt_to_adapter_mapping) = {len(prompt_to_adapter_mapping)}"
## STEP 4 -- run inference on the generate function
59
63
prompts= [
60
64
"""Please answer the following question: James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?\n\nAnswer:""",
0 commit comments