Skip to content

Commit 04b3625

Browse files
committed
support async_scheduling for spec-decode
Signed-off-by: Ronald1995 <[email protected]>
1 parent 8e67b25 commit 04b3625

File tree

8 files changed

+360
-85
lines changed

8 files changed

+360
-85
lines changed

vllm/config/speculative.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import ast
55
import hashlib
6-
from typing import TYPE_CHECKING, Any, Literal
6+
from typing import TYPE_CHECKING, Any, Literal, get_args
77

88
from pydantic import SkipValidation, model_validator
99
from pydantic.dataclasses import dataclass
@@ -30,28 +30,23 @@
3030

3131
logger = init_logger(__name__)
3232

33-
SpeculativeMethod = Literal[
34-
"ngram",
35-
"eagle",
36-
"eagle3",
37-
"medusa",
38-
"mlp_speculator",
39-
"draft_model",
40-
"deepseek_mtp",
41-
"ernie_mtp",
42-
"qwen3_next_mtp",
43-
"mimo_mtp",
44-
"longcat_flash_mtp",
45-
"mtp",
46-
]
47-
MTP_MODEL_TYPES = (
33+
MTPModelTypes = Literal[
4834
"deepseek_mtp",
4935
"mimo_mtp",
5036
"glm4_moe_mtp",
5137
"ernie_mtp",
5238
"qwen3_next_mtp",
5339
"longcat_flash_mtp",
54-
)
40+
"mtp",
41+
]
42+
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
43+
SpeculativeMethod = Literal[
44+
"ngram",
45+
"medusa",
46+
"mlp_speculator",
47+
"draft_model",
48+
EagleModelTypes,
49+
]
5550

5651

5752
@config
@@ -224,7 +219,7 @@ def __post_init__(self):
224219
# can not be detected, it will be considered as the "draft_model" by
225220
# default.
226221

227-
if self.method in MTP_MODEL_TYPES:
222+
if self.method in get_args(MTPModelTypes):
228223
logger.warning(
229224
"method `%s` is deprecated and replaced with mtp.", self.method
230225
)
@@ -338,7 +333,9 @@ def __post_init__(self):
338333
self.method = "medusa"
339334
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
340335
self.method = "mlp_speculator"
341-
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
336+
elif self.draft_model_config.hf_config.model_type in get_args(
337+
MTPModelTypes
338+
):
342339
self.method = "mtp"
343340
if self.num_speculative_tokens > 1:
344341
logger.warning(

vllm/engine/arg_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from vllm.config.observability import DetailedTraceModules
6969
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
7070
from vllm.config.scheduler import SchedulerPolicy
71+
from vllm.config.speculative import EagleModelTypes
7172
from vllm.config.utils import get_field
7273
from vllm.logger import init_logger
7374
from vllm.platforms import CpuArchEnum, current_platform
@@ -1424,12 +1425,15 @@ def create_engine_config(
14241425
"Async scheduling is not supported with pipeline-parallel-size > 1."
14251426
)
14261427

1427-
# Currently, async scheduling does not support speculative decoding.
1428-
# TODO(woosuk): Support it.
1429-
if self.speculative_config is not None:
1428+
# Currently, async scheduling only support eagle speculative
1429+
# decoding.
1430+
# TODO(woosuk): Support other kinds of speculative decoding.
1431+
if self.speculative_config is not None and self.speculative_config.get(
1432+
"method"
1433+
) not in get_args(EagleModelTypes):
14301434
raise ValueError(
1431-
"Currently, speculative decoding is not supported with "
1432-
"async scheduling."
1435+
"Currently, async scheduling is only supported "
1436+
"with EAGLE/MTP kind of speculative decodeing."
14331437
)
14341438

14351439
# Forward the deprecated CLI args to the EPLB config.

vllm/v1/core/sched/async_scheduler.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from vllm.v1.core.sched.output import SchedulerOutput
66
from vllm.v1.core.sched.scheduler import Scheduler
77
from vllm.v1.request import Request, RequestStatus
8+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
89

910
logger = init_logger(__name__)
1011

@@ -15,15 +16,24 @@ def _update_after_schedule(
1516
scheduler_output: SchedulerOutput,
1617
) -> None:
1718
super()._update_after_schedule(scheduler_output)
19+
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
1820
for req_id in scheduler_output.num_scheduled_tokens:
1921
request = self.requests[req_id]
22+
spec_tokens = len(spec_decode_tokens.get(req_id, []))
2023
if (
2124
request.num_computed_tokens
22-
== request.num_tokens + request.num_output_placeholders
25+
== request.num_tokens + request.num_output_placeholders + spec_tokens
2326
):
24-
# The request will generate a new token in this scheduling step.
25-
# TODO(woosuk): Support speculative decoding.
26-
request.num_output_placeholders += 1
27+
# The request will generate a new token plus num_spec_tokens
28+
# in this scheduling step.
29+
request.num_output_placeholders += 1 + spec_tokens
30+
# Add a placeholder for the new token in spec_token_ids.
31+
# because the actual token id is not known yet. so just use -1
32+
# as a placeholder and the length of spec_token_ids is set to
33+
# self.num_spec_tokens. we will update the actual spec token id
34+
# in worker process.
35+
if self.num_spec_tokens > 0:
36+
request.spec_token_ids = [-1] * self.num_spec_tokens
2737

2838
def _update_request_with_output(
2939
self,
@@ -45,3 +55,37 @@ def _update_request_with_output(
4555
request, request.num_computed_tokens - request.num_output_placeholders
4656
)
4757
return new_token_ids, stopped
58+
59+
def _update_computed_tokens(
60+
self,
61+
request: Request,
62+
scheduled_spec_token_ids: list[int],
63+
generated_token_ids: list[int],
64+
spec_decoding_stats: SpecDecodingStats | None,
65+
):
66+
"""Update the computed tokens for each request, which is necessary
67+
for spec decoding. In sync scheduler, we need to revert
68+
num_computed_tokens by num_rejected tokens,
69+
but in async scheduler, we also need to revert num_output_placeholders
70+
by num_rejected tokens for spec decoding.
71+
"""
72+
num_draft_tokens = len(scheduled_spec_token_ids)
73+
num_accepted = len(generated_token_ids) - 1
74+
num_rejected = num_draft_tokens - num_accepted
75+
# when spec decoding is enabled, num_output_placeholders
76+
# is increased by num_spec_tokens in _update_after_schedule.
77+
# update num_output_placeholders here to reflect the actual number
78+
# of accepted output tokens.
79+
request.num_output_placeholders -= num_rejected
80+
# num_computed_tokens represents the number of tokens
81+
# processed in the current step, considering scheduled
82+
# tokens and rejections. If some tokens are rejected,
83+
# num_computed_tokens is decreased by the number of rejected
84+
# tokens.
85+
request.num_computed_tokens -= num_rejected
86+
spec_decoding_stats = self.make_spec_decoding_stats(
87+
spec_decoding_stats,
88+
num_draft_tokens=num_draft_tokens,
89+
num_accepted_tokens=num_accepted,
90+
)
91+
return spec_decoding_stats

vllm/v1/core/sched/output.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,8 @@ class SchedulerOutput:
173173

174174
# KV Cache Connector metadata.
175175
kv_connector_metadata: KVConnectorMetadata | None = None
176+
177+
# Total number of speculative scheduled tokens for all requests.
178+
# this is needed when using async_scheduling and speculative
179+
# togather.
180+
total_num_scheduled_spec_tokens: int = 0

vllm/v1/core/sched/scheduler.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def schedule(self) -> SchedulerOutput:
194194
encoder_compute_budget = self.max_num_encoder_input_tokens
195195
# Spec decode-related.
196196
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
197-
197+
total_num_spec_tokens = 0
198198
# For logging.
199199
scheduled_timestamp = time.monotonic()
200200

@@ -282,6 +282,9 @@ def schedule(self) -> SchedulerOutput:
282282
preempted_req.status = RequestStatus.PREEMPTED
283283
preempted_req.num_computed_tokens = 0
284284
preempted_req.num_preemptions += 1
285+
# both sync and async scheduling don't use spec_token_ids
286+
# in waiting queue, so we can just clear it here.
287+
preempted_req.spec_token_ids.clear()
285288
if self.log_stats:
286289
preempted_req.record_event(
287290
EngineCoreEventType.PREEMPTED, scheduled_timestamp
@@ -307,13 +310,17 @@ def schedule(self) -> SchedulerOutput:
307310
# Speculative decode related.
308311
if request.spec_token_ids:
309312
num_scheduled_spec_tokens = (
310-
num_new_tokens + request.num_computed_tokens - request.num_tokens
313+
num_new_tokens
314+
+ request.num_computed_tokens
315+
- request.num_tokens
316+
- request.num_output_placeholders
311317
)
312318
if num_scheduled_spec_tokens > 0:
319+
total_num_spec_tokens += num_scheduled_spec_tokens
313320
# Trim spec_token_ids list to num_scheduled_spec_tokens.
314321
del request.spec_token_ids[num_scheduled_spec_tokens:]
315322
scheduled_spec_decode_tokens[request.request_id] = (
316-
request.spec_token_ids
323+
request.spec_token_ids.copy()
317324
)
318325

319326
# Encoder-related.
@@ -632,6 +639,7 @@ def schedule(self) -> SchedulerOutput:
632639
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
633640
structured_output_request_ids=structured_output_request_ids,
634641
grammar_bitmask=grammar_bitmask,
642+
total_num_scheduled_spec_tokens=total_num_spec_tokens,
635643
)
636644

637645
# NOTE(Kuntai): this function is designed for multiple purposes:
@@ -960,19 +968,11 @@ def update_from_output(
960968
scheduler_output.scheduled_spec_decode_tokens.get(req_id)
961969
)
962970
if scheduled_spec_token_ids:
963-
num_draft_tokens = len(scheduled_spec_token_ids)
964-
num_accepted = len(generated_token_ids) - 1
965-
num_rejected = num_draft_tokens - num_accepted
966-
# num_computed_tokens represents the number of tokens
967-
# processed in the current step, considering scheduled
968-
# tokens and rejections. If some tokens are rejected,
969-
# num_computed_tokens is decreased by the number of rejected
970-
# tokens.
971-
request.num_computed_tokens -= num_rejected
972-
spec_decoding_stats = self.make_spec_decoding_stats(
971+
spec_decoding_stats = self._update_computed_tokens(
972+
request,
973+
scheduled_spec_token_ids,
974+
generated_token_ids,
973975
spec_decoding_stats,
974-
num_draft_tokens=num_draft_tokens,
975-
num_accepted_tokens=num_accepted,
976976
)
977977

978978
stopped = False
@@ -1088,6 +1088,29 @@ def update_from_output(
10881088

10891089
return engine_core_outputs
10901090

1091+
def _update_computed_tokens(
1092+
self,
1093+
request: Request,
1094+
scheduled_spec_token_ids: list[int],
1095+
generated_token_ids: list[int],
1096+
spec_decoding_status: SpecDecodingStats | None,
1097+
):
1098+
num_draft_tokens = len(scheduled_spec_token_ids)
1099+
num_accepted = len(generated_token_ids) - 1
1100+
num_rejected = num_draft_tokens - num_accepted
1101+
# num_computed_tokens represents the number of tokens
1102+
# processed in the current step, considering scheduled
1103+
# tokens and rejections. If some tokens are rejected,
1104+
# num_computed_tokens is decreased by the number of rejected
1105+
# tokens.
1106+
request.num_computed_tokens -= num_rejected
1107+
spec_decoding_stats = self.make_spec_decoding_stats(
1108+
spec_decoding_status,
1109+
num_draft_tokens=num_draft_tokens,
1110+
num_accepted_tokens=num_accepted,
1111+
)
1112+
return spec_decoding_stats
1113+
10911114
def _update_request_with_output(
10921115
self,
10931116
request: Request,

vllm/v1/engine/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def __init__(
195195
self.step_fn = (
196196
self.step if self.batch_queue is None else self.step_with_batch_queue
197197
)
198+
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
198199

199200
def _initialize_kv_caches(
200201
self, vllm_config: VllmConfig
@@ -329,7 +330,10 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
329330
return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0)
330331

331332
def post_step(self, model_executed: bool) -> None:
332-
if self.use_spec_decode and model_executed:
333+
# when using async scheduling we can't get draft token ids in adavance,
334+
# so we update draft token ids in the worker process and don't
335+
# need to update draft token ids here.
336+
if self.use_spec_decode and model_executed and not self.async_scheduling:
333337
# Take the draft token ids.
334338
draft_token_ids = self.model_executor.take_draft_token_ids()
335339
if draft_token_ids is not None:

vllm/v1/worker/gpu_input_batch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class CachedRequestState:
4444

4545
lora_request: LoRARequest | None = None
4646
prompt_embeds: torch.Tensor | None = None
47+
# these are used when both async_scheduling and spec_decode are enabled.
48+
prev_num_draft_len: int = 0
49+
prev_sampled_tokens: torch.Tensor | None = None
50+
prev_draft_tokens: torch.Tensor | None = None
4751

4852
def __post_init__(self):
4953
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(

0 commit comments

Comments
 (0)