Skip to content

Commit f49e743

Browse files
committed
Fix test
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 47c4ac6 commit f49e743

File tree

1 file changed

+41
-37
lines changed

1 file changed

+41
-37
lines changed

tests/lora/test_llama_tp.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,27 @@
1111
from vllm.lora.request import LoRARequest
1212
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
1313

14-
from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test
15-
16-
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
14+
PROMPT_TEMPLATE = """<|eot_id|><|start_header_id|>user<|end_header_id|>
15+
I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
16+
"
17+
##Instruction:
18+
candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key.
19+
Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key.
20+
The People_ID of candidate is the foreign key of People_ID of people.
21+
###Input:
22+
{context}
23+
###Response:<|eot_id|><|start_header_id|>assistant<|end_header_id|>
24+
""" # noqa: E501
1725

1826
EXPECTED_LORA_OUTPUT = [
19-
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
20-
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ",
21-
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
22-
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
23-
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ",
24-
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' ", # noqa: E501
27+
"SELECT count(*) FROM candidate",
28+
"SELECT count(*) FROM candidate",
29+
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
30+
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
2531
]
2632

33+
MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct"
34+
2735

2836
def do_sample(
2937
llm: vllm.LLM,
@@ -32,18 +40,19 @@ def do_sample(
3240
tensorizer_config_dict: dict | None = None,
3341
) -> list[str]:
3442
prompts = [
35-
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
36-
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
37-
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
38-
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
39-
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
40-
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]", # noqa: E501
43+
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
44+
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
45+
PROMPT_TEMPLATE.format(
46+
context="Which poll resource provided the most number of candidate information?" # noqa: E501
47+
),
48+
PROMPT_TEMPLATE.format(
49+
context="Return the poll resource associated with the most candidates."
50+
),
4151
]
4252

4353
sampling_params = vllm.SamplingParams(
44-
temperature=0, max_tokens=256, skip_special_tokens=False, stop=["[/assistant]"]
54+
temperature=0, max_tokens=64, stop=["<|im_end|>"]
4555
)
46-
4756
if tensorizer_config_dict is not None:
4857
outputs = llm.generate(
4958
prompts,
@@ -75,13 +84,15 @@ def do_sample(
7584
return generated_texts
7685

7786

78-
def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None = None):
87+
def generate_and_test(
88+
llm, llama32_lora_files, tensorizer_config_dict: dict | None = None
89+
):
7990
print("lora adapter created")
8091
print("lora 1")
8192
assert (
8293
do_sample(
8394
llm,
84-
sql_lora_files,
95+
llama32_lora_files,
8596
tensorizer_config_dict=tensorizer_config_dict,
8697
lora_id=1,
8798
)
@@ -92,7 +103,7 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
92103
assert (
93104
do_sample(
94105
llm,
95-
sql_lora_files,
106+
llama32_lora_files,
96107
tensorizer_config_dict=tensorizer_config_dict,
97108
lora_id=2,
98109
)
@@ -102,12 +113,10 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
102113
print("removing lora")
103114

104115

105-
@create_new_process_for_each_test()
106116
@pytest.mark.parametrize("cudagraph_specialize_lora", [True, False])
107-
def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool):
117+
def test_llama_lora(llama32_lora_files, cudagraph_specialize_lora: bool):
108118
llm = vllm.LLM(
109119
MODEL_PATH,
110-
tokenizer=sql_lora_files,
111120
enable_lora=True,
112121
# also test odd max_num_seqs
113122
max_num_seqs=13,
@@ -116,39 +125,35 @@ def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool):
116125
cudagraph_specialize_lora=cudagraph_specialize_lora,
117126
),
118127
)
119-
generate_and_test(llm, sql_lora_files)
128+
generate_and_test(llm, llama32_lora_files)
120129

121130

122-
@multi_gpu_test(num_gpus=4)
123-
def test_llama_lora_tp4(sql_lora_files):
131+
def test_llama_lora_tp4(llama32_lora_files):
124132
llm = vllm.LLM(
125133
MODEL_PATH,
126-
tokenizer=sql_lora_files,
127134
enable_lora=True,
128135
max_num_seqs=16,
129136
max_loras=4,
130137
tensor_parallel_size=4,
131138
)
132-
generate_and_test(llm, sql_lora_files)
139+
generate_and_test(llm, llama32_lora_files)
133140

134141

135-
@multi_gpu_test(num_gpus=4)
136-
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
142+
def test_llama_lora_tp4_fully_sharded_loras(llama32_lora_files):
137143
llm = vllm.LLM(
138144
MODEL_PATH,
139-
tokenizer=sql_lora_files,
140145
enable_lora=True,
141146
max_num_seqs=16,
142147
max_loras=4,
143148
tensor_parallel_size=4,
144149
fully_sharded_loras=True,
145150
)
146-
generate_and_test(llm, sql_lora_files)
151+
generate_and_test(llm, llama32_lora_files)
147152

148153

149-
@multi_gpu_test(num_gpus=2)
150154
def test_tp2_serialize_and_deserialize_lora(
151-
tmp_path, sql_lora_files, sql_lora_huggingface_id
155+
tmp_path,
156+
llama32_lora_files,
152157
):
153158
# Run the tensorizing of the LoRA adapter and the model in a subprocess
154159
# to guarantee cleanup
@@ -157,7 +162,7 @@ def test_tp2_serialize_and_deserialize_lora(
157162
model_name = "model-rank-%03d.tensors"
158163

159164
model_ref = MODEL_PATH
160-
lora_path = sql_lora_huggingface_id
165+
lora_path = llama32_lora_files
161166
suffix = "test"
162167
try:
163168
result = subprocess.run(
@@ -195,7 +200,6 @@ def test_tp2_serialize_and_deserialize_lora(
195200

196201
loaded_llm = LLM(
197202
model=model_ref,
198-
tokenizer=sql_lora_files,
199203
load_format="tensorizer",
200204
enable_lora=True,
201205
enforce_eager=True,
@@ -211,7 +215,7 @@ def test_tp2_serialize_and_deserialize_lora(
211215
print("lora 1")
212216
assert (
213217
do_sample(
214-
loaded_llm, sql_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1
218+
loaded_llm, llama32_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1
215219
)
216220
== EXPECTED_LORA_OUTPUT
217221
)

0 commit comments

Comments
 (0)