Skip to content

Conversation

benchislett
Copy link
Collaborator

@benchislett benchislett commented Sep 9, 2025

Purpose

This PR implements padded speculative decoding as described in #21984. The implementation is the same as the one from #22684 (which itself came from #20078), with some cleanup and compatibility with latest main.

This PR does not add support for any new attention backends, such as #21984 or #24045. It will not take much work to enable for each new backend, but I think it is best to separate concerns into different PRs for this support. This PR represents the common core feature of padded speculation only.

When tested on Llama 3.1 8B-Instruct with EAGLE3 on 1xB200, this PR introduces a small 2% speedup on both TTFT and TPOT. Performance profiling confirmed that the implementation is free of GPU->CPU sync during the EAGLE drafting phase, but this performance gain is offset by the fact that CUDA graphs for the draft model are not working (see #23679) and that many small gpu operations are added in order to do sampled-token bookkeeping fully on the GPU.

Performance is expected to improve once new attention backends are supported, CUDA graph support for draft models is re-introduced, and a custom kernel is added to fuse the many small operations used in the bookkeeping (see #20078 for example).

The feature is now the default with opt-out via "disable_padded_batch" in the speculative_config. Performance is identical and behaviour is unchanged when disable_padded_batch is set.

Test Plan

Ran locally with the setup described above and saw consistent, reasonable outputs in both cases. No new functionality is introduced so existing testing should be sufficient.

To reproduce benchmarks, the server can be launched with this configuration:

VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve meta-llama/Llama-3.1-8B-Instruct --speculative-config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 4, "disable_padded_drafter_batch": false}' --max-model-len 2048 --no-enable-prefix-caching
  • disable_padded_drafter_batch can be set to true to measure the baseline behaviour. When omitted, it will default to "false" (the new behaviour)

Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an efficient implementation of padded speculative decoding for EAGLE, aiming to improve performance by moving token bookkeeping to the GPU and avoiding synchronization. The changes are substantial and primarily affect the gpu_model_runner and eagle proposer. While the overall approach and implementation of the GPU-side logic seem correct, I've identified a critical issue in vllm/v1/spec_decode/eagle.py where unclamped token positions are used for calculating KV cache slot mappings. This could lead to out-of-bounds memory access and should be addressed.

Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this! did a very quick pass; will do another later today or tomorrow. This is really great work! Overall this is looking good; on first blush I do think we should start thinking about the spec-decode abstractions and if we can move more out of the gpu model runner. Right now it feels like this line of input prep thats the responsibility of the purpose and whats the responsibility gpu model runner; not really the fault of this PR but feels like a natural time to start thinking about it. (feel free to push back havent had a chance to think that deeply about it yet)

It is unclear to me if this change needs to be opt-in and/or feature-flagged. Maintaining the existing pathway has a chance to net a few % speedup for some cases in the future, but increases the complexity and is not likely to be usable by any attention backends besides FlashAttention/TreeAttention. I am somewhat in favor of integrating this PR as-is and reverting if necessary, but I am open to adding a flag until we get a few more features in so that we can make a decision based on the fully-featured benchmarks.

I think we should maintain the existing pathways and make it feature flagged for now until we can test the performance implications on something like tree spec decode where the rejected tokens are not contingous and more numerous; cc @TheEpicDolphin @luccafong . I think it would probably make sense to turn it on by default (assuming we are relatively confident theres no/negligible perf regression) when not using tree spec decode though so we can start stressing the pathway.

Copy link

mergify bot commented Sep 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 16, 2025
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR is too complicated. I'm not sure if it pays off, given that the perf gain is limited.

Signed-off-by: Benjamin Chislett <[email protected]>
@benchislett
Copy link
Collaborator Author

@WoosukKwon the goal of this PR is not to introduce a raw speedup. Instead, it introduces foundational logic required to enable speculative decoding with new backends and features. To list a few:

  • FlashInfer and FlashInfer-MLA (among many other) backends have optimized decode kernels (trtllm-gen) with special support for speculative decoding, but they require the input tensors to have regular shape (num_reqs, query_len, ...). This means that padding the batch is the only way to support these kernels. While it is possible to do the padding directly in the attention op, this approach is limited and does not support the other features below
  • Padded batching makes overlapped execution and full-cuda graph support much easier to implement. See [Core] Async Scheduling X Spec Decoding Compatibility #24799 and [Spec-decode] Refoctor cudagraphs for spec-decode;support uniform_alignment of cudagraph sizes. #23679 which both rely on this feature. Eliminating the cpu-gpu synchronization allows hiding a lot of metadata preparation work, which cannot be properly benchmarked at this time since currently there is only support for FlashAttention.

@LucasWilkinson
Copy link
Collaborator

LucasWilkinson commented Sep 17, 2025

We could simplify the situation by just making it the only pathway (per @benchislett initial suggestion; i.e. completely remove unpadded speculation) if that makes it more palatable @WoosukKwon , I generally was in favor of landing it feature flagged since it allows:

  1. easier debugging if users run into issues (can ask them to set --disable-padded-drafter-batch if issues crop up to see if its related to the padding)
  2. allows for easier benchmark comparisons to see how this impacts less common cases (i.e. cases with higher num-speculated tokens and thus likely higher number of rejected tokens)
    • however the speculator generally doesn't take that long compared to the main model so I can definitely see justifications for just supporting one pathway to simplify the code given the differences are likely only to be at most a couple of percent
  3. leaves a pathway for the tree attention implementation but since we are now de-priortizing tree attention; we don't need to take this into account as heavily (I think if we end up prioritizing tree attention in the future there's pathway to making compatible with padded speculation)

however I can be convinced just committing to padded speculation always to simplify the code could make sense (especially in light of de-prioritizing tree attention)

I do believe pretty strongly though that padded speculation is needed for the reasons @benchislett already mentioned.

@benchislett
Copy link
Collaborator Author

Re-requesting review from @WoosukKwon after an offline discussion. I understand if this is too much complexity to introduce without comprehensive review and careful refactoring, but I still feel that it is a powerful feature that we will need at some point.

In the immediate term, we need to make a decision quickly to unblock FlashInfer support. If this PR is still decided to be too complex we can merge another approach that will get it working. Something like #24045 with https://github.com/vllm-project/vllm/pull/24539/files#diff-a4809a837fbf535a8f0999b11087a53ec1c53948b50c0a1fe64396bc86de9461R282-R338 would probably be sufficient to get it functional and compatible with piecewise graphs.

@LucasWilkinson
Copy link
Collaborator

can you expand on "Something like #24045"? Im not sure this addresses the issue of ending up with query lengths like [1, 2, 2, 1] after rejections which will force FlashMLA to run as [Decode, Prefill, Prefill, Prefill] which will be quite a bit slower for that specific MTP layer; granted relative to the full model that may not be a concern

disable_by_batch_size: Optional[int] = None
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
disable_padded_drafter_batch: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this flag necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to maintain both pathways, so that this PR's changes can be disabled if needed. Since it is not clear what the perf will look like once new backend support is added and piecewise cuda graphs are enabled for them. This flag makes it easy to compare the perf so we can easily decide if this change should be rolled back (or have the default swapped) in the future

Comment on lines +379 to +418
def prepare_next_token_ids_cpu(
self, sampled_token_ids: list[list[int]],
requests: dict[str,
CachedRequestState], gpu_input_batch: InputBatch,
num_scheduled_tokens: dict[str, int]) -> torch.Tensor:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids for each request based on the sampled
token ids from the CPU. If a request has no sampled token ids (e.g.,
during the initial decoding steps), it falls back to using the request
state to get the next token id.
"""
req_ids = gpu_input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = req_ids[i]
req_state = requests[req_id]
seq_len = (req_state.num_computed_tokens +
num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.input_ids.device)
return next_token_ids

def prepare_next_token_ids_padded(self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int) -> \
tuple[torch.Tensor, torch.Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we write this in Triton, instead of Python + PyTorch? I think that will make things simpler and more efficient.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the gpu pathway can be combined into a single kernel. That's a TODO item for me, see line 428 below. I did not include this as I thought it would add more complexity. It should be a fairly easy follow-up

@benchislett benchislett merged commit b7433ca into vllm-project:main Sep 18, 2025
41 checks passed
845473182 pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 18, 2025
…litPR into model_register

* 'model_register' of https://github.com/dsxsteven/vllm_splitPR: (138 commits)
  Retrieve `sliding_window` from text config in Gemma3 MM (vllm-project#25085)
  [Docs] Fix API Reference (vllm-project#25140)
  [Kernel] Better inf handling for grouped topk cu (vllm-project#24886)
  [CLI] Use streaming in CLI chat and completion commands (vllm-project#23769)
  [benchmark] add peak throughput metrics and plot (vllm-project#23867)
  [Spec Decode] Efficient padded speculation (vllm-project#24539)
  [V0 Deprecation] Remove more V0 tests (vllm-project#25117)
  [EPLB] Add EPLB support for hunyuan_v1 (vllm-project#23078)
  [XPU] Whisper model support on XPU Platform (vllm-project#25123)
  Mark prompt logprobs as incompatible with prompt embeds at API level (vllm-project#25077)
  [Model] enable data parallel for InternVL vision encoder (vllm-project#23909)
  [Kernels] Overlap shared experts with combine instead of dispatch (vllm-project#24254)
  [Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models (vllm-project#24960)
  [Core][MM] Cleanup `MultiModalCache` (vllm-project#25006)
  [Docs] Clean up the contributing README (vllm-project#25099)
  [MM Encoder] Apply DP ViT for Qwen3-VL model series (vllm-project#24955)
  [Kernels] Enable DeepGEMM by default (vllm-project#24462)
  [V0 Deprecation] Skip PP test (vllm-project#25128)
  [V0 Deprecation] Remove misc V0 tests (vllm-project#25118)
  [V0 Deprecation] Remove V0 Tracing & Metrics tests (vllm-project#25115)
  ...
debroy-rh pushed a commit to debroy-rh/vllm that referenced this pull request Sep 19, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
Comment on lines -1986 to -1988
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benchislett sorry was a bit late seeing this, but just wondering why these comments were removed. In particular the warning about cuda-specific torch internal details. These generator offsets aren't public APIs or applicable to all torch backend impls (e.g. CPU)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, my bad. I think I just swept them up with the other comments when I was refactoring. I can add these back if you'd like

xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants