77
88from vllm .config import CacheConfig , SchedulerConfig
99from vllm .core .scheduler import Scheduler
10+ from vllm .engine .arg_utils import EngineArgs
11+ from vllm .engine .llm_engine import LLMEngine
12+ from vllm .sampling_params import SamplingParams
1013from vllm .sequence import Logprob , SequenceGroup
1114
1215from .utils import create_dummy_prompt
@@ -16,7 +19,7 @@ def get_sequence_groups(scheduler_output):
1619 return [s .seq_group for s in scheduler_output .scheduled_seq_groups ]
1720
1821
19- def append_new_token (seq_group , token_id : int ):
22+ def append_new_token (seq_group : SequenceGroup , token_id : int ):
2023 for seq in seq_group .get_seqs ():
2124 seq .append_token_id (token_id , {token_id : Logprob (token_id )})
2225
@@ -123,6 +126,232 @@ def test_chunk():
123126 assert out .num_batched_tokens == 57
124127
125128
129+ def test_concurrent_chunking ():
130+ """Verify prefills are chunked properly when
131+ --max-num-partial-prefills is > 1"""
132+ block_size = 4
133+ max_seqs = 60
134+ max_model_len = 2000
135+ max_num_batched_tokens = 64
136+ scheduler_config = SchedulerConfig (
137+ "generate" ,
138+ max_num_batched_tokens ,
139+ max_seqs ,
140+ max_model_len ,
141+ enable_chunked_prefill = True ,
142+ max_num_partial_prefills = 2 , # Up to 2 partial prefills at a time
143+ )
144+ cache_config = CacheConfig (block_size , 1.0 , 1 , "auto" )
145+ cache_config .num_cpu_blocks = 32
146+ cache_config .num_gpu_blocks = 32
147+ scheduler = Scheduler (scheduler_config , cache_config , None )
148+ running : List [SequenceGroup ] = []
149+
150+ # Add seq groups to scheduler.
151+ for i in range (2 ):
152+ _ , seq_group = create_dummy_prompt (str (i ),
153+ prompt_length = 60 ,
154+ block_size = block_size )
155+ scheduler .add_seq_group (seq_group )
156+ running .append (seq_group )
157+
158+ # Verify both requests are chunked with half of max_num_batched_tokens each
159+ seq_group_meta , out = schedule_and_update_computed_tokens (scheduler )
160+ assert set (get_sequence_groups (out )) == set (running )
161+ assert seq_group_meta [0 ].token_chunk_size == 32
162+ assert seq_group_meta [1 ].token_chunk_size == 32
163+ assert out .num_prefill_groups == 2
164+ assert out .num_batched_tokens == 64
165+
166+ # After one iteration, both should have 60 - 32 = 28 tokens left to prefill
167+ seq_group_meta , out = schedule_and_update_computed_tokens (scheduler )
168+ assert set (get_sequence_groups (out )) == set (running )
169+ assert seq_group_meta [0 ].token_chunk_size == 28
170+ assert seq_group_meta [1 ].token_chunk_size == 28
171+ assert out .num_prefill_groups == 2
172+ assert out .num_batched_tokens == 56
173+
174+
175+ def test_concurrent_chunking_large_requests ():
176+ """Verify large prefill requests are run one at a time"""
177+ block_size = 4
178+ max_seqs = 60
179+ max_model_len = 2000
180+ max_num_batched_tokens = 64
181+ scheduler_config = SchedulerConfig (
182+ "generate" ,
183+ max_num_batched_tokens ,
184+ max_seqs ,
185+ max_model_len ,
186+ enable_chunked_prefill = True ,
187+ max_num_partial_prefills = 2 , # Up to 2 partial prefills at a time
188+ )
189+ cache_config = CacheConfig (block_size , 1.0 , 1 , "auto" )
190+ cache_config .num_cpu_blocks = 3200 # large KV cache size for large requests
191+ cache_config .num_gpu_blocks = 3200
192+ scheduler = Scheduler (scheduler_config , cache_config , None )
193+
194+ # Add seq groups to scheduler.
195+ for i in range (2 ):
196+ _ , seq_group = create_dummy_prompt (
197+ str (i ),
198+ prompt_length = 1200 , # Very large prompt
199+ block_size = block_size )
200+ scheduler .add_seq_group (seq_group )
201+
202+ # Verify only a single request is chunked, and it gets all 64 tokens
203+ seq_group_meta , out = schedule_and_update_computed_tokens (scheduler )
204+ assert len (get_sequence_groups (out )) == 1
205+ assert seq_group_meta [0 ].token_chunk_size == 64
206+ assert out .num_prefill_groups == 1
207+ assert out .num_batched_tokens == 64
208+
209+
210+ def test_short_prompts_jump_long_prompts_in_queue ():
211+ """Verify large prefill requests are punted behind smaller ones if
212+ another large prefill request is already running"""
213+ block_size = 4
214+ max_seqs = 60
215+ max_model_len = 2000
216+ max_num_batched_tokens = 64
217+ scheduler_config = SchedulerConfig (
218+ "generate" ,
219+ max_num_batched_tokens ,
220+ max_seqs ,
221+ max_model_len ,
222+ enable_chunked_prefill = True ,
223+ max_num_partial_prefills = 2 , # Up to 2 partial prefills at a time
224+ )
225+ cache_config = CacheConfig (block_size , 1.0 , 1 , "auto" )
226+ cache_config .num_cpu_blocks = 3200 # large KV cache size for large requests
227+ cache_config .num_gpu_blocks = 3200
228+ scheduler = Scheduler (scheduler_config , cache_config , None )
229+ long_seqs : List [SequenceGroup ] = []
230+ short_seqs : List [SequenceGroup ] = []
231+
232+ # Add 2 large seq groups to scheduler.
233+ for i in range (2 ):
234+ _ , seq_group = create_dummy_prompt (
235+ str (i ),
236+ prompt_length = 1200 , # Very large prompt
237+ block_size = block_size )
238+ scheduler .add_seq_group (seq_group )
239+ long_seqs .append (seq_group )
240+ assert seq_group .is_prefill ()
241+
242+ # Add 2 small seq groups behind them
243+ for i in range (2 ):
244+ _ , seq_group = create_dummy_prompt (
245+ str (i + 2 ),
246+ prompt_length = 40 , # Very small prompt
247+ block_size = block_size )
248+ scheduler .add_seq_group (seq_group )
249+ short_seqs .append (seq_group )
250+ assert seq_group .is_prefill ()
251+
252+ # Verify one large req and 1 small req chunked
253+ seq_group_meta , out = schedule_and_update_computed_tokens (scheduler )
254+ assert seq_group_meta [0 ].token_chunk_size == 32 # large req gets 32 tokens
255+ assert seq_group_meta [1 ].token_chunk_size == 32 # small req gets 32 tokens
256+
257+ # all 4 are prefilling
258+ assert long_seqs [0 ].is_prefill ()
259+ assert long_seqs [1 ].is_prefill ()
260+ assert short_seqs [0 ].is_prefill ()
261+ assert short_seqs [1 ].is_prefill ()
262+ # First short and first long sequences have been scheduled
263+ assert long_seqs [0 ].first_seq .get_num_computed_tokens () == 32
264+ assert long_seqs [1 ].first_seq .get_num_computed_tokens () == 0
265+ assert short_seqs [0 ].first_seq .get_num_computed_tokens () == 32
266+ assert short_seqs [1 ].first_seq .get_num_computed_tokens () == 0
267+
268+ assert out .num_prefill_groups == 2
269+ assert out .num_batched_tokens == 64
270+
271+ # in the second iteration,
272+ # the first small request had only 8 tokens left
273+ # so it went to decode
274+ # The other small req is scheduled
275+ seq_group_meta , out = schedule_and_update_computed_tokens (scheduler )
276+ # the new small req got 64 - (32+8) tokens
277+ assert seq_group_meta [0 ].token_chunk_size == 24
278+ assert seq_group_meta [1 ].token_chunk_size == 32 # large req still got 32
279+ # the other small request had only 8 tokens left
280+ assert seq_group_meta [2 ].token_chunk_size == 8 # 40-32
281+
282+ # The first small request got to decode now
283+ assert long_seqs [0 ].is_prefill ()
284+ assert long_seqs [1 ].is_prefill ()
285+ assert not short_seqs [0 ].is_prefill ()
286+ assert short_seqs [1 ].is_prefill ()
287+ # Both small requests have started in front of the second long request
288+ assert long_seqs [0 ].first_seq .get_num_computed_tokens () == 64
289+ assert long_seqs [1 ].first_seq .get_num_computed_tokens () == 0
290+ assert short_seqs [0 ].first_seq .get_num_computed_tokens () == 40
291+ assert short_seqs [1 ].first_seq .get_num_computed_tokens () == 24
292+
293+ assert out .num_prefill_groups == 3
294+ assert out .num_batched_tokens == 64
295+ # the first small seq group has a new token appended.
296+ append_new_token (short_seqs [0 ], 1 )
297+
298+ # in the third iteration,
299+ # the first small request is already decoding
300+ # the second small request only has 16 tokens left and will enter decoding
301+ seq_group_meta , out = schedule_and_update_computed_tokens (scheduler )
302+ assert seq_group_meta [0 ].token_chunk_size == 32 # large still got 32
303+ # small req finished prefilling 40-24=16 tokens
304+ assert seq_group_meta [1 ].token_chunk_size == 16
305+ assert seq_group_meta [2 ].token_chunk_size == 1 # decode
306+ assert out .num_prefill_groups == 2
307+ assert out .num_batched_tokens == 49 # (32+16+1 decode)
308+
309+ # both small requests have now reached decode
310+ assert long_seqs [0 ].is_prefill ()
311+ assert long_seqs [1 ].is_prefill ()
312+ assert not short_seqs [0 ].is_prefill ()
313+ assert not short_seqs [1 ].is_prefill ()
314+ assert long_seqs [0 ].first_seq .get_num_computed_tokens () == 96
315+ assert long_seqs [1 ].first_seq .get_num_computed_tokens () == 0
316+ assert short_seqs [0 ].first_seq .get_num_computed_tokens () == 41
317+ assert short_seqs [1 ].first_seq .get_num_computed_tokens () == 40
318+
319+ # both the small seq groups have a new token appended
320+ append_new_token (short_seqs [0 ], 1 )
321+ append_new_token (short_seqs [1 ], 1 )
322+
323+ # in the fourth iteration, both small requests are decoding
324+ # so large request gets all the budget
325+ seq_group_meta , out = schedule_and_update_computed_tokens (scheduler )
326+
327+ # large req gets 62 tokens (minus 2 for decode)
328+ assert seq_group_meta [0 ].token_chunk_size == 62
329+ assert seq_group_meta [1 ].token_chunk_size == 1 # decode
330+ assert seq_group_meta [2 ].token_chunk_size == 1 # decode
331+ assert out .num_prefill_groups == 1
332+ assert out .num_batched_tokens == 64
333+
334+ assert long_seqs [0 ].first_seq .get_num_computed_tokens () == 158
335+
336+ # assert long_seqs[0].is_prefill()
337+ # assert long_seqs[1].is_prefill()
338+ # assert not short_seqs[0].is_prefill()
339+ # assert not short_seqs[1].is_prefill()
340+
341+ # # both the small seq groups have a new token appended
342+ # append_new_token(short_seqs[0], 1)
343+ # append_new_token(short_seqs[1], 1)
344+
345+ # # in the fifth iteration, large request gets all the budget
346+ # # while both small requests are decoding
347+ # seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
348+ # assert seq_group_meta[0].token_chunk_size == 62
349+ # assert seq_group_meta[1].token_chunk_size == 1 # decode
350+ # assert seq_group_meta[2].token_chunk_size == 1 # decode
351+ # assert out.num_prefill_groups == 1
352+ # assert out.num_batched_tokens == 64
353+
354+
126355def test_complex ():
127356 block_size = 4
128357 max_seqs = 60
@@ -508,7 +737,7 @@ def test_chunked_prefill_max_seqs():
508737 assert not running [1 ].is_prefill ()
509738
510739
511- def test_perfix_caching ():
740+ def test_prefix_caching ():
512741 """Verify allocating full blocks when prefix caching is enabled."""
513742 block_size = 4
514743 max_seqs = 10
@@ -548,3 +777,86 @@ def test_perfix_caching():
548777 assert seq_group_meta [1 ].token_chunk_size == 12
549778 assert out .num_prefill_groups == 2
550779 assert out .num_batched_tokens == 62
780+
781+
782+ def test_prefix_caching_with_concurrent_partial_prefills ():
783+ """Verify allocating full blocks when prefix caching is enabled with
784+ --max-num-partial-prefills > 1."""
785+ block_size = 4
786+ max_seqs = 10
787+ max_model_len = 8000
788+ max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens
789+ scheduler_config = SchedulerConfig ("generate" ,
790+ max_num_batched_tokens ,
791+ max_seqs ,
792+ max_model_len ,
793+ enable_chunked_prefill = True ,
794+ max_num_partial_prefills = 2 )
795+ cache_config = CacheConfig (block_size ,
796+ 1.0 ,
797+ 1 ,
798+ "auto" ,
799+ enable_prefix_caching = True )
800+ cache_config .num_cpu_blocks = 0
801+ cache_config .num_gpu_blocks = 32
802+ scheduler = Scheduler (scheduler_config , cache_config , None )
803+ running : List [SequenceGroup ] = []
804+
805+ # Add seq groups to scheduler.
806+ for i in range (2 ):
807+ _ , seq_group = create_dummy_prompt (str (i ),
808+ block_size = block_size ,
809+ prompt_length = 50 )
810+ scheduler .add_seq_group (seq_group )
811+ running .append (seq_group )
812+
813+ seq_group_meta , out = schedule_and_update_computed_tokens (scheduler )
814+ assert set (get_sequence_groups (out )) == set (running )
815+ # To partially prefill both sequences, both can chunk up to 30 tokens
816+ # But the next lowest multiple of the block size (4) is 28
817+ assert seq_group_meta [0 ].token_chunk_size == 28
818+ assert seq_group_meta [1 ].token_chunk_size == 28
819+ assert out .num_prefill_groups == 2
820+ assert out .num_batched_tokens == 56
821+
822+ # On the next iteration, both sequences should finish prefill
823+ seq_group_meta , out = schedule_and_update_computed_tokens (scheduler )
824+ assert set (get_sequence_groups (out )) == set (running )
825+ # Both sequences have 50 - 28 = 22 tokens left to prefill.
826+ # This is not a multiple of the block size, but we don't care since we don't
827+ # cache the final partial block of prefix sequences
828+ assert seq_group_meta [0 ].token_chunk_size == 22
829+ assert seq_group_meta [1 ].token_chunk_size == 22
830+ assert out .num_prefill_groups == 2
831+ assert out .num_batched_tokens == 44
832+
833+
834+ @pytest .mark .parametrize ("model" , ["facebook/opt-125m" ])
835+ @pytest .mark .parametrize ("max_num_partial_prefills" , [2 , 4 , 8 ])
836+ def test_chunked_prefill_with_actual_engine (model : str ,
837+ max_num_partial_prefills : int ):
838+ """Make sure the model can actually sample with concurrent
839+ partial prefills
840+ """
841+
842+ prompt = "hello" * 40
843+
844+ engine_args = EngineArgs (
845+ model = model ,
846+ max_num_partial_prefills = max_num_partial_prefills ,
847+ max_num_batched_tokens = 40 ,
848+ max_num_seqs = 8 ,
849+ enable_chunked_prefill = True ,
850+ gpu_memory_utilization = 0.8 ,
851+ )
852+
853+ engine = LLMEngine .from_engine_args (engine_args )
854+ sampling_params = SamplingParams (temperature = 0 )
855+
856+ for req_num in range (max_num_partial_prefills ):
857+ engine .add_request (f"{ req_num } " , prompt , sampling_params )
858+ # first step
859+ request_outputs = engine .step ()
860+ # means all are prefilling
861+ assert len (request_outputs ) == 0
862+ assert len (engine .scheduler [0 ].running ) == max_num_partial_prefills
0 commit comments