diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e8069b8c6d7f..df487ec2ccaa 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -542,7 +542,7 @@ def test_allocate_with_lookahead(): num_tokens=3, num_lookahead_tokens=2, # Total required: 3+2=5 tokens ) - assert len(blocks) == 2 # ceil(5/4)=2 blocks + assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks # Test case 2: With precomputed blocks kv_cache_manager = KVCacheManager(kv_cache_config=config, @@ -553,7 +553,7 @@ def test_allocate_with_lookahead(): num_tokens=3, num_lookahead_tokens=2, ) - assert len(blocks) == 2 + assert len(blocks.blocks) == 2 # Test case 3: With precomputed blocks # required_blocks = ceil((3 + 4) / 4) = 2 @@ -564,4 +564,4 @@ def test_allocate_with_lookahead(): num_tokens=3, num_lookahead_tokens=4, ) - assert len(blocks) == 2 + assert len(blocks.blocks) == 2 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 4c05e0b87fc5..01295e848ee9 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -79,10 +79,10 @@ def test_prefill(hash_algo): req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(manager.req_to_block_hashes[req0.request_id]) == 3 - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] + assert blocks.get_block_ids() == [1, 2, 3, 4] # Check full block metadata parent_block_hash = None @@ -105,12 +105,12 @@ def test_prefill(hash_algo): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert [b.block_id for b in computed_blocks] == [1, 2, 3] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [5] - for block in computed_blocks: + assert blocks.get_block_ids() == [5] + for block in computed_blocks.blocks: assert block.ref_cnt == 2 # At this point, we should have 5 free blocks left. @@ -137,11 +137,11 @@ def test_prefill(hash_algo): req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 - assert [b.block_id for b in computed_blocks] == [1, 2, 3] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [6] + assert blocks.get_block_ids() == [6] # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -159,11 +159,11 @@ def test_prefill(hash_algo): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 10)) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks) # This block ID order also checks the eviction order. - assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] + assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None @@ -195,11 +195,11 @@ def test_prefill_plp(): req0 = make_request("0", all_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(manager.req_to_block_hashes[req0.request_id]) == 3 - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] - req0_block_hashes = [b.block_hash for b in blocks] + assert blocks.get_block_ids() == [1, 2, 3, 4] + req0_block_hashes = [b.block_hash for b in blocks.blocks] # Check full block metadata parent_block_hash = None @@ -223,12 +223,12 @@ def test_prefill_plp(): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert [b.block_id for b in computed_blocks] == [1, 2, 3] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [5] - for block in computed_blocks: + assert blocks.get_block_ids() == [5] + for block in computed_blocks.blocks: assert block.ref_cnt == 2 # At this point, we should have 5 free blocks left. @@ -257,12 +257,12 @@ def test_prefill_plp(): prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 55, computed_blocks) - block_ids = [b.block_id for b in blocks] + block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 - assert [b.block_hash for b in blocks] == req0_block_hashes + assert [b.block_hash for b in blocks.blocks] == req0_block_hashes assert block_ids != [1, 2, 3, 4] # Request #2 block hashes are valid since request #0 hashes are. @@ -288,17 +288,17 @@ def test_decode(): unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] + assert blocks.get_block_ids() == [1, 2, 3, 4] # Append slots without allocating a new block. req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 4) - assert new_blocks is not None and len(new_blocks) == 0 + assert new_blocks is not None and len(new_blocks.blocks) == 0 assert manager.req_to_blocks[req0.request_id][-1].block_hash is None # Append slots with allocating a new block. @@ -308,7 +308,7 @@ def test_decode(): for _ in range(9 + 10): req0.append_output_token_ids(7) new_blocks = manager.allocate_slots(req0, 19) - assert new_blocks is not None and len(new_blocks) == 1 + assert new_blocks is not None and len(new_blocks.blocks) == 1 assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None assert manager.req_to_blocks[req0.request_id][-1].block_hash is None @@ -323,19 +323,19 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) - assert len(blocks) == 6 # 5 full + 1 partial + assert len(blocks.blocks) == 6 # 5 full + 1 partial # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) - assert len(blocks) == 3 # 3 full blocks + assert len(blocks.blocks) == 3 # 3 full blocks last_token_id += 3 * 16 # 10 - (6 + 3) == 1 @@ -352,10 +352,10 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert [b.block_id for b in computed_blocks] == [1, 2] + assert computed_blocks.get_block_ids() == [1, 2] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) - assert [b.block_id for b in blocks] == [10] + assert blocks.get_block_ids() == [10] assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -375,10 +375,10 @@ def test_hash_block_correct_reuse(): num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens, computed_blocks) - assert len(blocks) == 1 + assert len(blocks.blocks) == 1 # Deallocate the block. manager.free(req) @@ -387,12 +387,13 @@ def test_hash_block_correct_reuse(): # block is cleared. req = make_request("1", list(range(num_tokens - 1))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) - assert len(blocks) == 1 + assert len(blocks.blocks) == 1 - assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None + assert manager.block_pool.blocks[ + blocks.blocks[0].block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -411,20 +412,20 @@ def test_computed_blocks_not_evicted(): num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 1 + assert len(blocks.blocks) == 1 + assert blocks.blocks[0].block_id == 1 # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 2 + assert len(blocks.blocks) == 1 + assert blocks.blocks[0].block_id == 2 # Free the blocks. manager.free(req0) @@ -434,14 +435,14 @@ def test_computed_blocks_not_evicted(): # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(computed_blocks) == 1 - assert computed_blocks[0].block_id == 1 + assert len(computed_blocks.blocks) == 1 + assert computed_blocks.blocks[0].block_id == 1 assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, computed_blocks) - assert len(blocks) == 1 - assert blocks[0].block_id == 2 + assert len(blocks.blocks) == 1 + assert blocks.blocks[0].block_id == 2 def test_basic_prefix_caching_disabled(): @@ -458,10 +459,10 @@ def test_basic_prefix_caching_disabled(): req1 = make_request("1", list(range(10))) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 10, computed_blocks) - assert len(blocks) == 3 + assert len(blocks.blocks) == 3 # Free the blocks. manager.free(req1) @@ -469,15 +470,15 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16))) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 16, computed_blocks) - assert len(blocks) == 4 + assert len(blocks.blocks) == 4 # New requests should not have any blocks. req3 = make_request("3", list(range(4))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 4, computed_blocks) assert not blocks @@ -569,7 +570,7 @@ def test_mm_prefix_caching(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req0.request_id] assert len(block_hashes) == 3 @@ -578,14 +579,14 @@ def test_mm_prefix_caching(): assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] + assert blocks.get_block_ids() == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 5) - assert new_blocks is not None and len(new_blocks) == 0 + assert new_blocks is not None and len(new_blocks.blocks) == 0 # The just completed block should have hashes with extra keys. assert len(block_hashes) == 4 @@ -603,7 +604,7 @@ def test_mm_prefix_caching(): mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(computed_blocks) == 3 + assert len(computed_blocks.blocks) == 3 assert num_computed_tokens == 3 * 16 @@ -626,7 +627,7 @@ def test_cache_key_salting(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req0.request_id] assert len(block_hashes) == 3 @@ -635,14 +636,14 @@ def test_cache_key_salting(): assert block_hashes[2].extra_keys is None blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] + assert blocks.get_block_ids() == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 5) - assert new_blocks is not None and len(new_blocks) == 0 + assert new_blocks is not None and len(new_blocks.blocks) == 0 # Now one more block that should not have extra keys. assert len(block_hashes) == 4 @@ -653,14 +654,14 @@ def test_cache_key_salting(): req1 = make_request("1", token_ids, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. - assert len(computed_blocks) == 3 + assert len(computed_blocks.blocks) == 3 assert num_computed_tokens == 3 * block_size # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 req2 = make_request("2", token_ids, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(computed_blocks) == 0 + assert len(computed_blocks.blocks) == 0 assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req2.request_id] assert len(block_hashes) == 3 @@ -685,7 +686,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 manager.allocate_slots(req0, 48, computed_blocks) block_part0 = manager.req_to_blocks[req0.request_id] @@ -693,7 +694,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert computed_blocks == block_part0 + assert computed_blocks.blocks == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots(req1, 48, computed_blocks) block_part1 = manager.req_to_blocks[req1.request_id] @@ -707,7 +708,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, computed_blocks) @@ -717,7 +718,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert manager.block_pool.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert computed_blocks == block_part1 + assert computed_blocks.blocks == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. assert manager.allocate_slots(req3, 48, computed_blocks) is None @@ -739,16 +740,16 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) blocks = manager.allocate_slots(req0, 55) - assert [b.block_id for b in blocks] == [1, 2, 3, 4] + assert blocks.get_block_ids() == [1, 2, 3, 4] unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) computed_blocks, _ = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert len(computed_blocks) == 3 + assert len(computed_blocks.blocks) == 3 blocks = manager.allocate_slots(req1, 7, computed_blocks) - assert [b.block_id for b in blocks] == [5] + assert blocks.get_block_ids() == [5] # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() @@ -776,7 +777,7 @@ def test_prefix_cache_stats_disabled(): # Call all functions that check whether log_stats is disabled. req = make_request("0", list(range(16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks + assert not computed_blocks.blocks assert num_computed_tokens == 0 manager.allocate_slots(req, 16, computed_blocks) manager.reset_prefix_cache() @@ -866,7 +867,7 @@ def test_eagle_enabled_removes_last_block(): # Should retain 1 block: # 1. Original 3 blocks → pop last hash → 2 matched blocks # 2. drop last matched block → 1 remaining block - assert len(computed_blocks) == 1 + assert len(computed_blocks.blocks) == 1 assert num_tokens == 1 * block_size # 16 tokens @@ -892,7 +893,7 @@ def test_eagle_with_partial_blocks(): req_eagle = make_request("partial_eagle", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining - assert len(computed_blocks) == 1 + assert len(computed_blocks.blocks) == 1 assert num_tokens == 1 * block_size @@ -934,7 +935,7 @@ def test_eagle_with_sliding_window(): req_eagle = make_request("partial_eagle", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining - assert len(computed_blocks) == 1 + assert len(computed_blocks.blocks) == 1 assert num_tokens == 1 * block_size # Evict the first block in the request @@ -948,5 +949,5 @@ def test_eagle_with_sliding_window(): # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, # there will be no matched prefix. - assert len(computed_blocks) == 0 + assert len(computed_blocks.blocks) == 0 assert num_tokens == 0 diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index a2fa5825bb1a..9e172b6bdb00 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -2,6 +2,7 @@ from collections import defaultdict from collections.abc import Iterable +from dataclasses import dataclass from typing import Optional from vllm.distributed.kv_events import KVCacheEvent @@ -18,6 +19,24 @@ logger = init_logger(__name__) +@dataclass +class KVCacheBlocks: + blocks: list[KVCacheBlock] + + def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": + """Adds two KVCacheBlocks instances.""" + return KVCacheBlocks(self.blocks + other.blocks) + + @classmethod + def create_empty(cls) -> "KVCacheBlocks": + """Creates a new KVCacheBlocks instance with no blocks.""" + return cls([]) + + def get_block_ids(self) -> list[int]: + """Converts the KVCacheBlocks instance to a list of block IDs.""" + return [block.block_id for block in self.blocks] + + class KVCacheManager: def __init__( @@ -94,8 +113,8 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks( - self, request: Request) -> tuple[list[KVCacheBlock], int]: + def get_computed_blocks(self, + request: Request) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -109,7 +128,7 @@ def get_computed_blocks( """ if not self.enable_caching: # Prefix caching is disabled. - return [], 0 + return KVCacheBlocks.create_empty(), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. @@ -124,7 +143,7 @@ def get_computed_blocks( self.prefix_cache_stats.requests += 1 # When the request requires prompt logprobs, we skip prefix caching. if request.sampling_params.prompt_logprobs is not None: - return [], 0 + return KVCacheBlocks.create_empty(), 0 if len(block_hashes) * self.block_size == request.num_tokens: # When prompt length is divisible by the block size and all @@ -157,15 +176,15 @@ def get_computed_blocks( # sharing, `num_computed_tokens` is always a multiple of # `block_size`. num_computed_tokens = len(computed_blocks) * self.block_size - return computed_blocks, num_computed_tokens + return KVCacheBlocks(computed_blocks), num_computed_tokens def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[list[KVCacheBlock]] = None, + new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, - ) -> Optional[list[KVCacheBlock]]: + ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. Args: @@ -173,7 +192,7 @@ def allocate_slots( num_tokens: The number of tokens to allocate, including external tokens. Note that this does not include tokens that have already been computed locally (i.e. new_computed_blocks). - new_computed_blocks: A list of new computed blocks just hitting the + new_computed_blocks: The new computed blocks just hitting the prefix caching. num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such @@ -199,7 +218,10 @@ def allocate_slots( if num_tokens == 0: raise ValueError("num_tokens must be greater than 0") - new_computed_blocks = new_computed_blocks or [] + if new_computed_blocks is not None: + new_computed_block_list = new_computed_blocks.blocks + else: + new_computed_block_list = [] req_blocks = self.req_to_blocks[request.request_id] @@ -216,17 +238,18 @@ def allocate_slots( # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + - len(new_computed_blocks) * self.block_size) + len(new_computed_block_list) * self.block_size) num_required_blocks = cdiv( num_computed_tokens + num_tokens + num_lookahead_tokens, self.block_size) num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_blocks)) + len(new_computed_block_list)) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it cannot be counted as a free block # when allocating this request. - num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks + num_evictable_computed_blocks = sum(1 + for blk in new_computed_block_list if blk.ref_cnt == 0) if (num_new_blocks > self.block_pool.get_num_free_blocks() - num_evictable_computed_blocks): @@ -235,15 +258,15 @@ def allocate_slots( # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self.block_pool.touch(new_computed_blocks) + self.block_pool.touch(new_computed_block_list) else: - assert not new_computed_blocks, ( + assert not new_computed_block_list, ( "Computed blocks should be empty when " "prefix caching is disabled") # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - req_blocks.extend(new_computed_blocks) + req_blocks.extend(new_computed_block_list) # Start to handle new blocks @@ -267,12 +290,12 @@ def allocate_slots( req_blocks.extend(new_blocks) if not self.enable_caching: - return new_blocks + return KVCacheBlocks(new_blocks) - # Use `new_computed_blocks` for a new request, and `num_cached_block` - # for a running request. - num_cached_blocks = self.num_cached_block.get(request.request_id, - len(new_computed_blocks)) + # Use `new_computed_block_list` for a new request, and + # `num_cached_block` for a running request. + num_cached_blocks = self.num_cached_block.get( + request.request_id, len(new_computed_block_list)) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. @@ -291,7 +314,7 @@ def allocate_slots( self.num_cached_block[ request.request_id] = num_full_blocks_after_append - return new_blocks + return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 05472ea573d3..258e0d570e3e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -261,9 +261,8 @@ def schedule(self) -> SchedulerOutput: # Therefore, we might introduce some additional # cycle to fill in the bitmask, which could be a big no-op. structured_output_request_ids[request.request_id] = req_index - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in new_blocks - ] + req_to_new_block_ids[request.request_id] = ( + new_blocks.get_block_ids()) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -407,9 +406,8 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in computed_blocks + new_blocks - ] + req_to_new_block_ids[request.request_id] = ( + computed_blocks + new_blocks).get_block_ids() num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING