Skip to content

Commit bb58f7c

Browse files
Merge pull request vllm-project#9 from njhill/abstract-async-load
Suggestion: Generalize/streamline async loading (remote prefill) side
2 parents 5c3fc88 + 5fd2138 commit bb58f7c

File tree

5 files changed

+82
-137
lines changed

5 files changed

+82
-137
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def get_num_new_matched_tokens(
185185
self,
186186
request: "Request",
187187
num_computed_tokens: int,
188-
) -> int:
188+
) -> tuple[int, bool]:
189189
"""
190190
Get number of new tokens that can be loaded from the
191191
external KV cache beyond the num_computed_tokens.
@@ -198,6 +198,8 @@ def get_num_new_matched_tokens(
198198
Returns:
199199
the number of tokens that can be loaded from the
200200
external KV cache beyond what is already computed.
201+
true if the external KV cache tokens will be loaded
202+
asynchronously (between scheduler steps).
201203
"""
202204
pass
203205

vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def get_num_new_matched_tokens(
9393
self,
9494
request: "Request",
9595
num_computed_tokens: int,
96-
) -> int:
96+
) -> tuple[int, bool]:
9797
"""
9898
Get number of new tokens that can be loaded from the
9999
external KV cache beyond the num_computed_tokens.
@@ -108,7 +108,7 @@ def get_num_new_matched_tokens(
108108
external KV cache beyond what is already computed.
109109
"""
110110
return self._lmcache_engine.get_num_new_matched_tokens(
111-
request, num_computed_tokens)
111+
request, num_computed_tokens), False
112112

113113
def update_state_after_alloc(self, request: "Request",
114114
blocks: "KVCacheBlocks",

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
104104
############################################################
105105
# Scheduler Side Methods
106106
############################################################
107-
def get_num_new_matched_tokens(self, request: "Request",
108-
num_computed_tokens: int) -> int:
107+
def get_num_new_matched_tokens(
108+
self, request: "Request",
109+
num_computed_tokens: int) -> tuple[int, bool]:
109110
assert self.connector_scheduler is not None
110111
return self.connector_scheduler.get_num_new_matched_tokens(
111112
request, num_computed_tokens)
@@ -170,23 +171,27 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
170171
# the scheduler. Used to make metadata passed to Worker.
171172
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
172173

173-
def get_num_new_matched_tokens(self, request: "Request",
174-
num_computed_tokens: int) -> int:
174+
def get_num_new_matched_tokens(
175+
self, request: "Request",
176+
num_computed_tokens: int) -> tuple[int, bool]:
175177
"""For remote prefill, allocate for all tokens."""
176178
if request.do_remote_prefill:
177179
assert num_computed_tokens % self.block_size == 0
178180
rounded_num_prompt_tokens = round_down(
179181
len(request.prompt_token_ids), self.block_size)
180-
return max(rounded_num_prompt_tokens - num_computed_tokens, 0)
182+
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
183+
return count, count > 0
181184

182-
return 0
185+
return 0, False
183186

184187
def update_state_after_alloc(self, request: "Request",
185188
blocks: "KVCacheBlocks",
186189
num_external_tokens: int):
187190
if request.do_remote_prefill and num_external_tokens > 0:
188191
self._reqs_need_recv[request.request_id] = (request,
189192
blocks.get_block_ids())
193+
# Only trigger a KV transfer once per request.
194+
request.do_remote_prefill = False
190195

191196
def build_connector_meta(
192197
self,

vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def get_num_new_matched_tokens(
225225
self,
226226
request: "Request",
227227
num_computed_tokens: int,
228-
) -> int:
228+
) -> tuple[int, bool]:
229229
"""
230230
Get number of new tokens that can be loaded from the
231231
external KV cache beyond the num_computed_tokens.
@@ -248,7 +248,7 @@ def get_num_new_matched_tokens(
248248
# with the block granularity. And it expects the returned blocks and
249249
# num_computed_tokens to also be aligned with the block granularity.
250250
if not self._found_match_for_request(request):
251-
return 0
251+
return 0, False
252252

253253
logger.info("External Cache Hit!")
254254

@@ -257,7 +257,7 @@ def get_num_new_matched_tokens(
257257
num_tokens_to_check = align_to_block_size(
258258
len(request.prompt_token_ids) - 1, self._block_size)
259259

260-
return num_tokens_to_check - num_computed_tokens
260+
return num_tokens_to_check - num_computed_tokens, False
261261

262262
def update_state_after_alloc(self, request: "Request",
263263
blocks: "KVCacheBlocks",

vllm/v1/core/sched/scheduler.py

Lines changed: 63 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from collections import defaultdict, deque
77
from collections.abc import Iterable
8-
from typing import TYPE_CHECKING, Optional, Union
8+
from typing import Optional, Union
99

1010
from vllm import envs
1111
from vllm.config import VllmConfig
@@ -32,9 +32,6 @@
3232
from vllm.v1.spec_decode.metrics import SpecDecodingStats
3333
from vllm.v1.structured_output import StructuredOutputManager
3434

35-
if TYPE_CHECKING:
36-
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
37-
3835
logger = init_logger(__name__)
3936

4037

@@ -317,9 +314,12 @@ def schedule(self) -> SchedulerOutput:
317314

318315
# P/D: skip request if still waiting for remote kvs.
319316
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
320-
is_ready = self._update_waiting_for_remote_kv(
321-
request, skipped_waiting_requests)
322-
if not is_ready:
317+
is_ready = self._update_waiting_for_remote_kv(request)
318+
if is_ready:
319+
request.status = RequestStatus.WAITING
320+
else:
321+
self.waiting.popleft()
322+
skipped_waiting_requests.appendleft(request)
323323
continue
324324

325325
# Skip request if the structured output request is still waiting
@@ -350,55 +350,48 @@ def schedule(self) -> SchedulerOutput:
350350
request)
351351

352352
# Get externally-cached tokens if using a KVConnector.
353-
num_external_tokens = (
354-
0 if self.connector is None else
353+
num_external_tokens, load_kv_async = (
354+
(0, False) if self.connector is None else
355355
self.connector.get_num_new_matched_tokens(
356356
request, num_computed_tokens))
357357

358358
# Total computed tokens (local + external).
359359
num_computed_tokens += num_external_tokens
360360

361-
# P/D: if remote prefill, allocate memory and put request
362-
# into the WAITING_FOR_REMOTE_KV state.
363-
if request.do_remote_prefill and num_external_tokens > 0:
364-
new_blocks = self._allocate_and_set_waiting_for_remote_kv(
365-
request, num_external_tokens, new_computed_blocks,
366-
skipped_waiting_requests)
367-
if new_blocks is None:
368-
# Not enough KV cache space
369-
break
370-
continue
371-
372-
# Number of tokens to be scheduled.
373-
# We use `request.num_tokens` instead of
374-
# `request.num_prompt_tokens` to consider the resumed requests,
375-
# which have output tokens.
376-
num_new_tokens = request.num_tokens - num_computed_tokens
377-
if (0 < self.scheduler_config.long_prefill_token_threshold <
378-
num_new_tokens):
379-
num_new_tokens = (
380-
self.scheduler_config.long_prefill_token_threshold)
381-
num_new_tokens = min(num_new_tokens, token_budget)
382-
assert num_new_tokens > 0
383-
384-
# Schedule encoder inputs.
385-
if request.has_encoder_inputs:
386-
(encoder_inputs_to_schedule, num_new_tokens,
387-
new_encoder_budget) = self._try_schedule_encoder_inputs(
388-
request, num_computed_tokens, num_new_tokens,
389-
encoder_budget)
390-
if num_new_tokens == 0:
391-
# The request cannot be scheduled.
392-
break
361+
encoder_inputs_to_schedule = None
362+
new_encoder_budget = encoder_budget
363+
if not load_kv_async:
364+
# Number of tokens to be scheduled.
365+
# We use `request.num_tokens` instead of
366+
# `request.num_prompt_tokens` to consider the resumed
367+
# requests, which have output tokens.
368+
num_new_tokens = request.num_tokens - num_computed_tokens
369+
if (0 < self.scheduler_config.long_prefill_token_threshold
370+
< num_new_tokens):
371+
num_new_tokens = (
372+
self.scheduler_config.long_prefill_token_threshold)
373+
num_new_tokens = min(num_new_tokens, token_budget)
374+
assert num_new_tokens > 0
375+
376+
# Schedule encoder inputs.
377+
if request.has_encoder_inputs:
378+
(encoder_inputs_to_schedule, num_new_tokens,
379+
new_encoder_budget
380+
) = self._try_schedule_encoder_inputs(
381+
request, num_computed_tokens, num_new_tokens,
382+
encoder_budget)
383+
if num_new_tokens == 0:
384+
# The request cannot be scheduled.
385+
break
393386
else:
394-
encoder_inputs_to_schedule = None
395-
new_encoder_budget = encoder_budget
387+
num_new_tokens = 0
396388

397389
new_blocks = self.kv_cache_manager.allocate_slots(
398390
request,
399391
num_new_tokens + num_external_tokens,
400392
new_computed_blocks,
401393
num_lookahead_tokens=self.num_lookahead_tokens,
394+
delay_cache_blocks=load_kv_async,
402395
)
403396
if new_blocks is None:
404397
# The request cannot be scheduled.
@@ -415,6 +408,13 @@ def schedule(self) -> SchedulerOutput:
415408
)
416409

417410
self.waiting.popleft()
411+
if load_kv_async:
412+
# If loading async, allocate memory and put request
413+
# into the WAITING_FOR_REMOTE_KV state.
414+
skipped_waiting_requests.appendleft(request)
415+
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
416+
continue
417+
418418
if request.use_structured_output:
419419
structured_output_request_ids[
420420
request.request_id] = req_index
@@ -445,7 +445,7 @@ def schedule(self) -> SchedulerOutput:
445445
request.num_computed_tokens = num_computed_tokens
446446

447447
# Encoder-related.
448-
if not request.do_remote_prefill and encoder_inputs_to_schedule:
448+
if encoder_inputs_to_schedule:
449449
scheduled_encoder_inputs[request.request_id] = (
450450
encoder_inputs_to_schedule)
451451
# Allocate the encoder cache.
@@ -924,11 +924,7 @@ def shutdown(self) -> None:
924924
# P/D Related Methods
925925
########################################################################
926926

927-
def _update_waiting_for_remote_kv(
928-
self,
929-
request: Request,
930-
skipped_waiting_requests: deque[Request],
931-
) -> bool:
927+
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
932928
"""
933929
P/D: check if the request_id is finished_recving.
934930
@@ -937,87 +933,29 @@ def _update_waiting_for_remote_kv(
937933
on the worker side connector.
938934
939935
When the kv transfer is ready, we cache the blocks
940-
and update the request state to be in WAITING from
936+
and the request state will be moved back to WAITING from
941937
WAITING_FOR_REMOTE_KV.
942938
"""
943-
if request.request_id in self.finished_recving_kv_req_ids:
944-
# Now that the blocks are ready, actually cache them.
945-
blocks = self.kv_cache_manager.req_to_blocks[request.request_id]
946-
num_computed_tokens = len(blocks) * self.block_size
947-
if num_computed_tokens == request.num_tokens:
948-
num_computed_tokens -= 1
949-
self.kv_cache_manager.cache_blocks(
950-
request,
951-
num_tokens=0,
952-
num_computed_tokens=num_computed_tokens,
953-
new_computed_block_list=[])
954-
955-
# Update the request state for scheduling.
956-
request.num_computed_tokens = num_computed_tokens
957-
request.status = RequestStatus.WAITING
958-
# NOTE(rob): only read the blocks from remote once.
959-
request.do_remote_prefill = False
960-
961-
# Set that we are ready.
962-
self.finished_recving_kv_req_ids.remove(request.request_id)
963-
is_ready = True
964-
else:
965-
self.waiting.popleft()
966-
skipped_waiting_requests.appendleft(request)
967-
is_ready = False
968-
969-
return is_ready
970-
971-
def _allocate_and_set_waiting_for_remote_kv(
972-
self,
973-
request: Request,
974-
num_external_tokens: int,
975-
new_computed_blocks: KVCacheBlocks,
976-
skipped_waiting_requests: deque[Request],
977-
) -> Optional[KVCacheBlocks]:
978-
"""
979-
P/D: allocate KV cache blocks for a request and put
980-
the request into the WAITING_FOR_REMOTE_KV state and
981-
update the KVConnector state.
982-
983-
The KV caches are allocated but NOT cached. This is
984-
to avoid another request getting a cache hit on a
985-
block that has not been written to. We will cache
986-
the blocks only after the recv is complete.
987-
988-
The update_state_after_alloc() function passes this
989-
request to the KVConnector, which triggers KVConnector
990-
to start a read_blocks transaction.
991-
"""
992-
993-
# Allocate slots for the external tokens, but skip
994-
# caching until after the KV transfer is done to avoid
995-
# cache hits on blocks that are still be written to.
996-
new_blocks = self.kv_cache_manager.allocate_slots(
939+
if request.request_id not in self.finished_recving_kv_req_ids:
940+
return False
941+
942+
# Now that the blocks are ready, actually cache them.
943+
blocks = self.kv_cache_manager.req_to_blocks[request.request_id]
944+
num_computed_tokens = len(blocks) * self.block_size
945+
if num_computed_tokens == request.num_tokens:
946+
num_computed_tokens -= 1
947+
self.kv_cache_manager.cache_blocks(
997948
request,
998-
num_external_tokens,
999-
new_computed_blocks,
1000-
delay_cache_blocks=True)
1001-
if new_blocks is None:
1002-
return None
949+
num_tokens=0,
950+
num_computed_tokens=num_computed_tokens,
951+
new_computed_block_list=[])
1003952

1004-
self.waiting.popleft()
1005-
skipped_waiting_requests.appendleft(request)
1006-
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
1007-
1008-
# KVConnector: update internal state after allocation.
1009-
# This information is used to determine if a load is
1010-
# needed for this request.
1011-
assert self.connector is not None
1012-
self.connector.update_state_after_alloc(
1013-
request,
1014-
new_computed_blocks + new_blocks,
1015-
num_external_tokens,
1016-
)
1017-
# Only trigger a KV transfer once per request.
1018-
request.do_remote_prefill = False
953+
# Update the request state for scheduling.
954+
request.num_computed_tokens = num_computed_tokens
1019955

1020-
return new_blocks
956+
# Return that we are ready.
957+
self.finished_recving_kv_req_ids.remove(request.request_id)
958+
return True
1021959

1022960
def _set_finished_remote_decode(
1023961
self,

0 commit comments

Comments
 (0)