diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 71ea43383a7e..ab7aa02823ab 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -15,8 +15,8 @@ from vllm.v1.core.kv_cache_utils import ( FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, - get_max_concurrency_for_kv_cache_config, hash_block_tokens, - hash_request_tokens, unify_kv_cache_configs) + get_kv_cache_config, get_max_concurrency_for_kv_cache_config, + hash_block_tokens, hash_request_tokens, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, SlidingWindowSpec) @@ -63,6 +63,20 @@ def new_kv_cache_spec(block_size=16, sliding_window=sliding_window) +def new_sliding_window_spec(block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float32, + use_mla=False, + sliding_window=1): + return SlidingWindowSpec(block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + use_mla=use_mla, + sliding_window=sliding_window) + + def test_none_hash(monkeypatch): import vllm.v1.core.kv_cache_utils @@ -403,10 +417,10 @@ def test_unify_kv_cache_configs(): same_kv_cache_config = [ KVCacheConfig( num_blocks=10, - tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), KVCacheGroupSpec(["layer2"], @@ -415,10 +429,10 @@ def test_unify_kv_cache_configs(): ), KVCacheConfig( num_blocks=20, - tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), KVCacheGroupSpec(["layer2"], @@ -433,10 +447,10 @@ def test_unify_kv_cache_configs(): need_sort_kv_cache_config = [ KVCacheConfig( num_blocks=10, - tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), KVCacheGroupSpec(["layer2"], @@ -445,10 +459,10 @@ def test_unify_kv_cache_configs(): ), KVCacheConfig( num_blocks=20, - tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer2"], new_kv_cache_spec(num_kv_heads=4)), @@ -464,10 +478,10 @@ def test_unify_kv_cache_configs(): diff_kv_cache_config = [ KVCacheConfig( num_blocks=10, - tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), KVCacheGroupSpec(["layer2"], @@ -476,10 +490,10 @@ def test_unify_kv_cache_configs(): ), KVCacheConfig( num_blocks=20, - tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), KVCacheGroupSpec(["layer2"], @@ -636,7 +650,7 @@ def test_get_max_concurrency_for_kv_cache_config(): kv_cache_config_full_attention = KVCacheConfig( num_blocks=int(1024 * 1.5), - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec([f"layer_{i}" for i in range(32)], full_attention_spec), @@ -648,7 +662,7 @@ def test_get_max_concurrency_for_kv_cache_config(): kv_cache_config_sliding_window = KVCacheConfig( num_blocks=129 * 3, - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec([f"layer_{i}" for i in range(32)], sliding_window_spec), @@ -660,7 +674,7 @@ def test_get_max_concurrency_for_kv_cache_config(): kv_cache_config_hybrid_model = KVCacheConfig( num_blocks=(1024 + 129) * 3, - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec([f"layer_{i}" for i in range(32)], full_attention_spec), @@ -678,9 +692,9 @@ def test_allocate_with_lookahead(): block_size = 4 config = KVCacheConfig( num_blocks=10, - tensors={ - "layer1": KVCacheTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec(block_size=block_size)), @@ -702,7 +716,7 @@ def test_allocate_with_lookahead(): num_new_tokens=3, num_lookahead_tokens=2, # Total required: 3+2=5 tokens ) - assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks + assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks # Test case 2: With precomputed blocks kv_cache_manager = KVCacheManager(kv_cache_config=config, @@ -713,7 +727,7 @@ def test_allocate_with_lookahead(): num_new_tokens=3, num_lookahead_tokens=2, ) - assert len(blocks.blocks) == 2 + assert len(blocks.get_block_ids()[0]) == 2 # Test case 3: With precomputed blocks # required_blocks = ceil((3 + 4) / 4) = 2 @@ -724,4 +738,165 @@ def test_allocate_with_lookahead(): num_new_tokens=3, num_lookahead_tokens=4, ) - assert len(blocks.blocks) == 2 + assert len(blocks.get_block_ids()[0]) == 2 + + +def test_get_kv_cache_config(): + # pass max_model_len to pass check_enough_kv_cache_memory + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config) + + mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 + # all layers are full attention -> single group + kv_cache_specs_full = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_kv_cache_spec(), + } + kv_cache_config_full = get_kv_cache_config( + vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32) + assert kv_cache_config_full == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) + ]) + + # all layers are sliding window -> single group + kv_cache_specs_sliding = { + 'layer_1': new_sliding_window_spec(), + 'layer_2': new_sliding_window_spec(), + } + kv_cache_config_sliding = get_kv_cache_config( + vllm_config, kv_cache_specs_sliding, mem_per_block_per_layer * 2 * 32) + assert kv_cache_config_sliding == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], new_sliding_window_spec()) + ]) + + # full + sliding, but disable_hybrid_kv_cache_manager + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = True + kv_cache_specs_hybrid = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_sliding_window_spec(), + } + kv_cache_config_hybrid = get_kv_cache_config( + vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], + new_kv_cache_spec(sliding_window=1)), + ], + ) + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False + + # full + sliding, with hybrid_kv_cache_manager + kv_cache_specs_hybrid = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_sliding_window_spec(), + } + kv_cache_config_hybrid = get_kv_cache_config( + vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=64, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 64, + shared_by=["layer_1", "layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer_2"], new_sliding_window_spec()), + ], + ) + + # 2 full + 4 sliding, 2 layers per group + kv_cache_specs_hybrid = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_kv_cache_spec(), + 'layer_3': new_sliding_window_spec(), + 'layer_4': new_sliding_window_spec(), + 'layer_5': new_sliding_window_spec(), + 'layer_6': new_sliding_window_spec(), + } + kv_cache_config_hybrid = get_kv_cache_config( + vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_1", "layer_3", "layer_5"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_2", "layer_4", "layer_6"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer_3", "layer_4"], + new_sliding_window_spec()), + KVCacheGroupSpec(["layer_5", "layer_6"], + new_sliding_window_spec()), + ], + ) + + # 3 full + 7 sliding, pad to 3 full + 9 sliding + kv_cache_specs_hybrid = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_kv_cache_spec(), + 'layer_3': new_kv_cache_spec(), + 'layer_4': new_sliding_window_spec(), + 'layer_5': new_sliding_window_spec(), + 'layer_6': new_sliding_window_spec(), + 'layer_7': new_sliding_window_spec(), + 'layer_8': new_sliding_window_spec(), + 'layer_9': new_sliding_window_spec(), + 'layer_10': new_sliding_window_spec(), + } + kv_cache_config_hybrid = get_kv_cache_config( + vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 3 * 32) + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=["layer_1", "layer_4", "layer_7", "layer_10"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_2", "layer_5", "layer_8"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_3", "layer_6", "layer_9"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"], + new_kv_cache_spec()), + KVCacheGroupSpec(["layer_4", "layer_5", "layer_6"], + new_sliding_window_spec()), + KVCacheGroupSpec(["layer_7", "layer_8", "layer_9"], + new_sliding_window_spec()), + KVCacheGroupSpec(["layer_10"], new_sliding_window_spec()), + ], + ) + + # different hidden size, unimplemented + kv_cache_specs_hybrid = { + 'layer_1': new_kv_cache_spec(head_size=128), + 'layer_2': new_kv_cache_spec(), + } + with pytest.raises(NotImplementedError): + get_kv_cache_config(vllm_config, kv_cache_specs_hybrid, + mem_per_block_per_layer * 2 * 32) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 897d181ec9d5..bf4cb539ebef 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Compare the with and without prefix caching.""" +import copy from typing import Optional import pytest @@ -13,8 +14,8 @@ from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - hash_block_tokens) +from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, + KVCacheBlock, hash_block_tokens) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) @@ -47,7 +48,7 @@ def make_request(request_id, def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: return KVCacheConfig( num_blocks=num_blocks, - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( ["layer"], @@ -57,6 +58,38 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: ) +def make_kv_cache_config_hybrid_model(block_size: int, + num_blocks: int) -> KVCacheConfig: + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec(block_size, 1, 1, torch.float32, False), + ), + KVCacheGroupSpec( + ["layer2"], + SlidingWindowSpec(block_size, + 1, + 1, + torch.float32, + False, + sliding_window=2 * block_size), + ), + KVCacheGroupSpec( + ["layer3"], + SlidingWindowSpec(block_size, + 1, + 1, + torch.float32, + False, + sliding_window=2 * block_size), + ), + ], + ) + + @pytest.mark.parametrize("hash_algo", ["sha256", "hash"]) def test_prefill(hash_algo): manager = KVCacheManager( @@ -79,10 +112,10 @@ def test_prefill(hash_algo): req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(manager.req_to_block_hashes[req0.request_id]) == 3 - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[1, 2, 3, 4]] @@ -92,7 +125,8 @@ def test_prefill(hash_algo): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - assert manager.block_pool.blocks[block_id].block_hash == block_hash + assert manager.block_pool.blocks[ + block_id].block_hash.block_hash == block_hash assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value @@ -111,10 +145,10 @@ def test_prefill(hash_algo): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[5]] - for block in computed_blocks.blocks: + for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 # At this point, we should have 5 free blocks left. @@ -145,7 +179,7 @@ def test_prefill(hash_algo): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[6]] @@ -165,10 +199,10 @@ def test_prefill(hash_algo): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 10)) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 16 * 10, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) # This block ID order also checks the eviction order. assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]] @@ -177,6 +211,138 @@ def test_prefill(hash_algo): assert manager.block_pool.free_block_queue.free_list_tail is None +def test_prefill_hybrid_model(): + block_size = 16 + manager = KVCacheManager( + make_kv_cache_config_hybrid_model(block_size, 21), + max_model_len=8192, + enable_caching=True, + ) + + hash_fn = hash + + # Complete 3 blocks (48 tokens) + common_token_ids = [i for i in range(3) for _ in range(block_size)] + + # Fully cache miss + # Incomplete 1 block (7 tokens) + unique_token_ids = [3] * 7 + all_token_ids = common_token_ids + unique_token_ids + req0 = make_request("0", all_token_ids) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert not computed_blocks.blocks[0] + assert num_computed_tokens == 0 + blocks = manager.allocate_slots(req0, 55, + len(computed_blocks.blocks[0]) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12]] + + # Check full block metadata + parent_block_hash = None + for length, block_ids in zip((1, 2, 3), + ((1, 5, 9), (2, 6, 10), (3, 7, 11))): + block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, + block_tokens) + for block_id in block_ids: + assert manager.block_pool.blocks[ + block_id].block_hash.block_hash == block_hash + assert manager.block_pool.blocks[block_id].ref_cnt == 1 + parent_block_hash = block_hash.hash_value + + # Check partial block metadata + for block_id in (4, 8, 12): + assert manager.block_pool.blocks[block_id].block_hash is None + assert manager.block_pool.blocks[block_id].ref_cnt == 1 + + # Cache hit in the common prefix + # Incomplete 1 block (5 tokens) + unique_token_ids = [3] * 5 + req1 = make_request("1", common_token_ids + unique_token_ids) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7], + [0, 10, 11]] + assert num_computed_tokens == 3 * 16 + num_new_tokens = 53 - 3 * 16 + blocks = manager.allocate_slots(req1, num_new_tokens, + len(computed_blocks.blocks[0]) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[13], [14], [15]] + for block_per_group in computed_blocks.blocks: + for block in block_per_group: + if block != manager.block_pool.null_block: + assert block.ref_cnt == 2 + + block_hashes = manager.req_to_block_hashes[req1.request_id] + manager.free(req0) + manager.free(req1) + + cached_block_hash_to_block_bak = copy.copy( + manager.block_pool.cached_block_hash_to_block) + + def test_partial_request_hit(request_id: str, + hash_to_evict: list[BlockHashWithGroupId], + expect_hit_length: int): + req = make_request(request_id, common_token_ids + unique_token_ids) + for hash_with_group_id in hash_to_evict: + manager.block_pool.cached_block_hash_to_block.pop( + hash_with_group_id) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + assert len(manager.req_to_block_hashes[req.request_id]) == 3 + assert num_computed_tokens == expect_hit_length * block_size + for block_per_group in computed_blocks.blocks: + assert len(block_per_group) == num_computed_tokens // block_size + for hash_with_group_id in hash_to_evict: + manager.block_pool.cached_block_hash_to_block[ + hash_with_group_id] = cached_block_hash_to_block_bak[ + hash_with_group_id] + manager.free(req) + + # Evict the blocks outside sliding window, does not affect the hit length. + test_partial_request_hit("2", [ + BlockHashWithGroupId(block_hashes[0], 1), + BlockHashWithGroupId(block_hashes[0], 2) + ], 3) + + # Evict the first block of full attention, makes total cache miss. + test_partial_request_hit("3", [ + BlockHashWithGroupId(block_hashes[0], 0), + ], 0) + + # Evict the last block of all layers, reduces the hit length to 2. + test_partial_request_hit("4", [ + BlockHashWithGroupId(block_hashes[2], 0), + BlockHashWithGroupId(block_hashes[2], 1), + BlockHashWithGroupId(block_hashes[2], 2), + ], 2) + + # Evict the last block of full attention, reduces the hit length to 2. + test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)], + 2) + + # Evict the last block of sliding window, reduces the hit length to 2. + test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)], + 2) + + # Evict the last block of sliding window, reduces the hit length to 2. + test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)], + 2) + + # Evict different set of blocks for full attention and sliding window makes + # total cache miss. + # The cache hit length of full attention is 1 * block_size. + # The cache hit length of sliding window is 2 * block_size. + # Then it is cache miss as the two type of layers have different hit length. + test_partial_request_hit("8", [ + BlockHashWithGroupId(block_hashes[2], 0), + BlockHashWithGroupId(block_hashes[0], 1), + BlockHashWithGroupId(block_hashes[0], 2), + ], 0) + + def test_prefill_plp(): '''Test prefill with APC and some prompt logprobs (plp) requests. @@ -203,13 +369,13 @@ def test_prefill_plp(): req0 = make_request("0", all_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(manager.req_to_block_hashes[req0.request_id]) == 0 - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[1, 2, 3, 4]] - req0_block_hashes = [b.block_hash for b in blocks.blocks] + req0_block_hashes = [b.block_hash for b in blocks.blocks[0]] # Check full block metadata parent_block_hash = None @@ -217,7 +383,8 @@ def test_prefill_plp(): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - assert manager.block_pool.blocks[block_id].block_hash == block_hash + assert manager.block_pool.blocks[ + block_id].block_hash.block_hash == block_hash assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value @@ -237,10 +404,10 @@ def test_prefill_plp(): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[5]] - for block in computed_blocks.blocks: + for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 # At this point, we should have 5 free blocks left. @@ -269,14 +436,14 @@ def test_prefill_plp(): prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 0 - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 55, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 - assert [b.block_hash for b in blocks.blocks] == req0_block_hashes + assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes assert block_ids != [[1, 2, 3, 4]] # Request #2 block hashes are valid since request #0 hashes are. @@ -302,10 +469,10 @@ def test_decode(): unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[1, 2, 3, 4]] @@ -314,10 +481,10 @@ def test_decode(): for _ in range(4): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 4, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert new_blocks is not None and len(new_blocks.blocks) == 0 - assert manager.single_type_manager.req_to_blocks[ + assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 + assert manager.coordinator.single_type_managers[0].req_to_blocks[ req0.request_id][-1].block_hash is None # Append slots with allocating a new block. @@ -327,12 +494,12 @@ def test_decode(): for _ in range(9 + 10): req0.append_output_token_ids(7) new_blocks = manager.allocate_slots(req0, 19, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert new_blocks is not None and len(new_blocks.blocks) == 1 - assert manager.single_type_manager.req_to_blocks[ + assert new_blocks is not None and len(new_blocks.blocks[0]) == 1 + assert manager.coordinator.single_type_managers[0].req_to_blocks[ req0.request_id][-2].block_hash is not None - assert manager.single_type_manager.req_to_blocks[ + assert manager.coordinator.single_type_managers[0].req_to_blocks[ req0.request_id][-1].block_hash is None @@ -346,23 +513,23 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 6 # 5 full + 1 partial + assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 3 * 16, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 3 # 3 full blocks + assert len(blocks.blocks[0]) == 3 # 3 full blocks last_token_id += 3 * 16 # 10 - (6 + 3) == 1 @@ -382,7 +549,7 @@ def test_evict(): assert computed_blocks.get_block_ids() == [[1, 2]] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[10]] assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -404,12 +571,12 @@ def test_hash_block_correct_reuse(): num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 1 + assert len(blocks.blocks[0]) == 1 # Deallocate the block. manager.free(req) @@ -418,15 +585,15 @@ def test_hash_block_correct_reuse(): # block is cleared. req = make_request("1", list(range(num_tokens - 1))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens - 1, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 1 + assert len(blocks.blocks[0]) == 1 - assert manager.block_pool.blocks[ - blocks.blocks[0].block_id].block_hash is None + assert manager.block_pool.blocks[blocks.blocks[0] + [0].block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -445,24 +612,24 @@ def test_computed_blocks_not_evicted(): num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, num_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 1 - assert blocks.blocks[0].block_id == 1 + assert len(blocks.blocks[0]) == 1 + assert blocks.blocks[0][0].block_id == 1 # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, num_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 1 - assert blocks.blocks[0].block_id == 2 + assert len(blocks.blocks[0]) == 1 + assert blocks.blocks[0][0].block_id == 2 # Free the blocks. manager.free(req0) @@ -472,15 +639,15 @@ def test_computed_blocks_not_evicted(): # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(computed_blocks.blocks) == 1 - assert computed_blocks.blocks[0].block_id == 1 + assert len(computed_blocks.blocks[0]) == 1 + assert computed_blocks.blocks[0][0].block_id == 1 assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 1 - assert blocks.blocks[0].block_id == 2 + assert len(blocks.blocks[0]) == 1 + assert blocks.blocks[0][0].block_id == 2 def test_basic_prefix_caching_disabled(): @@ -497,12 +664,12 @@ def test_basic_prefix_caching_disabled(): req1 = make_request("1", list(range(10))) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 10, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 3 + assert len(blocks.blocks[0]) == 3 # Free the blocks. manager.free(req1) @@ -510,20 +677,20 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16))) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 16, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 4 + assert len(blocks.blocks[0]) == 4 # New requests should not have any blocks. req3 = make_request("3", list(range(4))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 4, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert not blocks @@ -558,6 +725,7 @@ def test_cache_blocks(hash_fn): num_full_blocks=2, block_size=block_size, hash_fn=hash_fn, + kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 2 @@ -573,11 +741,83 @@ def test_cache_blocks(hash_fn): num_full_blocks=3, block_size=block_size, hash_fn=hash_fn, + kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 3 assert blocks[0].block_hash is not None +def test_cache_blocks_multi_group(): + """ + This tests that blocks are cached correctly for different kv cache groups. + """ + block_size = 4 + block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True) + + # Req: + # Block 0/4: [0, 1, 2, 3] + # Block 1/5: [4, 5, 6, 7] + # Block 2/6: [8, 9, 10, 11] + # Block 3/7: [12, 13] + req = make_request("0", list(range(14))) + + # Cache the blocks for group 0. + blocks = [KVCacheBlock(block_id=i) for i in range(2)] + block_hashes: list[BlockHash] = [] + block_pool.cache_full_blocks( + request=req, + blocks=blocks, + block_hashes=block_hashes, + num_cached_blocks=0, + num_full_blocks=2, + block_size=block_size, + hash_fn=hash, + kv_cache_group_id=0, + ) + assert len(block_pool.cached_block_hash_to_block) == 2 + assert len(block_hashes) == 2 + assert all([block.block_hash is not None for block in blocks]) + + # Cache the blocks for group 1. + blocks = [KVCacheBlock(block_id=i) for i in range(3)] + block_pool.cache_full_blocks( + request=req, + blocks=blocks, + block_hashes=block_hashes, + num_cached_blocks=0, + num_full_blocks=3, + block_size=block_size, + hash_fn=hash, + kv_cache_group_id=1, + ) + assert len(block_pool.cached_block_hash_to_block) == 5 + assert len(block_hashes) == 3 + assert all([block.block_hash is not None for block in blocks]) + + # Block hash 0: hit for group 0 and 1 + # Block hash 1: hit for group 0 and 1 + # Block hash 2: hit for group 1 + + assert block_pool.get_cached_block(block_hashes[0], + kv_cache_group_ids=[0]) is not None + assert block_pool.get_cached_block(block_hashes[1], + kv_cache_group_ids=[0]) is not None + assert block_pool.get_cached_block(block_hashes[2], + kv_cache_group_ids=[0]) is None + assert block_pool.get_cached_block(block_hashes[0], + kv_cache_group_ids=[1]) is not None + assert block_pool.get_cached_block(block_hashes[1], + kv_cache_group_ids=[1]) is not None + assert block_pool.get_cached_block(block_hashes[2], + kv_cache_group_ids=[1]) is not None + assert block_pool.get_cached_block(block_hashes[0], + kv_cache_group_ids=[0, 1]) is not None + assert block_pool.get_cached_block(block_hashes[1], + kv_cache_group_ids=[0, 1]) is not None + assert block_pool.get_cached_block(block_hashes[2], + kv_cache_group_ids=[0, 1]) is None + + def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. @@ -614,7 +854,7 @@ def test_mm_prefix_caching(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req0.request_id] assert len(block_hashes) == 3 @@ -623,7 +863,7 @@ def test_mm_prefix_caching(): assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[1, 2, 3, 4]] req0.num_computed_tokens = 59 @@ -632,9 +872,9 @@ def test_mm_prefix_caching(): for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 5, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert new_blocks is not None and len(new_blocks.blocks) == 0 + assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 # The just completed block should have hashes with extra keys. assert len(block_hashes) == 4 @@ -652,7 +892,7 @@ def test_mm_prefix_caching(): mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(computed_blocks.blocks) == 3 + assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * 16 @@ -675,7 +915,7 @@ def test_cache_key_salting(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req0.request_id] assert len(block_hashes) == 3 @@ -684,7 +924,7 @@ def test_cache_key_salting(): assert block_hashes[2].extra_keys is None blocks = manager.allocate_slots(req0, 59, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[1, 2, 3, 4]] req0.num_computed_tokens = 59 @@ -693,9 +933,9 @@ def test_cache_key_salting(): for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 5, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert new_blocks is not None and len(new_blocks.blocks) == 0 + assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 # Now one more block that should not have extra keys. assert len(block_hashes) == 4 @@ -706,14 +946,14 @@ def test_cache_key_salting(): req1 = make_request("1", token_ids, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. - assert len(computed_blocks.blocks) == 3 + assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * block_size # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 req2 = make_request("2", token_ids, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(computed_blocks.blocks) == 0 + assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 block_hashes = manager.req_to_block_hashes[req2.request_id] assert len(block_hashes) == 3 @@ -738,20 +978,24 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req0, 48, - len(computed_blocks.blocks) * 16, computed_blocks) - block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id] + len(computed_blocks.blocks[0]) * 16, + computed_blocks) + block_part0 = manager.coordinator.single_type_managers[0].req_to_blocks[ + req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert computed_blocks.blocks == block_part0 + assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots(req1, 48, - len(computed_blocks.blocks) * 16, computed_blocks) - block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id] + len(computed_blocks.blocks[0]) * 16, + computed_blocks) + block_part1 = manager.coordinator.single_type_managers[0].req_to_blocks[ + req1.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | manager.free(req1) @@ -762,10 +1006,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, - len(computed_blocks.blocks) * 16, computed_blocks) + len(computed_blocks.blocks[0]) * 16, + computed_blocks) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). @@ -773,11 +1018,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert manager.block_pool.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert computed_blocks.blocks == block_part1 + assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. assert manager.allocate_slots(req3, 48, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) is None # Block 0-2 are used by Req 1. assert {block.ref_cnt for block in block_part1[:3]} == {1} @@ -804,9 +1049,9 @@ def test_reset_prefix_cache(): req1 = make_request("1", all_token_ids) computed_blocks, _ = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert len(computed_blocks.blocks) == 3 + assert len(computed_blocks.blocks[0]) == 3 blocks = manager.allocate_slots(req1, 7, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[5]] @@ -836,10 +1081,11 @@ def test_prefix_cache_stats_disabled(): # Call all functions that check whether log_stats is disabled. req = make_request("0", list(range(16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req, 16, - len(computed_blocks.blocks) * 16, computed_blocks) + len(computed_blocks.blocks[0]) * 16, + computed_blocks) manager.reset_prefix_cache() # Ensure prefix_cache_stats remains None @@ -918,7 +1164,8 @@ def test_eagle_enabled_removes_last_block(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks) * 16, computed_blocks) + len(computed_blocks.blocks[0]) * 16, + computed_blocks) manager.free(req) # New request with same tokens + Eagle enabled @@ -928,7 +1175,7 @@ def test_eagle_enabled_removes_last_block(): # Should retain 1 block: # 1. Original 3 blocks → pop last hash → 2 matched blocks # 2. drop last matched block → 1 remaining block - assert len(computed_blocks.blocks) == 1 + assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size # 16 tokens @@ -948,14 +1195,15 @@ def test_eagle_with_partial_blocks(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks) * 16, computed_blocks) + len(computed_blocks.blocks[0]) * 16, + computed_blocks) manager.free(req) # New request with Eagle enabled req_eagle = make_request("partial_eagle", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining - assert len(computed_blocks.blocks) == 1 + assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size @@ -973,7 +1221,7 @@ def test_eagle_with_sliding_window(): manager = KVCacheManager( KVCacheConfig( num_blocks=10, - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)], ), max_model_len=8192, @@ -988,7 +1236,8 @@ def test_eagle_with_sliding_window(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks) * 16, computed_blocks) + len(computed_blocks.blocks[0]) * 16, + computed_blocks) # record the block hash of the first block in the request for later use block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] assert block_hash_first_block is not None @@ -998,13 +1247,14 @@ def test_eagle_with_sliding_window(): req_eagle = make_request("partial_eagle", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining - assert len(computed_blocks.blocks) == 1 + assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size # Evict the first block in the request assert manager.block_pool.get_cached_block( - block_hash_first_block) is not None - manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block) + block_hash_first_block, kv_cache_group_ids=[0]) is not None + manager.block_pool.cached_block_hash_to_block.pop( + BlockHashWithGroupId(block_hash_first_block, 0)) # New request req_after_evict = make_request("partial_eagle_after_evict", token_ids) @@ -1012,5 +1262,5 @@ def test_eagle_with_sliding_window(): # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, # there will be no matched prefix. - assert len(computed_blocks.blocks) == 0 + assert len(computed_blocks.blocks[0]) == 0 assert num_tokens == 0 diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index aa074f1bb37f..d348956aa177 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -97,7 +97,7 @@ def create_scheduler( ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, @@ -814,10 +814,10 @@ def _assert_right_kv_cache_manager( # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size for req_id in req_ids: - blocks = (scheduler.kv_cache_manager.single_type_manager. - req_to_blocks[req_id]) + blocks = (scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[req_id]) hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] - assert (scheduler.kv_cache_manager.single_type_manager. + assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0]. num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS @@ -1198,11 +1198,11 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len( - scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + req_to_blocks) == 0 assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 - assert len( - scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + num_cached_block) == 0 num_free_blocks = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) assert num_free_blocks == ( diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 92ce8ea8b8dd..a9e1898df934 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -4,7 +4,8 @@ import torch from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, + KVCacheBlock) from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager from vllm.v1.kv_cache_interface import SlidingWindowSpec @@ -12,9 +13,8 @@ def get_sliding_window_manager(sliding_window_spec, block_pool): return SlidingWindowManager(sliding_window_spec, block_pool, - use_eagle=False, - num_kv_cache_groups=1, - caching_hash_fn=lambda x: x) + caching_hash_fn=lambda x: x, + kv_cache_group_id=0) def test_sliding_window_possible_cached_prefix(): @@ -42,13 +42,18 @@ def run_one_case(block_is_cached, expect_length): for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[block_hash] = { - i: block_pool.blocks[i + 10] - } + block_pool.cached_block_hash_to_block[BlockHashWithGroupId( + block_hash, 0)] = { + i: block_pool.blocks[i + 10], + } computed_blocks = manager.find_longest_cache_hit( - block_hash_list, - len(block_hash_list) * block_size) + block_hashes=block_hash_list, + max_length=len(block_hash_list) * block_size, + kv_cache_group_ids=[0], + block_pool=block_pool, + kv_cache_spec=sliding_window_spec, + use_eagle=False)[0] assert len(computed_blocks) == expect_length assert all(block == block_pool.null_block @@ -95,13 +100,13 @@ def test_sliding_window_remove_skipped_blocks(): null_block_id = block_pool.null_block.block_id - def id_to_block_table(ids): + def id_to_block_table(ids) -> list[KVCacheBlock]: return [ KVCacheBlock(id_) if id_ != null_block_id else block_pool.null_block for id_ in ids ] - def assert_block_id(block_table, ids): + def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): for block, id_ in zip(block_table, ids): if id_ == null_block_id: assert block == block_pool.null_block diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 3eedc535d7f4..d8882b1d9432 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -18,7 +18,7 @@ class TestConfig: model_config = { "bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)), - "google/gemma-2-2b-it": TestConfig(4096, (400, 800)), + "google/gemma-3-1b-it": TestConfig(4096, (400, 800)), } @@ -26,7 +26,7 @@ class TestConfig: "model", [ "bigcode/starcoder2-3b", # sliding window only - "google/gemma-2-2b-it", # sliding window + full attention + "google/gemma-3-1b-it", # sliding window + full attention ]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 9b257143d69d..622ab6f35db3 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -36,8 +36,8 @@ def test_basic_inferface(): req_meta = kv_connector_metadata.requests[request_id] for block_id, block in zip( - req_meta.local_block_ids, scheduler.kv_cache_manager. - single_type_manager.req_to_blocks[request_id]): + req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]): assert block_id == block.block_id diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index 52dc21a2cdba..ff36a281c413 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -54,8 +54,8 @@ def test_basic_lifecycle(): assert len(scheduler.waiting) == 0 # ... but blocks should not be freed. - blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ - request_id] + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 2312e2135908..a1156306dc4b 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -51,8 +51,8 @@ def test_basic_lifecycle(): assert (block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE) assert len(block_pool.cached_block_hash_to_block) == 0 - blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ - request_id] + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_id] for block in blocks: assert block._block_hash is None @@ -87,8 +87,8 @@ def test_basic_lifecycle(): # Confirm the block are actually allocated. num_hashed_blocks = 0 - blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ - request_id] + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 num_hashed_blocks += (1 if block._block_hash is not None else 0) @@ -261,10 +261,10 @@ def test_no_spurious_prefix_caching(): assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 - local_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ - request_local.request_id] - remote_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ # noqa: E501 - request_remote.request_id] + local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_local.request_id] + remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_remote.request_id] # Local should have cached blocks (but not all due to preallocate). num_hashed_blocks = 0 @@ -300,8 +300,8 @@ def test_full_block_prompt(): # STEP (1): Initialize a recv. scheduler_output = scheduler.schedule() # All blocks should be allocated. - num_blocks = len(scheduler.kv_cache_manager.single_type_manager. - req_to_blocks[request_id]) + num_blocks = len(scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]) assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT scheduler.update_from_output(scheduler_output, model_runner_output) @@ -319,8 +319,8 @@ def test_full_block_prompt(): # We need to recompute the final token of the prompt to generate # the first new token, so we should not have a new block. - num_blocks = len(scheduler.kv_cache_manager.single_type_manager. - req_to_blocks[request_id]) + num_blocks = len(scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]) assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == NUM_TOKENS - 1) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index e190e956170d..4a9e3a7ad807 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -32,11 +32,11 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len( - scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + req_to_blocks) == 0 assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 - assert len( - scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + num_cached_block) == 0 num_free_blocks = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) assert num_free_blocks == ( @@ -96,7 +96,7 @@ def create_scheduler( block_size = vllm_config.cache_config.block_size kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 0553d94de4c2..caacb1652e9a 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -40,12 +40,13 @@ def initialize_kv_cache(runner: GPUModelRunner): tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS kv_cache_config = KVCacheConfig( num_blocks=NUM_BLOCKS, - tensors={ - "layer.0": KVCacheTensor(size=tensor_size), - }, + kv_cache_tensors=[ + KVCacheTensor(size=tensor_size, shared_by=["layer.0"]), + ], kv_cache_groups=[ KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec) - ]) + ], + ) runner.kv_cache_config = kv_cache_config runner.input_batch = InputBatch( max_num_reqs=runner.max_num_reqs, @@ -518,9 +519,9 @@ def test_init_kv_cache_without_kv_sharing(): kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, available_memory) assert kv_cache_config.num_blocks == num_expected_blocks - assert len(kv_cache_config.tensors) == 2 - assert kv_cache_config.tensors[layer_0].size == available_memory // 2 - assert kv_cache_config.tensors[layer_1].size == available_memory // 2 + assert len(kv_cache_config.kv_cache_tensors) == 2 + assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 + assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2 max_context_len =\ estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) @@ -530,9 +531,9 @@ def test_init_kv_cache_without_kv_sharing(): # important: override tensor size to prevent large mem alloc during test # this will only allocate 2 block worth of memory (2 * 32kb) kv_cache_config.num_blocks = 1 - for layer in kv_cache_config.tensors: - kv_cache_config.tensors[layer].size =\ - kv_cache_spec[layer].page_size_bytes + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + kv_cache_tensor.size = ( + kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes) runner.initialize_kv_cache(kv_cache_config) @@ -589,10 +590,10 @@ def test_init_kv_cache_with_kv_sharing_valid(): kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, available_memory) assert kv_cache_config.num_blocks == num_expected_blocks - assert len(kv_cache_config.tensors) == 1 + assert len(kv_cache_config.kv_cache_tensors) == 1 # Each layer now has twice the available memory for KV cache # compared to no KV sharing - assert kv_cache_config.tensors[layer_0].size == available_memory + assert kv_cache_config.kv_cache_tensors[0].size == available_memory max_context_len =\ estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) @@ -602,7 +603,7 @@ def test_init_kv_cache_with_kv_sharing_valid(): # important: override tensor size to prevent large mem alloc during test # this will only allocate 1 block worth of memory (32kb) kv_cache_config.num_blocks = 1 - kv_cache_config.tensors[layer_0].size =\ + kv_cache_config.kv_cache_tensors[0].size =\ kv_cache_spec[layer_0].page_size_bytes runner.initialize_kv_cache(kv_cache_config) diff --git a/vllm/config.py b/vllm/config.py index f6ca9328b8a1..15e1b530dc9e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2104,6 +2104,12 @@ class SchedulerConfig: default scheduler. Can be a class directly or the path to a class of form "mod.custom_class".""" + disable_hybrid_kv_cache_manager: bool = False + """If set to True, KV cache manager will allocate the same size of KV cache + for all attention layers even if there are multiple type of attention layers + like full attention and sliding window attention. + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -4463,6 +4469,21 @@ def __post_init__(self): if not self.instance_id: self.instance_id = random_uuid()[:5] + if (envs.VLLM_USE_V1 + and not self.scheduler_config.disable_hybrid_kv_cache_manager): + # logger should only print warning message for hybrid models. As we + # can't know whether the model is hybrid or not now, so we don't log + # warning message here and will log it later. + if not (current_platform.is_cuda() or current_platform.is_rocm()): + # Hybrid KV cache manager is not supported on non-GPU platforms. + self.disable_hybrid_kv_cache_manager = True + if self.kv_transfer_config is not None: + # Hybrid KV cache manager is not compatible with KV transfer. + self.disable_hybrid_kv_cache_manager = True + if self.kv_events_config is not None: + # Hybrid KV cache manager is not compatible with KV events. + self.disable_hybrid_kv_cache_manager = True + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b1c4b27a0ca4..d1e554f6d128 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -387,6 +387,9 @@ class EngineArgs: bool] = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input + disable_hybrid_kv_cache_manager: bool = ( + SchedulerConfig.disable_hybrid_kv_cache_manager) + guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback guided_decoding_disable_any_whitespace: bool = \ @@ -849,6 +852,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **scheduler_kwargs["disable_chunked_mm_input"]) scheduler_group.add_argument("--scheduler-cls", **scheduler_kwargs["scheduler_cls"]) + scheduler_group.add_argument( + "--disable-hybrid-kv-cache-manager", + **scheduler_kwargs["disable_hybrid_kv_cache_manager"]) # vLLM arguments vllm_kwargs = get_kwargs(VllmConfig) @@ -1174,6 +1180,8 @@ def create_engine_config( max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, + disable_hybrid_kv_cache_manager=self. + disable_hybrid_kv_cache_manager, ) lora_config = LoRAConfig( diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 5118e4d8e614..3b2a4f936000 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -7,8 +7,8 @@ from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, BlockStored, KVCacheEvent) from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import (BlockHash, FreeKVCacheBlockQueue, - KVCacheBlock, +from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, + FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens) from vllm.v1.request import Request @@ -27,6 +27,7 @@ class BlockPool: Args: num_gpu_blocks: The number of blocks in the pool. enable_caching: Whether to enable prefix caching. + enable_kv_cache_events: Whether to enable kv cache events. """ def __init__( @@ -56,7 +57,7 @@ def __init__( # if there is already an identical block in the cache. This is because # we want to make sure the allocated block IDs won't change so that # block tables are append-only. - self.cached_block_hash_to_block: dict[BlockHash, dict[ + self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[ int, KVCacheBlock]] = defaultdict(dict) # To represent a placeholder block with block_id=0. @@ -68,22 +69,29 @@ def __init__( self.enable_kv_cache_events = enable_kv_cache_events self.kv_event_queue: list[KVCacheEvent] = [] - def get_cached_block(self, - block_hash: BlockHash) -> Optional[KVCacheBlock]: - """Get a cached block by the block hash, or None if cache miss. + def get_cached_block( + self, block_hash: BlockHash, + kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]: + """Get the cached block by the block hash for each group in + `kv_cache_group_ids`, or None if cache miss for any group. If there are duplicated blocks, we return the first block in the cache. Args: block_hash: The hash value of the block. + kv_cache_group_ids: The ids of the KV cache groups. Returns: - The cached block if it exists, or None. + The cached blocks if exists, or None. """ - cached_blocks = self.cached_block_hash_to_block.get(block_hash) - if not cached_blocks: - return None - first_block_id = next(iter(cached_blocks)) - return cached_blocks[first_block_id] + cached_blocks = [] + for group_id in kv_cache_group_ids: + cached_blocks_one_group = self.cached_block_hash_to_block.get( + BlockHashWithGroupId(block_hash, group_id)) + if not cached_blocks_one_group: + return None + first_block_id = next(iter(cached_blocks_one_group)) + cached_blocks.append(cached_blocks_one_group[first_block_id]) + return cached_blocks def cache_full_blocks( self, @@ -93,6 +101,7 @@ def cache_full_blocks( num_cached_blocks: int, num_full_blocks: int, block_size: int, + kv_cache_group_id: int, hash_fn: Callable, ) -> None: """Cache a list of full blocks for prefix caching. @@ -112,6 +121,7 @@ def cache_full_blocks( num_full_blocks: The number of blocks that are full and should be cached after this function. block_size: Number of tokens in each block. + kv_cache_group_id: The id of the KV cache group. hash_fn: The hash function to use for block hashes. """ if num_cached_blocks == num_full_blocks: @@ -126,7 +136,7 @@ def cache_full_blocks( else: prev_block = blocks[num_cached_blocks - 1] assert prev_block.block_hash is not None - prev_block_hash_value = prev_block.block_hash.hash_value + prev_block_hash_value = prev_block.block_hash.get_hash_value() parent_block_hash = prev_block_hash_value new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events @@ -138,8 +148,9 @@ def cache_full_blocks( # The block hash may already be computed in # "get_computed_blocks" if the tokens are not generated by # this request (either the prompt tokens or the previously - # generated tokens with preemption). In this case we simply - # reuse the block hash. + # generated tokens with preemption), or by other + # single_type_managers with the same block_size. + # In this case we simply reuse the block hash. block_hash = new_block_hashes[i] else: # Otherwise compute the block hash and cache it in the request @@ -166,8 +177,11 @@ def cache_full_blocks( block_hashes.append(block_hash) # Update and added the full block to the cache. - blk.block_hash = block_hash - self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + block_hash_with_group_id = BlockHashWithGroupId( + block_hash, kv_cache_group_id) + blk.block_hash = block_hash_with_group_id + self.cached_block_hash_to_block[block_hash_with_group_id][ + blk.block_id] = blk if new_hashes is not None: new_hashes.append(block_hash.hash_value) prev_block_hash_value = block_hash.hash_value @@ -237,12 +251,16 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: del self.cached_block_hash_to_block[block_hash] if self.enable_kv_cache_events: + # FIXME (Chen): Not sure whether we should return `hash_value` + # or `(hash_value, group_id)` here. But it's fine now because + # we disable hybrid kv cache manager when kv cache event is + # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[block_hash.hash_value])) + BlockRemoved(block_hashes=[block_hash.get_hash_value()])) return True return False - def touch(self, blocks: list[KVCacheBlock]) -> None: + def touch(self, blocks: list[list[KVCacheBlock]]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. @@ -250,12 +268,13 @@ def touch(self, blocks: list[KVCacheBlock]) -> None: Args: blocks: A list of blocks to touch. """ - for block in blocks: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0 and not block.is_null: - self.free_block_queue.remove(block) - block.incr_ref() + for blocks_per_group in blocks: + for block in blocks_per_group: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0 and not block.is_null: + self.free_block_queue.remove(block) + block.incr_ref() def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: """Free a list of blocks. The blocks should be ordered by their diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py new file mode 100644 index 000000000000..993ce4b484f9 --- /dev/null +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -0,0 +1,358 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from typing import Callable, Optional + +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.core.single_type_kv_cache_manager import ( + FullAttentionManager, SingleTypeKVCacheManager, + get_manager_for_kv_cache_spec) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig +from vllm.v1.request import Request + + +class KVCacheCoordinator(ABC): + """ + Coordinate the KV cache of different KV cache groups. + """ + + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + caching_hash_fn: Callable, + enable_kv_cache_events: bool, + ): + self.kv_cache_config = kv_cache_config + self.max_model_len = max_model_len + + self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, + enable_kv_cache_events) + self.single_type_managers: list[SingleTypeKVCacheManager] = [] + + # Needs special handling for find_longest_cache_hit if eagle is enabled + self.use_eagle = use_eagle + + for i in range(len(self.kv_cache_config.kv_cache_groups)): + kv_cache_spec = self.kv_cache_config.kv_cache_groups[ + i].kv_cache_spec + self.single_type_managers.append( + get_manager_for_kv_cache_spec( + kv_cache_spec=kv_cache_spec, + block_pool=self.block_pool, + kv_cache_group_id=i, + caching_hash_fn=caching_hash_fn, + )) + + def get_num_blocks_to_allocate( + self, request_id: str, num_tokens: int, + new_computed_blocks: list[list[KVCacheBlock]]) -> int: + """ + Get the number of blocks needed to be allocated for the request. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + new_computed_blocks: The new computed blocks just hitting the + prefix caching. + + Returns: + The number of blocks. + """ + num_blocks_to_allocate = 0 + for i, manager in enumerate(self.single_type_managers): + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks[i]) + return num_blocks_to_allocate + + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[list[KVCacheBlock]]) -> None: + """ + Add the new computed blocks to the request. + + Args: + request_id: The request ID. + new_computed_blocks: The new computed blocks just hitting the + prefix cache. + """ + for i, manager in enumerate(self.single_type_managers): + manager.save_new_computed_blocks(request_id, + new_computed_blocks[i]) + + def allocate_new_blocks(self, request_id: str, + num_tokens: int) -> list[list[KVCacheBlock]]: + """ + Allocate new blocks for the request to give it at least `num_tokens` + token slots. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + + Returns: + The new allocated blocks. + """ + new_blocks = [] + for manager in self.single_type_managers: + new_blocks.append( + manager.allocate_new_blocks(request_id, num_tokens)) + return new_blocks + + def cache_blocks(self, request: Request, block_hashes: list[BlockHash], + num_computed_tokens: int) -> None: + """ + Cache the blocks for the request. + + Args: + request: The request. + block_hashes: The block hashes of the request. + num_tokens: The total number of tokens that need to be cached + (including tokens that are already cached). + """ + for manager in self.single_type_managers: + manager.cache_blocks(request, block_hashes, num_computed_tokens) + + def free(self, request_id: str) -> None: + """ + Free the blocks for the request. + + Args: + request_id: The request ID. + """ + for manager in self.single_type_managers: + manager.free(request_id) + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> list[int]: + """ + Get the number of common prefix blocks for a request. + + Args: + request_id: The request ID. + block_hashes: The block hashes of the request. + + Returns: + The number of common prefix blocks. + """ + num_blocks_per_group = [ + manager.get_num_common_prefix_blocks(request_id, + num_running_requests) + for manager in self.single_type_managers + ] + return num_blocks_per_group + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + """ + Remove the blocks that are no longer needed from `blocks` and replace + the removed blocks with null_block. + + Args: + request_id: The request ID. + num_computed_tokens: The number of tokens that have been computed. + """ + for manager in self.single_type_managers: + manager.remove_skipped_blocks(request_id, num_computed_tokens) + + def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]: + """ + Get the blocks for the request. + """ + return [ + manager.req_to_blocks[request_id] + for manager in self.single_type_managers + ] + + @abstractmethod + def find_longest_cache_hit( + self, block_hashes: list[BlockHash], + max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: + pass + + +class UnitaryKVCacheCoordinator(KVCacheCoordinator): + """ + KV cache coordinator for models with only one KV cache group. This is the + case for models with only one KV cache type, e.g., all attention layers use + full attention or all attention layers use sliding window attention. + """ + + def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, + use_eagle: bool, enable_caching: bool, + caching_hash_fn: Callable, enable_kv_cache_events: bool): + super().__init__(kv_cache_config, max_model_len, use_eagle, + enable_caching, caching_hash_fn, + enable_kv_cache_events) + self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec + self.block_size = self.kv_cache_spec.block_size + assert len(self.kv_cache_config.kv_cache_groups) == 1, ( + "UnitaryKVCacheCoordinator assumes only one kv cache group") + + def find_longest_cache_hit( + self, block_hashes: list[BlockHash], + max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: + hit_blocks = self.single_type_managers[0].find_longest_cache_hit( + block_hashes=block_hashes, + max_length=max_cache_hit_length, + kv_cache_group_ids=[0], + block_pool=self.block_pool, + kv_cache_spec=self.kv_cache_spec, + use_eagle=self.use_eagle, + ) + return hit_blocks, len(hit_blocks[0]) * self.block_size + + +class HybridKVCacheCoordinator(KVCacheCoordinator): + """ + KV cache coordinator for hybrid models with multiple KV cache types, and + thus multiple kv cache groups. + To simplify `find_longest_cache_hit`, it only supports the combination of + two types of KV cache groups, and one of them must be full attention. + May extend to more general cases in the future. + """ + + def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, + use_eagle: bool, enable_caching: bool, + caching_hash_fn: Callable, enable_kv_cache_events: bool): + super().__init__(kv_cache_config, max_model_len, use_eagle, + enable_caching, caching_hash_fn, + enable_kv_cache_events) + self.verify_and_split_kv_cache_groups() + + def verify_and_split_kv_cache_groups(self) -> None: + """ + Verifies that the model has exactly two types of KV cache groups, and + one of them is full attention. Then, split the kv cache groups into full + attention groups and other groups. + """ + full_attention_type_id: Optional[str] = None + other_type_id: Optional[str] = None + self.full_attention_group_ids: list[int] = [] + self.other_group_ids: list[int] = [] + for i, g in enumerate(self.kv_cache_config.kv_cache_groups): + if isinstance(g.kv_cache_spec, FullAttentionSpec): + if full_attention_type_id is None: + full_attention_type_id = g.kv_cache_spec.type_id + else: + assert full_attention_type_id == g.kv_cache_spec.type_id, ( + "HybridKVCacheCoordinator assumes exactly one type of " + "full attention groups now.") + self.full_attention_group_ids.append(i) + else: + if other_type_id is None: + other_type_id = g.kv_cache_spec.type_id + else: + assert other_type_id == g.kv_cache_spec.type_id, ( + "HybridKVCacheCoordinator assumes " + "exactly one other type of groups now.") + self.other_group_ids.append(i) + + assert full_attention_type_id is not None, ( + "HybridKVCacheCoordinator assumes exactly one type of full " + "attention groups now.") + assert other_type_id is not None, ( + "HybridKVCacheCoordinator assumes exactly one type of other " + "groups now.") + + self.full_attention_manager_cls = FullAttentionManager + self.other_attention_cls = self.single_type_managers[ + self.other_group_ids[0]].__class__ + + self.full_attention_spec = self.kv_cache_config.kv_cache_groups[ + self.full_attention_group_ids[0]].kv_cache_spec + self.other_spec = self.kv_cache_config.kv_cache_groups[ + self.other_group_ids[0]].kv_cache_spec + + self.full_attention_block_size = self.full_attention_spec.block_size + self.other_block_size = self.other_spec.block_size + assert self.other_block_size % self.full_attention_block_size == 0, ( + "KVCacheCoordinator assumes the block_size of full attention " + "layers is divisible by other layers now.") + + def find_longest_cache_hit( + self, + block_hashes: list[BlockHash], + max_cache_hit_length: int, + ) -> tuple[list[list[KVCacheBlock]], int]: + """ + Find the longest cache hit for the request. + + Args: + block_hashes: The block hashes of the request. + max_cache_hit_length: The maximum length of the cache hit. + + Returns: + A tuple containing: + - A list of the cache hit blocks for each single type manager. + - The number of tokens of the longest cache hit. + """ + # First, find the longest cache hit for full attention. + hit_blocks_full_attn = ( + self.full_attention_manager_cls.find_longest_cache_hit( + block_hashes=block_hashes, + max_length=max_cache_hit_length, + kv_cache_group_ids=self.full_attention_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.full_attention_spec, + use_eagle=self.use_eagle, + )) + hit_length = len( + hit_blocks_full_attn[0]) * self.full_attention_block_size + + # Next, find the cache hit for the other attention WITHIN + # the cache hit of full attention. + hit_blocks_other_attn = ( + self.other_attention_cls.find_longest_cache_hit( + block_hashes=block_hashes, + max_length=hit_length, + kv_cache_group_ids=self.other_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.other_spec, + use_eagle=self.use_eagle, + )) + hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size + + # NOTE: the prefix cache hit length must be a multiply of block_size as + # we don't support partial block cache hit yet. The cache hit length + # of other attention is ensured to be a multiply of the block size of + # full attention layers in current implementation, because hit_length is + # a multiply of other attention's block size, and other attention's + # block size is a multiply of full attention's block size (verified in + # `verify_and_split_kv_cache_groups`). + assert hit_length % self.full_attention_block_size == 0 + + # Truncate the full attention cache hit to the length of the + # cache hit of the other attention. + for i in range(len(hit_blocks_full_attn)): + del hit_blocks_full_attn[i][hit_length // + self.full_attention_block_size:] + + # Merge the hit blocks of full attention and other attention. + hit_blocks = hit_blocks_other_attn + for group_id, blocks in enumerate(hit_blocks_full_attn): + # NOTE: there is only one full attention group in most cases. So + # the time complexity of insert is fine. + hit_blocks.insert(group_id, blocks) + return hit_blocks, hit_length + + +def get_kv_cache_coordinator( + kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, + enable_caching: bool, caching_hash_fn: Callable, + enable_kv_cache_events: bool) -> KVCacheCoordinator: + if len(kv_cache_config.kv_cache_groups) == 1: + return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, + use_eagle, enable_caching, + caching_hash_fn, + enable_kv_cache_events) + else: + return HybridKVCacheCoordinator(kv_cache_config, max_model_len, + use_eagle, enable_caching, + caching_hash_fn, + enable_kv_cache_events) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 91999d30035b..fc701215ba5d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -8,11 +8,9 @@ from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.utils import sha256 -from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, hash_request_tokens) -from vllm.v1.core.single_type_kv_cache_manager import ( - get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -22,16 +20,24 @@ @dataclass class KVCacheBlocks: - blocks: list[KVCacheBlock] + """ + The allocation result of KVCacheManager, work as the interface between + Scheduler and KVCacheManager, to hide KVCacheManager's internal data + structure from the Scheduler. + """ + blocks: list[list[KVCacheBlock]] + """ + blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens. + We don't use block of tokens as the outer dimension because it assumes all + kv_cache_groups have the same number of blocks, which is true for now but + will be broken if we want to give different block_size to different + kv_cache_groups in the future. + """ def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" - return KVCacheBlocks(self.blocks + other.blocks) - - @classmethod - def create_empty(cls) -> "KVCacheBlocks": - """Creates a new KVCacheBlocks instance with no blocks.""" - return cls([]) + return KVCacheBlocks( + [blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)]) def get_block_ids(self) -> list[list[int]]: """ @@ -39,15 +45,20 @@ def get_block_ids(self) -> list[list[int]]: Returns: list[list[int]]: A two-level list where - * the outer list corresponds to KV cache groups (only 1 group now) + * the outer list corresponds to KV cache groups * each inner list contains the block_ids of the blocks in that group """ - return [[block.block_id for block in self.blocks]] + block_ids = [] + for group in self.blocks: + block_ids.append([blk.block_id for blk in group]) + return block_ids def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" + assert len(self.blocks) == 1, "Only one group is supported" return [ - block.block_id for block in self.blocks if block.block_hash is None + block.block_id for block in self.blocks[0] + if block.block_hash is None ] @@ -63,12 +74,6 @@ def __init__( log_stats: bool = False, enable_kv_cache_events: bool = False, ) -> None: - assert len(kv_cache_config.kv_cache_groups) == 1, ( - "KVCacheManager does not support hybrid models with more than 1 " - "kv cache group") - kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec - self.block_size = kv_cache_spec.block_size - self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len self.enable_caching = enable_caching @@ -77,17 +82,24 @@ def __init__( self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - - self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, - enable_kv_cache_events) - - self.single_type_manager = get_manager_for_kv_cache_spec( - kv_cache_spec=kv_cache_spec, - block_pool=self.block_pool, + assert len( + set(g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups) + ) == 1, "Only one block size is supported for now" + self.block_size = kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.block_size + + self.coordinator = get_kv_cache_coordinator( + kv_cache_config=kv_cache_config, + max_model_len=self.max_model_len, use_eagle=self.use_eagle, - num_kv_cache_groups=1, + enable_caching=enable_caching, caching_hash_fn=self.caching_hash_fn, + enable_kv_cache_events=enable_kv_cache_events, ) + self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) + self.block_pool = self.coordinator.block_pool + self.kv_cache_config = kv_cache_config # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of @@ -133,7 +145,7 @@ def get_computed_blocks(self, # When the request requires prompt logprobs, we skip prefix caching. if (not self.enable_caching or request.sampling_params.prompt_logprobs is not None): - return KVCacheBlocks.create_empty(), 0 + return self.create_empty_block_list(), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. @@ -154,20 +166,16 @@ def get_computed_blocks(self, # num_computed_tokens to be block-size aligned. Removing this limitation # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 - - computed_blocks = self.single_type_manager.find_longest_cache_hit( - block_hashes, max_cache_hit_length) - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size + computed_blocks, num_new_computed_tokens = ( + self.coordinator.find_longest_cache_hit(block_hashes, + max_cache_hit_length)) if self.log_stats: assert self.prefix_cache_stats is not None self.prefix_cache_stats.queries += request.num_tokens - self.prefix_cache_stats.hits += num_computed_tokens + self.prefix_cache_stats.hits += num_new_computed_tokens - return KVCacheBlocks(computed_blocks), num_computed_tokens + return KVCacheBlocks(computed_blocks), num_new_computed_tokens def allocate_slots( self, @@ -220,7 +228,9 @@ def allocate_slots( if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks else: - new_computed_block_list = [] + new_computed_block_list = [ + [] for _ in range(len(self.kv_cache_config.kv_cache_groups)) + ] # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). @@ -228,8 +238,8 @@ def allocate_slots( # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - self.single_type_manager.remove_skipped_blocks( - request.request_id, request.num_computed_tokens) + self.coordinator.remove_skipped_blocks(request.request_id, + request.num_computed_tokens) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits @@ -238,12 +248,12 @@ def allocate_slots( num_tokens_need_slot = min( num_computed_tokens + num_new_tokens + num_lookahead_tokens, self.max_model_len) - num_blocks_to_allocate = ( - self.single_type_manager.get_num_blocks_to_allocate( - request_id=request.request_id, - num_tokens=num_tokens_need_slot, - new_computed_blocks=new_computed_block_list, - )) + + num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( + request_id=request.request_id, + num_tokens=num_tokens_need_slot, + new_computed_blocks=new_computed_block_list, + ) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): # Cannot allocate new blocks @@ -253,16 +263,16 @@ def allocate_slots( if self.enable_caching: self.block_pool.touch(new_computed_block_list) else: - assert not new_computed_block_list, ( + assert all(not blocks for blocks in new_computed_block_list), ( "Computed blocks should be empty when " "prefix caching is disabled") # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - self.single_type_manager.save_new_computed_blocks( - request.request_id, new_computed_block_list) + self.coordinator.save_new_computed_blocks(request.request_id, + new_computed_block_list) - new_blocks = self.single_type_manager.allocate_new_blocks( + new_blocks = self.coordinator.allocate_new_blocks( request.request_id, num_tokens_need_slot) # P/D: delay caching blocks if we have to recv from @@ -273,7 +283,7 @@ def allocate_slots( # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. - self.single_type_manager.cache_blocks( + self.coordinator.cache_blocks( request, self.req_to_block_hashes[request.request_id], num_computed_tokens + num_new_tokens - num_draft_tokens) @@ -287,7 +297,7 @@ def free(self, request: Request) -> None: Args: request: The request to free the blocks. """ - self.single_type_manager.free(request.request_id) + self.coordinator.free(request.request_id) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -345,10 +355,8 @@ def get_num_common_prefix_blocks( group. """ assert request.status == RequestStatus.RUNNING - return [ - self.single_type_manager.get_num_common_prefix_blocks( - request.request_id, num_running_requests) - ] + return self.coordinator.get_num_common_prefix_blocks( + request.request_id, num_running_requests) def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. @@ -368,6 +376,15 @@ def take_events(self) -> list[KVCacheEvent]: def get_block_ids(self, request_id: str) -> list[list[int]]: """Get the block ids of a request.""" - assert request_id in self.single_type_manager.req_to_blocks - return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id] - ).get_block_ids() + return KVCacheBlocks( + self.coordinator.get_blocks(request_id)).get_block_ids() + + def cache_blocks(self, request: Request, block_hashes: list[BlockHash], + num_computed_tokens: int) -> None: + """Cache the blocks for the request.""" + self.coordinator.cache_blocks(request, block_hashes, + num_computed_tokens) + + def create_empty_block_list(self) -> KVCacheBlocks: + """Creates a new KVCacheBlocks instance with no blocks.""" + return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)]) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index ad3c21f794b9..6d4bcfe64a35 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """KV-Cache Utilities.""" + import os -from collections import deque +from collections import defaultdict, deque from collections.abc import Iterable, Sequence from dataclasses import dataclass from typing import Any, Callable, NamedTuple, Optional @@ -33,6 +34,18 @@ class BlockHash(NamedTuple): extra_keys: Optional[Any] = None +class BlockHashWithGroupId(NamedTuple): + # The hash value for the contents (e.g., token_ids) of a block without group + # ID. The value is the same for blocks representing the same tokens but for + # different groups. + block_hash: BlockHash + # The KV cache group ID. + group_id: int + + def get_hash_value(self) -> int: + return self.block_hash.hash_value + + # The hash seed for the first block of the prefix block sequence. # # Even if the hash function is the builtin hash(), we use sha256 to generate @@ -44,7 +57,7 @@ class BlockHash(NamedTuple): # This aligns with the behavior of Python's hash() function, which also uses # a random seed if PYTHONHASHSEED is not set. NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv( - 'PYTHONHASHSEED') is None else sha256(os.getenv('PYTHONHASHSEED')) + "PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED")) class PrefixCachingMetrics: @@ -118,7 +131,7 @@ class KVCacheBlock: ref_cnt: int = 0 # The hash of the block composed of (block hash, tuple of token IDs). # It is only available when the block is full. - _block_hash: Optional[BlockHash] = None + _block_hash: Optional[BlockHashWithGroupId] = None # Used to construct a doubly linked list for free blocks. # These two attributes should only be manipulated by FreeKVCacheBlockQueue. @@ -135,11 +148,11 @@ def decr_ref(self): self.ref_cnt -= 1 @property - def block_hash(self) -> Optional[BlockHash]: + def block_hash(self) -> Optional[BlockHashWithGroupId]: return self._block_hash @block_hash.setter - def block_hash(self, block_hash: BlockHash): + def block_hash(self, block_hash: BlockHashWithGroupId): assert self.block_hash is None, ( "The block already has a hash. This should not happen.") self._block_hash = block_hash @@ -151,10 +164,10 @@ def reset_hash(self): def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ # on KVCacheBlock object recursively. - prev_block_id = self.prev_free_block.block_id \ - if self.prev_free_block else None - next_block_id = self.next_free_block.block_id \ - if self.next_free_block else None + prev_block_id = (self.prev_free_block.block_id + if self.prev_free_block else None) + next_block_id = (self.next_free_block.block_id + if self.next_free_block else None) return (f"KVCacheBlock(block_id={self.block_id}, " f"ref_cnt={self.ref_cnt}, " f"_block_hash={self._block_hash}, " @@ -570,20 +583,20 @@ def create_kv_cache_group_specs( kv_cache_spec: dict[str, KVCacheSpec], grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: """ - Create KVCacheGroupSpec object for each kv cache group layer. - The layers in the same group should share the same - KVCacheSpec. - - Args: - kv_cache_spec: - A mapping from each layer name to its corresponding KVCacheSpec. - grouped_layer_names: - A list of kv cache groups, where each element is a list of layer - names that belong to the same group and should share the same - KVCacheSpec. - Returns: - A list of KVCacheGroupSpec objects, one for each group. - """ + Create KVCacheGroupSpec object for each kv cache group layer. + The layers in the same group should share the same + KVCacheSpec. + + Args: + kv_cache_spec: + A mapping from each layer name to its corresponding KVCacheSpec. + grouped_layer_names: + A list of kv cache groups, where each element is a list of layer + names that belong to the same group and should share the same + KVCacheSpec. + Returns: + A list of KVCacheGroupSpec objects, one for each group. + """ kv_cache_groups = [] for layer_names_one_group in grouped_layer_names: layer_specs = [ @@ -628,6 +641,37 @@ def get_max_concurrency_for_kv_cache_config( return max_concurrency +def get_num_blocks(vllm_config: VllmConfig, num_layers: int, + available_memory: int, page_size: int) -> int: + """ + Get the number of kv cache blocks. + + Args: + vllm_config: The global VllmConfig + num_layers: The number of layers + available_memory: Memory available for KV cache in bytes. + page_size: The page size of the KV cache. + """ + num_blocks = int(available_memory // page_size // num_layers) + num_blocks = max(num_blocks, 0) + if vllm_config.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = \ + vllm_config.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + return num_blocks + + +def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: + """ + Get the page size of the KV cache. + """ + page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values()) + assert len(page_sizes) == 1 + return page_sizes.pop() + + def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: @@ -644,32 +688,24 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, The generated KVCacheConfig """ - page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} - assert len(page_sizes) == 1 - page_size = page_sizes.pop() - - num_blocks = int(available_memory // page_size // len(kv_cache_spec)) - num_blocks = max(num_blocks, 0) - - if vllm_config.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = \ - vllm_config.cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) - num_blocks = num_gpu_blocks_override + page_size = get_uniform_page_size(kv_cache_spec) + num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec), + available_memory, page_size) per_layer_size = page_size * num_blocks # All layers have the same KV cache spec, so we create one kv cache group # for all layers. grouped_layer_names = [list(kv_cache_spec.keys())] + # Each layer uses a separate Tensor to store its KV cache. + kv_cache_tensors = [ + KVCacheTensor(size=per_layer_size, shared_by=[layer_name]) + for layer_name in kv_cache_spec + ] + kv_cache_config = KVCacheConfig( num_blocks=num_blocks, - tensors={ - layer_name: KVCacheTensor(size=per_layer_size) - for layer_name in kv_cache_spec - }, + kv_cache_tensors=kv_cache_tensors, kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, grouped_layer_names), ) @@ -685,17 +721,185 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, return kv_cache_config +def is_kv_cache_page_size_uniform( + kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + """ + Whether all layers in the given KVCacheSpec have the same page size. + Args: + kv_cache_spec: The KVCacheSpec of each attention layer in the model + + Returns: + True if all layers have the same page size, False otherwise. + """ + + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + return len(page_sizes) == 1 + + +def _get_kv_cache_config_uniform_page_size( + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for hybrid models with multiple + attention types but still with a uniform page size (physical memory per + block per layer) for all layers. + + Detailed explanation about kv cache management of hybrid models: + The layers in the models are repeated with some patterns, e.g., a model + with 10 full attention layers and 20 sliding window attention layers can be + regarded as repeating the pattern (1 * full, 2 * sw) 10 times. + The KVCacheManager allocates different block tables for each of the 3 layers + in the pattern, and repeats each of them 10 times to generate the + block_table for the 30 layers in the model. + Therefore, we can group the layers in the model into 3 kv_cache_groups, each + of which contains 10 layers in the model. + The KVCacheManager allocates the block_table for each group based on its + kv_cache spec, and the model runner applies the block table to each layer + in the group. + For example: + 1. A model only uses full attention. The pattern is + (num_hidden_layers * full), so there is only one group and the block table + is shared by all layers. It is already handled by + `_get_kv_cache_config_uniform_type`. + 2. A model with 10 full attention layers and 20 sliding window + attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so + there are 3 kv_cache_groups, each of which represents 10 layers. + + To simplify the implementation, we make the following assumptions: + 1. Physical memory per block: Must be the same across all KV cache groups. + Breaking this assumption is non-trivial due to memory fragmentation concerns + when allocating blocks of different sizes. + 2. Tokens per block (block_size): Currently, we directly use + `CacheConfig.block_size` for all layers. It can be extended to vary by KV + cache group, but within each KV cache group, all layers must share the same + block size. + 3. Physical memory per token per layer: This property is decided by model + config. Currently we only support models that have the same physical memory + per token per layer for all layers. Can be relaxed with a simple extension, + but still need to keep physical memory per block the same for all groups. + 4. Number of layers per group: Currently assumed the same for all layers. + Can be relaxed with a simple extension, but still need to keep physical + memory per block the same for all groups. + 5. Attention type within groups: All layers in a group must share the same + attention type. One exception is that, when + `--disable-hybrid-kv-cache-manager` is true, the single group for full + attention layers may also include attention layers using sliding window or + LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details. + 6. Support for multiple attention types: The design for most components is + general to an arbitrary number of attention types. But + `find_longest_cache_hit` only supports one attention type or two + types of full-attention plus exactly one another type. The general + implementation of this function is feasible but we don't know how to + implement it cleanly yet. + + As we assume tokens per block, physical memory per token per layer, and + number of layers per group are the same now, we can ensure that physical + memory per block is the same for all groups. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The KVCacheSpec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + Returns: + The generated KVCacheConfig + """ + # Group all layers by type_id. + # E.g., 2 full attention layers and 3 sliding window attention layers, + # -> (full.0, full.1), (sw.0, sw.1, sw.2). + same_type_layers: dict[str, list[str]] = defaultdict(list) + for layer_name, layer_spec in kv_cache_spec.items(): + same_type_layers[layer_spec.type_id].append(layer_name) + + # Split each group into smaller groups, to make the number of layers in each + # group identical. Add padding to the last group of each type if necessary. + # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) + # split to 3 groups with 2 layers each: + # (full.0, full.1), (sw.0, sw.1), (sw.2, padding). + # FIXME(Chen): At the moment of writing this code (2025-06-02), all + # open-source hybrid model follows a n:1 pattern between different attention + # types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and + # full), so we can use the "1" in the n:1 pattern as the group size, which + # is the minimum number of layers among all attention types. Need a better + # strategy if we want to support more complex patterns (e.g., 20 full + 30 + # sw, where the group size should be 10). + group_size = min([len(layers) for layers in same_type_layers.values()]) + grouped_layers = [] + for layers in same_type_layers.values(): + num_padding_layers = group_size - len(layers) % group_size + if num_padding_layers != group_size: + logger.warning( + "Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa + num_padding_layers, + num_padding_layers / len(layers) * 100, + ) + for i in range(0, len(layers), group_size): + grouped_layers.append(layers[i:i + group_size]) + kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec, + grouped_layers) + + # Determine how model runners should initialize the KV cache tensors. + # We will have group_size memory pools, each is shared by one layer from + # each group. As layers of different groups have different block table, + # they will use different parts of the shared Tensor. + # The memory layout in the example will be: + # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 + # full.1, sw.1: share another Tensor with size=available_memory//2 + page_size = get_uniform_page_size(kv_cache_spec) + num_blocks = get_num_blocks(vllm_config, group_size, available_memory, + page_size) + per_memory_pool_size = page_size * num_blocks + kv_cache_tensors = [] + for i in range(group_size): + shared_by = [] + for j in range(len(kv_cache_groups)): + if i < len(grouped_layers[j]): + shared_by.append(grouped_layers[j][i]) + kv_cache_tensors.append( + KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by)) + + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=kv_cache_tensors, + kv_cache_groups=kv_cache_groups, + ) + + # Print the KV cache size and maximum concurrency. + num_tokens = num_blocks // len( + grouped_layers) * vllm_config.cache_config.block_size + num_tokens_str = f"{num_tokens:,}" + logger.info("GPU KV cache size: %s tokens", num_tokens_str) + max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_concurrency = get_max_concurrency_for_kv_cache_config( + vllm_config, kv_cache_config) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, max_concurrency) + return kv_cache_config + + def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ - Only models with one type of KV cache are supported yet. This function tries - to convert the KV cache specs to one type if the model is a hybrid model - with multiple type of KV cache. It will convert all SlidingWindowSpec to - FullAttentionSpec if both types are present. + This function tries to convert the KV cache specs to one type if the model + is a hybrid model with multiple type of KV cache. It will convert all + SlidingWindowSpec to FullAttentionSpec if both types are present. Args: kv_cache_spec: The kv cache spec of each attention layer in the model """ + def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + type_ids = set(layer_spec.type_id + for layer_spec in kv_cache_spec.values()) + return len(type_ids) > 1 + + if not is_hybrid(kv_cache_spec): + return + + logger.warning( + "Hybrid KV cache manager is disabled for this hybrid model, " + "This means we do not enable any optimizations for saving KV cache " + "memory (e.g., dropping the KV cache outside the sliding window). " + "The compute of layers like sliding window is still saved.") + has_full_attention = any( isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) has_sliding_window = any( @@ -712,13 +916,18 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): sliding_window=spec.sliding_window, ) + if is_hybrid(kv_cache_spec): + raise ValueError("Hybrid KV cache manager is disabled but failed to " + "convert the KV cache specs to one unified type.") + -def get_kv_cache_config(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: +def get_kv_cache_config( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> KVCacheConfig: """ - Generates the KV cache configuration for a model - TODO: support hybrid models with more than one type of KV cache. + Generates the KV cache configuration for a model. Args: vllm_config: The global VllmConfig @@ -728,14 +937,25 @@ def get_kv_cache_config(vllm_config: VllmConfig, Returns: The generated KVCacheConfigs """ - unify_hybrid_kv_cache_specs(kv_cache_spec) check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) + + if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: + unify_hybrid_kv_cache_specs(kv_cache_spec) + if is_kv_cache_type_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, available_memory) + elif is_kv_cache_page_size_uniform(kv_cache_spec): + # Model contains multiple attention types, but KV cache of all layers + # have the same physical memory per block per layer. Split the layers + # into groups with the same number of layers, and thus same total page + # size. + return _get_kv_cache_config_uniform_page_size(vllm_config, + kv_cache_spec, + available_memory) raise NotImplementedError diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 32d03b311a4e..c97f18b61f5b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -18,7 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -377,7 +377,8 @@ def schedule(self) -> SchedulerOutput: # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. else: - new_computed_blocks = KVCacheBlocks.create_empty() + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list()) num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens @@ -1010,7 +1011,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: num_computed_tokens = len(block_ids) * self.block_size if num_computed_tokens == request.num_tokens: num_computed_tokens -= 1 - self.kv_cache_manager.single_type_manager.cache_blocks( + self.kv_cache_manager.cache_blocks( request, self.kv_cache_manager.req_to_block_hashes[request.request_id], num_computed_tokens, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index a529cde097f5..98d758f820ad 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -22,8 +22,7 @@ def __init__( self, kv_cache_spec: KVCacheSpec, block_pool: BlockPool, - use_eagle: bool, - num_kv_cache_groups: int, + kv_cache_group_id: int, caching_hash_fn: Callable, ) -> None: """ @@ -31,9 +30,7 @@ def __init__( Args: kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. - use_eagle: Whether to use eagle. - num_kv_cache_groups: The number of kv cache groups managed by this - manager. + kv_cache_group_id: The id of the kv cache group of this manager. caching_hash_fn: The caching hash function. """ @@ -41,9 +38,6 @@ def __init__( self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool - # Needs special handling for find_longest_cache_hit if eagle is enabled - self.use_eagle = use_eagle - # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. @@ -56,8 +50,8 @@ def __init__( # data for reempted ones. self.num_cached_block: dict[str, int] = {} - self.num_kv_cache_groups = num_kv_cache_groups self.caching_hash_fn = caching_hash_fn + self.kv_cache_group_id = kv_cache_group_id def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, @@ -86,8 +80,7 @@ def get_num_blocks_to_allocate( num_evictable_computed_blocks = sum( blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks) - return ((num_new_blocks + num_evictable_computed_blocks) * - self.num_kv_cache_groups) + return num_new_blocks + num_evictable_computed_blocks def save_new_computed_blocks( self, request_id: str, @@ -130,8 +123,7 @@ def allocate_new_blocks(self, request_id: str, if num_new_blocks <= 0: return [] else: - new_blocks = self.block_pool.get_new_blocks( - num_new_blocks * self.num_kv_cache_groups) + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) return new_blocks @@ -156,12 +148,19 @@ def cache_blocks(self, request: Request, block_hashes: list[BlockHash], num_cached_blocks=num_cached_blocks, num_full_blocks=num_full_blocks, block_size=self.block_size, + kv_cache_group_id=self.kv_cache_group_id, hash_fn=self.caching_hash_fn, ) self.num_cached_block[request.request_id] = num_full_blocks def free(self, request_id: str) -> None: + """ + Free the blocks for the request. + + Args: + request_id: The request ID. + """ # Default to [] in case a request is freed (aborted) before alloc. req_blocks = self.req_to_blocks.pop(request_id, []) @@ -188,12 +187,22 @@ def get_num_common_prefix_blocks(self, request_id: str, raise NotImplementedError + @classmethod @abstractmethod - def find_longest_cache_hit(self, block_hashes: list[BlockHash], - max_length: int) -> list[KVCacheBlock]: + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> list[list[KVCacheBlock]]: """ Get the longest cache hit prefix of the blocks that is not longer than - `max_length`. If no cache hit is found, return an empty list. + `max_length`. The prefix should be a common prefix hit for all the + kv cache groups in `kv_cache_group_ids`. If no cache hit is found, + return an empty list. If eagle is enabled, drop the last matched block to force recompute the last block to get the required hidden states for eagle drafting head. Need to be customized for each attention type. @@ -201,12 +210,20 @@ def find_longest_cache_hit(self, block_hashes: list[BlockHash], Args: block_hashes: The block hashes of the request. max_length: The maximum length of the cache hit prefix. + kv_cache_group_ids: The ids of the kv cache groups. + block_pool: The block pool. + kv_cache_spec: The kv cache spec. + use_eagle: Whether to use eagle. Returns: - A list of cached blocks with skipped blocks replaced by null block. + A list of cached blocks with skipped blocks replaced by null block + for each kv cache group in `kv_cache_group_ids`. + Return a list of length `len(kv_cache_group_ids)`, where the i-th + element is a list of cached blocks for the i-th kv cache group + in `kv_cache_group_ids`. For example, sliding window manager should return a list like - [NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and - sliding window 8. + [[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4 + and sliding window 8 and len(kv_cache_group_ids) = 1. """ raise NotImplementedError @@ -215,11 +232,9 @@ def find_longest_cache_hit(self, block_hashes: list[BlockHash], def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks`. The removed - blocks should be replaced by null_block. Return the removed blocks in - eviction order, where the first returned block should be evicted first. - Don't free the removed blocks in this function. Need to be customized - for each attention type. + Remove the blocks that are no longer needed from `blocks` and free the + blocks. The removed blocks should be replaced by null_block. + Need to be customized for each attention type. Args: request_id: The request ID. @@ -230,21 +245,36 @@ def remove_skipped_blocks(self, request_id: str, class FullAttentionManager(SingleTypeKVCacheManager): - def find_longest_cache_hit(self, block_hashes: list[BlockHash], - max_length: int) -> list[KVCacheBlock]: - computed_blocks: list[KVCacheBlock] = [] - max_num_blocks = max_length // self.block_size + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> list[list[KVCacheBlock]]: + assert isinstance(kv_cache_spec, FullAttentionSpec), ( + "FullAttentionManager can only be used for full attention groups") + computed_blocks: list[list[KVCacheBlock]] = [ + [] for _ in range(len(kv_cache_group_ids)) + ] + max_num_blocks = max_length // kv_cache_spec.block_size for i in range(max_num_blocks): block_hash = block_hashes[i] # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. - if cached_block := self.block_pool.get_cached_block(block_hash): - computed_blocks.append(cached_block) + if cached_block := block_pool.get_cached_block( + block_hash, kv_cache_group_ids): + for j in range(len(kv_cache_group_ids)): + computed_blocks[j].append(cached_block[j]) else: break - if self.use_eagle and len(computed_blocks) > 0: - computed_blocks.pop() + if use_eagle and len(computed_blocks[0]) > 0: + for j in range(len(kv_cache_group_ids)): + computed_blocks[j].pop() return computed_blocks def remove_skipped_blocks(self, request_id: str, @@ -267,45 +297,58 @@ def get_num_common_prefix_blocks(self, request_id: str, class SlidingWindowManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, - use_eagle: bool, **kwargs) -> None: - super().__init__(kv_cache_spec, block_pool, use_eagle, **kwargs) + **kwargs) -> None: + super().__init__(kv_cache_spec, block_pool, **kwargs) self.sliding_window = kv_cache_spec.sliding_window + self._null_block = block_pool.null_block + + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> list[list[KVCacheBlock]]: + assert isinstance(kv_cache_spec, SlidingWindowSpec), ( + "SlidingWindowManager can only be used for sliding window groups") + # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window - self.sliding_window_contiguous_blocks = cdiv( - (kv_cache_spec.sliding_window - 1), self.block_size) - if self.use_eagle: + sliding_window_contiguous_blocks = cdiv( + kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size) + if use_eagle: # Need to drop the last matched block if eagle is enabled. For # sliding window layer, we achieve this by increasing the number of # contiguous blocks needed for prefix cache hit by one and dropping # the last matched block. - self.sliding_window_contiguous_blocks += 1 - self._null_block = block_pool.null_block + sliding_window_contiguous_blocks += 1 - def find_longest_cache_hit(self, block_hashes: list[BlockHash], - max_length: int) -> list[KVCacheBlock]: # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to # optimize the time complexity from O(max_num_blocks) to # O(max_num_blocks / sliding_window_contiguous_blocks + # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. - max_num_blocks = max_length // self.block_size - computed_blocks = [self._null_block] * max_num_blocks + max_num_blocks = max_length // kv_cache_spec.block_size + computed_blocks = [[block_pool.null_block] * max_num_blocks + for _ in range(len(kv_cache_group_ids))] num_contiguous_blocks = 0 - match_found = False # Search from right to left and early stop when a match is found. for i in range(max_num_blocks - 1, -1, -1): - if cached_block := self.block_pool.get_cached_block( - block_hashes[i]): - computed_blocks[i] = cached_block + if cached_block := block_pool.get_cached_block( + block_hashes[i], kv_cache_group_ids): + for j in range(len(kv_cache_group_ids)): + computed_blocks[j][i] = cached_block[j] num_contiguous_blocks += 1 - if (num_contiguous_blocks - >= self.sliding_window_contiguous_blocks): + if (num_contiguous_blocks >= sliding_window_contiguous_blocks): # Trim the trailing blocks. # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] # when sliding_window_contiguous_blocks=2. - del computed_blocks[i + num_contiguous_blocks:] + for j in range(len(kv_cache_group_ids)): + del computed_blocks[j][i + num_contiguous_blocks:] match_found = True break else: @@ -313,9 +356,11 @@ def find_longest_cache_hit(self, block_hashes: list[BlockHash], if not match_found: # The first `num_contiguous_blocks` is a cache hit even if # `num_contiguous_blocks < sliding_window_contiguous_blocks`. - del computed_blocks[num_contiguous_blocks:] - if self.use_eagle and len(computed_blocks) > 0: - computed_blocks.pop() + for j in range(len(kv_cache_group_ids)): + del computed_blocks[j][num_contiguous_blocks:] + if use_eagle and len(computed_blocks[0]) > 0: + for j in range(len(kv_cache_group_ids)): + computed_blocks[j].pop() return computed_blocks def remove_skipped_blocks(self, request_id: str, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index cf2eb3b95569..e938f3bfc671 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -157,11 +157,10 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: @dataclass class KVCacheTensor: """ - A dataclass for specifying how the workers should initialize the KV cache - for a layer. Only contains the size of KV cache for that layer for now. Will - be extended to support multiple layers sharing the same memory pool. + A class for specifying how the workers should initialize the KV cache. """ - size: int # The size of KV cache Tensor in bytes + size: int # size of the KV cache tensor in bytes + shared_by: list[str] # layer names that share the same KV cache tensor @dataclass @@ -183,27 +182,13 @@ class KVCacheConfig: """ """The number of KV cache blocks""" num_blocks: int - """layer_name -> how to initialize KV cache for that layer""" - tensors: dict[str, KVCacheTensor] + """How should model runner initialize the KV cache tensors for each layer""" + kv_cache_tensors: list[KVCacheTensor] """ The kv cache groups of the model. - The layers in the models are repeated with some patterns, e.g., a model - with 10 full attention layers and 20 sliding window attention layers can be - regarded as repeating the pattern (1 * full, 2 * sw) 10 times. - The KVCacheManager allocates different block tables for each of the 3 layers - in the pattern, and repeats each of them 10 times to generate the - block_table for the 30 layers in the model. - Therefore, we can group the layers in the model into 3 groups, each of which - contains 10 layers in the model. - The KVCacheManager allocates the block_table for each group based on its - kv_cache spec, and the model runner applies the block table to each layer - in the group. - For example: - 1. A model only uses full attention. The pattern is - (num_hidden_layers * full), so there is only one group and the block table - is shared by all layers. - 2. (WIP) A model with 10 full attention layers and 20 sliding window - attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so - there are 3 groups, each of which represents 10 layers in the model. + For models with only one type of attention, there is only one group that + contains all layers. + For models with multiple types of attention, there will be multiple groups, + see `_get_kv_cache_config_uniform_page_size` for more details. """ kv_cache_groups: list[KVCacheGroupSpec] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6ea6bb020ed7..def80c5421c5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2085,33 +2085,58 @@ def may_reinitialize_input_batch(self, block_sizes=block_sizes, ) - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + def _allocate_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ - Initialize KV cache based on `kv_cache_config`. + Initializes the KV cache buffer with the correct size. The buffer needs + to be reshaped to the desired shape before being used by the models. + Args: - kv_cache_config: Configuration for the KV cache, including the KV - cache size of each layer + kv_cache_config: The KV cache config + Returns: + dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_cache_raw_tensors: dict[str, torch.Tensor] = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + tensor = torch.zeros(kv_cache_tensor.size, + dtype=torch.int8, + device=self.device) + for layer_name in kv_cache_tensor.shared_by: + kv_cache_raw_tensors[layer_name] = tensor + + layer_names = set() + for group in kv_cache_config.kv_cache_groups: + layer_names.update(group.layer_names) + assert layer_names == set(kv_cache_raw_tensors.keys( + )), "Some layers are not correctly initialized" + return kv_cache_raw_tensors + + def _reshape_kv_cache_tensors( + self, + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: """ - self.kv_cache_config = kv_cache_config - self.may_reinitialize_input_batch(kv_cache_config) - self.initialize_attn_backend(kv_cache_config) + Reshape the KV cache tensors to the desired shape and dtype. + Args: + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer, with + correct size but uninitialized shape. + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ kv_caches: dict[str, torch.Tensor] = {} - - for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group.kv_cache_spec - for layer_name in kv_cache_group.layer_names: - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes - # `num_blocks` is the number of blocks the model runner can use. - # `kv_cache_config.num_blocks` is the number of blocks that - # KVCacheManager may allocate. - # Since different GPUs may have different number of layers and - # different memory capacities, `num_blocks` can be different on - # different GPUs, and `kv_cache_config.num_blocks` is set to - # the min of all `num_blocks`. Verify it here. - assert num_blocks >= kv_cache_config.num_blocks + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + for layer_name in kv_cache_group_spec.layer_names: + raw_tensor = kv_cache_raw_tensors[layer_name] + assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = (raw_tensor.numel() // + kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, @@ -2137,13 +2162,29 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = torch.zeros( - kv_cache_shape, dtype=dtype, - device=self.device).permute(*inv_order) + kv_caches[layer_name] = kv_cache_raw_tensors[ + layer_name].view(dtype).view(kv_cache_shape).permute( + *inv_order) else: - # TODO: add new branches when introducing more types of - # KV cache specs. - raise ValueError("Unknown KV cache spec type.") + raise NotImplementedError + return kv_caches + + def initialize_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + """ + Initialize the memory buffer for KV cache. + + Args: + kv_cache_config: The KV cache config + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, + kv_cache_raw_tensors) # Setup `kv_cache_config` and `kv_caches` for models # with cross-layer KV sharing @@ -2154,17 +2195,30 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_caches, ) + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) + return kv_caches + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + self.kv_cache_config = kv_cache_config + self.may_reinitialize_input_batch(kv_cache_config) + self.initialize_attn_backend(kv_cache_config) + kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) + if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # validate all draft model layers belong to the same kv cache # group self.drafter.validate_same_kv_cache_group(kv_cache_config) - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) - if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 73c445d14e38..20de924624b7 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1366,14 +1366,20 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: assert self.block_table_cpu.dtype == self.input_batch.block_table[ 0].get_cpu_tensor().dtype - kv_caches: dict[str, torch.Tensor] = {} + kv_cache_sizes = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1, ( + "KV cache tensor shared by multiple layers is not supported in " + "TPU.") + kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + kv_caches: dict[str, torch.Tensor] = {} for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes + tensor_size = kv_cache_sizes[layer_name] + assert tensor_size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_size // kv_cache_spec.page_size_bytes # noqa if isinstance(kv_cache_spec, AttentionSpec): if self.use_spmd: num_kv_heads = kv_cache_spec.num_kv_heads