Skip to content

Commit db4361d

Browse files
joerundeprashantgupta24comaniac
authored andcommitted
[Core] Reduce TTFT with concurrent partial prefills (vllm-project#10235)
Signed-off-by: Joe Runde <[email protected]> Signed-off-by: Prashant Gupta <[email protected]> Co-authored-by: Prashant Gupta <[email protected]> Co-authored-by: Cody Yu <[email protected]>
1 parent 4abde6f commit db4361d

File tree

6 files changed

+701
-108
lines changed

6 files changed

+701
-108
lines changed

tests/basic_correctness/test_chunked_prefill.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Run `pytest tests/models/test_chunked_prefill.py`.
99
"""
1010
import os
11-
from contextlib import nullcontext
1211

1312
import pytest
1413

@@ -233,7 +232,6 @@ def test_with_prefix_caching(
233232

234233
max_num_batched_tokens = max_num_seqs = chunk_size
235234
outputs = {} # type: ignore
236-
check_result = True
237235
for enable in (True, False):
238236
with vllm_runner(
239237
model,
@@ -245,25 +243,17 @@ def test_with_prefix_caching(
245243
enforce_eager=enforce_eager,
246244
max_num_seqs=max_num_seqs,
247245
) as vllm_model:
248-
# It should fail when prefix caching is enable and chunk
249-
# size is not a multiple of block size (16).
250-
should_fail = chunk_size % 16 != 0 and enable
251-
check_result &= not should_fail
252246
outputs[enable] = []
253-
# Send the request one-by-one to ensure the cache is populated.
254-
with pytest.raises(ValueError) if should_fail else nullcontext():
255-
for prompt in full_prompts:
256-
outputs[enable] += vllm_model.generate_greedy([prompt],
257-
max_tokens)
258-
259-
# Check results only if we did not expect a failure.
260-
if check_result:
261-
check_outputs_equal(
262-
outputs_0_lst=outputs[False],
263-
outputs_1_lst=outputs[True],
264-
name_0="w/o prefix caching",
265-
name_1="with prefix caching",
266-
)
247+
for prompt in full_prompts:
248+
outputs[enable] += vllm_model.generate_greedy([prompt],
249+
max_tokens)
250+
251+
check_outputs_equal(
252+
outputs_0_lst=outputs[False],
253+
outputs_1_lst=outputs[True],
254+
name_0="w/o prefix caching",
255+
name_1="with prefix caching",
256+
)
267257

268258

269259
@pytest.mark.parametrize("model", ["facebook/opt-125m"])

tests/core/test_chunked_prefill_scheduler.py

Lines changed: 314 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
from vllm.config import CacheConfig, SchedulerConfig
99
from vllm.core.scheduler import Scheduler
10+
from vllm.engine.arg_utils import EngineArgs
11+
from vllm.engine.llm_engine import LLMEngine
12+
from vllm.sampling_params import SamplingParams
1013
from vllm.sequence import Logprob, SequenceGroup
1114

1215
from .utils import create_dummy_prompt
@@ -16,7 +19,7 @@ def get_sequence_groups(scheduler_output):
1619
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
1720

1821

19-
def append_new_token(seq_group, token_id: int):
22+
def append_new_token(seq_group: SequenceGroup, token_id: int):
2023
for seq in seq_group.get_seqs():
2124
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
2225

@@ -123,6 +126,232 @@ def test_chunk():
123126
assert out.num_batched_tokens == 57
124127

125128

129+
def test_concurrent_chunking():
130+
"""Verify prefills are chunked properly when
131+
--max-num-partial-prefills is > 1"""
132+
block_size = 4
133+
max_seqs = 60
134+
max_model_len = 2000
135+
max_num_batched_tokens = 64
136+
scheduler_config = SchedulerConfig(
137+
"generate",
138+
max_num_batched_tokens,
139+
max_seqs,
140+
max_model_len,
141+
enable_chunked_prefill=True,
142+
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
143+
)
144+
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
145+
cache_config.num_cpu_blocks = 32
146+
cache_config.num_gpu_blocks = 32
147+
scheduler = Scheduler(scheduler_config, cache_config, None)
148+
running: List[SequenceGroup] = []
149+
150+
# Add seq groups to scheduler.
151+
for i in range(2):
152+
_, seq_group = create_dummy_prompt(str(i),
153+
prompt_length=60,
154+
block_size=block_size)
155+
scheduler.add_seq_group(seq_group)
156+
running.append(seq_group)
157+
158+
# Verify both requests are chunked with half of max_num_batched_tokens each
159+
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
160+
assert set(get_sequence_groups(out)) == set(running)
161+
assert seq_group_meta[0].token_chunk_size == 32
162+
assert seq_group_meta[1].token_chunk_size == 32
163+
assert out.num_prefill_groups == 2
164+
assert out.num_batched_tokens == 64
165+
166+
# After one iteration, both should have 60 - 32 = 28 tokens left to prefill
167+
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
168+
assert set(get_sequence_groups(out)) == set(running)
169+
assert seq_group_meta[0].token_chunk_size == 28
170+
assert seq_group_meta[1].token_chunk_size == 28
171+
assert out.num_prefill_groups == 2
172+
assert out.num_batched_tokens == 56
173+
174+
175+
def test_concurrent_chunking_large_requests():
176+
"""Verify large prefill requests are run one at a time"""
177+
block_size = 4
178+
max_seqs = 60
179+
max_model_len = 2000
180+
max_num_batched_tokens = 64
181+
scheduler_config = SchedulerConfig(
182+
"generate",
183+
max_num_batched_tokens,
184+
max_seqs,
185+
max_model_len,
186+
enable_chunked_prefill=True,
187+
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
188+
)
189+
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
190+
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
191+
cache_config.num_gpu_blocks = 3200
192+
scheduler = Scheduler(scheduler_config, cache_config, None)
193+
194+
# Add seq groups to scheduler.
195+
for i in range(2):
196+
_, seq_group = create_dummy_prompt(
197+
str(i),
198+
prompt_length=1200, # Very large prompt
199+
block_size=block_size)
200+
scheduler.add_seq_group(seq_group)
201+
202+
# Verify only a single request is chunked, and it gets all 64 tokens
203+
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
204+
assert len(get_sequence_groups(out)) == 1
205+
assert seq_group_meta[0].token_chunk_size == 64
206+
assert out.num_prefill_groups == 1
207+
assert out.num_batched_tokens == 64
208+
209+
210+
def test_short_prompts_jump_long_prompts_in_queue():
211+
"""Verify large prefill requests are punted behind smaller ones if
212+
another large prefill request is already running"""
213+
block_size = 4
214+
max_seqs = 60
215+
max_model_len = 2000
216+
max_num_batched_tokens = 64
217+
scheduler_config = SchedulerConfig(
218+
"generate",
219+
max_num_batched_tokens,
220+
max_seqs,
221+
max_model_len,
222+
enable_chunked_prefill=True,
223+
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
224+
)
225+
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
226+
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
227+
cache_config.num_gpu_blocks = 3200
228+
scheduler = Scheduler(scheduler_config, cache_config, None)
229+
long_seqs: List[SequenceGroup] = []
230+
short_seqs: List[SequenceGroup] = []
231+
232+
# Add 2 large seq groups to scheduler.
233+
for i in range(2):
234+
_, seq_group = create_dummy_prompt(
235+
str(i),
236+
prompt_length=1200, # Very large prompt
237+
block_size=block_size)
238+
scheduler.add_seq_group(seq_group)
239+
long_seqs.append(seq_group)
240+
assert seq_group.is_prefill()
241+
242+
# Add 2 small seq groups behind them
243+
for i in range(2):
244+
_, seq_group = create_dummy_prompt(
245+
str(i + 2),
246+
prompt_length=40, # Very small prompt
247+
block_size=block_size)
248+
scheduler.add_seq_group(seq_group)
249+
short_seqs.append(seq_group)
250+
assert seq_group.is_prefill()
251+
252+
# Verify one large req and 1 small req chunked
253+
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
254+
assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens
255+
assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens
256+
257+
# all 4 are prefilling
258+
assert long_seqs[0].is_prefill()
259+
assert long_seqs[1].is_prefill()
260+
assert short_seqs[0].is_prefill()
261+
assert short_seqs[1].is_prefill()
262+
# First short and first long sequences have been scheduled
263+
assert long_seqs[0].first_seq.get_num_computed_tokens() == 32
264+
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
265+
assert short_seqs[0].first_seq.get_num_computed_tokens() == 32
266+
assert short_seqs[1].first_seq.get_num_computed_tokens() == 0
267+
268+
assert out.num_prefill_groups == 2
269+
assert out.num_batched_tokens == 64
270+
271+
# in the second iteration,
272+
# the first small request had only 8 tokens left
273+
# so it went to decode
274+
# The other small req is scheduled
275+
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
276+
# the new small req got 64 - (32+8) tokens
277+
assert seq_group_meta[0].token_chunk_size == 24
278+
assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32
279+
# the other small request had only 8 tokens left
280+
assert seq_group_meta[2].token_chunk_size == 8 # 40-32
281+
282+
# The first small request got to decode now
283+
assert long_seqs[0].is_prefill()
284+
assert long_seqs[1].is_prefill()
285+
assert not short_seqs[0].is_prefill()
286+
assert short_seqs[1].is_prefill()
287+
# Both small requests have started in front of the second long request
288+
assert long_seqs[0].first_seq.get_num_computed_tokens() == 64
289+
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
290+
assert short_seqs[0].first_seq.get_num_computed_tokens() == 40
291+
assert short_seqs[1].first_seq.get_num_computed_tokens() == 24
292+
293+
assert out.num_prefill_groups == 3
294+
assert out.num_batched_tokens == 64
295+
# the first small seq group has a new token appended.
296+
append_new_token(short_seqs[0], 1)
297+
298+
# in the third iteration,
299+
# the first small request is already decoding
300+
# the second small request only has 16 tokens left and will enter decoding
301+
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
302+
assert seq_group_meta[0].token_chunk_size == 32 # large still got 32
303+
# small req finished prefilling 40-24=16 tokens
304+
assert seq_group_meta[1].token_chunk_size == 16
305+
assert seq_group_meta[2].token_chunk_size == 1 # decode
306+
assert out.num_prefill_groups == 2
307+
assert out.num_batched_tokens == 49 # (32+16+1 decode)
308+
309+
# both small requests have now reached decode
310+
assert long_seqs[0].is_prefill()
311+
assert long_seqs[1].is_prefill()
312+
assert not short_seqs[0].is_prefill()
313+
assert not short_seqs[1].is_prefill()
314+
assert long_seqs[0].first_seq.get_num_computed_tokens() == 96
315+
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
316+
assert short_seqs[0].first_seq.get_num_computed_tokens() == 41
317+
assert short_seqs[1].first_seq.get_num_computed_tokens() == 40
318+
319+
# both the small seq groups have a new token appended
320+
append_new_token(short_seqs[0], 1)
321+
append_new_token(short_seqs[1], 1)
322+
323+
# in the fourth iteration, both small requests are decoding
324+
# so large request gets all the budget
325+
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
326+
327+
# large req gets 62 tokens (minus 2 for decode)
328+
assert seq_group_meta[0].token_chunk_size == 62
329+
assert seq_group_meta[1].token_chunk_size == 1 # decode
330+
assert seq_group_meta[2].token_chunk_size == 1 # decode
331+
assert out.num_prefill_groups == 1
332+
assert out.num_batched_tokens == 64
333+
334+
assert long_seqs[0].first_seq.get_num_computed_tokens() == 158
335+
336+
# assert long_seqs[0].is_prefill()
337+
# assert long_seqs[1].is_prefill()
338+
# assert not short_seqs[0].is_prefill()
339+
# assert not short_seqs[1].is_prefill()
340+
341+
# # both the small seq groups have a new token appended
342+
# append_new_token(short_seqs[0], 1)
343+
# append_new_token(short_seqs[1], 1)
344+
345+
# # in the fifth iteration, large request gets all the budget
346+
# # while both small requests are decoding
347+
# seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
348+
# assert seq_group_meta[0].token_chunk_size == 62
349+
# assert seq_group_meta[1].token_chunk_size == 1 # decode
350+
# assert seq_group_meta[2].token_chunk_size == 1 # decode
351+
# assert out.num_prefill_groups == 1
352+
# assert out.num_batched_tokens == 64
353+
354+
126355
def test_complex():
127356
block_size = 4
128357
max_seqs = 60
@@ -508,7 +737,7 @@ def test_chunked_prefill_max_seqs():
508737
assert not running[1].is_prefill()
509738

510739

511-
def test_perfix_caching():
740+
def test_prefix_caching():
512741
"""Verify allocating full blocks when prefix caching is enabled."""
513742
block_size = 4
514743
max_seqs = 10
@@ -548,3 +777,86 @@ def test_perfix_caching():
548777
assert seq_group_meta[1].token_chunk_size == 12
549778
assert out.num_prefill_groups == 2
550779
assert out.num_batched_tokens == 62
780+
781+
782+
def test_prefix_caching_with_concurrent_partial_prefills():
783+
"""Verify allocating full blocks when prefix caching is enabled with
784+
--max-num-partial-prefills > 1."""
785+
block_size = 4
786+
max_seqs = 10
787+
max_model_len = 8000
788+
max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens
789+
scheduler_config = SchedulerConfig("generate",
790+
max_num_batched_tokens,
791+
max_seqs,
792+
max_model_len,
793+
enable_chunked_prefill=True,
794+
max_num_partial_prefills=2)
795+
cache_config = CacheConfig(block_size,
796+
1.0,
797+
1,
798+
"auto",
799+
enable_prefix_caching=True)
800+
cache_config.num_cpu_blocks = 0
801+
cache_config.num_gpu_blocks = 32
802+
scheduler = Scheduler(scheduler_config, cache_config, None)
803+
running: List[SequenceGroup] = []
804+
805+
# Add seq groups to scheduler.
806+
for i in range(2):
807+
_, seq_group = create_dummy_prompt(str(i),
808+
block_size=block_size,
809+
prompt_length=50)
810+
scheduler.add_seq_group(seq_group)
811+
running.append(seq_group)
812+
813+
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
814+
assert set(get_sequence_groups(out)) == set(running)
815+
# To partially prefill both sequences, both can chunk up to 30 tokens
816+
# But the next lowest multiple of the block size (4) is 28
817+
assert seq_group_meta[0].token_chunk_size == 28
818+
assert seq_group_meta[1].token_chunk_size == 28
819+
assert out.num_prefill_groups == 2
820+
assert out.num_batched_tokens == 56
821+
822+
# On the next iteration, both sequences should finish prefill
823+
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
824+
assert set(get_sequence_groups(out)) == set(running)
825+
# Both sequences have 50 - 28 = 22 tokens left to prefill.
826+
# This is not a multiple of the block size, but we don't care since we don't
827+
# cache the final partial block of prefix sequences
828+
assert seq_group_meta[0].token_chunk_size == 22
829+
assert seq_group_meta[1].token_chunk_size == 22
830+
assert out.num_prefill_groups == 2
831+
assert out.num_batched_tokens == 44
832+
833+
834+
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
835+
@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8])
836+
def test_chunked_prefill_with_actual_engine(model: str,
837+
max_num_partial_prefills: int):
838+
"""Make sure the model can actually sample with concurrent
839+
partial prefills
840+
"""
841+
842+
prompt = "hello" * 40
843+
844+
engine_args = EngineArgs(
845+
model=model,
846+
max_num_partial_prefills=max_num_partial_prefills,
847+
max_num_batched_tokens=40,
848+
max_num_seqs=8,
849+
enable_chunked_prefill=True,
850+
gpu_memory_utilization=0.8,
851+
)
852+
853+
engine = LLMEngine.from_engine_args(engine_args)
854+
sampling_params = SamplingParams(temperature=0)
855+
856+
for req_num in range(max_num_partial_prefills):
857+
engine.add_request(f"{req_num}", prompt, sampling_params)
858+
# first step
859+
request_outputs = engine.step()
860+
# means all are prefilling
861+
assert len(request_outputs) == 0
862+
assert len(engine.scheduler[0].running) == max_num_partial_prefills

0 commit comments

Comments
 (0)