Skip to content
39 changes: 39 additions & 0 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def cloud_ai_100_exec_kv(
stream: bool = True,
write_io_dir: Optional[str] = None,
automation=False,
prompt_to_lora_id_mapping: Optional[List[int]] = None,
):
"""
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
Expand Down Expand Up @@ -277,6 +278,7 @@ def cloud_ai_100_exec_kv(
stream=stream,
write_io_dir=write_io_dir,
full_batch_size=full_batch_size,
prompt_to_lora_id_mapping=prompt_to_lora_id_mapping,
)
if full_batch_size is None:
exec_info = [
Expand Down Expand Up @@ -313,6 +315,7 @@ def __init__(
qpc_path: str,
prompt: List[str],
full_batch_size: Optional[int] = None,
prompt_to_lora_id_mapping: Optional[List[int]] = None,
ctx_len: Optional[int] = None,
generation_len: Optional[int] = None,
device_id: Optional[List[int]] = None,
Expand Down Expand Up @@ -342,6 +345,16 @@ def __init__(
full_batch_size if full_batch_size else self._fetch_full_batch_size()
) # Check and fetch full batch size if CB is enabled

if prompt_to_lora_id_mapping:
self.prompt_to_lora_id_mapping_prefill = deque(prompt_to_lora_id_mapping)
if self.full_batch_size:
self.prompt_to_lora_id_mapping_decode = prompt_to_lora_id_mapping
else:
self.prompt_to_lora_id_mapping_decode = deque(prompt_to_lora_id_mapping)
else:
self.prompt_to_lora_id_mapping_prefill = None
self.prompt_to_lora_id_mapping_decode = None

self.set_tokenizer_params() # set tokenizer params

# Initialize the storage variables.
Expand Down Expand Up @@ -461,6 +474,16 @@ def prepare_decode_inputs(self):
if self.batch_index is not None:
decode_inputs["batch_index"] = self.batch_index

if self.prompt_to_lora_id_mapping_decode:
if self.full_batch_size:
first_batch_lora_ids = [self.prompt_to_lora_id_mapping_decode[i] for i in range(self.full_batch_size)]
decode_inputs["lora_ids"] = np.array(first_batch_lora_ids, dtype=np.int64).reshape(
self.full_batch_size, 1
)
else:
batch_lora_ids = [self.prompt_to_lora_id_mapping_decode.popleft() for i in range(self.batch_size)]
decode_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)

return decode_inputs

def _update_decode_input(self, outputs, position_ids, generation_len, decode_batch_id=None):
Expand Down Expand Up @@ -549,6 +572,15 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
if decode_batch_id is not None:
inputs["batch_index"] = decode_batch_id

if self.prompt_to_lora_id_mapping_prefill:
if self.full_batch_size:
inputs["lora_ids"] = np.array(self.prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64).reshape(
1, 1
)
else:
batch_lora_ids = [self.prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)

for i in range(num_chunks):
chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][
Expand Down Expand Up @@ -625,6 +657,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):

self.session.set_buffers({"logits": logits_out_placeholder})
decode_pause_time += perf_counter() - start

if self.prompt_to_lora_id_mapping_decode:
decode_inputs["lora_ids"][decode_batch_id] = self.prompt_to_lora_id_mapping_decode[
batch_id_map[decode_batch_id]
]

else:
current_decode_ongoing[decode_batch_id] = False
else:
Expand All @@ -636,6 +674,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
)

generated_id_current_index[decode_batch_id] += 1

return decode_pause_time

def run_decode(self, decode_inputs, generation_len):
Expand Down
26 changes: 24 additions & 2 deletions QEfficient/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import numpy as np
import torch
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM, load_peft_weights
from peft import AutoPeftModelForCausalLM, PeftConfig, PeftModelForCausalLM, load_peft_weights
from torch import nn
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList
from transformers.generation.streamers import BaseStreamer
Expand All @@ -21,6 +21,7 @@
from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform
from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform
from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform
Expand All @@ -38,6 +39,7 @@ class QEffAutoPeftModelForCausalLM(QEFFBaseModel):

Args:
:model (nn.Module): PyTorch model
:finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification.

.. code-block:: python

Expand Down Expand Up @@ -80,6 +82,9 @@ def __init__(self, model: nn.Module):
for adapter_name in model.peft_config
}

def __repr__(self) -> str:
return self.__class__.__name__ + "\n" + self.model.__repr__()

@property
def model_name(self) -> str:
mname = self.model.get_base_model().__class__.__name__ + "-lora"
Expand Down Expand Up @@ -145,14 +150,31 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
"""
Args:
:pretrained_name_or_path (str): Model card name from huggingface or local path to model directory.
:finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification.
:adapter_name (str): Name used to identify loaded adapter.
:args, kwargs: Additional arguments to pass to peft.AutoPeftModelForCausalLM.
"""
if kwargs.get("full_batch_size"):
raise NotImplementedError("Continuous batching currently not supported for PEFT models")
if kwargs.get("use_cache") is False:
warnings.warn("Overriding to use_cache=True")
kwargs["use_cache"] = True
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)

if kwargs.pop("finite_adapters", False): # initialize through finite_adapters class
obj = QEffAutoLoraModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=PeftConfig.from_pretrained(
pretrained_name_or_path
).base_model_name_or_path,
**kwargs,
)
if adapter_name := kwargs.pop("adapter_name", None):
obj.load_adapter(pretrained_name_or_path, adapter_name=adapter_name)
return obj
if len(args) == 0 or not isinstance(list(args)[0], str):
raise TypeError("Required adapter name argument in string format")
obj.load_adapter(pretrained_name_or_path, list(args)[0])
else:
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
return obj

def export(self, export_dir: Optional[str] = None) -> str:
Expand Down
12 changes: 12 additions & 0 deletions QEfficient/peft/lora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

from QEfficient.peft.lora.auto import QEffAutoLoraModelForCausalLM

__all__ = [
"QEffAutoLoraModelForCausalLM",
]
Loading
Loading