-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Spec Decode] Efficient padded speculation #24539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an efficient implementation of padded speculative decoding for EAGLE, aiming to improve performance by moving token bookkeeping to the GPU and avoiding synchronization. The changes are substantial and primarily affect the gpu_model_runner
and eagle
proposer. While the overall approach and implementation of the GPU-side logic seem correct, I've identified a critical issue in vllm/v1/spec_decode/eagle.py
where unclamped token positions are used for calculating KV cache slot mappings. This could lead to out-of-bounds memory access and should be addressed.
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this! did a very quick pass; will do another later today or tomorrow. This is really great work! Overall this is looking good; on first blush I do think we should start thinking about the spec-decode abstractions and if we can move more out of the gpu model runner. Right now it feels like this line of input prep thats the responsibility of the purpose and whats the responsibility gpu model runner; not really the fault of this PR but feels like a natural time to start thinking about it. (feel free to push back havent had a chance to think that deeply about it yet)
It is unclear to me if this change needs to be opt-in and/or feature-flagged. Maintaining the existing pathway has a chance to net a few % speedup for some cases in the future, but increases the complexity and is not likely to be usable by any attention backends besides FlashAttention/TreeAttention. I am somewhat in favor of integrating this PR as-is and reverting if necessary, but I am open to adding a flag until we get a few more features in so that we can make a decision based on the fully-featured benchmarks.
I think we should maintain the existing pathways and make it feature flagged for now until we can test the performance implications on something like tree spec decode where the rejected tokens are not contingous and more numerous; cc @TheEpicDolphin @luccafong . I think it would probably make sense to turn it on by default (assuming we are relatively confident theres no/negligible perf regression) when not using tree spec decode though so we can start stressing the pathway.
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR is too complicated. I'm not sure if it pays off, given that the perf gain is limited.
Signed-off-by: Benjamin Chislett <[email protected]>
@WoosukKwon the goal of this PR is not to introduce a raw speedup. Instead, it introduces foundational logic required to enable speculative decoding with new backends and features. To list a few:
|
We could simplify the situation by just making it the only pathway (per @benchislett initial suggestion; i.e. completely remove unpadded speculation) if that makes it more palatable @WoosukKwon , I generally was in favor of landing it feature flagged since it allows:
however I can be convinced just committing to padded speculation always to simplify the code could make sense (especially in light of de-prioritizing tree attention) I do believe pretty strongly though that padded speculation is needed for the reasons @benchislett already mentioned. |
Re-requesting review from @WoosukKwon after an offline discussion. I understand if this is too much complexity to introduce without comprehensive review and careful refactoring, but I still feel that it is a powerful feature that we will need at some point. In the immediate term, we need to make a decision quickly to unblock FlashInfer support. If this PR is still decided to be too complex we can merge another approach that will get it working. Something like #24045 with https://github.com/vllm-project/vllm/pull/24539/files#diff-a4809a837fbf535a8f0999b11087a53ec1c53948b50c0a1fe64396bc86de9461R282-R338 would probably be sufficient to get it functional and compatible with piecewise graphs. |
can you expand on "Something like #24045"? Im not sure this addresses the issue of ending up with query lengths like |
disable_by_batch_size: Optional[int] = None | ||
"""Disable speculative decoding for new incoming requests when the number | ||
of enqueued requests is larger than this value, if provided.""" | ||
disable_padded_drafter_batch: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this flag necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to maintain both pathways, so that this PR's changes can be disabled if needed. Since it is not clear what the perf will look like once new backend support is added and piecewise cuda graphs are enabled for them. This flag makes it easy to compare the perf so we can easily decide if this change should be rolled back (or have the default swapped) in the future
def prepare_next_token_ids_cpu( | ||
self, sampled_token_ids: list[list[int]], | ||
requests: dict[str, | ||
CachedRequestState], gpu_input_batch: InputBatch, | ||
num_scheduled_tokens: dict[str, int]) -> torch.Tensor: | ||
""" | ||
This function is used to prepare the inputs for speculative decoding. | ||
It calculates the next token ids for each request based on the sampled | ||
token ids from the CPU. If a request has no sampled token ids (e.g., | ||
during the initial decoding steps), it falls back to using the request | ||
state to get the next token id. | ||
""" | ||
req_ids = gpu_input_batch.req_ids | ||
next_token_ids: list[int] = [] | ||
for i, token_ids in enumerate(sampled_token_ids): | ||
if token_ids: | ||
# Common case. | ||
next_token_id = token_ids[-1] | ||
else: | ||
# Partial prefill (rare case). | ||
# Get the next token id from the request state. | ||
req_id = req_ids[i] | ||
req_state = requests[req_id] | ||
seq_len = (req_state.num_computed_tokens + | ||
num_scheduled_tokens[req_id]) | ||
next_token_id = req_state.get_token_id(seq_len) | ||
next_token_ids.append(next_token_id) | ||
next_token_ids = torch.tensor(next_token_ids, | ||
dtype=torch.int32, | ||
device=self.input_ids.device) | ||
return next_token_ids | ||
|
||
def prepare_next_token_ids_padded(self, | ||
common_attn_metadata: CommonAttentionMetadata, | ||
sampled_token_ids: torch.Tensor, | ||
requests: dict[str, CachedRequestState], | ||
gpu_input_batch: InputBatch, | ||
discard_request_indices: torch.Tensor, | ||
num_discarded_requests: int) -> \ | ||
tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we write this in Triton, instead of Python + PyTorch? I think that will make things simpler and more efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the gpu pathway can be combined into a single kernel. That's a TODO item for me, see line 428 below. I did not include this as I thought it would add more complexity. It should be a fairly easy follow-up
…litPR into model_register * 'model_register' of https://github.com/dsxsteven/vllm_splitPR: (138 commits) Retrieve `sliding_window` from text config in Gemma3 MM (vllm-project#25085) [Docs] Fix API Reference (vllm-project#25140) [Kernel] Better inf handling for grouped topk cu (vllm-project#24886) [CLI] Use streaming in CLI chat and completion commands (vllm-project#23769) [benchmark] add peak throughput metrics and plot (vllm-project#23867) [Spec Decode] Efficient padded speculation (vllm-project#24539) [V0 Deprecation] Remove more V0 tests (vllm-project#25117) [EPLB] Add EPLB support for hunyuan_v1 (vllm-project#23078) [XPU] Whisper model support on XPU Platform (vllm-project#25123) Mark prompt logprobs as incompatible with prompt embeds at API level (vllm-project#25077) [Model] enable data parallel for InternVL vision encoder (vllm-project#23909) [Kernels] Overlap shared experts with combine instead of dispatch (vllm-project#24254) [Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models (vllm-project#24960) [Core][MM] Cleanup `MultiModalCache` (vllm-project#25006) [Docs] Clean up the contributing README (vllm-project#25099) [MM Encoder] Apply DP ViT for Qwen3-VL model series (vllm-project#24955) [Kernels] Enable DeepGEMM by default (vllm-project#24462) [V0 Deprecation] Skip PP test (vllm-project#25128) [V0 Deprecation] Remove misc V0 tests (vllm-project#25118) [V0 Deprecation] Remove V0 Tracing & Metrics tests (vllm-project#25115) ...
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: charlifu <[email protected]>
# Ignore the sampled token for partial prefills. | ||
# Rewind the generator state as if the token was not sampled. | ||
# This relies on cuda-specific torch-internal impl details |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@benchislett sorry was a bit late seeing this, but just wondering why these comments were removed. In particular the warning about cuda-specific torch internal details. These generator offsets aren't public APIs or applicable to all torch backend impls (e.g. CPU)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, my bad. I think I just swept them up with the other comments when I was refactoring. I can add these back if you'd like
Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Purpose
This PR implements padded speculative decoding as described in #21984. The implementation is the same as the one from #22684 (which itself came from #20078), with some cleanup and compatibility with latest main.
This PR does not add support for any new attention backends, such as #21984 or #24045. It will not take much work to enable for each new backend, but I think it is best to separate concerns into different PRs for this support. This PR represents the common core feature of padded speculation only.
When tested on Llama 3.1 8B-Instruct with EAGLE3 on 1xB200, this PR introduces a small 2% speedup on both TTFT and TPOT. Performance profiling confirmed that the implementation is free of GPU->CPU sync during the EAGLE drafting phase, but this performance gain is offset by the fact that CUDA graphs for the draft model are not working (see #23679) and that many small gpu operations are added in order to do sampled-token bookkeeping fully on the GPU.
Performance is expected to improve once new attention backends are supported, CUDA graph support for draft models is re-introduced, and a custom kernel is added to fuse the many small operations used in the bookkeeping (see #20078 for example).
The feature is now the default with opt-out via "disable_padded_batch" in the speculative_config. Performance is identical and behaviour is unchanged when disable_padded_batch is set.
Test Plan
Ran locally with the setup described above and saw consistent, reasonable outputs in both cases. No new functionality is introduced so existing testing should be sufficient.
To reproduce benchmarks, the server can be launched with this configuration:
disable_padded_drafter_batch
can be set totrue
to measure the baseline behaviour. When omitted, it will default to "false" (the new behaviour)