Skip to content

Commit 596b307

Browse files
committed
[Prefix Prefill] Allow larger batch size for prefix prefill
1 parent ef65dcf commit 596b307

File tree

6 files changed

+424
-11
lines changed

6 files changed

+424
-11
lines changed

tests/core/test_block_manager.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,208 @@ def test_sliding_window_multi_seq():
361361

362362
# assert all blocks are free now
363363
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
364+
365+
366+
def test_seq_cached_blocks_num():
367+
# Initialize the block manager
368+
block_size = 16
369+
num_gpu_blocks = 64
370+
num_cpu_blocks = 32
371+
block_manager = BlockSpaceManager(block_size,
372+
num_gpu_blocks,
373+
num_cpu_blocks,
374+
enable_caching=True)
375+
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
376+
assert block_manager.get_num_free_cpu_blocks() == num_cpu_blocks
377+
378+
seq_prompt_length = 64
379+
seq = Sequence(seq_id=0,
380+
prompt="zero to sixty three",
381+
block_size=block_size,
382+
prompt_token_ids=list(range(seq_prompt_length)))
383+
seq_group = SequenceGroup(request_id=0,
384+
seqs=[seq],
385+
sampling_params=SamplingParams(),
386+
arrival_time=time.time())
387+
block_manager.allocate(seq_group)
388+
389+
seq_num_cached_blocks = block_manager.get_num_cached_blocks(seq)
390+
assert seq_num_cached_blocks == 4
391+
# 64 - 4 = 60
392+
assert block_manager.get_num_free_gpu_blocks() == 60
393+
394+
seq1_prompt_length = 32
395+
seq1 = Sequence(seq_id=1,
396+
prompt="zero to thirty one",
397+
block_size=block_size,
398+
prompt_token_ids=list(range(seq1_prompt_length)))
399+
seq1_num_cached_blocks = block_manager.get_num_cached_blocks(seq1)
400+
assert seq1_num_cached_blocks == 2
401+
seq1_group = SequenceGroup(request_id=1,
402+
seqs=[seq1],
403+
sampling_params=SamplingParams(),
404+
arrival_time=time.time())
405+
block_manager.allocate(seq1_group)
406+
# 64 - 4 - (2 - 2) = 60
407+
assert block_manager.get_num_free_gpu_blocks() == 60
408+
409+
seq2_prompt_length = 47
410+
seq2 = Sequence(seq_id=2,
411+
prompt="zero to forty six",
412+
block_size=block_size,
413+
prompt_token_ids=list(range(seq2_prompt_length)))
414+
seq2_num_cached_blocks = block_manager.get_num_cached_blocks(seq2)
415+
assert seq2_num_cached_blocks == 2
416+
seq2_group = SequenceGroup(request_id=2,
417+
seqs=[seq2],
418+
sampling_params=SamplingParams(),
419+
arrival_time=time.time())
420+
block_manager.allocate(seq2_group)
421+
# 64 - 4 - (2 - 2) - (3 - 2) = 59
422+
assert block_manager.get_num_free_gpu_blocks() == 59
423+
424+
seq3_prompt_length = 96
425+
seq3 = Sequence(seq_id=3,
426+
prompt="zero to ninety five",
427+
block_size=block_size,
428+
prompt_token_ids=list(range(seq3_prompt_length)))
429+
seq3_num_cached_blocks = block_manager.get_num_cached_blocks(seq3)
430+
assert seq3_num_cached_blocks == 4
431+
seq3_group = SequenceGroup(request_id=3,
432+
seqs=[seq3],
433+
sampling_params=SamplingParams(),
434+
arrival_time=time.time())
435+
block_manager.allocate(seq3_group)
436+
# 64 - 4 - (2 - 2) - (3 - 2) - (6 - 4) = 57
437+
assert block_manager.get_num_free_gpu_blocks() == 57
438+
439+
440+
def test_seq_computed_blocks_num():
441+
# Initialize the block manager
442+
block_size = 16
443+
num_gpu_blocks = 64
444+
num_cpu_blocks = 32
445+
block_manager = BlockSpaceManager(block_size,
446+
num_gpu_blocks,
447+
num_cpu_blocks,
448+
enable_caching=True)
449+
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
450+
assert block_manager.get_num_free_cpu_blocks() == num_cpu_blocks
451+
452+
seq_prompt_length = 64
453+
seq = Sequence(seq_id=0,
454+
prompt="zero to sixty three",
455+
block_size=block_size,
456+
prompt_token_ids=list(range(seq_prompt_length)))
457+
seq_group = SequenceGroup(request_id=0,
458+
seqs=[seq],
459+
sampling_params=SamplingParams(),
460+
arrival_time=time.time())
461+
block_manager.allocate(seq_group)
462+
block_manager.mark_blocks_as_computed(seq_group)
463+
464+
seq_num_computed_blocks = block_manager.get_num_computed_blocks(seq)
465+
assert seq_num_computed_blocks == 3
466+
# Ensure the computed blocks number aligns with the real allocation behavior
467+
seq_computed_blocks = block_manager.get_all_computed_blocks(seq)
468+
assert seq_num_computed_blocks == len(seq_computed_blocks)
469+
470+
seq1_prompt_length = 48
471+
seq1 = Sequence(seq_id=1,
472+
prompt="zero to forty seven",
473+
block_size=block_size,
474+
prompt_token_ids=list(range(seq1_prompt_length)))
475+
seq1_num_computed_blocks = block_manager.get_num_computed_blocks(seq1)
476+
assert seq1_num_computed_blocks == 2
477+
seq1_group = SequenceGroup(request_id=1,
478+
seqs=[seq1],
479+
sampling_params=SamplingParams(),
480+
arrival_time=time.time())
481+
# Ensure the computed blocks number aligns with the real allocation behavior
482+
block_manager.allocate(seq1_group)
483+
block_manager.mark_blocks_as_computed(seq1_group)
484+
seq1_computed_blocks = block_manager.get_all_computed_blocks(seq1)
485+
assert seq1_num_computed_blocks == len(seq1_computed_blocks)
486+
487+
seq2_prompt_length = 55
488+
seq2 = Sequence(seq_id=2,
489+
prompt="zero to fifty four",
490+
block_size=block_size,
491+
prompt_token_ids=list(range(seq2_prompt_length)))
492+
seq2_num_computed_blocks = block_manager.get_num_computed_blocks(seq2)
493+
assert seq2_num_computed_blocks == 3
494+
seq2_group = SequenceGroup(request_id=1,
495+
seqs=[seq2],
496+
sampling_params=SamplingParams(),
497+
arrival_time=time.time())
498+
# Ensure the computed blocks number aligns with the real allocation behavior
499+
block_manager.allocate(seq2_group)
500+
block_manager.mark_blocks_as_computed(seq2_group)
501+
seq2_computed_blocks = block_manager.get_all_computed_blocks(seq2)
502+
assert seq2_num_computed_blocks == len(seq2_computed_blocks)
503+
504+
seq3_prompt_length = 81
505+
seq3 = Sequence(seq_id=3,
506+
prompt="zero to eighty",
507+
block_size=block_size,
508+
prompt_token_ids=list(range(seq3_prompt_length)))
509+
seq3_num_computed_blocks = block_manager.get_num_computed_blocks(seq3)
510+
assert seq3_num_computed_blocks == 4
511+
seq3_group = SequenceGroup(request_id=3,
512+
seqs=[seq3],
513+
sampling_params=SamplingParams(),
514+
arrival_time=time.time())
515+
# Ensure the computed blocks number aligns with the real allocation behavior
516+
block_manager.allocate(seq3_group)
517+
block_manager.mark_blocks_as_computed(seq3_group)
518+
seq3_computed_blocks = block_manager.get_all_computed_blocks(seq3)
519+
assert seq3_num_computed_blocks == len(seq3_computed_blocks)
520+
521+
# Test the computed blocks num after the sequences are freed
522+
# Free operation doesn't influence the computed blocks number
523+
block_manager.free(seq)
524+
block_manager.free(seq1)
525+
block_manager.free(seq2)
526+
block_manager.free(seq3)
527+
seq_num_computed_blocks = block_manager.get_num_computed_blocks(seq)
528+
assert seq_num_computed_blocks == 3
529+
seq1_num_computed_blocks = block_manager.get_num_computed_blocks(seq1)
530+
assert seq1_num_computed_blocks == 2
531+
seq2_num_computed_blocks = block_manager.get_num_computed_blocks(seq2)
532+
assert seq2_num_computed_blocks == 3
533+
seq3_num_computed_blocks = block_manager.get_num_computed_blocks(seq3)
534+
assert seq3_num_computed_blocks == 4
535+
536+
# Test the computed blocks num
537+
# after the second block(token 15~31) are evicted
538+
# Since the second block is evicted, the caches are not continuous
539+
# from the second block. Therefore, all seqs' computed blocks numbers
540+
# are reduced to 1.
541+
evicted_block_hash = seq.hash_of_block(1)
542+
evicted_block = block_manager.gpu_allocator.evictor.remove(
543+
evicted_block_hash)
544+
seq_num_computed_blocks = block_manager.get_num_computed_blocks(seq)
545+
assert seq_num_computed_blocks == 1
546+
seq1_num_computed_blocks = block_manager.get_num_computed_blocks(seq1)
547+
assert seq1_num_computed_blocks == 1
548+
seq2_num_computed_blocks = block_manager.get_num_computed_blocks(seq2)
549+
assert seq2_num_computed_blocks == 1
550+
seq3_num_computed_blocks = block_manager.get_num_computed_blocks(seq3)
551+
assert seq3_num_computed_blocks == 1
552+
553+
# Test the computed blocks num
554+
# after the second block(token 15~31) are marked as not computed
555+
# Since the second block is marked as not computed, the caches are not
556+
# continuous from the second block. Therefore, all seqs' computed blocks
557+
# numbers are reduced to 1.
558+
evicted_block.computed = False
559+
block_manager.gpu_allocator.cached_blocks[
560+
evicted_block.block_hash] = evicted_block
561+
seq_num_computed_blocks = block_manager.get_num_computed_blocks(seq)
562+
assert seq_num_computed_blocks == 1
563+
seq1_num_computed_blocks = block_manager.get_num_computed_blocks(seq1)
564+
assert seq1_num_computed_blocks == 1
565+
seq2_num_computed_blocks = block_manager.get_num_computed_blocks(seq2)
566+
assert seq2_num_computed_blocks == 1
567+
seq3_num_computed_blocks = block_manager.get_num_computed_blocks(seq3)
568+
assert seq3_num_computed_blocks == 1

tests/core/test_scheduler.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from typing import List
22
import pytest # noqa
3+
import time
34

5+
from vllm import SamplingParams
46
from vllm.config import CacheConfig, SchedulerConfig
57
from vllm.core.scheduler import Scheduler
6-
from vllm.sequence import SequenceGroup, Logprob
8+
from vllm.sequence import Sequence, SequenceGroup, Logprob
79

810
from .utils import create_dummy_prompt
911

@@ -168,3 +170,117 @@ def test_scheduler_max_seqs():
168170
# and one is prompting.
169171
_, out = scheduler.schedule()
170172
assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]])
173+
174+
175+
def test_scheduler_with_cache():
176+
# Initialize the scheduler
177+
max_batched_tokens = 96
178+
max_seq_group = 8
179+
max_model_length = 96
180+
max_paddings = 256
181+
scheduler_config = SchedulerConfig(max_batched_tokens, max_seq_group,
182+
max_model_length, max_paddings)
183+
184+
block_size = 16
185+
cache_config = CacheConfig(block_size,
186+
1.0,
187+
1,
188+
"auto",
189+
enable_prefix_caching=True)
190+
cache_config.num_gpu_blocks = 8
191+
cache_config.num_cpu_blocks = 8
192+
193+
scheduler = Scheduler(scheduler_config, cache_config, None)
194+
195+
seq0_prompt_length = 64
196+
seq0 = Sequence(seq_id=0,
197+
prompt="zero to sixty three",
198+
block_size=block_size,
199+
prompt_token_ids=list(range(seq0_prompt_length)))
200+
seq0_group = SequenceGroup(request_id=0,
201+
seqs=[seq0],
202+
sampling_params=SamplingParams(),
203+
arrival_time=time.time())
204+
# Allocate 4 blocks for caching
205+
scheduler.block_manager.allocate(seq0_group)
206+
# Mark the 4 blocks as computed
207+
scheduler.block_manager.mark_blocks_as_computed(seq0_group)
208+
# Requires 0 extra blocks, 16 batched tokens
209+
scheduler.add_seq_group(seq0_group)
210+
assert len(seq0.logical_token_blocks) -\
211+
scheduler.block_manager.get_num_cached_blocks(seq0) == 0
212+
assert seq0.get_len() -\
213+
scheduler.block_manager.get_num_computed_tokens(seq0) == 16
214+
215+
seq1_prompt_length = 48
216+
seq1 = Sequence(seq_id=1,
217+
prompt="zero to forty seven",
218+
block_size=block_size,
219+
prompt_token_ids=list(range(seq1_prompt_length)))
220+
seq1_group = SequenceGroup(request_id=1,
221+
seqs=[seq1],
222+
sampling_params=SamplingParams(),
223+
arrival_time=time.time())
224+
# Requires 0 extra block, 16 batched tokens
225+
scheduler.add_seq_group(seq1_group)
226+
assert len(seq1.logical_token_blocks) -\
227+
scheduler.block_manager.get_num_cached_blocks(seq1) == 0
228+
assert seq1.get_len() -\
229+
scheduler.block_manager.get_num_computed_tokens(seq1) == 16
230+
231+
seq2_prompt_length = 56
232+
seq2 = Sequence(seq_id=2,
233+
prompt="zero to fifty four",
234+
block_size=block_size,
235+
prompt_token_ids=list(range(seq2_prompt_length)))
236+
seq2_group = SequenceGroup(request_id=2,
237+
seqs=[seq2],
238+
sampling_params=SamplingParams(),
239+
arrival_time=time.time())
240+
# Requires 1 extra block, 8 batched tokens
241+
scheduler.add_seq_group(seq2_group)
242+
assert len(seq2.logical_token_blocks) -\
243+
scheduler.block_manager.get_num_cached_blocks(seq2) == 1
244+
assert seq2.get_len() -\
245+
scheduler.block_manager.get_num_computed_tokens(seq2) == 8
246+
247+
seq3_prompt_length = 80
248+
seq3 = Sequence(seq_id=3,
249+
prompt="zero to seventy nine",
250+
block_size=block_size,
251+
prompt_token_ids=list(range(seq3_prompt_length)))
252+
seq3_group = SequenceGroup(request_id=3,
253+
seqs=[seq3],
254+
sampling_params=SamplingParams(),
255+
arrival_time=time.time())
256+
# Requires 1 extra blocks, 16 batched tokens
257+
scheduler.add_seq_group(seq3_group)
258+
assert len(seq3.logical_token_blocks) -\
259+
scheduler.block_manager.get_num_cached_blocks(seq3) == 1
260+
assert seq3.get_len() -\
261+
scheduler.block_manager.get_num_computed_tokens(seq3) == 16
262+
263+
seq4_prompt_length = 96
264+
seq4 = Sequence(seq_id=4,
265+
prompt="zero to ninety five",
266+
block_size=block_size,
267+
prompt_token_ids=list(range(seq4_prompt_length)))
268+
seq4_group = SequenceGroup(request_id=4,
269+
seqs=[seq4],
270+
sampling_params=SamplingParams(),
271+
arrival_time=time.time())
272+
# Requires 2 extra block, 32 batched tokens
273+
scheduler.add_seq_group(seq4_group)
274+
assert len(seq4.logical_token_blocks) -\
275+
scheduler.block_manager.get_num_cached_blocks(seq4) == 2
276+
assert seq4.get_len() -\
277+
scheduler.block_manager.get_num_computed_tokens(seq4) == 32
278+
279+
scheduler_outputs = scheduler._schedule()
280+
scheduled_seq_groups_ids = []
281+
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
282+
scheduled_seq_groups_ids.append(scheduled_seq_group.request_id)
283+
scheduled_seq_groups_ids.sort()
284+
# The seq4 cannot be scheduled because if it is added, then the
285+
# batched tokens num will exceed the limitation
286+
assert scheduled_seq_groups_ids == [0, 1, 2, 3]

vllm/block.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Token blocks."""
2-
from typing import List
2+
from typing import List, Optional
33

44
from vllm.utils import Device
55

@@ -25,6 +25,7 @@ def __init__(
2525

2626
self.token_ids = [_BLANK_TOKEN_ID] * block_size
2727
self.num_tokens = 0
28+
self.block_hash: Optional[int] = None
2829

2930
def is_empty(self) -> bool:
3031
return self.num_tokens == 0

0 commit comments

Comments
 (0)