@@ -253,6 +253,40 @@ def __init__(
253253 self .prev_sampled_token_ids : Optional [torch .Tensor ] = None
254254 self .prev_sampled_token_ids_invalid_indices : Optional [set [int ]] = None
255255 self .prev_req_id_to_index : Optional [dict [str , int ]] = None
256+ # These are used to update output_token_ids with real sampled
257+ # ids from prior step, if required by current sampling params
258+ # (e.g. penalties).
259+ self .sampled_token_ids_cpu : Optional [torch .Tensor ] = None
260+ self .async_copy_ready_event : Optional [torch .cuda .Event ] = None
261+
262+ def set_async_sampled_token_ids (
263+ self ,
264+ sampled_token_ids_cpu : torch .Tensor ,
265+ async_copy_ready_event : torch .cuda .Event ,
266+ ) -> None :
267+ if self .sampling_metadata .output_token_ids :
268+ self .sampled_token_ids_cpu = sampled_token_ids_cpu
269+ self .async_copy_ready_event = async_copy_ready_event
270+ else :
271+ self .sampled_token_ids_cpu = None
272+ self .async_copy_ready_event = None
273+
274+ def update_async_output_token_ids (self ) -> None :
275+ output_token_ids = self .sampling_metadata .output_token_ids
276+ if self .sampled_token_ids_cpu is None or not output_token_ids :
277+ return
278+
279+ assert self .prev_req_id_to_index is not None
280+ sampled_token_ids = None
281+ for index , req_id in enumerate (self .req_ids ):
282+ prev_index = self .prev_req_id_to_index .get (req_id )
283+ if prev_index is None :
284+ continue
285+ if sampled_token_ids is None :
286+ assert self .async_copy_ready_event is not None
287+ self .async_copy_ready_event .synchronize ()
288+ sampled_token_ids = self .sampled_token_ids_cpu .squeeze ().tolist ()
289+ output_token_ids [index ][- 1 ] = sampled_token_ids [prev_index ]
256290
257291 @property
258292 def req_ids (self ) -> list [str ]:
@@ -777,6 +811,15 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
777811 self ._make_prompt_token_ids_tensor () if needs_prompt_token_ids else None
778812 )
779813
814+ # Only set output_token_ids if required by the current requests'
815+ # sampling parameters.
816+ needs_output_token_ids = not self .no_penalties or bool (self .bad_words_token_ids )
817+ output_token_ids = (
818+ cast (list [list [int ]], self .req_output_token_ids )
819+ if needs_output_token_ids
820+ else []
821+ )
822+
780823 allowed_token_ids_mask : Optional [torch .Tensor ] = None
781824 if not self .no_allowed_token_ids :
782825 assert self .allowed_token_ids_mask is not None
@@ -799,7 +842,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
799842 frequency_penalties = self .frequency_penalties [:num_reqs ],
800843 presence_penalties = self .presence_penalties [:num_reqs ],
801844 repetition_penalties = self .repetition_penalties [:num_reqs ],
802- output_token_ids = cast ( list [ list [ int ]], self . req_output_token_ids ) ,
845+ output_token_ids = output_token_ids ,
803846 spec_token_ids = cast (list [list [int ]], self .spec_token_ids ),
804847 no_penalties = self .no_penalties ,
805848 allowed_token_ids_mask = allowed_token_ids_mask ,
0 commit comments