Skip to content

Commit 0c7a832

Browse files
g-eojjimpang
authored andcommitted
[Bugfix][CI/Build] Test prompt adapters in openai entrypoint tests (vllm-project#6419)
1 parent 9feb0a4 commit 0c7a832

File tree

2 files changed

+54
-31
lines changed

2 files changed

+54
-31
lines changed

tests/entrypoints/openai/test_completion.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717

1818
# any model with a chat template should work here
1919
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
20-
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
21-
# generation quality here
20+
# technically these adapters use a different base model,
21+
# but we're not testing generation quality here
2222
LORA_NAME = "typeof/zephyr-7b-beta-lora"
23+
PA_NAME = "swapnilbp/llama_tweet_ptune"
24+
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
25+
# need to change to match the prompt adapter
26+
PA_NUM_VIRTUAL_TOKENS = 8
2327

2428

2529
@pytest.fixture(scope="module")
@@ -28,7 +32,12 @@ def zephyr_lora_files():
2832

2933

3034
@pytest.fixture(scope="module")
31-
def server(zephyr_lora_files):
35+
def zephyr_pa_files():
36+
return snapshot_download(repo_id=PA_NAME)
37+
38+
39+
@pytest.fixture(scope="module")
40+
def server(zephyr_lora_files, zephyr_pa_files):
3241
with RemoteOpenAIServer([
3342
"--model",
3443
MODEL_NAME,
@@ -37,8 +46,10 @@ def server(zephyr_lora_files):
3746
"bfloat16",
3847
"--max-model-len",
3948
"8192",
49+
"--max-num-seqs",
50+
"128",
4051
"--enforce-eager",
41-
# lora config below
52+
# lora config
4253
"--enable-lora",
4354
"--lora-modules",
4455
f"zephyr-lora={zephyr_lora_files}",
@@ -47,7 +58,14 @@ def server(zephyr_lora_files):
4758
"64",
4859
"--max-cpu-loras",
4960
"2",
50-
"--max-num-seqs",
61+
# pa config
62+
"--enable-prompt-adapter",
63+
"--prompt-adapters",
64+
f"zephyr-pa={zephyr_pa_files}",
65+
f"zephyr-pa2={zephyr_pa_files}",
66+
"--max-prompt-adapters",
67+
"2",
68+
"--max-prompt-adapter-token",
5169
"128",
5270
]) as remote_server:
5371
yield remote_server
@@ -60,11 +78,14 @@ def client(server):
6078

6179
@pytest.mark.asyncio
6280
@pytest.mark.parametrize(
63-
# first test base model, then test loras
64-
"model_name",
65-
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
81+
# first test base model, then test loras, then test prompt adapters
82+
"model_name,num_virtual_tokens",
83+
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
84+
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
85+
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
6686
)
67-
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
87+
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
88+
num_virtual_tokens: int):
6889
completion = await client.completions.create(model=model_name,
6990
prompt="Hello, my name is",
7091
max_tokens=5,
@@ -77,28 +98,30 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
7798
assert len(choice.text) >= 5
7899
assert choice.finish_reason == "length"
79100
assert completion.usage == openai.types.CompletionUsage(
80-
completion_tokens=5, prompt_tokens=6, total_tokens=11)
101+
completion_tokens=5,
102+
prompt_tokens=6 + num_virtual_tokens,
103+
total_tokens=11 + num_virtual_tokens)
81104

82105
# test using token IDs
83106
completion = await client.completions.create(
84-
model=MODEL_NAME,
107+
model=model_name,
85108
prompt=[0, 0, 0, 0, 0],
86109
max_tokens=5,
87110
temperature=0.0,
88111
)
89-
assert len(completion.choices[0].text) >= 5
112+
assert len(completion.choices[0].text) >= 1
90113

91114

92115
@pytest.mark.asyncio
93116
@pytest.mark.parametrize(
94-
# first test base model, then test loras
117+
# first test base model, then test loras, then test prompt adapters
95118
"model_name",
96-
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
119+
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
97120
)
98121
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
99122
# test using token IDs
100123
completion = await client.completions.create(
101-
model=MODEL_NAME,
124+
model=model_name,
102125
prompt=[0, 0, 0, 0, 0],
103126
max_tokens=5,
104127
temperature=0.0,
@@ -110,14 +133,14 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
110133

111134
@pytest.mark.asyncio
112135
@pytest.mark.parametrize(
113-
# just test 1 lora hereafter
136+
# just test 1 lora and 1 pa hereafter
114137
"model_name",
115-
[MODEL_NAME, "zephyr-lora"],
138+
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
116139
)
117140
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
118141
# test using token IDs
119142
completion = await client.completions.create(
120-
model=MODEL_NAME,
143+
model=model_name,
121144
prompt=[0, 0, 0, 0, 0],
122145
max_tokens=5,
123146
temperature=0.0,
@@ -133,12 +156,12 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
133156
@pytest.mark.asyncio
134157
@pytest.mark.parametrize(
135158
"model_name",
136-
[MODEL_NAME, "zephyr-lora"],
159+
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
137160
)
138161
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
139162
# test using token IDs
140163
completion = await client.completions.create(
141-
model=MODEL_NAME,
164+
model=model_name,
142165
prompt=[0, 0, 0, 0, 0],
143166
max_tokens=5,
144167
temperature=0.0,
@@ -154,15 +177,15 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
154177
@pytest.mark.asyncio
155178
@pytest.mark.parametrize(
156179
"model_name",
157-
[MODEL_NAME, "zephyr-lora"],
180+
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
158181
)
159182
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
160183
model_name: str):
161184

162185
with pytest.raises(
163186
(openai.BadRequestError, openai.APIError)): # test using token IDs
164187
await client.completions.create(
165-
model=MODEL_NAME,
188+
model=model_name,
166189
prompt=[0, 0, 0, 0, 0],
167190
max_tokens=5,
168191
temperature=0.0,
@@ -174,7 +197,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
174197
with pytest.raises(
175198
(openai.BadRequestError, openai.APIError)): # test using token IDs
176199
stream = await client.completions.create(
177-
model=MODEL_NAME,
200+
model=model_name,
178201
prompt=[0, 0, 0, 0, 0],
179202
max_tokens=5,
180203
temperature=0.0,
@@ -199,7 +222,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
199222
@pytest.mark.asyncio
200223
@pytest.mark.parametrize(
201224
"model_name",
202-
[MODEL_NAME, "zephyr-lora"],
225+
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
203226
)
204227
async def test_completion_streaming(client: openai.AsyncOpenAI,
205228
model_name: str):
@@ -233,7 +256,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
233256
@pytest.mark.asyncio
234257
@pytest.mark.parametrize(
235258
"model_name",
236-
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
259+
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
237260
)
238261
async def test_completion_stream_options(client: openai.AsyncOpenAI,
239262
model_name: str):
@@ -369,9 +392,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
369392

370393
@pytest.mark.asyncio
371394
@pytest.mark.parametrize(
372-
# just test 1 lora hereafter
373395
"model_name",
374-
[MODEL_NAME, "zephyr-lora"],
396+
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
375397
)
376398
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
377399
# test both text and token IDs
@@ -623,7 +645,7 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
623645
)
624646
async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
625647
base_url = str(client.base_url)[:-3].strip("/")
626-
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
648+
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
627649

628650
for add_special in [False, True]:
629651
prompt = "This is a test prompt."
@@ -650,7 +672,7 @@ async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
650672
)
651673
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
652674
base_url = str(client.base_url)[:-3]
653-
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
675+
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
654676

655677
prompt = "This is a test prompt."
656678
tokens = tokenizer.encode(prompt, add_special_tokens=False)

vllm/entrypoints/openai/serving_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import pathlib
23
from dataclasses import dataclass
34
from http import HTTPStatus
45
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -74,8 +75,8 @@ def __init__(
7475
self.prompt_adapter_requests = []
7576
if prompt_adapters is not None:
7677
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
77-
with open(f"./{prompt_adapter.local_path}"
78-
f"/adapter_config.json") as f:
78+
with pathlib.Path(prompt_adapter.local_path,
79+
"adapter_config.json").open() as f:
7980
adapter_config = json.load(f)
8081
num_virtual_tokens = adapter_config["num_virtual_tokens"]
8182
self.prompt_adapter_requests.append(

0 commit comments

Comments
 (0)