1717
1818# any model with a chat template should work here
1919MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
20- # technically this needs Mistral-7B-v0.1 as base, but we're not testing
21- # generation quality here
20+ # technically these adapters use a different base model,
21+ # but we're not testing generation quality here
2222LORA_NAME = "typeof/zephyr-7b-beta-lora"
23+ PA_NAME = "swapnilbp/llama_tweet_ptune"
24+ # if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
25+ # need to change to match the prompt adapter
26+ PA_NUM_VIRTUAL_TOKENS = 8
2327
2428
2529@pytest .fixture (scope = "module" )
@@ -28,7 +32,12 @@ def zephyr_lora_files():
2832
2933
3034@pytest .fixture (scope = "module" )
31- def server (zephyr_lora_files ):
35+ def zephyr_pa_files ():
36+ return snapshot_download (repo_id = PA_NAME )
37+
38+
39+ @pytest .fixture (scope = "module" )
40+ def server (zephyr_lora_files , zephyr_pa_files ):
3241 with RemoteOpenAIServer ([
3342 "--model" ,
3443 MODEL_NAME ,
@@ -37,8 +46,10 @@ def server(zephyr_lora_files):
3746 "bfloat16" ,
3847 "--max-model-len" ,
3948 "8192" ,
49+ "--max-num-seqs" ,
50+ "128" ,
4051 "--enforce-eager" ,
41- # lora config below
52+ # lora config
4253 "--enable-lora" ,
4354 "--lora-modules" ,
4455 f"zephyr-lora={ zephyr_lora_files } " ,
@@ -47,7 +58,14 @@ def server(zephyr_lora_files):
4758 "64" ,
4859 "--max-cpu-loras" ,
4960 "2" ,
50- "--max-num-seqs" ,
61+ # pa config
62+ "--enable-prompt-adapter" ,
63+ "--prompt-adapters" ,
64+ f"zephyr-pa={ zephyr_pa_files } " ,
65+ f"zephyr-pa2={ zephyr_pa_files } " ,
66+ "--max-prompt-adapters" ,
67+ "2" ,
68+ "--max-prompt-adapter-token" ,
5169 "128" ,
5270 ]) as remote_server :
5371 yield remote_server
@@ -60,11 +78,14 @@ def client(server):
6078
6179@pytest .mark .asyncio
6280@pytest .mark .parametrize (
63- # first test base model, then test loras
64- "model_name" ,
65- [MODEL_NAME , "zephyr-lora" , "zephyr-lora2" ],
81+ # first test base model, then test loras, then test prompt adapters
82+ "model_name,num_virtual_tokens" ,
83+ [(MODEL_NAME , 0 ), ("zephyr-lora" , 0 ), ("zephyr-lora2" , 0 ),
84+ ("zephyr-pa" , PA_NUM_VIRTUAL_TOKENS ),
85+ ("zephyr-pa2" , PA_NUM_VIRTUAL_TOKENS )],
6686)
67- async def test_single_completion (client : openai .AsyncOpenAI , model_name : str ):
87+ async def test_single_completion (client : openai .AsyncOpenAI , model_name : str ,
88+ num_virtual_tokens : int ):
6889 completion = await client .completions .create (model = model_name ,
6990 prompt = "Hello, my name is" ,
7091 max_tokens = 5 ,
@@ -77,28 +98,30 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
7798 assert len (choice .text ) >= 5
7899 assert choice .finish_reason == "length"
79100 assert completion .usage == openai .types .CompletionUsage (
80- completion_tokens = 5 , prompt_tokens = 6 , total_tokens = 11 )
101+ completion_tokens = 5 ,
102+ prompt_tokens = 6 + num_virtual_tokens ,
103+ total_tokens = 11 + num_virtual_tokens )
81104
82105 # test using token IDs
83106 completion = await client .completions .create (
84- model = MODEL_NAME ,
107+ model = model_name ,
85108 prompt = [0 , 0 , 0 , 0 , 0 ],
86109 max_tokens = 5 ,
87110 temperature = 0.0 ,
88111 )
89- assert len (completion .choices [0 ].text ) >= 5
112+ assert len (completion .choices [0 ].text ) >= 1
90113
91114
92115@pytest .mark .asyncio
93116@pytest .mark .parametrize (
94- # first test base model, then test loras
117+ # first test base model, then test loras, then test prompt adapters
95118 "model_name" ,
96- [MODEL_NAME , "zephyr-lora" , "zephyr-lora2" ],
119+ [MODEL_NAME , "zephyr-lora" , "zephyr-lora2" , "zephyr-pa" , "zephyr-pa2" ],
97120)
98121async def test_no_logprobs (client : openai .AsyncOpenAI , model_name : str ):
99122 # test using token IDs
100123 completion = await client .completions .create (
101- model = MODEL_NAME ,
124+ model = model_name ,
102125 prompt = [0 , 0 , 0 , 0 , 0 ],
103126 max_tokens = 5 ,
104127 temperature = 0.0 ,
@@ -110,14 +133,14 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
110133
111134@pytest .mark .asyncio
112135@pytest .mark .parametrize (
113- # just test 1 lora hereafter
136+ # just test 1 lora and 1 pa hereafter
114137 "model_name" ,
115- [MODEL_NAME , "zephyr-lora" ],
138+ [MODEL_NAME , "zephyr-lora" , "zephyr-pa" ],
116139)
117140async def test_zero_logprobs (client : openai .AsyncOpenAI , model_name : str ):
118141 # test using token IDs
119142 completion = await client .completions .create (
120- model = MODEL_NAME ,
143+ model = model_name ,
121144 prompt = [0 , 0 , 0 , 0 , 0 ],
122145 max_tokens = 5 ,
123146 temperature = 0.0 ,
@@ -133,12 +156,12 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
133156@pytest .mark .asyncio
134157@pytest .mark .parametrize (
135158 "model_name" ,
136- [MODEL_NAME , "zephyr-lora" ],
159+ [MODEL_NAME , "zephyr-lora" , "zephyr-pa" ],
137160)
138161async def test_some_logprobs (client : openai .AsyncOpenAI , model_name : str ):
139162 # test using token IDs
140163 completion = await client .completions .create (
141- model = MODEL_NAME ,
164+ model = model_name ,
142165 prompt = [0 , 0 , 0 , 0 , 0 ],
143166 max_tokens = 5 ,
144167 temperature = 0.0 ,
@@ -154,15 +177,15 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
154177@pytest .mark .asyncio
155178@pytest .mark .parametrize (
156179 "model_name" ,
157- [MODEL_NAME , "zephyr-lora" ],
180+ [MODEL_NAME , "zephyr-lora" , "zephyr-pa" ],
158181)
159182async def test_too_many_completion_logprobs (client : openai .AsyncOpenAI ,
160183 model_name : str ):
161184
162185 with pytest .raises (
163186 (openai .BadRequestError , openai .APIError )): # test using token IDs
164187 await client .completions .create (
165- model = MODEL_NAME ,
188+ model = model_name ,
166189 prompt = [0 , 0 , 0 , 0 , 0 ],
167190 max_tokens = 5 ,
168191 temperature = 0.0 ,
@@ -174,7 +197,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
174197 with pytest .raises (
175198 (openai .BadRequestError , openai .APIError )): # test using token IDs
176199 stream = await client .completions .create (
177- model = MODEL_NAME ,
200+ model = model_name ,
178201 prompt = [0 , 0 , 0 , 0 , 0 ],
179202 max_tokens = 5 ,
180203 temperature = 0.0 ,
@@ -199,7 +222,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
199222@pytest .mark .asyncio
200223@pytest .mark .parametrize (
201224 "model_name" ,
202- [MODEL_NAME , "zephyr-lora" ],
225+ [MODEL_NAME , "zephyr-lora" , "zephyr-pa" ],
203226)
204227async def test_completion_streaming (client : openai .AsyncOpenAI ,
205228 model_name : str ):
@@ -233,7 +256,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
233256@pytest .mark .asyncio
234257@pytest .mark .parametrize (
235258 "model_name" ,
236- ["HuggingFaceH4/ zephyr-7b-beta " , "zephyr-lora " ],
259+ [MODEL_NAME , " zephyr-lora " , "zephyr-pa " ],
237260)
238261async def test_completion_stream_options (client : openai .AsyncOpenAI ,
239262 model_name : str ):
@@ -369,9 +392,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
369392
370393@pytest .mark .asyncio
371394@pytest .mark .parametrize (
372- # just test 1 lora hereafter
373395 "model_name" ,
374- [MODEL_NAME , "zephyr-lora" ],
396+ [MODEL_NAME , "zephyr-lora" , "zephyr-pa" ],
375397)
376398async def test_batch_completions (client : openai .AsyncOpenAI , model_name : str ):
377399 # test both text and token IDs
@@ -623,7 +645,7 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
623645)
624646async def test_tokenize (client : openai .AsyncOpenAI , model_name : str ):
625647 base_url = str (client .base_url )[:- 3 ].strip ("/" )
626- tokenizer = get_tokenizer (tokenizer_name = MODEL_NAME , tokenizer_mode = "fast" )
648+ tokenizer = get_tokenizer (tokenizer_name = model_name , tokenizer_mode = "fast" )
627649
628650 for add_special in [False , True ]:
629651 prompt = "This is a test prompt."
@@ -650,7 +672,7 @@ async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
650672)
651673async def test_detokenize (client : openai .AsyncOpenAI , model_name : str ):
652674 base_url = str (client .base_url )[:- 3 ]
653- tokenizer = get_tokenizer (tokenizer_name = MODEL_NAME , tokenizer_mode = "fast" )
675+ tokenizer = get_tokenizer (tokenizer_name = model_name , tokenizer_mode = "fast" )
654676
655677 prompt = "This is a test prompt."
656678 tokens = tokenizer .encode (prompt , add_special_tokens = False )
0 commit comments