@@ -79,6 +79,7 @@ def __init__(
7979 block_sizes : list [int ], # The block_size of each kv cache group
8080 kernel_block_sizes : list [int ],
8181 logitsprocs : Optional [LogitsProcessors ] = None ,
82+ logitsprocs_need_output_token_ids : bool = False ,
8283 is_spec_decode : bool = False ,
8384 is_pooling_model : bool = False ,
8485 num_speculative_tokens : int = 0 ,
@@ -240,6 +241,7 @@ def __init__(
240241 # Store provided logitsprocs. If none are provided, initialize empty
241242 # data structure
242243 self .logitsprocs = logitsprocs or LogitsProcessors ()
244+ self .logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
243245
244246 # Store last speculative tokens for sampler.
245247 self .spec_token_ids : list [Optional [list [int ]]] = []
@@ -252,6 +254,11 @@ def __init__(
252254 # Cached reference to the GPU tensor of previously sampled tokens
253255 self .prev_sampled_token_ids : Optional [torch .Tensor ] = None
254256 self .prev_req_id_to_index : Optional [dict [str , int ]] = None
257+ # These are used to update output_token_ids with real sampled
258+ # ids from prior step, if required by current sampling params
259+ # (e.g. penalties).
260+ self .sampled_token_ids_cpu : Optional [torch .Tensor ] = None
261+ self .async_copy_ready_event : Optional [torch .cuda .Event ] = None
255262
256263 @property
257264 def req_ids (self ) -> list [str ]:
@@ -776,6 +783,19 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
776783 self ._make_prompt_token_ids_tensor () if needs_prompt_token_ids else None
777784 )
778785
786+ # Only set output_token_ids if required by the current requests'
787+ # sampling parameters.
788+ needs_output_token_ids = (
789+ not self .no_penalties
790+ or bool (self .bad_words_token_ids )
791+ or self .logitsprocs_need_output_token_ids
792+ )
793+ output_token_ids = (
794+ cast (list [list [int ]], self .req_output_token_ids )
795+ if needs_output_token_ids
796+ else []
797+ )
798+
779799 allowed_token_ids_mask : Optional [torch .Tensor ] = None
780800 if not self .no_allowed_token_ids :
781801 assert self .allowed_token_ids_mask is not None
@@ -798,7 +818,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
798818 frequency_penalties = self .frequency_penalties [:num_reqs ],
799819 presence_penalties = self .presence_penalties [:num_reqs ],
800820 repetition_penalties = self .repetition_penalties [:num_reqs ],
801- output_token_ids = cast ( list [ list [ int ]], self . req_output_token_ids ) ,
821+ output_token_ids = output_token_ids ,
802822 spec_token_ids = cast (list [list [int ]], self .spec_token_ids ),
803823 no_penalties = self .no_penalties ,
804824 allowed_token_ids_mask = allowed_token_ids_mask ,
@@ -859,6 +879,52 @@ def make_lora_inputs(
859879
860880 return prompt_lora_mapping , token_lora_mapping , active_lora_requests
861881
882+ def set_async_sampled_token_ids (
883+ self ,
884+ sampled_token_ids_cpu : torch .Tensor ,
885+ async_copy_ready_event : torch .cuda .Event ,
886+ ) -> None :
887+ """
888+ In async scheduling case, store ref to sampled_token_ids_cpu
889+ tensor and corresponding copy-ready event. Used to repair
890+ output_token_ids prior to sampling, if needed by logits processors.
891+ """
892+ if self .sampling_metadata .output_token_ids :
893+ self .sampled_token_ids_cpu = sampled_token_ids_cpu
894+ self .async_copy_ready_event = async_copy_ready_event
895+ else :
896+ self .sampled_token_ids_cpu = None
897+ self .async_copy_ready_event = None
898+
899+ def update_async_output_token_ids (self ) -> None :
900+ """
901+ In async scheduling case, update output_token_ids in sampling metadata
902+ from prior steps sampled token ids once they've finished copying to CPU.
903+ This is called right before they are needed by the logits processors.
904+ """
905+ output_token_ids = self .sampling_metadata .output_token_ids
906+ if self .sampled_token_ids_cpu is None or not output_token_ids :
907+ # Output token ids not needed or not async scheduling.
908+ return
909+
910+ assert self .prev_req_id_to_index is not None
911+ sampled_token_ids = None
912+ for index , req_id in enumerate (self .req_ids ):
913+ prev_index = self .prev_req_id_to_index .get (req_id )
914+ if prev_index is None :
915+ continue
916+ req_output_token_ids = output_token_ids [index ]
917+ if not req_output_token_ids or req_output_token_ids [- 1 ] != - 1 :
918+ # Final output id is not a placeholder, some tokens must have
919+ # been discarded after a kv-load failure.
920+ continue
921+ if sampled_token_ids is None :
922+ assert self .async_copy_ready_event is not None
923+ self .async_copy_ready_event .synchronize ()
924+ sampled_token_ids = self .sampled_token_ids_cpu .squeeze (- 1 ).tolist ()
925+ # Replace placeholder token id with actual sampled id.
926+ req_output_token_ids [- 1 ] = sampled_token_ids [prev_index ]
927+
862928 @property
863929 def num_reqs (self ) -> int :
864930 return len (self .req_id_to_index )
0 commit comments