Skip to content

Commit 9803a4b

Browse files
committed
[BugFix] Make penalties and bad_words work with async scheduling
Signed-off-by: Nick Hill <[email protected]>
1 parent bb6d8c2 commit 9803a4b

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

vllm/v1/worker/gpu_input_batch.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,40 @@ def __init__(
252252
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
253253
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
254254
self.prev_req_id_to_index: Optional[dict[str, int]] = None
255+
# These are used to update output_token_ids with real sampled
256+
# ids from prior step, if required by current sampling params
257+
# (e.g. penalties).
258+
self.sampled_token_ids_cpu: Optional[torch.Tensor] = None
259+
self.async_copy_ready_event: Optional[torch.cuda.Event] = None
260+
261+
def set_async_sampled_token_ids(
262+
self,
263+
sampled_token_ids_cpu: torch.Tensor,
264+
async_copy_ready_event: torch.cuda.Event,
265+
) -> None:
266+
if self.sampling_metadata.output_token_ids:
267+
self.sampled_token_ids_cpu = sampled_token_ids_cpu
268+
self.async_copy_ready_event = async_copy_ready_event
269+
else:
270+
self.sampled_token_ids_cpu = None
271+
self.async_copy_ready_event = None
272+
273+
def update_async_output_token_ids(self) -> None:
274+
output_token_ids = self.sampling_metadata.output_token_ids
275+
if self.sampled_token_ids_cpu is None or not output_token_ids:
276+
return
277+
278+
assert self.prev_req_id_to_index is not None
279+
sampled_token_ids = None
280+
for index, req_id in enumerate(self.req_ids):
281+
prev_index = self.prev_req_id_to_index.get(req_id)
282+
if prev_index is None:
283+
continue
284+
if sampled_token_ids is None:
285+
assert self.async_copy_ready_event is not None
286+
self.async_copy_ready_event.synchronize()
287+
sampled_token_ids = self.sampled_token_ids_cpu.squeeze().tolist()
288+
output_token_ids[index][-1] = sampled_token_ids[prev_index]
255289

256290
@property
257291
def req_ids(self) -> list[str]:
@@ -777,6 +811,15 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
777811
else:
778812
prompt_token_ids = None
779813

814+
# Only set output_token_ids if required by the current requests'
815+
# sampling parameters.
816+
needs_output_token_ids = not self.no_penalties or bool(self.bad_words_token_ids)
817+
output_token_ids = (
818+
cast(list[list[int]], self.req_output_token_ids)
819+
if needs_output_token_ids
820+
else []
821+
)
822+
780823
allowed_token_ids_mask: Optional[torch.Tensor] = None
781824
if not self.no_allowed_token_ids:
782825
assert self.allowed_token_ids_mask is not None
@@ -799,7 +842,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
799842
frequency_penalties=self.frequency_penalties[:num_reqs],
800843
presence_penalties=self.presence_penalties[:num_reqs],
801844
repetition_penalties=self.repetition_penalties[:num_reqs],
802-
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
845+
output_token_ids=output_token_ids,
803846
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
804847
no_penalties=self.no_penalties,
805848
allowed_token_ids_mask=allowed_token_ids_mask,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(
178178
self._invalid_req_indices = invalid_req_indices
179179

180180
# Event on the copy stream so we can synchronize the non-blocking copy.
181-
self._async_copy_ready_event = torch.cuda.Event()
181+
self.async_copy_ready_event = torch.cuda.Event()
182182

183183
# Keep a reference to the device tensor to avoid it being
184184
# deallocated until we finish copying it to the host.
@@ -188,22 +188,22 @@ def __init__(
188188
default_stream = torch.cuda.current_stream()
189189
with torch.cuda.stream(async_output_copy_stream):
190190
async_output_copy_stream.wait_stream(default_stream)
191-
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
191+
self.sampled_token_ids_cpu = self._sampled_token_ids.to(
192192
"cpu", non_blocking=True
193193
)
194-
self._async_copy_ready_event.record()
194+
self.async_copy_ready_event.record()
195195

196196
def get_output(self) -> ModelRunnerOutput:
197197
"""Copy the device tensors to the host and return a ModelRunnerOutput.
198198
199199
This function blocks until the copy is finished.
200200
"""
201-
self._async_copy_ready_event.synchronize()
201+
self.async_copy_ready_event.synchronize()
202202

203203
# Release the device tensor once the copy has completed
204204
del self._sampled_token_ids
205205

206-
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
206+
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
207207
for i in self._invalid_req_indices:
208208
valid_sampled_token_ids[i].clear()
209209

@@ -2185,6 +2185,9 @@ def _sample(
21852185
# Sample the next token and get logprobs if needed.
21862186
sampling_metadata = self.input_batch.sampling_metadata
21872187
if spec_decode_metadata is None:
2188+
# Update output token ids with tokens sampled in last step
2189+
# if async scheduling and required by current sampling params.
2190+
self.input_batch.update_async_output_token_ids()
21882191
sampler_output = self.sampler(
21892192
logits=logits,
21902193
sampling_metadata=sampling_metadata,
@@ -2644,13 +2647,22 @@ def propose_draft_token_ids(sampled_token_ids):
26442647
if not self.use_async_scheduling:
26452648
return output
26462649

2647-
return AsyncGPUModelRunnerOutput(
2650+
async_output = AsyncGPUModelRunnerOutput(
26482651
model_runner_output=output,
26492652
sampled_token_ids=sampler_output.sampled_token_ids,
26502653
invalid_req_indices=invalid_req_indices,
26512654
async_output_copy_stream=self.async_output_copy_stream,
26522655
)
26532656

2657+
# Save ref of sampled_token_ids CPU tensor if the batch contains
2658+
# any requests with sampling params that that require output ids.
2659+
self.input_batch.set_async_sampled_token_ids(
2660+
async_output.sampled_token_ids_cpu,
2661+
async_output.async_copy_ready_event,
2662+
)
2663+
2664+
return async_output
2665+
26542666
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
26552667
if self._draft_token_ids is None:
26562668
return None

0 commit comments

Comments
 (0)