Skip to content

Commit 24a709f

Browse files
committed
[BugFix] Make penalties and bad_words work with async scheduling
Signed-off-by: Nick Hill <[email protected]>
1 parent 2c1c7df commit 24a709f

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

vllm/v1/worker/gpu_input_batch.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,40 @@ def __init__(
253253
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
254254
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
255255
self.prev_req_id_to_index: Optional[dict[str, int]] = None
256+
# These are used to update output_token_ids with real sampled
257+
# ids from prior step, if required by current sampling params
258+
# (e.g. penalties).
259+
self.sampled_token_ids_cpu: Optional[torch.Tensor] = None
260+
self.async_copy_ready_event: Optional[torch.cuda.Event] = None
261+
262+
def set_async_sampled_token_ids(
263+
self,
264+
sampled_token_ids_cpu: torch.Tensor,
265+
async_copy_ready_event: torch.cuda.Event,
266+
) -> None:
267+
if self.sampling_metadata.output_token_ids:
268+
self.sampled_token_ids_cpu = sampled_token_ids_cpu
269+
self.async_copy_ready_event = async_copy_ready_event
270+
else:
271+
self.sampled_token_ids_cpu = None
272+
self.async_copy_ready_event = None
273+
274+
def update_async_output_token_ids(self) -> None:
275+
output_token_ids = self.sampling_metadata.output_token_ids
276+
if self.sampled_token_ids_cpu is None or not output_token_ids:
277+
return
278+
279+
assert self.prev_req_id_to_index is not None
280+
sampled_token_ids = None
281+
for index, req_id in enumerate(self.req_ids):
282+
prev_index = self.prev_req_id_to_index.get(req_id)
283+
if prev_index is None:
284+
continue
285+
if sampled_token_ids is None:
286+
assert self.async_copy_ready_event is not None
287+
self.async_copy_ready_event.synchronize()
288+
sampled_token_ids = self.sampled_token_ids_cpu.squeeze().tolist()
289+
output_token_ids[index][-1] = sampled_token_ids[prev_index]
256290

257291
@property
258292
def req_ids(self) -> list[str]:
@@ -777,6 +811,15 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
777811
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
778812
)
779813

814+
# Only set output_token_ids if required by the current requests'
815+
# sampling parameters.
816+
needs_output_token_ids = not self.no_penalties or bool(self.bad_words_token_ids)
817+
output_token_ids = (
818+
cast(list[list[int]], self.req_output_token_ids)
819+
if needs_output_token_ids
820+
else []
821+
)
822+
780823
allowed_token_ids_mask: Optional[torch.Tensor] = None
781824
if not self.no_allowed_token_ids:
782825
assert self.allowed_token_ids_mask is not None
@@ -799,7 +842,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
799842
frequency_penalties=self.frequency_penalties[:num_reqs],
800843
presence_penalties=self.presence_penalties[:num_reqs],
801844
repetition_penalties=self.repetition_penalties[:num_reqs],
802-
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
845+
output_token_ids=output_token_ids,
803846
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
804847
no_penalties=self.no_penalties,
805848
allowed_token_ids_mask=allowed_token_ids_mask,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(
178178
self._invalid_req_indices = invalid_req_indices
179179

180180
# Event on the copy stream so we can synchronize the non-blocking copy.
181-
self._async_copy_ready_event = torch.cuda.Event()
181+
self.async_copy_ready_event = torch.cuda.Event()
182182

183183
# Keep a reference to the device tensor to avoid it being
184184
# deallocated until we finish copying it to the host.
@@ -188,22 +188,22 @@ def __init__(
188188
default_stream = torch.cuda.current_stream()
189189
with torch.cuda.stream(async_output_copy_stream):
190190
async_output_copy_stream.wait_stream(default_stream)
191-
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
191+
self.sampled_token_ids_cpu = self._sampled_token_ids.to(
192192
"cpu", non_blocking=True
193193
)
194-
self._async_copy_ready_event.record()
194+
self.async_copy_ready_event.record()
195195

196196
def get_output(self) -> ModelRunnerOutput:
197197
"""Copy the device tensors to the host and return a ModelRunnerOutput.
198198
199199
This function blocks until the copy is finished.
200200
"""
201-
self._async_copy_ready_event.synchronize()
201+
self.async_copy_ready_event.synchronize()
202202

203203
# Release the device tensor once the copy has completed
204204
del self._sampled_token_ids
205205

206-
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
206+
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
207207
for i in self._invalid_req_indices:
208208
valid_sampled_token_ids[i].clear()
209209

@@ -2188,6 +2188,9 @@ def _sample(
21882188
# Sample the next token and get logprobs if needed.
21892189
sampling_metadata = self.input_batch.sampling_metadata
21902190
if spec_decode_metadata is None:
2191+
# Update output token ids with tokens sampled in last step
2192+
# if async scheduling and required by current sampling params.
2193+
self.input_batch.update_async_output_token_ids()
21912194
return self.sampler(
21922195
logits=logits,
21932196
sampling_metadata=sampling_metadata,
@@ -2646,13 +2649,22 @@ def propose_draft_token_ids(sampled_token_ids):
26462649
if not self.use_async_scheduling:
26472650
return output
26482651

2649-
return AsyncGPUModelRunnerOutput(
2652+
async_output = AsyncGPUModelRunnerOutput(
26502653
model_runner_output=output,
26512654
sampled_token_ids=sampler_output.sampled_token_ids,
26522655
invalid_req_indices=invalid_req_indices,
26532656
async_output_copy_stream=self.async_output_copy_stream,
26542657
)
26552658

2659+
# Save ref of sampled_token_ids CPU tensor if the batch contains
2660+
# any requests with sampling params that that require output ids.
2661+
self.input_batch.set_async_sampled_token_ids(
2662+
async_output.sampled_token_ids_cpu,
2663+
async_output.async_copy_ready_event,
2664+
)
2665+
2666+
return async_output
2667+
26562668
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
26572669
if self._draft_token_ids is None:
26582670
return None

0 commit comments

Comments
 (0)