Skip to content

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Sep 17, 2025

Purpose

Support PIECEWISE cudagraphs with eagle head; in-between fix until #23679 can be refactored and landed. This should get use most of the performance of that with alot less complexity while the gpu model runner is refactored.

image

Test Plan

vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct -tp 4 --tokenizer-mode auto --speculative-config '{"method": "qwen3_next_mtp", "num_speculative_tokens": 2}' --port 3331 --max-num-batched-tokens 512 --gpu-memory-utilization 0.75

(vllm2) lwilkinson@H100-GPU17:~/code/vllm2$ python tests/evals/gsm8k/gsm8k_eval.py --port 3331
Running GSM8K evaluation: 1319 questions, 5-shot
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [03:48<00:00,  5.77it/s]

Results:
Accuracy: 0.835
Invalid responses: 0.000
Total latency: 228.584 s
Questions per second: 5.770

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Lucas Wilkinson <[email protected]>
@LucasWilkinson LucasWilkinson changed the title Lwilkinson/eagle piecewise [Spec-Decode] Support piecewise cudagraphs for Eagle head Sep 17, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Copy link

mergify bot commented Sep 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 23, 2025
@facebook-github-bot
Copy link

@zixi-qi has imported this pull request. If you are a Meta employee, you can view this in D83391264.

@mergify mergify bot removed the needs-rebase label Oct 1, 2025
Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked on an updated development branch as well as with the current branch, and it looks like the CUDA graphs aren't actually running for MTP.

I ran like this:

nsys launch --cuda-event-trace=false -t nvtx,cuda --trace-fork-before-exec=true --cuda-graph-trace=node vllm serve meta-llama/Llama-3.1-8B-Instruct --speculative-config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 4}' --max-model-len 2048 --max-num-seqs 128 --no-enable-prefix-caching --port 8049

in the nsys profile, the base model is running with piecewise graphs but the EAGLE head is not. I also checked with MTP on DSR1 and I observe the same issue there.

I did some light debugging and observed that the dummy run and the forward context both seem to be receiving the correct cudagraph mode, but for some reason it isn't being used.

Copy link

mergify bot commented Oct 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 6, 2025
@mergify mergify bot removed the needs-rebase label Oct 7, 2025
@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Oct 7, 2025

I checked on an updated development branch as well as with the current branch, and it looks like the CUDA graphs aren't actually running for MTP.

I ran like this:

nsys launch --cuda-event-trace=false -t nvtx,cuda --trace-fork-before-exec=true --cuda-graph-trace=node vllm serve meta-llama/Llama-3.1-8B-Instruct --speculative-config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 4}' --max-model-len 2048 --max-num-seqs 128 --no-enable-prefix-caching --port 8049

in the nsys profile, the base model is running with piecewise graphs but the EAGLE head is not. I also checked with MTP on DSR1 and I observe the same issue there.

I did some light debugging and observed that the dummy run and the forward context both seem to be receiving the correct cudagraph mode, but for some reason it isn't being used.

Hmmm I ran it again and got the same result (will share the trace over slack; too large for GitHub)
image

vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct -tp 4 --tokenizer-mode auto --speculative-config '{"method": "qwen3_next_mtp", "num_speculative_tokens": 2}' --port 3331 --max-num-batched-tokens 512 --gpu-memory-utilization 0.75
vllm bench serve --base-url http://0.0.0.0:3331 --model Qwen/Qwen3-Next-80B-A3B-Instruct --dataset-name random --random-in 100 --random-out 2 --num-prompts 128 --profile

Your profile isn't ending before the first client request right? I find sometimes nsys cuts off early

@mergify mergify bot added the llama Related to Llama models label Oct 8, 2025
@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Oct 8, 2025

I checked on an updated development branch as well as with the current branch, and it looks like the CUDA graphs aren't actually running for MTP.

I ran like this:

nsys launch --cuda-event-trace=false -t nvtx,cuda --trace-fork-before-exec=true --cuda-graph-trace=node vllm serve meta-llama/Llama-3.1-8B-Instruct --speculative-config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 4}' --max-model-len 2048 --max-num-seqs 128 --no-enable-prefix-caching --port 8049

in the nsys profile, the base model is running with piecewise graphs but the EAGLE head is not. I also checked with MTP on DSR1 and I observe the same issue there.

I did some light debugging and observed that the dummy run and the forward context both seem to be receiving the correct cudagraph mode, but for some reason it isn't being used.

Turns out this was unique to llama/deepseek: d78f30e / f79b9a9

Signed-off-by: Lucas Wilkinson <[email protected]>
@mergify mergify bot added the deepseek Related to DeepSeek models label Oct 8, 2025
Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@benchislett
Copy link
Collaborator

Found the PR that removed torch compile for llama_eagle3 in the first place. Unclear why it was done. I'm in favor of merging and then monitoring/expanding the tests to cover rare cases as needed.

#22872

@benchislett
Copy link
Collaborator

Looks like enabling torch compile for llama_eagle3 might not work well with multimodal. I guess we don't have a test for this?
https://github.com/vllm-project/vllm/pull/22872/files#r2366260173

@LucasWilkinson
Copy link
Collaborator Author

Found the PR that removed torch compile for llama_eagle3 in the first place. Unclear why it was done. I'm in favor of merging and then monitoring/expanding the tests to cover rare cases as needed.

Agreed

@benchislett
Copy link
Collaborator

Multimodal support patch looks good

@LucasWilkinson LucasWilkinson merged commit 29255cf into vllm-project:main Oct 10, 2025
54 checks passed
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…ct#25109)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Benjamin Chislett <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
…ct#25109)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Benjamin Chislett <[email protected]>
Signed-off-by: Dhruvil Bhatt <[email protected]>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
…ct#25109)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Benjamin Chislett <[email protected]>
Signed-off-by: bbartels <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants