55from vllm .v1 .core .sched .output import SchedulerOutput
66from vllm .v1 .core .sched .scheduler import Scheduler
77from vllm .v1 .request import Request , RequestStatus
8+ from vllm .v1 .spec_decode .metrics import SpecDecodingStats
89
910logger = init_logger (__name__ )
1011
@@ -15,15 +16,25 @@ def _update_after_schedule(
1516 scheduler_output : SchedulerOutput ,
1617 ) -> None :
1718 super ()._update_after_schedule (scheduler_output )
19+ spec_decode_tokens = scheduler_output .scheduled_spec_decode_tokens
1820 for req_id in scheduler_output .num_scheduled_tokens :
1921 request = self .requests [req_id ]
22+ cur_num_spec_tokens = len (spec_decode_tokens .get (req_id , []))
2023 if (
2124 request .num_computed_tokens
22- == request .num_tokens + request .num_output_placeholders
25+ == request .num_tokens
26+ + request .num_output_placeholders
27+ + cur_num_spec_tokens
2328 ):
24- # The request will generate a new token in this scheduling step.
25- # TODO(woosuk): Support speculative decoding.
26- request .num_output_placeholders += 1
29+ # The request will generate a new token plus num_spec_tokens
30+ # in this scheduling step.
31+ request .num_output_placeholders += 1 + cur_num_spec_tokens
32+ # Add a placeholder for the new token in spec_token_ids.
33+ # because the actual token id is not known yet. so just use -1
34+ # as a placeholder and the length of spec_token_ids is set to
35+ # self.num_spec_tokens. we will update the actual spec token id
36+ # in worker process.
37+ request .spec_token_ids = [- 1 ] * self .num_spec_tokens
2738
2839 def _update_request_with_output (
2940 self ,
@@ -34,9 +45,13 @@ def _update_request_with_output(
3445 new_token_ids , stopped = super ()._update_request_with_output (
3546 request , new_token_ids
3647 )
37-
38- # Update the number of output placeholders.
39- request .num_output_placeholders -= len (new_token_ids )
48+ # num_output_placeholders = 0 happend when a request is preempted.
49+ # a preempted request will be added to waitting queue again and
50+ # num_output_placeholders is reset to 0,
51+ # so don't need to revert num_output_placeholders for this situation.
52+ if request .num_output_placeholders > 0 :
53+ # Update the number of output placeholders.
54+ request .num_output_placeholders -= len (new_token_ids )
4055 assert request .num_output_placeholders >= 0
4156
4257 # Cache the new tokens. Preempted requests should be skipped.
@@ -45,3 +60,40 @@ def _update_request_with_output(
4560 request , request .num_computed_tokens - request .num_output_placeholders
4661 )
4762 return new_token_ids , stopped
63+
64+ def _update_computed_tokens (
65+ self ,
66+ request : Request ,
67+ num_draft_tokens : int ,
68+ num_accepted : int ,
69+ num_rejected : int ,
70+ spec_decoding_stats : SpecDecodingStats | None ,
71+ ):
72+ """Update the computed tokens for each request, which is necessary
73+ for spec decoding. In sync scheduler, we need to revert
74+ num_computed_tokens by num_rejected tokens,
75+ but in async scheduler, we also need to revert num_output_placeholders
76+ by num_rejected tokens for spec decoding.
77+ """
78+ # num_computed_tokens = 0 happend when a request is preempted.
79+ # a preempted request will be added to waitting queue again and
80+ # num_computed_tokens is reset to 0,
81+ # so don't need to revert num_computed_tokens for this situation.
82+ if request .num_computed_tokens > 0 :
83+ # when spec decoding is enabled, num_output_placeholders
84+ # is increased by num_spec_tokens in _update_after_schedule.
85+ # update num_output_placeholders here to reflect the actual number
86+ # of accepted output tokens.
87+ request .num_output_placeholders -= num_rejected
88+ # num_computed_tokens represents the number of tokens
89+ # processed in the current step, considering scheduled
90+ # tokens and rejections. If some tokens are rejected,
91+ # num_computed_tokens is decreased by the number of rejected
92+ # tokens.
93+ request .num_computed_tokens -= num_rejected
94+ spec_decoding_stats = self .make_spec_decoding_stats (
95+ spec_decoding_stats ,
96+ num_draft_tokens = num_draft_tokens ,
97+ num_accepted_tokens = num_accepted ,
98+ )
99+ return spec_decoding_stats
0 commit comments