1
1
# SPDX-License-Identifier: Apache-2.0
2
2
"""Compare the with and without prefix caching."""
3
+ from typing import List
4
+
3
5
import pytest
4
6
5
7
from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
6
8
from vllm .sampling_params import SamplingParams
7
9
from vllm .utils import cdiv
10
+ from vllm .v1 .core .block_pool import BlockPool
8
11
from vllm .v1 .core .kv_cache_manager import KVCacheManager , Request
9
- from vllm .v1 .core .kv_cache_utils import KVCacheBlock , hash_block_tokens
12
+ from vllm .v1 .core .kv_cache_utils import (BlockHashType , KVCacheBlock ,
13
+ hash_block_tokens )
10
14
11
15
12
16
def make_request (request_id ,
@@ -62,14 +66,14 @@ def test_prefill():
62
66
for block_id in (0 , 1 , 2 ):
63
67
block_tokens = tuple (all_token_ids [block_id * 16 :(block_id + 1 ) * 16 ])
64
68
block_hash = hash_block_tokens (parent_block_hash , block_tokens )
65
- assert manager .block_pool [block_id ].block_hash == block_hash
66
- assert manager .block_pool [block_id ].ref_cnt == 1
69
+ assert manager .block_pool . blocks [block_id ].block_hash == block_hash
70
+ assert manager .block_pool . blocks [block_id ].ref_cnt == 1
67
71
parent_block_hash = block_hash .hash_value
68
72
69
73
# Check partial/preallocated block metadata
70
74
for block_id in (3 , 4 ):
71
- assert manager .block_pool [block_id ].block_hash is None
72
- assert manager .block_pool [block_id ].ref_cnt == 1
75
+ assert manager .block_pool . blocks [block_id ].block_hash is None
76
+ assert manager .block_pool . blocks [block_id ].ref_cnt == 1
73
77
74
78
# Cache hit in the common prefix when the original block is still in use.
75
79
# Incomplete 1 block (5 tokens)
@@ -86,20 +90,21 @@ def test_prefill():
86
90
assert block .ref_cnt == 2
87
91
88
92
# At this point, we should have 3 free blocks left.
89
- assert manager .free_block_queue .num_free_blocks == 3
93
+ assert manager .block_pool . free_block_queue .num_free_blocks == 3
90
94
91
95
manager .free (req0 )
92
96
manager .free (req1 )
93
97
94
98
# All blocks should be available.
95
- assert manager .free_block_queue .num_free_blocks == 10
99
+ assert manager .block_pool . free_block_queue .num_free_blocks == 10
96
100
# The order should be
97
101
# [unallocated (7, 8, 9)]
98
102
# [unique_req0 (4, 3)]
99
103
# [unique_req1 (6, 5)]
100
104
# [common (2, 1, 0)]
101
105
assert [
102
- b .block_id for b in manager .free_block_queue .get_all_free_blocks ()
106
+ b .block_id
107
+ for b in manager .block_pool .free_block_queue .get_all_free_blocks ()
103
108
] == [7 , 8 , 9 , 4 , 3 , 6 , 5 , 2 , 1 , 0 ]
104
109
105
110
# Cache hit in the common prefix when the original block is already free.
@@ -116,12 +121,14 @@ def test_prefill():
116
121
117
122
# Although we only have 5 free blocks, we have 8 blocks in
118
123
# the free block queue due to lazy removal.
119
- assert manager .free_block_queue .num_free_blocks == 5
124
+ assert manager .block_pool . free_block_queue .num_free_blocks == 5
120
125
assert all ([
121
- b .ref_cnt == 0 for b in manager .free_block_queue .get_all_free_blocks ()
126
+ b .ref_cnt == 0
127
+ for b in manager .block_pool .free_block_queue .get_all_free_blocks ()
122
128
])
123
- assert len ([b
124
- for b in manager .free_block_queue .get_all_free_blocks ()]) == 5
129
+ assert len ([
130
+ b for b in manager .block_pool .free_block_queue .get_all_free_blocks ()
131
+ ]) == 5
125
132
126
133
manager .free (req2 )
127
134
@@ -133,9 +140,9 @@ def test_prefill():
133
140
blocks = manager .allocate_slots (req3 , 16 * 9 , computed_blocks )
134
141
# This block ID order also checks the eviction order.
135
142
assert [b .block_id for b in blocks ] == [9 , 4 , 3 , 6 , 5 , 8 , 7 , 2 , 1 , 0 ]
136
- assert manager .free_block_queue .num_free_blocks == 0
137
- assert manager .free_block_queue .free_list_head is None
138
- assert manager .free_block_queue .free_list_tail is None
143
+ assert manager .block_pool . free_block_queue .num_free_blocks == 0
144
+ assert manager .block_pool . free_block_queue .free_list_head is None
145
+ assert manager .block_pool . free_block_queue .free_list_tail is None
139
146
140
147
141
148
def test_decode ():
@@ -219,13 +226,14 @@ def test_evict():
219
226
assert len (blocks ) == 3 # 3 full blocks
220
227
last_token_id += 3 * 16
221
228
222
- assert manager .free_block_queue .num_free_blocks == 0
229
+ assert manager .block_pool . free_block_queue .num_free_blocks == 0
223
230
224
231
manager .free (req0 )
225
232
manager .free (req1 )
226
- assert manager .free_block_queue .num_free_blocks == 10
233
+ assert manager .block_pool . free_block_queue .num_free_blocks == 10
227
234
assert [
228
- b .block_id for b in manager .free_block_queue .get_all_free_blocks ()
235
+ b .block_id
236
+ for b in manager .block_pool .free_block_queue .get_all_free_blocks ()
229
237
] == [6 , 5 , 4 , 3 , 2 , 1 , 0 , 9 , 8 , 7 ]
230
238
231
239
# Touch the first 2 blocks.
@@ -235,7 +243,7 @@ def test_evict():
235
243
assert num_computed_tokens == 2 * 16
236
244
blocks = manager .allocate_slots (req2 , 3 , computed_blocks )
237
245
assert [b .block_id for b in blocks ] == [6 , 5 ]
238
- assert manager .free_block_queue .num_free_blocks == 6
246
+ assert manager .block_pool . free_block_queue .num_free_blocks == 6
239
247
240
248
241
249
def test_hash_block_correct_reuse ():
@@ -274,7 +282,7 @@ def test_hash_block_correct_reuse():
274
282
blocks = manager .allocate_slots (req , num_tokens - 1 , computed_blocks )
275
283
assert len (blocks ) == 1
276
284
277
- assert manager .block_pool [blocks [0 ].block_id ].block_hash is None
285
+ assert manager .block_pool . blocks [blocks [0 ].block_id ].block_hash is None
278
286
279
287
280
288
def test_computed_blocks_not_evicted ():
@@ -413,13 +421,9 @@ def test_cache_blocks():
413
421
function of KVCacheManager.
414
422
"""
415
423
block_size = 4
416
- manager = KVCacheManager (
417
- block_size = block_size ,
424
+ block_pool = BlockPool (
418
425
num_gpu_blocks = 5 ,
419
- max_model_len = 8192 ,
420
- sliding_window = None ,
421
426
enable_caching = True ,
422
- num_preallocate_tokens = 0 ,
423
427
)
424
428
# Req:
425
429
# Block 0: [0, 1, 2, 3]
@@ -430,26 +434,31 @@ def test_cache_blocks():
430
434
431
435
# Test that blocks are cached correctly for 2 full blocks from the start.
432
436
blocks = [KVCacheBlock (block_id = i ) for i in range (2 )]
437
+ block_hashes : List [BlockHashType ] = []
433
438
434
- manager . _cache_full_blocks (
439
+ block_pool . cache_full_blocks (
435
440
request = req ,
436
- blk_start_idx = 0 ,
437
- full_blocks = blocks ,
438
- prev_block = None ,
441
+ blocks = blocks ,
442
+ block_hashes = block_hashes ,
443
+ num_cached_blocks = 0 ,
444
+ num_full_blocks = 2 ,
445
+ block_size = block_size ,
439
446
)
440
447
441
- assert len (manager .cached_block_hash_to_block ) == 2
448
+ assert len (block_pool .cached_block_hash_to_block ) == 2
442
449
assert all ([block .block_hash is not None for block in blocks ])
443
450
444
451
# Test that blocks that don't start from the beginning are cached correctly.
445
- blocks = [KVCacheBlock (block_id = 2 )]
446
- manager . _cache_full_blocks (
452
+ blocks + = [KVCacheBlock (block_id = 2 )]
453
+ block_pool . cache_full_blocks (
447
454
request = req ,
448
- blk_start_idx = 2 ,
449
- full_blocks = blocks ,
450
- prev_block = None ,
455
+ blocks = blocks ,
456
+ block_hashes = block_hashes ,
457
+ num_cached_blocks = 2 ,
458
+ num_full_blocks = 3 ,
459
+ block_size = block_size ,
451
460
)
452
- assert len (manager .cached_block_hash_to_block ) == 3
461
+ assert len (block_pool .cached_block_hash_to_block ) == 3
453
462
assert blocks [0 ].block_hash is not None
454
463
455
464
@@ -580,7 +589,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
580
589
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
581
590
# but it cannot be allocated due to insufficient free blocks (2).
582
591
# In this case, the ref_cnt of the computed blocks should not be changed.
583
- assert manager .free_block_queue .num_free_blocks == 5
592
+ assert manager .block_pool . free_block_queue .num_free_blocks == 5
584
593
req3 = make_request ("3" , common_token_ids * 3 )
585
594
computed_blocks , num_computed_tokens = manager .get_computed_blocks (req3 )
586
595
assert computed_blocks == block_part1
@@ -621,12 +630,12 @@ def test_reset_prefix_cache():
621
630
622
631
# Failed to reset prefix cache because some blocks are not freed yet.
623
632
assert not manager .reset_prefix_cache ()
624
- assert manager .cached_block_hash_to_block
633
+ assert manager .block_pool . cached_block_hash_to_block
625
634
626
635
# Free the blocks.
627
636
manager .free (req0 )
628
637
manager .free (req1 )
629
638
630
639
assert manager .reset_prefix_cache ()
631
- assert not manager .cached_block_hash_to_block
632
- assert all ([blk .block_hash is None for blk in manager .block_pool ])
640
+ assert not manager .block_pool . cached_block_hash_to_block
641
+ assert all ([blk .block_hash is None for blk in manager .block_pool . blocks ])
0 commit comments