Skip to content

Commit 28943d3

Browse files
[v1] Move block pool operations to a separate class (#13973)
Signed-off-by: Chen Zhang <[email protected]> Co-authored-by: Cody Yu <[email protected]>
1 parent b526ca6 commit 28943d3

File tree

3 files changed

+360
-277
lines changed

3 files changed

+360
-277
lines changed

tests/v1/core/test_prefix_caching.py

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Compare the with and without prefix caching."""
3+
from typing import List
4+
35
import pytest
46

57
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
68
from vllm.sampling_params import SamplingParams
79
from vllm.utils import cdiv
10+
from vllm.v1.core.block_pool import BlockPool
811
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
9-
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
12+
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
13+
hash_block_tokens)
1014

1115

1216
def make_request(request_id,
@@ -62,14 +66,14 @@ def test_prefill():
6266
for block_id in (0, 1, 2):
6367
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
6468
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
65-
assert manager.block_pool[block_id].block_hash == block_hash
66-
assert manager.block_pool[block_id].ref_cnt == 1
69+
assert manager.block_pool.blocks[block_id].block_hash == block_hash
70+
assert manager.block_pool.blocks[block_id].ref_cnt == 1
6771
parent_block_hash = block_hash.hash_value
6872

6973
# Check partial/preallocated block metadata
7074
for block_id in (3, 4):
71-
assert manager.block_pool[block_id].block_hash is None
72-
assert manager.block_pool[block_id].ref_cnt == 1
75+
assert manager.block_pool.blocks[block_id].block_hash is None
76+
assert manager.block_pool.blocks[block_id].ref_cnt == 1
7377

7478
# Cache hit in the common prefix when the original block is still in use.
7579
# Incomplete 1 block (5 tokens)
@@ -86,20 +90,21 @@ def test_prefill():
8690
assert block.ref_cnt == 2
8791

8892
# At this point, we should have 3 free blocks left.
89-
assert manager.free_block_queue.num_free_blocks == 3
93+
assert manager.block_pool.free_block_queue.num_free_blocks == 3
9094

9195
manager.free(req0)
9296
manager.free(req1)
9397

9498
# All blocks should be available.
95-
assert manager.free_block_queue.num_free_blocks == 10
99+
assert manager.block_pool.free_block_queue.num_free_blocks == 10
96100
# The order should be
97101
# [unallocated (7, 8, 9)]
98102
# [unique_req0 (4, 3)]
99103
# [unique_req1 (6, 5)]
100104
# [common (2, 1, 0)]
101105
assert [
102-
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
106+
b.block_id
107+
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
103108
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
104109

105110
# Cache hit in the common prefix when the original block is already free.
@@ -116,12 +121,14 @@ def test_prefill():
116121

117122
# Although we only have 5 free blocks, we have 8 blocks in
118123
# the free block queue due to lazy removal.
119-
assert manager.free_block_queue.num_free_blocks == 5
124+
assert manager.block_pool.free_block_queue.num_free_blocks == 5
120125
assert all([
121-
b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks()
126+
b.ref_cnt == 0
127+
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
122128
])
123-
assert len([b
124-
for b in manager.free_block_queue.get_all_free_blocks()]) == 5
129+
assert len([
130+
b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
131+
]) == 5
125132

126133
manager.free(req2)
127134

@@ -133,9 +140,9 @@ def test_prefill():
133140
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
134141
# This block ID order also checks the eviction order.
135142
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
136-
assert manager.free_block_queue.num_free_blocks == 0
137-
assert manager.free_block_queue.free_list_head is None
138-
assert manager.free_block_queue.free_list_tail is None
143+
assert manager.block_pool.free_block_queue.num_free_blocks == 0
144+
assert manager.block_pool.free_block_queue.free_list_head is None
145+
assert manager.block_pool.free_block_queue.free_list_tail is None
139146

140147

141148
def test_decode():
@@ -219,13 +226,14 @@ def test_evict():
219226
assert len(blocks) == 3 # 3 full blocks
220227
last_token_id += 3 * 16
221228

222-
assert manager.free_block_queue.num_free_blocks == 0
229+
assert manager.block_pool.free_block_queue.num_free_blocks == 0
223230

224231
manager.free(req0)
225232
manager.free(req1)
226-
assert manager.free_block_queue.num_free_blocks == 10
233+
assert manager.block_pool.free_block_queue.num_free_blocks == 10
227234
assert [
228-
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
235+
b.block_id
236+
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
229237
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]
230238

231239
# Touch the first 2 blocks.
@@ -235,7 +243,7 @@ def test_evict():
235243
assert num_computed_tokens == 2 * 16
236244
blocks = manager.allocate_slots(req2, 3, computed_blocks)
237245
assert [b.block_id for b in blocks] == [6, 5]
238-
assert manager.free_block_queue.num_free_blocks == 6
246+
assert manager.block_pool.free_block_queue.num_free_blocks == 6
239247

240248

241249
def test_hash_block_correct_reuse():
@@ -274,7 +282,7 @@ def test_hash_block_correct_reuse():
274282
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
275283
assert len(blocks) == 1
276284

277-
assert manager.block_pool[blocks[0].block_id].block_hash is None
285+
assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None
278286

279287

280288
def test_computed_blocks_not_evicted():
@@ -413,13 +421,9 @@ def test_cache_blocks():
413421
function of KVCacheManager.
414422
"""
415423
block_size = 4
416-
manager = KVCacheManager(
417-
block_size=block_size,
424+
block_pool = BlockPool(
418425
num_gpu_blocks=5,
419-
max_model_len=8192,
420-
sliding_window=None,
421426
enable_caching=True,
422-
num_preallocate_tokens=0,
423427
)
424428
# Req:
425429
# Block 0: [0, 1, 2, 3]
@@ -430,26 +434,31 @@ def test_cache_blocks():
430434

431435
# Test that blocks are cached correctly for 2 full blocks from the start.
432436
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
437+
block_hashes: List[BlockHashType] = []
433438

434-
manager._cache_full_blocks(
439+
block_pool.cache_full_blocks(
435440
request=req,
436-
blk_start_idx=0,
437-
full_blocks=blocks,
438-
prev_block=None,
441+
blocks=blocks,
442+
block_hashes=block_hashes,
443+
num_cached_blocks=0,
444+
num_full_blocks=2,
445+
block_size=block_size,
439446
)
440447

441-
assert len(manager.cached_block_hash_to_block) == 2
448+
assert len(block_pool.cached_block_hash_to_block) == 2
442449
assert all([block.block_hash is not None for block in blocks])
443450

444451
# Test that blocks that don't start from the beginning are cached correctly.
445-
blocks = [KVCacheBlock(block_id=2)]
446-
manager._cache_full_blocks(
452+
blocks += [KVCacheBlock(block_id=2)]
453+
block_pool.cache_full_blocks(
447454
request=req,
448-
blk_start_idx=2,
449-
full_blocks=blocks,
450-
prev_block=None,
455+
blocks=blocks,
456+
block_hashes=block_hashes,
457+
num_cached_blocks=2,
458+
num_full_blocks=3,
459+
block_size=block_size,
451460
)
452-
assert len(manager.cached_block_hash_to_block) == 3
461+
assert len(block_pool.cached_block_hash_to_block) == 3
453462
assert blocks[0].block_hash is not None
454463

455464

@@ -580,7 +589,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
580589
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
581590
# but it cannot be allocated due to insufficient free blocks (2).
582591
# In this case, the ref_cnt of the computed blocks should not be changed.
583-
assert manager.free_block_queue.num_free_blocks == 5
592+
assert manager.block_pool.free_block_queue.num_free_blocks == 5
584593
req3 = make_request("3", common_token_ids * 3)
585594
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
586595
assert computed_blocks == block_part1
@@ -621,12 +630,12 @@ def test_reset_prefix_cache():
621630

622631
# Failed to reset prefix cache because some blocks are not freed yet.
623632
assert not manager.reset_prefix_cache()
624-
assert manager.cached_block_hash_to_block
633+
assert manager.block_pool.cached_block_hash_to_block
625634

626635
# Free the blocks.
627636
manager.free(req0)
628637
manager.free(req1)
629638

630639
assert manager.reset_prefix_cache()
631-
assert not manager.cached_block_hash_to_block
632-
assert all([blk.block_hash is None for blk in manager.block_pool])
640+
assert not manager.block_pool.cached_block_hash_to_block
641+
assert all([blk.block_hash is None for blk in manager.block_pool.blocks])

0 commit comments

Comments
 (0)