Skip to content

Commit c3c00fc

Browse files
committed
Fix base model inference index INTMAX issue
Signed-off-by: Jou-An Chen <[email protected]>
1 parent 68cdf3b commit c3c00fc

File tree

5 files changed

+28
-28
lines changed

5 files changed

+28
-28
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -654,10 +654,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
654654

655655
generated_id_current_index[decode_batch_id] += 1
656656

657-
if self.prompt_to_lora_id_mapping_decode:
658-
decode_inputs["lora_ids"][decode_batch_id] = self.prompt_to_lora_id_mapping_decode[
659-
batch_id_map[decode_batch_id]
660-
]
657+
if self.prompt_to_lora_id_mapping_decode:
658+
decode_inputs["lora_ids"][decode_batch_id] = self.prompt_to_lora_id_mapping_decode[
659+
batch_id_map[decode_batch_id]
660+
]
661661

662662
return decode_pause_time
663663

QEfficient/lora/auto.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import hashlib
99
import os
10-
import sys
1110
from pathlib import Path
1211
from typing import Any, List, Optional
1312

@@ -24,8 +23,6 @@
2423
from QEfficient.utils.constants import QEFF_MODELS_DIR
2524
from QEfficient.utils.logging_utils import logger
2625

27-
INTMAX = sys.maxsize
28-
2926

3027
class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM):
3128
"""
@@ -54,7 +51,7 @@ class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM):
5451
m.compile(num_cores=16, device_group=[0])
5552
5653
prompts=["code prompt", "math prompt", "generic"]
57-
m.generate(prompts, device_group=[0], prompt_to_lora_id_mapping=[magicoder_id,gsm8k_id,INTMAX])
54+
m.generate(prompts, device_group=[0], prompt_to_lora_id_mapping=[magicoder_id,gsm8k_id,0])
5855
5956
"""
6057

@@ -148,7 +145,7 @@ def load_adapter(self, adapter_model_id: str, adapter_name: str, **kwargs: Any):
148145

149146
# set active adapter id to current max if adapter_name is new
150147
if adapter_name not in self.active_adapter_to_id.keys():
151-
self.active_adapter_to_id[adapter_name] = self.max_num_adapters
148+
self.active_adapter_to_id[adapter_name] = self.max_num_adapters + 1 # reserve 0 for base
152149

153150
# add active adapter to set
154151
self.active_adapters.add(adapter_name)
@@ -168,7 +165,7 @@ def unload_adapter(self, adapter_name: str):
168165

169166
# renumbering of active adapter id
170167
for index, (key, value) in enumerate(self.active_adapter_to_id.items()):
171-
self.active_adapter_to_id[key] = index
168+
self.active_adapter_to_id[key] = index + 1
172169

173170
logger.warning(f"Deleting {adapter_name} from active adapters.")
174171
if self.onnx_path or self.qpc_path:
@@ -203,9 +200,9 @@ def load_adapter_weights_to_model(self):
203200
for i in range(num_hidden_layers):
204201
for target_module in self.target_modules_for_all_adapters:
205202
# stack all adapters weights
206-
a_tensor_list = list(range(self.max_num_adapters))
207-
b_tensor_list = list(range(self.max_num_adapters))
208-
c_tensor_list = list(range(self.max_num_adapters))
203+
a_tensor_list = list(range(self.max_num_adapters + 1))
204+
b_tensor_list = list(range(self.max_num_adapters + 1))
205+
c_tensor_list = list(range(self.max_num_adapters + 1))
209206

210207
for lora_name, lora_id in self.active_adapter_to_id.items():
211208
if (
@@ -232,12 +229,18 @@ def load_adapter_weights_to_model(self):
232229
dtype=torch.float16,
233230
)
234231

232+
# dummy zero tensor for base model
233+
a_tensor_list[0] = torch.zeros_like(a_tensor_list[1])
234+
b_tensor_list[0] = torch.zeros_like(b_tensor_list[1])
235+
c_tensor_list[0] = torch.zeros_like(c_tensor_list[1])
236+
237+
# stack weight tensors
235238
stacked_lora_A = (
236239
torch.stack(a_tensor_list, dim=0).unsqueeze(1).transpose(2, 3)
237-
) # <num_adapters, 1, in_feature, r>
240+
) # <num_loras, 1, in_feature, r>
238241
stacked_lora_B = (
239242
torch.stack(b_tensor_list, dim=0).unsqueeze(1).transpose(2, 3)
240-
) # <num_adapters, 1, r, out_feature>
243+
) # <num_loras, 1, r, out_feature>
241244
stacked_lora_C = (
242245
torch.stack(c_tensor_list, dim=0).unsqueeze(1).unsqueeze(2).unsqueeze(3)
243246
) # <num_loras, 1, 1, 1>
@@ -308,6 +311,7 @@ def export(self, **kwargs) -> str:
308311
export_dir = kwargs.get("export_dir", None)
309312

310313
# obtain all necessary information to initialize the model
314+
assert self.max_num_adapters, "Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage"
311315
self.init_adapter_model()
312316

313317
assert self.is_transformed, "Please first run transform on the QEFFAutoModelForCausalLM object"
@@ -411,7 +415,7 @@ def export_and_compile(
411415
def run_cloud_ai_100(self, prompts: List[str], device_id: List[int] = None, **kwargs):
412416
assert isinstance(self.qpc_path, str), "Please run compile API first!"
413417
generation_len = kwargs.pop("generation_len", None)
414-
default_mapping = [INTMAX for _ in range(len(prompts))]
418+
default_mapping = [0 for _ in range(len(prompts))]
415419
prompt_to_lora_id_mapping = kwargs.pop("prompt_to_lora_id_mapping", default_mapping)
416420
return QEfficient.cloud_ai_100_exec_kv(
417421
self.tokenizer,

QEfficient/lora/layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ def multilora_init(self, lora_rank, max_num_adapters):
2121
self.lora_rank = lora_rank
2222

2323
self.lora_weight_A = nn.Parameter(
24-
self.weight.new_zeros(self.max_num_adapters, 1, self.in_features, self.lora_rank)
24+
self.weight.new_zeros(self.max_num_adapters + 1, 1, self.in_features, self.lora_rank)
2525
)
2626
self.lora_weight_A.requires_grad = False
2727
self.lora_weight_B = nn.Parameter(
28-
self.weight.new_zeros(self.max_num_adapters, 1, self.lora_rank, self.out_features)
28+
self.weight.new_zeros(self.max_num_adapters + 1, 1, self.lora_rank, self.out_features)
2929
)
3030
self.lora_weight_B.requires_grad = False
31-
self.lora_weight_C = torch.full((self.max_num_adapters, 1, 1, 1), 1.0, dtype=torch.float)
31+
self.lora_weight_C = torch.full((self.max_num_adapters + 1, 1, 1, 1), 1.0, dtype=torch.float)
3232

3333
nn.init.kaiming_uniform_(self.lora_weight_A, a=math.sqrt(5))
3434
nn.init.zeros_(self.lora_weight_B)

examples/lora_models.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,9 @@
77

88
## This example works on continuous batching with different lora adapters in the same batch ##
99

10-
import sys
1110

1211
from QEfficient import QEffAutoLoraModelForCausalLM
1312

14-
INTMAX = sys.maxsize
15-
1613
base_model_name = "mistralai/Mistral-7B-v0.1"
1714
seq_len = 128
1815
ctx_len = 256
@@ -67,7 +64,7 @@
6764
# prompt_to_lora_id_mapping is a list of lora_id of which the size matches num of prompts
6865
# and is a one-on-one mapping for the prompt-to-loraid
6966
# e.g., prompt_to_lora_id_mapping = [{adapter_id_0}, {adapter_id_1}, {adapter_id_0}, {adapter_id_1}, ...]
70-
# setting INTMAX means using base model
67+
# setting 0 means using base model
7168
prompts = [
7269
"""Please answer the following question: James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?\n\nAnswer:""",
7370
"""The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: Harvard shrank its insect-inspired microrobot to the size of a penny\n\nContent:""",
@@ -81,9 +78,11 @@
8178
qeff_model.generate(
8279
prompts,
8380
device_group,
84-
prompt_to_lora_id_mapping=[gsm8k_id, tldr_id, gsm8k_id, INTMAX, gsm8k_id, tldr_id, gsm8k_id, tldr_id],
81+
prompt_to_lora_id_mapping=[0, 0, 0, 0, 0, 0, 0, 0],
8582
)
8683

84+
# [gsm8k_id, tldr_id, gsm8k_id, 0, gsm8k_id, tldr_id, gsm8k_id, tldr_id]
85+
8786
"""
8887
expected response:
8988

tests/lora/test_lora_model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7-
import sys
87
from pathlib import Path
98
from time import perf_counter
109

@@ -15,8 +14,6 @@
1514

1615
from QEfficient import QEffAutoLoraModelForCausalLM
1716

18-
INTMAX = sys.maxsize
19-
2017
configs = [
2118
pytest.param(
2219
AutoConfig.for_model(
@@ -226,4 +223,4 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name,
226223

227224
# test generate
228225
prompts = ["hello!", "hi", "hello, my name is", "hey"]
229-
qeff_model.generate(prompts, [0], prompt_to_lora_id_mapping=[id_0, id_1, id_0, INTMAX])
226+
qeff_model.generate(prompts, [0], prompt_to_lora_id_mapping=[id_0, id_1, id_0, 0])

0 commit comments

Comments
 (0)