Skip to content

Commit b9f1d42

Browse files
authored
[v1][Bugfix] Only cache blocks that are not in the prefix cache (#14073)
1 parent b28246f commit b9f1d42

File tree

2 files changed

+9
-22
lines changed

2 files changed

+9
-22
lines changed

vllm/v1/core/block_pool.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -107,34 +107,20 @@ def cache_full_blocks(
107107
assert prev_block.block_hash is not None
108108
prev_block_hash_value = prev_block.block_hash.hash_value
109109

110-
# Find the first uncached block.
111-
# FIXME: num_cached_blocks should be corrected by the caller
112-
# so this should never happen.
113-
offset = 0
114-
for blk in new_full_blocks:
115-
if blk.block_hash is None:
116-
break
117-
else:
118-
prev_block_hash_value = blk.block_hash.hash_value
119-
offset += 1
120-
else:
121-
# All blocks are cached.
122-
return
123-
124-
for i, blk in enumerate(new_full_blocks[offset:]):
125-
blk_idx = num_cached_blocks + offset + i
110+
for i, blk in enumerate(new_full_blocks):
126111
assert blk.block_hash is None
127112

128-
if i + offset < len(new_block_hashes):
113+
if i < len(new_block_hashes):
129114
# The block hash may already be computed in
130115
# "get_computed_blocks" if the tokens are not generated by
131116
# this request (either the prompt tokens or the previously
132117
# generated tokens with preemption). In this case we simply
133118
# reuse the block hash.
134-
block_hash = new_block_hashes[i + offset]
119+
block_hash = new_block_hashes[i]
135120
else:
136121
# Otherwise compute the block hash and cache it in the request
137122
# in case it will be preempted in the future.
123+
blk_idx = num_cached_blocks + i
138124
start_token_idx = blk_idx * block_size
139125
end_token_idx = (blk_idx + 1) * block_size
140126
block_tokens = request.all_token_ids[

vllm/v1/core/kv_cache_manager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
# This is used to track the number of cached blocks for each request.
6666
# This is only used to track the RUNNING requests, we do not track the
6767
# data for reempted ones.
68-
self.num_cached_block: Dict[str, int] = defaultdict(int)
68+
self.num_cached_block: Dict[str, int] = {}
6969
self.prefix_cache_stats = PrefixCacheStats()
7070

7171
@property
@@ -224,9 +224,10 @@ def allocate_slots(
224224
if not self.enable_caching:
225225
return new_blocks
226226

227-
# FIXME: `num_cached_blocks` is not correct when the prefix cache
228-
# of a new request is hit.
229-
num_cached_blocks = self.num_cached_block[request.request_id]
227+
# Use `new_computed_blocks` for a new request, and `num_cached_block`
228+
# for a running request.
229+
num_cached_blocks = self.num_cached_block.get(request.request_id,
230+
len(new_computed_blocks))
230231
# Speculated tokens might be rejected in the future, so we does
231232
# not cache any speculated tokens. We only cache blocks with
232233
# generated (accepted) tokens.

0 commit comments

Comments
 (0)