Skip to content

Commit 00a9b41

Browse files
njhillxuebwang-amd
authored andcommitted
[BugFix] Make penalties and bad_words work with async scheduling (vllm-project#26467)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
1 parent d5540b1 commit 00a9b41

File tree

4 files changed

+113
-14
lines changed

4 files changed

+113
-14
lines changed

tests/v1/e2e/test_async_sched_and_preempt.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
2828
sampling_param_tests: list[dict[str, Any]] = [
2929
dict(),
3030
# dict(min_tokens=20),
31-
# TODO enable these with https://github.com/vllm-project/vllm/pull/26467.
32-
# dict(repetition_penalty=0.1),
33-
# dict(bad_words=[]),
31+
dict(presence_penalty=-1.0),
32+
dict(bad_words=["the", " the"]),
3433
]
3534

3635
default_params = dict(
@@ -42,9 +41,9 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
4241
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
4342
# m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1")
4443

45-
outputs = []
44+
outputs: list[tuple[str, list]] = []
4645
for test_preemption in [False, True]:
47-
for executor in ["uni", "mp"]:
46+
for executor in ["mp", "uni"]:
4847
for async_scheduling in [False, True]:
4948
cache_arg: dict[str, Any] = (
5049
dict(num_gpu_blocks_override=32)
@@ -78,6 +77,21 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
7877
),
7978
)
8079
)
80+
81+
if not outputs:
82+
# First check that the different parameter configs
83+
# actually result in different output.
84+
for other_test, params in zip(
85+
results[1:], sampling_param_tests[1:]
86+
):
87+
with pytest.raises(AssertionError):
88+
check_outputs_equal(
89+
outputs_0_lst=results[0],
90+
outputs_1_lst=other_test,
91+
name_0=f"baseline params={params}",
92+
name_1=f"other params={params}",
93+
)
94+
8195
outputs.append((test_config, results))
8296

8397
baseline_config, baseline_tests = outputs[0]

vllm/v1/core/sched/scheduler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,9 @@ def _make_cached_request_data(
737737
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
738738
)
739739
num_computed_tokens.append(req.num_computed_tokens)
740-
num_output_tokens.append(req.num_output_tokens)
740+
num_output_tokens.append(
741+
req.num_output_tokens + req.num_output_placeholders
742+
)
741743

742744
return CachedRequestData(
743745
req_ids=req_ids,

vllm/v1/worker/gpu_input_batch.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
block_sizes: list[int], # The block_size of each kv cache group
8080
kernel_block_sizes: list[int],
8181
logitsprocs: Optional[LogitsProcessors] = None,
82+
logitsprocs_need_output_token_ids: bool = False,
8283
is_spec_decode: bool = False,
8384
is_pooling_model: bool = False,
8485
num_speculative_tokens: int = 0,
@@ -240,6 +241,7 @@ def __init__(
240241
# Store provided logitsprocs. If none are provided, initialize empty
241242
# data structure
242243
self.logitsprocs = logitsprocs or LogitsProcessors()
244+
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
243245

244246
# Store last speculative tokens for sampler.
245247
self.spec_token_ids: list[Optional[list[int]]] = []
@@ -252,6 +254,11 @@ def __init__(
252254
# Cached reference to the GPU tensor of previously sampled tokens
253255
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
254256
self.prev_req_id_to_index: Optional[dict[str, int]] = None
257+
# These are used to update output_token_ids with real sampled
258+
# ids from prior step, if required by current sampling params
259+
# (e.g. penalties).
260+
self.sampled_token_ids_cpu: Optional[torch.Tensor] = None
261+
self.async_copy_ready_event: Optional[torch.cuda.Event] = None
255262

256263
@property
257264
def req_ids(self) -> list[str]:
@@ -776,6 +783,19 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
776783
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
777784
)
778785

786+
# Only set output_token_ids if required by the current requests'
787+
# sampling parameters.
788+
needs_output_token_ids = (
789+
not self.no_penalties
790+
or bool(self.bad_words_token_ids)
791+
or self.logitsprocs_need_output_token_ids
792+
)
793+
output_token_ids = (
794+
cast(list[list[int]], self.req_output_token_ids)
795+
if needs_output_token_ids
796+
else []
797+
)
798+
779799
allowed_token_ids_mask: Optional[torch.Tensor] = None
780800
if not self.no_allowed_token_ids:
781801
assert self.allowed_token_ids_mask is not None
@@ -798,7 +818,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
798818
frequency_penalties=self.frequency_penalties[:num_reqs],
799819
presence_penalties=self.presence_penalties[:num_reqs],
800820
repetition_penalties=self.repetition_penalties[:num_reqs],
801-
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
821+
output_token_ids=output_token_ids,
802822
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
803823
no_penalties=self.no_penalties,
804824
allowed_token_ids_mask=allowed_token_ids_mask,
@@ -859,6 +879,52 @@ def make_lora_inputs(
859879

860880
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
861881

882+
def set_async_sampled_token_ids(
883+
self,
884+
sampled_token_ids_cpu: torch.Tensor,
885+
async_copy_ready_event: torch.cuda.Event,
886+
) -> None:
887+
"""
888+
In async scheduling case, store ref to sampled_token_ids_cpu
889+
tensor and corresponding copy-ready event. Used to repair
890+
output_token_ids prior to sampling, if needed by logits processors.
891+
"""
892+
if self.sampling_metadata.output_token_ids:
893+
self.sampled_token_ids_cpu = sampled_token_ids_cpu
894+
self.async_copy_ready_event = async_copy_ready_event
895+
else:
896+
self.sampled_token_ids_cpu = None
897+
self.async_copy_ready_event = None
898+
899+
def update_async_output_token_ids(self) -> None:
900+
"""
901+
In async scheduling case, update output_token_ids in sampling metadata
902+
from prior steps sampled token ids once they've finished copying to CPU.
903+
This is called right before they are needed by the logits processors.
904+
"""
905+
output_token_ids = self.sampling_metadata.output_token_ids
906+
if self.sampled_token_ids_cpu is None or not output_token_ids:
907+
# Output token ids not needed or not async scheduling.
908+
return
909+
910+
assert self.prev_req_id_to_index is not None
911+
sampled_token_ids = None
912+
for index, req_id in enumerate(self.req_ids):
913+
prev_index = self.prev_req_id_to_index.get(req_id)
914+
if prev_index is None:
915+
continue
916+
req_output_token_ids = output_token_ids[index]
917+
if not req_output_token_ids or req_output_token_ids[-1] != -1:
918+
# Final output id is not a placeholder, some tokens must have
919+
# been discarded after a kv-load failure.
920+
continue
921+
if sampled_token_ids is None:
922+
assert self.async_copy_ready_event is not None
923+
self.async_copy_ready_event.synchronize()
924+
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist()
925+
# Replace placeholder token id with actual sampled id.
926+
req_output_token_ids[-1] = sampled_token_ids[prev_index]
927+
862928
@property
863929
def num_reqs(self) -> int:
864930
return len(self.req_id_to_index)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(
178178
self._invalid_req_indices = invalid_req_indices
179179

180180
# Event on the copy stream so we can synchronize the non-blocking copy.
181-
self._async_copy_ready_event = torch.cuda.Event()
181+
self.async_copy_ready_event = torch.cuda.Event()
182182

183183
# Keep a reference to the device tensor to avoid it being
184184
# deallocated until we finish copying it to the host.
@@ -188,22 +188,22 @@ def __init__(
188188
default_stream = torch.cuda.current_stream()
189189
with torch.cuda.stream(async_output_copy_stream):
190190
async_output_copy_stream.wait_stream(default_stream)
191-
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
191+
self.sampled_token_ids_cpu = self._sampled_token_ids.to(
192192
"cpu", non_blocking=True
193193
)
194-
self._async_copy_ready_event.record()
194+
self.async_copy_ready_event.record()
195195

196196
def get_output(self) -> ModelRunnerOutput:
197197
"""Copy the device tensors to the host and return a ModelRunnerOutput.
198198
199199
This function blocks until the copy is finished.
200200
"""
201-
self._async_copy_ready_event.synchronize()
201+
self.async_copy_ready_event.synchronize()
202202

203203
# Release the device tensor once the copy has completed
204204
del self._sampled_token_ids
205205

206-
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
206+
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
207207
for i in self._invalid_req_indices:
208208
valid_sampled_token_ids[i].clear()
209209

@@ -349,6 +349,7 @@ def __init__(
349349
# solution, we initialize the input batch here, and re-initialize it
350350
# in `initialize_kv_cache` if the block_sizes here is different from
351351
# the block_sizes in the kv cache config.
352+
custom_logitsprocs = model_config.logits_processors
352353
self.input_batch = InputBatch(
353354
max_num_reqs=self.max_num_reqs,
354355
# We need to use the encoder length for encoder-decoer
@@ -366,8 +367,11 @@ def __init__(
366367
self.device,
367368
self.pin_memory,
368369
self.is_pooling_model,
369-
self.vllm_config.model_config.logits_processors,
370+
custom_logitsprocs,
370371
),
372+
# We currently don't know whether a particular custom logits processor
373+
# uses output token ids so we set this conservatively.
374+
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
371375
is_pooling_model=self.is_pooling_model,
372376
)
373377

@@ -2210,6 +2214,9 @@ def _sample(
22102214
# Sample the next token and get logprobs if needed.
22112215
sampling_metadata = self.input_batch.sampling_metadata
22122216
if spec_decode_metadata is None:
2217+
# Update output token ids with tokens sampled in last step
2218+
# if async scheduling and required by current sampling params.
2219+
self.input_batch.update_async_output_token_ids()
22132220
return self.sampler(
22142221
logits=logits,
22152222
sampling_metadata=sampling_metadata,
@@ -2666,13 +2673,22 @@ def propose_draft_token_ids(sampled_token_ids):
26662673
if not self.use_async_scheduling:
26672674
return output
26682675

2669-
return AsyncGPUModelRunnerOutput(
2676+
async_output = AsyncGPUModelRunnerOutput(
26702677
model_runner_output=output,
26712678
sampled_token_ids=sampler_output.sampled_token_ids,
26722679
invalid_req_indices=invalid_req_indices,
26732680
async_output_copy_stream=self.async_output_copy_stream,
26742681
)
26752682

2683+
# Save ref of sampled_token_ids CPU tensor if the batch contains
2684+
# any requests with sampling params that that require output ids.
2685+
self.input_batch.set_async_sampled_token_ids(
2686+
async_output.sampled_token_ids_cpu,
2687+
async_output.async_copy_ready_event,
2688+
)
2689+
2690+
return async_output
2691+
26762692
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
26772693
if self._draft_token_ids is None:
26782694
return None
@@ -4198,6 +4214,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
41984214
kernel_block_sizes=kernel_block_sizes,
41994215
is_spec_decode=bool(self.vllm_config.speculative_config),
42004216
logitsprocs=self.input_batch.logitsprocs,
4217+
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
42014218
is_pooling_model=self.is_pooling_model,
42024219
num_speculative_tokens=(
42034220
self.vllm_config.speculative_config.num_speculative_tokens

0 commit comments

Comments
 (0)