Skip to content

Conversation

BoyuanFeng
Copy link
Contributor

@BoyuanFeng BoyuanFeng commented Sep 4, 2025

Depends on pytorch/pytorch#162207 [landed and available in PyTorch 2.9].

Command:

vllm bench latency -O.cudagraph_mode=PIECEWISE -O.use_inductor_graph_partition=true

Tests

Test1

test_simple_inductor_graph_partition checks that, when use_inductor_graph_partition=True, we have 1 num_piecewise_graphs_seen and 1 num_backend_compilations. By contrast, test_simple_piecewise_compile checks that, when use_inductor_graph_partition=True, we have 5 num_piecewise_graphs_seen and 3 num_backend_compilations since we are splitting at fx graph level. For both tests, we assert that num_cudagraph_captured=6.

Test2

test_custom_compile_config checks that use_inductor_graph_partition=True, level=CompilationLevel.PIECEWISE, cudagraph_mode=CUDAGraphMode.PIECEWISE work together.

Test3

test_attention_quant_pattern checks that attention+FP8Quant fusion happens when use_inductor_graph_partition=True.

Benchmark

Model: meta-llama/Meta-Llama-3.1-8B
Hardware: B200

With vllm Piecewise CUDAGraph backend:
image

With inductor graph partition:
image

TTFT is 0.4% faster, TPOT is 2% slower.

Start Time

image image

Support Attention Fusion

trace w/o attn fusion. We can see [prior cudagraph'ed kernels] -> vllm::scaled_fp8_quant_kernel_strided -> void vllm::reshape_and_cache_flash_kernel -> fmhaSm100Kernel -> [(next cudagraph'ed kernels) vllm::scaled_fp8_quant_kernel_strided -> _ZN7cutlass13device_kernelINS_4gemm6kernel -> ...].
image

trace w/ attn fusion. We can see [prior cudagraph'ed kernels] -> vllm::scaled_fp8_quant_kernel -> vllm::reshape_and_cache_flash_kernel -> fmhaSm100Kernel -> [(next cudagraph'ed kernels) _ZN7cutlass13device_kernelINS_4gemm6kernel...]. Note that the second vllm::scaled_fp8_quant_kernel_strided is moved from cudagraph'ed region into fmhaSm100Kernel.
image

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces CUDAGraph partitioning by integrating custom wrappers with the torch inductor compiler. The core logic is added in vllm/compilation/backends.py. The review identifies a critical debugging statement (breakpoint()) that must be removed, as it will halt execution. Additionally, there are several large blocks of commented-out code and unused variables that should be cleaned up to improve code readability and maintainability.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a flag to config that enables this. It can be experimental for now but we should add good documentation because it will likely stick around (even if on by default) because other platforms will reuse the existing cudagraph wrapper mechanism and piecewise spliting after dynamo!

@BoyuanFeng BoyuanFeng force-pushed the bf/cg-partition branch 2 times, most recently from f5af8f6 to 1397e35 Compare September 5, 2025 05:33
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Sep 6, 2025
This PR adds an interface to allow users to specify custom cudagraph wrapper. User example: [vllm](vllm-project/vllm#24281)

Pull Request resolved: #162207
Approved by: https://github.com/zou3519, https://github.com/eellison, https://github.com/ProExpertProg
daisyden pushed a commit to daisyden/pytorch that referenced this pull request Sep 8, 2025
This PR adds an interface to allow users to specify custom cudagraph wrapper. User example: [vllm](vllm-project/vllm#24281)

Pull Request resolved: pytorch#162207
Approved by: https://github.com/zou3519, https://github.com/eellison, https://github.com/ProExpertProg
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Nice work. I left a few comments mostly about comments and asserts. We should improve the documentation around this, If you want to accept some of my suggestions directly that would also help me be added as a coauthor :D.

We should add tests: one that just extends the current piecewise cudagraph tests, and one that also tests that attention fusion happened with this splitting method (and that it's not broken). We should also check the performance of attention fusion after this PR. Let me know if you need help with what commands to run.

Additionally, it might be nice to be able to pass a list of "splitting ops" to Inductor during compilation (as opposed to op declaration time). If we want to decide whether to exclude attention or fused_moe (or neither or both), we can't depend on torch._C.Tag.cudagraph_unsafe because that's a static property of the op. I guess for now we can depend on the old (current) splitting pathway but it might be nice in the future to use a list inside config.

@ProExpertProg ProExpertProg enabled auto-merge (squash) September 19, 2025 20:35
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 19, 2025
@ProExpertProg ProExpertProg changed the title CUDAGraph partition integration [torch.compile] CUDAGraph Inductor partition integration Sep 19, 2025
auto-merge was automatically disabled September 19, 2025 22:35

Head branch was pushed to by a user without write access

@ProExpertProg ProExpertProg enabled auto-merge (squash) September 19, 2025 22:39
@ProExpertProg ProExpertProg merged commit 8945b00 into vllm-project:main Sep 20, 2025
44 checks passed
@fhl2000 fhl2000 mentioned this pull request Sep 20, 2025
5 tasks
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
This PR adds an interface to allow users to specify custom cudagraph wrapper. User example: [vllm](vllm-project/vllm#24281)

Pull Request resolved: pytorch#162207
Approved by: https://github.com/zou3519, https://github.com/eellison, https://github.com/ProExpertProg
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
This PR adds an interface to allow users to specify custom cudagraph wrapper. User example: [vllm](vllm-project/vllm#24281)

Pull Request resolved: pytorch#162207
Approved by: https://github.com/zou3519, https://github.com/eellison, https://github.com/ProExpertProg
@ProExpertProg ProExpertProg moved this from To triage to Done in torch.compile integration Sep 24, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…t#24281)

Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: boyuanfeng <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…t#24281)

Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: boyuanfeng <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: charlifu <[email protected]>
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
This PR adds an interface to allow users to specify custom cudagraph wrapper. User example: [vllm](vllm-project/vllm#24281)

Pull Request resolved: pytorch#162207
Approved by: https://github.com/zou3519, https://github.com/eellison, https://github.com/ProExpertProg
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: boyuanfeng <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…t#24281)

Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: boyuanfeng <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
…t#24281)

Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: boyuanfeng <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm torch.compile v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[RFC]: Address piecewise graph splitting and attention fusion incompatibility

5 participants