@@ -198,7 +198,7 @@ def schedule(self) -> SchedulerOutput:
198198 encoder_compute_budget = self .max_num_encoder_input_tokens
199199 # Spec decode-related.
200200 scheduled_spec_decode_tokens : dict [str , list [int ]] = {}
201-
201+ total_num_spec_tokens = 0
202202 # For logging.
203203 scheduled_timestamp = time .monotonic ()
204204
@@ -286,6 +286,9 @@ def schedule(self) -> SchedulerOutput:
286286 preempted_req .status = RequestStatus .PREEMPTED
287287 preempted_req .num_computed_tokens = 0
288288 preempted_req .num_preemptions += 1
289+ # both sync and async scheduling don't use spec_token_ids
290+ # in waiting queue, so we can just clear it here.
291+ preempted_req .spec_token_ids .clear ()
289292 if self .log_stats :
290293 preempted_req .record_event (
291294 EngineCoreEventType .PREEMPTED , scheduled_timestamp
@@ -311,13 +314,17 @@ def schedule(self) -> SchedulerOutput:
311314 # Speculative decode related.
312315 if request .spec_token_ids :
313316 num_scheduled_spec_tokens = (
314- num_new_tokens + request .num_computed_tokens - request .num_tokens
317+ num_new_tokens
318+ + request .num_computed_tokens
319+ - request .num_tokens
320+ - request .num_output_placeholders
315321 )
316322 if num_scheduled_spec_tokens > 0 :
323+ total_num_spec_tokens += num_scheduled_spec_tokens
317324 # Trim spec_token_ids list to num_scheduled_spec_tokens.
318325 del request .spec_token_ids [num_scheduled_spec_tokens :]
319326 scheduled_spec_decode_tokens [request .request_id ] = (
320- request .spec_token_ids
327+ request .spec_token_ids . copy ()
321328 )
322329
323330 # Encoder-related.
@@ -631,6 +638,7 @@ def schedule(self) -> SchedulerOutput:
631638 free_encoder_mm_hashes = self .encoder_cache_manager .get_freed_mm_hashes (),
632639 structured_output_request_ids = structured_output_request_ids ,
633640 grammar_bitmask = grammar_bitmask ,
641+ total_num_scheduled_spec_tokens = total_num_spec_tokens ,
634642 )
635643
636644 # NOTE(Kuntai): this function is designed for multiple purposes:
@@ -959,19 +967,11 @@ def update_from_output(
959967 scheduler_output .scheduled_spec_decode_tokens .get (req_id )
960968 )
961969 if scheduled_spec_token_ids :
962- num_draft_tokens = len (scheduled_spec_token_ids )
963- num_accepted = len (generated_token_ids ) - 1
964- num_rejected = num_draft_tokens - num_accepted
965- # num_computed_tokens represents the number of tokens
966- # processed in the current step, considering scheduled
967- # tokens and rejections. If some tokens are rejected,
968- # num_computed_tokens is decreased by the number of rejected
969- # tokens.
970- request .num_computed_tokens -= num_rejected
971- spec_decoding_stats = self .make_spec_decoding_stats (
970+ spec_decoding_stats = self ._update_computed_tokens (
971+ request ,
972+ scheduled_spec_token_ids ,
973+ generated_token_ids ,
972974 spec_decoding_stats ,
973- num_draft_tokens = num_draft_tokens ,
974- num_accepted_tokens = num_accepted ,
975975 )
976976
977977 stopped = False
@@ -1085,6 +1085,29 @@ def update_from_output(
10851085
10861086 return engine_core_outputs
10871087
1088+ def _update_computed_tokens (
1089+ self ,
1090+ request : Request ,
1091+ scheduled_spec_token_ids : list [int ],
1092+ generated_token_ids : list [int ],
1093+ spec_decoding_status : SpecDecodingStats | None ,
1094+ ):
1095+ num_draft_tokens = len (scheduled_spec_token_ids )
1096+ num_accepted = len (generated_token_ids ) - 1
1097+ num_rejected = num_draft_tokens - num_accepted
1098+ # num_computed_tokens represents the number of tokens
1099+ # processed in the current step, considering scheduled
1100+ # tokens and rejections. If some tokens are rejected,
1101+ # num_computed_tokens is decreased by the number of rejected
1102+ # tokens.
1103+ request .num_computed_tokens -= num_rejected
1104+ spec_decoding_stats = self .make_spec_decoding_stats (
1105+ spec_decoding_status ,
1106+ num_draft_tokens = num_draft_tokens ,
1107+ num_accepted_tokens = num_accepted ,
1108+ )
1109+ return spec_decoding_stats
1110+
10881111 def _update_request_with_output (
10891112 self ,
10901113 request : Request ,
0 commit comments