Skip to content

Commit 0eee271

Browse files
committed
support async_scheduling for spec-decode
Signed-off-by: Ronald1995 <[email protected]>
1 parent ba09652 commit 0eee271

File tree

9 files changed

+431
-93
lines changed

9 files changed

+431
-93
lines changed

vllm/config/speculative.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import ast
55
import hashlib
6-
from typing import TYPE_CHECKING, Any, Literal
6+
from typing import TYPE_CHECKING, Any, Literal, get_args
77

88
from pydantic import SkipValidation, model_validator
99
from pydantic.dataclasses import dataclass
@@ -30,28 +30,23 @@
3030

3131
logger = init_logger(__name__)
3232

33-
SpeculativeMethod = Literal[
34-
"ngram",
35-
"eagle",
36-
"eagle3",
37-
"medusa",
38-
"mlp_speculator",
39-
"draft_model",
40-
"deepseek_mtp",
41-
"ernie_mtp",
42-
"qwen3_next_mtp",
43-
"mimo_mtp",
44-
"longcat_flash_mtp",
45-
"mtp",
46-
]
47-
MTP_MODEL_TYPES = (
33+
MTPModelTypes = Literal[
4834
"deepseek_mtp",
4935
"mimo_mtp",
5036
"glm4_moe_mtp",
5137
"ernie_mtp",
5238
"qwen3_next_mtp",
5339
"longcat_flash_mtp",
54-
)
40+
"mtp",
41+
]
42+
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
43+
SpeculativeMethod = Literal[
44+
"ngram",
45+
"medusa",
46+
"mlp_speculator",
47+
"draft_model",
48+
EagleModelTypes,
49+
]
5550

5651

5752
@config
@@ -224,7 +219,7 @@ def __post_init__(self):
224219
# can not be detected, it will be considered as the "draft_model" by
225220
# default.
226221

227-
if self.method in MTP_MODEL_TYPES:
222+
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
228223
logger.warning(
229224
"method `%s` is deprecated and replaced with mtp.", self.method
230225
)
@@ -338,7 +333,9 @@ def __post_init__(self):
338333
self.method = "medusa"
339334
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
340335
self.method = "mlp_speculator"
341-
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
336+
elif self.draft_model_config.hf_config.model_type in get_args(
337+
MTPModelTypes
338+
):
342339
self.method = "mtp"
343340
if self.num_speculative_tokens > 1:
344341
logger.warning(

vllm/engine/arg_utils.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from vllm.config.observability import DetailedTraceModules
7070
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
7171
from vllm.config.scheduler import SchedulerPolicy
72+
from vllm.config.speculative import EagleModelTypes
7273
from vllm.config.utils import get_field
7374
from vllm.logger import init_logger
7475
from vllm.platforms import CpuArchEnum, current_platform
@@ -1465,13 +1466,25 @@ def create_engine_config(
14651466
"Async scheduling is not supported with pipeline-parallel-size > 1."
14661467
)
14671468

1468-
# Currently, async scheduling does not support speculative decoding.
1469-
# TODO(woosuk): Support it.
1469+
# Currently, async scheduling only support eagle speculative
1470+
# decoding.
1471+
# TODO(woosuk): Support other kinds of speculative decoding.
14701472
if self.speculative_config is not None:
1471-
raise ValueError(
1472-
"Currently, speculative decoding is not supported with "
1473-
"async scheduling."
1474-
)
1473+
if self.speculative_config.get("method") not in get_args(
1474+
EagleModelTypes
1475+
):
1476+
raise ValueError(
1477+
"Currently, async scheduling is only supported "
1478+
"with EAGLE/MTP kind of speculative decodeing"
1479+
)
1480+
elif self.speculative_config.get("disable_padded_drafter_batch"):
1481+
raise ValueError(
1482+
"async scheduling for EAGLE/MTP kind of speculative "
1483+
"decodeing is enabled, but disable_padded_drafter_batch=True "
1484+
"disable_padded_drafter_batch=True is not supported for "
1485+
"this situation now. please set "
1486+
"disable_padded_drafter_batch=Fasle"
1487+
)
14751488

14761489
# Forward the deprecated CLI args to the EPLB config.
14771490
if self.num_redundant_experts is not None:

vllm/v1/core/sched/async_scheduler.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from vllm.v1.core.sched.output import SchedulerOutput
66
from vllm.v1.core.sched.scheduler import Scheduler
77
from vllm.v1.request import Request, RequestStatus
8+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
89

910
logger = init_logger(__name__)
1011

@@ -15,15 +16,25 @@ def _update_after_schedule(
1516
scheduler_output: SchedulerOutput,
1617
) -> None:
1718
super()._update_after_schedule(scheduler_output)
19+
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
1820
for req_id in scheduler_output.num_scheduled_tokens:
1921
request = self.requests[req_id]
22+
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, []))
2023
if (
2124
request.num_computed_tokens
22-
== request.num_tokens + request.num_output_placeholders
25+
== request.num_tokens
26+
+ request.num_output_placeholders
27+
+ cur_num_spec_tokens
2328
):
24-
# The request will generate a new token in this scheduling step.
25-
# TODO(woosuk): Support speculative decoding.
26-
request.num_output_placeholders += 1
29+
# The request will generate a new token plus num_spec_tokens
30+
# in this scheduling step.
31+
request.num_output_placeholders += 1 + cur_num_spec_tokens
32+
# Add a placeholder for the new token in spec_token_ids.
33+
# because the actual token id is not known yet. so just use -1
34+
# as a placeholder and the length of spec_token_ids is set to
35+
# self.num_spec_tokens. we will update the actual spec token id
36+
# in worker process.
37+
request.spec_token_ids = [-1] * self.num_spec_tokens
2738

2839
def _update_request_with_output(
2940
self,
@@ -34,9 +45,13 @@ def _update_request_with_output(
3445
new_token_ids, stopped = super()._update_request_with_output(
3546
request, new_token_ids
3647
)
37-
38-
# Update the number of output placeholders.
39-
request.num_output_placeholders -= len(new_token_ids)
48+
# num_output_placeholders = 0 happend when a request is preempted.
49+
# a preempted request will be added to waitting queue again and
50+
# num_output_placeholders is reset to 0,
51+
# so don't need to revert num_output_placeholders for this situation.
52+
if request.num_output_placeholders > 0:
53+
# Update the number of output placeholders.
54+
request.num_output_placeholders -= len(new_token_ids)
4055
assert request.num_output_placeholders >= 0
4156

4257
# Cache the new tokens. Preempted requests should be skipped.
@@ -45,3 +60,40 @@ def _update_request_with_output(
4560
request, request.num_computed_tokens - request.num_output_placeholders
4661
)
4762
return new_token_ids, stopped
63+
64+
def _update_computed_tokens(
65+
self,
66+
request: Request,
67+
num_draft_tokens: int,
68+
num_accepted: int,
69+
num_rejected: int,
70+
spec_decoding_stats: SpecDecodingStats | None,
71+
):
72+
"""Update the computed tokens for each request, which is necessary
73+
for spec decoding. In sync scheduler, we need to revert
74+
num_computed_tokens by num_rejected tokens,
75+
but in async scheduler, we also need to revert num_output_placeholders
76+
by num_rejected tokens for spec decoding.
77+
"""
78+
# num_computed_tokens = 0 happend when a request is preempted.
79+
# a preempted request will be added to waitting queue again and
80+
# num_computed_tokens is reset to 0,
81+
# so don't need to revert num_computed_tokens for this situation.
82+
if request.num_computed_tokens > 0:
83+
# when spec decoding is enabled, num_output_placeholders
84+
# is increased by num_spec_tokens in _update_after_schedule.
85+
# update num_output_placeholders here to reflect the actual number
86+
# of accepted output tokens.
87+
request.num_output_placeholders -= num_rejected
88+
# num_computed_tokens represents the number of tokens
89+
# processed in the current step, considering scheduled
90+
# tokens and rejections. If some tokens are rejected,
91+
# num_computed_tokens is decreased by the number of rejected
92+
# tokens.
93+
request.num_computed_tokens -= num_rejected
94+
spec_decoding_stats = self.make_spec_decoding_stats(
95+
spec_decoding_stats,
96+
num_draft_tokens=num_draft_tokens,
97+
num_accepted_tokens=num_accepted,
98+
)
99+
return spec_decoding_stats

vllm/v1/core/sched/output.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,8 @@ class SchedulerOutput:
174174

175175
# KV Cache Connector metadata.
176176
kv_connector_metadata: KVConnectorMetadata | None = None
177+
178+
# Total number of speculative scheduled tokens for all requests.
179+
# this is needed when using async_scheduling and speculative
180+
# togather.
181+
total_num_scheduled_spec_tokens: int = 0

vllm/v1/core/sched/scheduler.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def schedule(self) -> SchedulerOutput:
198198
encoder_compute_budget = self.max_num_encoder_input_tokens
199199
# Spec decode-related.
200200
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
201-
201+
total_num_spec_tokens = 0
202202
# For logging.
203203
scheduled_timestamp = time.monotonic()
204204

@@ -285,7 +285,11 @@ def schedule(self) -> SchedulerOutput:
285285
self.encoder_cache_manager.free(preempted_req)
286286
preempted_req.status = RequestStatus.PREEMPTED
287287
preempted_req.num_computed_tokens = 0
288+
preempted_req.num_output_placeholders = 0
288289
preempted_req.num_preemptions += 1
290+
# both sync and async scheduling don't use spec_token_ids
291+
# in waiting queue, so we can just clear it here.
292+
preempted_req.spec_token_ids.clear()
289293
if self.log_stats:
290294
preempted_req.record_event(
291295
EngineCoreEventType.PREEMPTED, scheduled_timestamp
@@ -311,9 +315,13 @@ def schedule(self) -> SchedulerOutput:
311315
# Speculative decode related.
312316
if request.spec_token_ids:
313317
num_scheduled_spec_tokens = (
314-
num_new_tokens + request.num_computed_tokens - request.num_tokens
318+
num_new_tokens
319+
+ request.num_computed_tokens
320+
- request.num_tokens
321+
- request.num_output_placeholders
315322
)
316323
if num_scheduled_spec_tokens > 0:
324+
total_num_spec_tokens += num_scheduled_spec_tokens
317325
# Trim spec_token_ids list to num_scheduled_spec_tokens.
318326
del request.spec_token_ids[num_scheduled_spec_tokens:]
319327
scheduled_spec_decode_tokens[request.request_id] = (
@@ -631,6 +639,7 @@ def schedule(self) -> SchedulerOutput:
631639
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
632640
structured_output_request_ids=structured_output_request_ids,
633641
grammar_bitmask=grammar_bitmask,
642+
total_num_scheduled_spec_tokens=total_num_spec_tokens,
634643
)
635644

636645
# NOTE(Kuntai): this function is designed for multiple purposes:
@@ -962,16 +971,12 @@ def update_from_output(
962971
num_draft_tokens = len(scheduled_spec_token_ids)
963972
num_accepted = len(generated_token_ids) - 1
964973
num_rejected = num_draft_tokens - num_accepted
965-
# num_computed_tokens represents the number of tokens
966-
# processed in the current step, considering scheduled
967-
# tokens and rejections. If some tokens are rejected,
968-
# num_computed_tokens is decreased by the number of rejected
969-
# tokens.
970-
request.num_computed_tokens -= num_rejected
971-
spec_decoding_stats = self.make_spec_decoding_stats(
974+
spec_decoding_stats = self._update_computed_tokens(
975+
request,
976+
num_draft_tokens,
977+
num_accepted,
978+
num_rejected,
972979
spec_decoding_stats,
973-
num_draft_tokens=num_draft_tokens,
974-
num_accepted_tokens=num_accepted,
975980
)
976981

977982
stopped = False
@@ -1085,6 +1090,27 @@ def update_from_output(
10851090

10861091
return engine_core_outputs
10871092

1093+
def _update_computed_tokens(
1094+
self,
1095+
request: Request,
1096+
num_draft_tokens: int,
1097+
num_accepted: int,
1098+
num_rejected: int,
1099+
spec_decoding_stats: SpecDecodingStats | None,
1100+
):
1101+
# num_computed_tokens represents the number of tokens
1102+
# processed in the current step, considering scheduled
1103+
# tokens and rejections. If some tokens are rejected,
1104+
# num_computed_tokens is decreased by the number of rejected
1105+
# tokens.
1106+
request.num_computed_tokens -= num_rejected
1107+
spec_decoding_stats = self.make_spec_decoding_stats(
1108+
spec_decoding_stats,
1109+
num_draft_tokens=num_draft_tokens,
1110+
num_accepted_tokens=num_accepted,
1111+
)
1112+
return spec_decoding_stats
1113+
10881114
def _update_request_with_output(
10891115
self,
10901116
request: Request,

vllm/v1/engine/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def __init__(
198198
self.step_fn = (
199199
self.step if self.batch_queue is None else self.step_with_batch_queue
200200
)
201+
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
201202

202203
def _initialize_kv_caches(
203204
self, vllm_config: VllmConfig
@@ -330,7 +331,10 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
330331
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
331332

332333
def post_step(self, model_executed: bool) -> None:
333-
if self.use_spec_decode and model_executed:
334+
# when using async scheduling we can't get draft token ids in adavance,
335+
# so we update draft token ids in the worker process and don't
336+
# need to update draft token ids here.
337+
if self.use_spec_decode and model_executed and not self.async_scheduling:
334338
# Take the draft token ids.
335339
draft_token_ids = self.model_executor.take_draft_token_ids()
336340
if draft_token_ids is not None:

vllm/v1/spec_decode/eagle.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(
185185
device=device,
186186
dtype=torch.int32,
187187
).repeat(max_batch_size, 1)
188+
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
188189

189190
def _get_positions(self, num_tokens: int):
190191
if self.uses_mrope:
@@ -387,14 +388,27 @@ def propose(
387388
positions += 1
388389
exceeds_max_model_len = positions >= self.max_model_len
389390
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
391+
# when enable use_async_scheduling, we shouldn't use in place
392+
# operations in case they are modified in next step `prepare_input`
393+
# of main model.
394+
if self.use_async_scheduling:
395+
# Increment the sequence lengths.
396+
common_attn_metadata.seq_lens = common_attn_metadata.seq_lens + 1
397+
common_attn_metadata.seq_lens_cpu = (
398+
common_attn_metadata.seq_lens_cpu + 1
399+
)
400+
# For the requests that exceed the max model length, we set the
401+
# sequence length to 1 to minimize their overheads in attention.
390402

391-
# Increment the sequence lengths.
392-
common_attn_metadata.seq_lens += 1
393-
common_attn_metadata.seq_lens_cpu += 1
394-
# For the requests that exceed the max model length, we set the
395-
# sequence length to 1 to minimize their overheads in attention.
403+
common_attn_metadata.seq_lens.masked_fill(exceeds_max_model_len, 1)
404+
else:
405+
# Increment the sequence lengths.
406+
common_attn_metadata.seq_lens += 1
407+
common_attn_metadata.seq_lens_cpu += 1
408+
# For the requests that exceed the max model length, we set the
409+
# sequence length to 1 to minimize their overheads in attention.
396410

397-
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
411+
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
398412

399413
common_attn_metadata.num_computed_tokens_cpu = (
400414
common_attn_metadata.seq_lens_cpu - 1

vllm/v1/worker/gpu_input_batch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ class CachedRequestState:
4545

4646
lora_request: LoRARequest | None = None
4747
prompt_embeds: torch.Tensor | None = None
48+
# these are used when both async_scheduling and spec_decode are enabled.
49+
prev_num_draft_len: int = 0
50+
prev_sampled_tokens: torch.Tensor | None = None
51+
prev_draft_tokens: torch.Tensor | None = None
52+
resumed_from_preemption: bool = False
4853

4954
def __post_init__(self):
5055
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(

0 commit comments

Comments
 (0)