Skip to content

Commit a55d38f

Browse files
committed
support async_scheduling for spec-decode
Signed-off-by: Ronald1995 <[email protected]>
1 parent ba09652 commit a55d38f

File tree

9 files changed

+382
-91
lines changed

9 files changed

+382
-91
lines changed

vllm/config/speculative.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import ast
55
import hashlib
6-
from typing import TYPE_CHECKING, Any, Literal
6+
from typing import TYPE_CHECKING, Any, Literal, get_args
77

88
from pydantic import SkipValidation, model_validator
99
from pydantic.dataclasses import dataclass
@@ -30,28 +30,23 @@
3030

3131
logger = init_logger(__name__)
3232

33-
SpeculativeMethod = Literal[
34-
"ngram",
35-
"eagle",
36-
"eagle3",
37-
"medusa",
38-
"mlp_speculator",
39-
"draft_model",
40-
"deepseek_mtp",
41-
"ernie_mtp",
42-
"qwen3_next_mtp",
43-
"mimo_mtp",
44-
"longcat_flash_mtp",
45-
"mtp",
46-
]
47-
MTP_MODEL_TYPES = (
33+
MTPModelTypes = Literal[
4834
"deepseek_mtp",
4935
"mimo_mtp",
5036
"glm4_moe_mtp",
5137
"ernie_mtp",
5238
"qwen3_next_mtp",
5339
"longcat_flash_mtp",
54-
)
40+
"mtp",
41+
]
42+
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
43+
SpeculativeMethod = Literal[
44+
"ngram",
45+
"medusa",
46+
"mlp_speculator",
47+
"draft_model",
48+
EagleModelTypes,
49+
]
5550

5651

5752
@config
@@ -224,7 +219,7 @@ def __post_init__(self):
224219
# can not be detected, it will be considered as the "draft_model" by
225220
# default.
226221

227-
if self.method in MTP_MODEL_TYPES:
222+
if self.method in get_args(MTPModelTypes):
228223
logger.warning(
229224
"method `%s` is deprecated and replaced with mtp.", self.method
230225
)
@@ -338,7 +333,9 @@ def __post_init__(self):
338333
self.method = "medusa"
339334
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
340335
self.method = "mlp_speculator"
341-
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
336+
elif self.draft_model_config.hf_config.model_type in get_args(
337+
MTPModelTypes
338+
):
342339
self.method = "mtp"
343340
if self.num_speculative_tokens > 1:
344341
logger.warning(

vllm/engine/arg_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from vllm.config.observability import DetailedTraceModules
7070
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
7171
from vllm.config.scheduler import SchedulerPolicy
72+
from vllm.config.speculative import EagleModelTypes
7273
from vllm.config.utils import get_field
7374
from vllm.logger import init_logger
7475
from vllm.platforms import CpuArchEnum, current_platform
@@ -1465,12 +1466,17 @@ def create_engine_config(
14651466
"Async scheduling is not supported with pipeline-parallel-size > 1."
14661467
)
14671468

1468-
# Currently, async scheduling does not support speculative decoding.
1469-
# TODO(woosuk): Support it.
1470-
if self.speculative_config is not None:
1469+
# Currently, async scheduling only support eagle speculative
1470+
# decoding.
1471+
# TODO(woosuk): Support other kinds of speculative decoding.
1472+
if self.speculative_config is not None and (
1473+
self.speculative_config.get("method") not in get_args(EagleModelTypes)
1474+
or self.speculative_config.get("disable_padded_drafter_batch")
1475+
):
14711476
raise ValueError(
1472-
"Currently, speculative decoding is not supported with "
1473-
"async scheduling."
1477+
"Currently, async scheduling is only supported "
1478+
"with EAGLE/MTP kind of speculative decodeing and "
1479+
"disable_padded_drafter_batch must to be false."
14741480
)
14751481

14761482
# Forward the deprecated CLI args to the EPLB config.

vllm/v1/core/sched/async_scheduler.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from vllm.v1.core.sched.output import SchedulerOutput
66
from vllm.v1.core.sched.scheduler import Scheduler
77
from vllm.v1.request import Request, RequestStatus
8+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
89

910
logger = init_logger(__name__)
1011

@@ -15,15 +16,24 @@ def _update_after_schedule(
1516
scheduler_output: SchedulerOutput,
1617
) -> None:
1718
super()._update_after_schedule(scheduler_output)
19+
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
1820
for req_id in scheduler_output.num_scheduled_tokens:
1921
request = self.requests[req_id]
22+
spec_tokens = len(spec_decode_tokens.get(req_id, []))
2023
if (
2124
request.num_computed_tokens
22-
== request.num_tokens + request.num_output_placeholders
25+
== request.num_tokens + request.num_output_placeholders + spec_tokens
2326
):
24-
# The request will generate a new token in this scheduling step.
25-
# TODO(woosuk): Support speculative decoding.
26-
request.num_output_placeholders += 1
27+
# The request will generate a new token plus num_spec_tokens
28+
# in this scheduling step.
29+
request.num_output_placeholders += 1 + spec_tokens
30+
# Add a placeholder for the new token in spec_token_ids.
31+
# because the actual token id is not known yet. so just use -1
32+
# as a placeholder and the length of spec_token_ids is set to
33+
# self.num_spec_tokens. we will update the actual spec token id
34+
# in worker process.
35+
if self.num_spec_tokens > 0:
36+
request.spec_token_ids = [-1] * self.num_spec_tokens
2737

2838
def _update_request_with_output(
2939
self,
@@ -45,3 +55,37 @@ def _update_request_with_output(
4555
request, request.num_computed_tokens - request.num_output_placeholders
4656
)
4757
return new_token_ids, stopped
58+
59+
def _update_computed_tokens(
60+
self,
61+
request: Request,
62+
scheduled_spec_token_ids: list[int],
63+
generated_token_ids: list[int],
64+
spec_decoding_stats: SpecDecodingStats | None,
65+
):
66+
"""Update the computed tokens for each request, which is necessary
67+
for spec decoding. In sync scheduler, we need to revert
68+
num_computed_tokens by num_rejected tokens,
69+
but in async scheduler, we also need to revert num_output_placeholders
70+
by num_rejected tokens for spec decoding.
71+
"""
72+
num_draft_tokens = len(scheduled_spec_token_ids)
73+
num_accepted = len(generated_token_ids) - 1
74+
num_rejected = num_draft_tokens - num_accepted
75+
# when spec decoding is enabled, num_output_placeholders
76+
# is increased by num_spec_tokens in _update_after_schedule.
77+
# update num_output_placeholders here to reflect the actual number
78+
# of accepted output tokens.
79+
request.num_output_placeholders -= num_rejected
80+
# num_computed_tokens represents the number of tokens
81+
# processed in the current step, considering scheduled
82+
# tokens and rejections. If some tokens are rejected,
83+
# num_computed_tokens is decreased by the number of rejected
84+
# tokens.
85+
request.num_computed_tokens -= num_rejected
86+
spec_decoding_stats = self.make_spec_decoding_stats(
87+
spec_decoding_stats,
88+
num_draft_tokens=num_draft_tokens,
89+
num_accepted_tokens=num_accepted,
90+
)
91+
return spec_decoding_stats

vllm/v1/core/sched/output.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,8 @@ class SchedulerOutput:
174174

175175
# KV Cache Connector metadata.
176176
kv_connector_metadata: KVConnectorMetadata | None = None
177+
178+
# Total number of speculative scheduled tokens for all requests.
179+
# this is needed when using async_scheduling and speculative
180+
# togather.
181+
total_num_scheduled_spec_tokens: int = 0

vllm/v1/core/sched/scheduler.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def schedule(self) -> SchedulerOutput:
198198
encoder_compute_budget = self.max_num_encoder_input_tokens
199199
# Spec decode-related.
200200
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
201-
201+
total_num_spec_tokens = 0
202202
# For logging.
203203
scheduled_timestamp = time.monotonic()
204204

@@ -286,6 +286,9 @@ def schedule(self) -> SchedulerOutput:
286286
preempted_req.status = RequestStatus.PREEMPTED
287287
preempted_req.num_computed_tokens = 0
288288
preempted_req.num_preemptions += 1
289+
# both sync and async scheduling don't use spec_token_ids
290+
# in waiting queue, so we can just clear it here.
291+
preempted_req.spec_token_ids.clear()
289292
if self.log_stats:
290293
preempted_req.record_event(
291294
EngineCoreEventType.PREEMPTED, scheduled_timestamp
@@ -311,13 +314,17 @@ def schedule(self) -> SchedulerOutput:
311314
# Speculative decode related.
312315
if request.spec_token_ids:
313316
num_scheduled_spec_tokens = (
314-
num_new_tokens + request.num_computed_tokens - request.num_tokens
317+
num_new_tokens
318+
+ request.num_computed_tokens
319+
- request.num_tokens
320+
- request.num_output_placeholders
315321
)
316322
if num_scheduled_spec_tokens > 0:
323+
total_num_spec_tokens += num_scheduled_spec_tokens
317324
# Trim spec_token_ids list to num_scheduled_spec_tokens.
318325
del request.spec_token_ids[num_scheduled_spec_tokens:]
319326
scheduled_spec_decode_tokens[request.request_id] = (
320-
request.spec_token_ids
327+
request.spec_token_ids.copy()
321328
)
322329

323330
# Encoder-related.
@@ -631,6 +638,7 @@ def schedule(self) -> SchedulerOutput:
631638
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
632639
structured_output_request_ids=structured_output_request_ids,
633640
grammar_bitmask=grammar_bitmask,
641+
total_num_scheduled_spec_tokens=total_num_spec_tokens,
634642
)
635643

636644
# NOTE(Kuntai): this function is designed for multiple purposes:
@@ -959,19 +967,11 @@ def update_from_output(
959967
scheduler_output.scheduled_spec_decode_tokens.get(req_id)
960968
)
961969
if scheduled_spec_token_ids:
962-
num_draft_tokens = len(scheduled_spec_token_ids)
963-
num_accepted = len(generated_token_ids) - 1
964-
num_rejected = num_draft_tokens - num_accepted
965-
# num_computed_tokens represents the number of tokens
966-
# processed in the current step, considering scheduled
967-
# tokens and rejections. If some tokens are rejected,
968-
# num_computed_tokens is decreased by the number of rejected
969-
# tokens.
970-
request.num_computed_tokens -= num_rejected
971-
spec_decoding_stats = self.make_spec_decoding_stats(
970+
spec_decoding_stats = self._update_computed_tokens(
971+
request,
972+
scheduled_spec_token_ids,
973+
generated_token_ids,
972974
spec_decoding_stats,
973-
num_draft_tokens=num_draft_tokens,
974-
num_accepted_tokens=num_accepted,
975975
)
976976

977977
stopped = False
@@ -1085,6 +1085,29 @@ def update_from_output(
10851085

10861086
return engine_core_outputs
10871087

1088+
def _update_computed_tokens(
1089+
self,
1090+
request: Request,
1091+
scheduled_spec_token_ids: list[int],
1092+
generated_token_ids: list[int],
1093+
spec_decoding_status: SpecDecodingStats | None,
1094+
):
1095+
num_draft_tokens = len(scheduled_spec_token_ids)
1096+
num_accepted = len(generated_token_ids) - 1
1097+
num_rejected = num_draft_tokens - num_accepted
1098+
# num_computed_tokens represents the number of tokens
1099+
# processed in the current step, considering scheduled
1100+
# tokens and rejections. If some tokens are rejected,
1101+
# num_computed_tokens is decreased by the number of rejected
1102+
# tokens.
1103+
request.num_computed_tokens -= num_rejected
1104+
spec_decoding_stats = self.make_spec_decoding_stats(
1105+
spec_decoding_status,
1106+
num_draft_tokens=num_draft_tokens,
1107+
num_accepted_tokens=num_accepted,
1108+
)
1109+
return spec_decoding_stats
1110+
10881111
def _update_request_with_output(
10891112
self,
10901113
request: Request,

vllm/v1/engine/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def __init__(
198198
self.step_fn = (
199199
self.step if self.batch_queue is None else self.step_with_batch_queue
200200
)
201+
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
201202

202203
def _initialize_kv_caches(
203204
self, vllm_config: VllmConfig
@@ -330,7 +331,10 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
330331
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
331332

332333
def post_step(self, model_executed: bool) -> None:
333-
if self.use_spec_decode and model_executed:
334+
# when using async scheduling we can't get draft token ids in adavance,
335+
# so we update draft token ids in the worker process and don't
336+
# need to update draft token ids here.
337+
if self.use_spec_decode and model_executed and not self.async_scheduling:
334338
# Take the draft token ids.
335339
draft_token_ids = self.model_executor.take_draft_token_ids()
336340
if draft_token_ids is not None:

vllm/v1/spec_decode/eagle.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(
185185
device=device,
186186
dtype=torch.int32,
187187
).repeat(max_batch_size, 1)
188+
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
188189

189190
def _get_positions(self, num_tokens: int):
190191
if self.uses_mrope:
@@ -387,14 +388,27 @@ def propose(
387388
positions += 1
388389
exceeds_max_model_len = positions >= self.max_model_len
389390
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
391+
# when enable use_async_scheduling, we shouldn't use in place
392+
# operations in case they are modified in next step `prepare_input`
393+
# of main model.
394+
if self.use_async_scheduling:
395+
# Increment the sequence lengths.
396+
common_attn_metadata.seq_lens = common_attn_metadata.seq_lens + 1
397+
common_attn_metadata.seq_lens_cpu = (
398+
common_attn_metadata.seq_lens_cpu + 1
399+
)
400+
# For the requests that exceed the max model length, we set the
401+
# sequence length to 1 to minimize their overheads in attention.
390402

391-
# Increment the sequence lengths.
392-
common_attn_metadata.seq_lens += 1
393-
common_attn_metadata.seq_lens_cpu += 1
394-
# For the requests that exceed the max model length, we set the
395-
# sequence length to 1 to minimize their overheads in attention.
403+
common_attn_metadata.seq_lens.masked_fill(exceeds_max_model_len, 1)
404+
else:
405+
# Increment the sequence lengths.
406+
common_attn_metadata.seq_lens += 1
407+
common_attn_metadata.seq_lens_cpu += 1
408+
# For the requests that exceed the max model length, we set the
409+
# sequence length to 1 to minimize their overheads in attention.
396410

397-
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
411+
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
398412

399413
common_attn_metadata.num_computed_tokens_cpu = (
400414
common_attn_metadata.seq_lens_cpu - 1

vllm/v1/worker/gpu_input_batch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class CachedRequestState:
4545

4646
lora_request: LoRARequest | None = None
4747
prompt_embeds: torch.Tensor | None = None
48+
# these are used when both async_scheduling and spec_decode are enabled.
49+
prev_num_draft_len: int = 0
50+
prev_sampled_tokens: torch.Tensor | None = None
51+
prev_draft_tokens: torch.Tensor | None = None
4852

4953
def __post_init__(self):
5054
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(

0 commit comments

Comments
 (0)