55import time
66from collections import defaultdict , deque
77from collections .abc import Iterable
8- from typing import TYPE_CHECKING , Optional , Union
8+ from typing import Optional , Union
99
1010from vllm import envs
1111from vllm .config import VllmConfig
3232from vllm .v1 .spec_decode .metrics import SpecDecodingStats
3333from vllm .v1 .structured_output import StructuredOutputManager
3434
35- if TYPE_CHECKING :
36- from vllm .v1 .core .kv_cache_manager import KVCacheBlocks
37-
3835logger = init_logger (__name__ )
3936
4037
@@ -317,9 +314,12 @@ def schedule(self) -> SchedulerOutput:
317314
318315 # P/D: skip request if still waiting for remote kvs.
319316 if request .status == RequestStatus .WAITING_FOR_REMOTE_KVS :
320- is_ready = self ._update_waiting_for_remote_kv (
321- request , skipped_waiting_requests )
322- if not is_ready :
317+ is_ready = self ._update_waiting_for_remote_kv (request )
318+ if is_ready :
319+ request .status = RequestStatus .WAITING
320+ else :
321+ self .waiting .popleft ()
322+ skipped_waiting_requests .appendleft (request )
323323 continue
324324
325325 # Skip request if the structured output request is still waiting
@@ -350,55 +350,48 @@ def schedule(self) -> SchedulerOutput:
350350 request )
351351
352352 # Get externally-cached tokens if using a KVConnector.
353- num_external_tokens = (
354- 0 if self .connector is None else
353+ num_external_tokens , load_kv_async = (
354+ ( 0 , False ) if self .connector is None else
355355 self .connector .get_num_new_matched_tokens (
356356 request , num_computed_tokens ))
357357
358358 # Total computed tokens (local + external).
359359 num_computed_tokens += num_external_tokens
360360
361- # P/D: if remote prefill, allocate memory and put request
362- # into the WAITING_FOR_REMOTE_KV state.
363- if request .do_remote_prefill and num_external_tokens > 0 :
364- new_blocks = self ._allocate_and_set_waiting_for_remote_kv (
365- request , num_external_tokens , new_computed_blocks ,
366- skipped_waiting_requests )
367- if new_blocks is None :
368- # Not enough KV cache space
369- break
370- continue
371-
372- # Number of tokens to be scheduled.
373- # We use `request.num_tokens` instead of
374- # `request.num_prompt_tokens` to consider the resumed requests,
375- # which have output tokens.
376- num_new_tokens = request .num_tokens - num_computed_tokens
377- if (0 < self .scheduler_config .long_prefill_token_threshold <
378- num_new_tokens ):
379- num_new_tokens = (
380- self .scheduler_config .long_prefill_token_threshold )
381- num_new_tokens = min (num_new_tokens , token_budget )
382- assert num_new_tokens > 0
383-
384- # Schedule encoder inputs.
385- if request .has_encoder_inputs :
386- (encoder_inputs_to_schedule , num_new_tokens ,
387- new_encoder_budget ) = self ._try_schedule_encoder_inputs (
388- request , num_computed_tokens , num_new_tokens ,
389- encoder_budget )
390- if num_new_tokens == 0 :
391- # The request cannot be scheduled.
392- break
361+ encoder_inputs_to_schedule = None
362+ new_encoder_budget = encoder_budget
363+ if not load_kv_async :
364+ # Number of tokens to be scheduled.
365+ # We use `request.num_tokens` instead of
366+ # `request.num_prompt_tokens` to consider the resumed
367+ # requests, which have output tokens.
368+ num_new_tokens = request .num_tokens - num_computed_tokens
369+ if (0 < self .scheduler_config .long_prefill_token_threshold
370+ < num_new_tokens ):
371+ num_new_tokens = (
372+ self .scheduler_config .long_prefill_token_threshold )
373+ num_new_tokens = min (num_new_tokens , token_budget )
374+ assert num_new_tokens > 0
375+
376+ # Schedule encoder inputs.
377+ if request .has_encoder_inputs :
378+ (encoder_inputs_to_schedule , num_new_tokens ,
379+ new_encoder_budget
380+ ) = self ._try_schedule_encoder_inputs (
381+ request , num_computed_tokens , num_new_tokens ,
382+ encoder_budget )
383+ if num_new_tokens == 0 :
384+ # The request cannot be scheduled.
385+ break
393386 else :
394- encoder_inputs_to_schedule = None
395- new_encoder_budget = encoder_budget
387+ num_new_tokens = 0
396388
397389 new_blocks = self .kv_cache_manager .allocate_slots (
398390 request ,
399391 num_new_tokens + num_external_tokens ,
400392 new_computed_blocks ,
401393 num_lookahead_tokens = self .num_lookahead_tokens ,
394+ delay_cache_blocks = load_kv_async ,
402395 )
403396 if new_blocks is None :
404397 # The request cannot be scheduled.
@@ -415,6 +408,13 @@ def schedule(self) -> SchedulerOutput:
415408 )
416409
417410 self .waiting .popleft ()
411+ if load_kv_async :
412+ # If loading async, allocate memory and put request
413+ # into the WAITING_FOR_REMOTE_KV state.
414+ skipped_waiting_requests .appendleft (request )
415+ request .status = RequestStatus .WAITING_FOR_REMOTE_KVS
416+ continue
417+
418418 if request .use_structured_output :
419419 structured_output_request_ids [
420420 request .request_id ] = req_index
@@ -445,7 +445,7 @@ def schedule(self) -> SchedulerOutput:
445445 request .num_computed_tokens = num_computed_tokens
446446
447447 # Encoder-related.
448- if not request . do_remote_prefill and encoder_inputs_to_schedule :
448+ if encoder_inputs_to_schedule :
449449 scheduled_encoder_inputs [request .request_id ] = (
450450 encoder_inputs_to_schedule )
451451 # Allocate the encoder cache.
@@ -924,11 +924,7 @@ def shutdown(self) -> None:
924924 # P/D Related Methods
925925 ########################################################################
926926
927- def _update_waiting_for_remote_kv (
928- self ,
929- request : Request ,
930- skipped_waiting_requests : deque [Request ],
931- ) -> bool :
927+ def _update_waiting_for_remote_kv (self , request : Request ) -> bool :
932928 """
933929 P/D: check if the request_id is finished_recving.
934930
@@ -937,87 +933,29 @@ def _update_waiting_for_remote_kv(
937933 on the worker side connector.
938934
939935 When the kv transfer is ready, we cache the blocks
940- and update the request state to be in WAITING from
936+ and the request state will be moved back to WAITING from
941937 WAITING_FOR_REMOTE_KV.
942938 """
943- if request .request_id in self .finished_recving_kv_req_ids :
944- # Now that the blocks are ready, actually cache them.
945- blocks = self .kv_cache_manager .req_to_blocks [request .request_id ]
946- num_computed_tokens = len (blocks ) * self .block_size
947- if num_computed_tokens == request .num_tokens :
948- num_computed_tokens -= 1
949- self .kv_cache_manager .cache_blocks (
950- request ,
951- num_tokens = 0 ,
952- num_computed_tokens = num_computed_tokens ,
953- new_computed_block_list = [])
954-
955- # Update the request state for scheduling.
956- request .num_computed_tokens = num_computed_tokens
957- request .status = RequestStatus .WAITING
958- # NOTE(rob): only read the blocks from remote once.
959- request .do_remote_prefill = False
960-
961- # Set that we are ready.
962- self .finished_recving_kv_req_ids .remove (request .request_id )
963- is_ready = True
964- else :
965- self .waiting .popleft ()
966- skipped_waiting_requests .appendleft (request )
967- is_ready = False
968-
969- return is_ready
970-
971- def _allocate_and_set_waiting_for_remote_kv (
972- self ,
973- request : Request ,
974- num_external_tokens : int ,
975- new_computed_blocks : KVCacheBlocks ,
976- skipped_waiting_requests : deque [Request ],
977- ) -> Optional [KVCacheBlocks ]:
978- """
979- P/D: allocate KV cache blocks for a request and put
980- the request into the WAITING_FOR_REMOTE_KV state and
981- update the KVConnector state.
982-
983- The KV caches are allocated but NOT cached. This is
984- to avoid another request getting a cache hit on a
985- block that has not been written to. We will cache
986- the blocks only after the recv is complete.
987-
988- The update_state_after_alloc() function passes this
989- request to the KVConnector, which triggers KVConnector
990- to start a read_blocks transaction.
991- """
992-
993- # Allocate slots for the external tokens, but skip
994- # caching until after the KV transfer is done to avoid
995- # cache hits on blocks that are still be written to.
996- new_blocks = self .kv_cache_manager .allocate_slots (
939+ if request .request_id not in self .finished_recving_kv_req_ids :
940+ return False
941+
942+ # Now that the blocks are ready, actually cache them.
943+ blocks = self .kv_cache_manager .req_to_blocks [request .request_id ]
944+ num_computed_tokens = len (blocks ) * self .block_size
945+ if num_computed_tokens == request .num_tokens :
946+ num_computed_tokens -= 1
947+ self .kv_cache_manager .cache_blocks (
997948 request ,
998- num_external_tokens ,
999- new_computed_blocks ,
1000- delay_cache_blocks = True )
1001- if new_blocks is None :
1002- return None
949+ num_tokens = 0 ,
950+ num_computed_tokens = num_computed_tokens ,
951+ new_computed_block_list = [])
1003952
1004- self .waiting .popleft ()
1005- skipped_waiting_requests .appendleft (request )
1006- request .status = RequestStatus .WAITING_FOR_REMOTE_KVS
1007-
1008- # KVConnector: update internal state after allocation.
1009- # This information is used to determine if a load is
1010- # needed for this request.
1011- assert self .connector is not None
1012- self .connector .update_state_after_alloc (
1013- request ,
1014- new_computed_blocks + new_blocks ,
1015- num_external_tokens ,
1016- )
1017- # Only trigger a KV transfer once per request.
1018- request .do_remote_prefill = False
953+ # Update the request state for scheduling.
954+ request .num_computed_tokens = num_computed_tokens
1019955
1020- return new_blocks
956+ # Return that we are ready.
957+ self .finished_recving_kv_req_ids .remove (request .request_id )
958+ return True
1021959
1022960 def _set_finished_remote_decode (
1023961 self ,
0 commit comments