@@ -178,7 +178,7 @@ def __init__(
178178 self ._invalid_req_indices = invalid_req_indices
179179
180180 # Event on the copy stream so we can synchronize the non-blocking copy.
181- self ._async_copy_ready_event = torch .cuda .Event ()
181+ self .async_copy_ready_event = torch .cuda .Event ()
182182
183183 # Keep a reference to the device tensor to avoid it being
184184 # deallocated until we finish copying it to the host.
@@ -188,22 +188,22 @@ def __init__(
188188 default_stream = torch .cuda .current_stream ()
189189 with torch .cuda .stream (async_output_copy_stream ):
190190 async_output_copy_stream .wait_stream (default_stream )
191- self ._sampled_token_ids_cpu = self ._sampled_token_ids .to (
191+ self .sampled_token_ids_cpu = self ._sampled_token_ids .to (
192192 "cpu" , non_blocking = True
193193 )
194- self ._async_copy_ready_event .record ()
194+ self .async_copy_ready_event .record ()
195195
196196 def get_output (self ) -> ModelRunnerOutput :
197197 """Copy the device tensors to the host and return a ModelRunnerOutput.
198198
199199 This function blocks until the copy is finished.
200200 """
201- self ._async_copy_ready_event .synchronize ()
201+ self .async_copy_ready_event .synchronize ()
202202
203203 # Release the device tensor once the copy has completed
204204 del self ._sampled_token_ids
205205
206- valid_sampled_token_ids = self ._sampled_token_ids_cpu .tolist ()
206+ valid_sampled_token_ids = self .sampled_token_ids_cpu .tolist ()
207207 for i in self ._invalid_req_indices :
208208 valid_sampled_token_ids [i ].clear ()
209209
@@ -2185,6 +2185,9 @@ def _sample(
21852185 # Sample the next token and get logprobs if needed.
21862186 sampling_metadata = self .input_batch .sampling_metadata
21872187 if spec_decode_metadata is None :
2188+ # Update output token ids with tokens sampled in last step
2189+ # if async scheduling and required by current sampling params.
2190+ self .input_batch .update_async_output_token_ids ()
21882191 sampler_output = self .sampler (
21892192 logits = logits ,
21902193 sampling_metadata = sampling_metadata ,
@@ -2644,13 +2647,22 @@ def propose_draft_token_ids(sampled_token_ids):
26442647 if not self .use_async_scheduling :
26452648 return output
26462649
2647- return AsyncGPUModelRunnerOutput (
2650+ async_output = AsyncGPUModelRunnerOutput (
26482651 model_runner_output = output ,
26492652 sampled_token_ids = sampler_output .sampled_token_ids ,
26502653 invalid_req_indices = invalid_req_indices ,
26512654 async_output_copy_stream = self .async_output_copy_stream ,
26522655 )
26532656
2657+ # Save ref of sampled_token_ids CPU tensor if the batch contains
2658+ # any requests with sampling params that that require output ids.
2659+ self .input_batch .set_async_sampled_token_ids (
2660+ async_output .sampled_token_ids_cpu ,
2661+ async_output .async_copy_ready_event ,
2662+ )
2663+
2664+ return async_output
2665+
26542666 def take_draft_token_ids (self ) -> Optional [DraftTokenIds ]:
26552667 if self ._draft_token_ids is None :
26562668 return None
0 commit comments