@@ -194,7 +194,7 @@ def schedule(self) -> SchedulerOutput:
194194        encoder_compute_budget  =  self .max_num_encoder_input_tokens 
195195        # Spec decode-related. 
196196        scheduled_spec_decode_tokens : dict [str , list [int ]] =  {}
197- 
197+          total_num_spec_tokens   =   0 
198198        # For logging. 
199199        scheduled_timestamp  =  time .monotonic ()
200200
@@ -282,6 +282,9 @@ def schedule(self) -> SchedulerOutput:
282282                preempted_req .status  =  RequestStatus .PREEMPTED 
283283                preempted_req .num_computed_tokens  =  0 
284284                preempted_req .num_preemptions  +=  1 
285+                 # both sync and async scheduling don't use spec_token_ids 
286+                 # in waiting queue, so we can just clear it here. 
287+                 preempted_req .spec_token_ids .clear ()
285288                if  self .log_stats :
286289                    preempted_req .record_event (
287290                        EngineCoreEventType .PREEMPTED , scheduled_timestamp 
@@ -307,13 +310,17 @@ def schedule(self) -> SchedulerOutput:
307310            # Speculative decode related. 
308311            if  request .spec_token_ids :
309312                num_scheduled_spec_tokens  =  (
310-                     num_new_tokens  +  request .num_computed_tokens  -  request .num_tokens 
313+                     num_new_tokens 
314+                     +  request .num_computed_tokens 
315+                     -  request .num_tokens 
316+                     -  request .num_output_placeholders 
311317                )
312318                if  num_scheduled_spec_tokens  >  0 :
319+                     total_num_spec_tokens  +=  num_scheduled_spec_tokens 
313320                    # Trim spec_token_ids list to num_scheduled_spec_tokens. 
314321                    del  request .spec_token_ids [num_scheduled_spec_tokens :]
315322                    scheduled_spec_decode_tokens [request .request_id ] =  (
316-                         request .spec_token_ids 
323+                         request .spec_token_ids . copy () 
317324                    )
318325
319326            # Encoder-related. 
@@ -632,6 +639,7 @@ def schedule(self) -> SchedulerOutput:
632639            free_encoder_mm_hashes = self .encoder_cache_manager .get_freed_mm_hashes (),
633640            structured_output_request_ids = structured_output_request_ids ,
634641            grammar_bitmask = grammar_bitmask ,
642+             total_num_scheduled_spec_tokens = total_num_spec_tokens ,
635643        )
636644
637645        # NOTE(Kuntai): this function is designed for multiple purposes: 
@@ -960,19 +968,11 @@ def update_from_output(
960968                scheduler_output .scheduled_spec_decode_tokens .get (req_id )
961969            )
962970            if  scheduled_spec_token_ids :
963-                 num_draft_tokens  =  len (scheduled_spec_token_ids )
964-                 num_accepted  =  len (generated_token_ids ) -  1 
965-                 num_rejected  =  num_draft_tokens  -  num_accepted 
966-                 # num_computed_tokens represents the number of tokens 
967-                 # processed in the current step, considering scheduled 
968-                 # tokens and rejections. If some tokens are rejected, 
969-                 # num_computed_tokens is decreased by the number of rejected 
970-                 # tokens. 
971-                 request .num_computed_tokens  -=  num_rejected 
972-                 spec_decoding_stats  =  self .make_spec_decoding_stats (
971+                 spec_decoding_stats  =  self ._update_computed_tokens (
972+                     request ,
973+                     scheduled_spec_token_ids ,
974+                     generated_token_ids ,
973975                    spec_decoding_stats ,
974-                     num_draft_tokens = num_draft_tokens ,
975-                     num_accepted_tokens = num_accepted ,
976976                )
977977
978978            stopped  =  False 
@@ -1088,6 +1088,29 @@ def update_from_output(
10881088
10891089        return  engine_core_outputs 
10901090
1091+     def  _update_computed_tokens (
1092+         self ,
1093+         request : Request ,
1094+         scheduled_spec_token_ids : list [int ],
1095+         generated_token_ids : list [int ],
1096+         spec_decoding_status : SpecDecodingStats  |  None ,
1097+     ):
1098+         num_draft_tokens  =  len (scheduled_spec_token_ids )
1099+         num_accepted  =  len (generated_token_ids ) -  1 
1100+         num_rejected  =  num_draft_tokens  -  num_accepted 
1101+         # num_computed_tokens represents the number of tokens 
1102+         # processed in the current step, considering scheduled 
1103+         # tokens and rejections. If some tokens are rejected, 
1104+         # num_computed_tokens is decreased by the number of rejected 
1105+         # tokens. 
1106+         request .num_computed_tokens  -=  num_rejected 
1107+         spec_decoding_stats  =  self .make_spec_decoding_stats (
1108+             spec_decoding_status ,
1109+             num_draft_tokens = num_draft_tokens ,
1110+             num_accepted_tokens = num_accepted ,
1111+         )
1112+         return  spec_decoding_stats 
1113+ 
10911114    def  _update_request_with_output (
10921115        self ,
10931116        request : Request ,
0 commit comments