Skip to content

Conversation

cascade812
Copy link
Contributor

@cascade812 cascade812 commented Jul 16, 2025

Previously it has to specify compile_sizes in CompilationConfig to enable sequence parallelism.
This PR is to remove this limitation for full cuda graph compilation.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@cascade812 cascade812 changed the title enable sequence parallelism for full cuda graph without specifying compile sizes Enable sequence parallelism for full cuda graph without specifying compile sizes Jul 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables sequence parallelism for full cuda graph compilation without specifying compile sizes. The type hint for splitting_ops is incorrect and has been corrected to list[str] to match the actual data type.

Signed-off-by: cascade812 <[email protected]>
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the number of tokens cannot divide the tp size?

@youkaichao youkaichao requested a review from ProExpertProg July 16, 2025 09:04
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to add some documentation for this behavior, and get some benchmarking results! And should we pad to multiples of tp size?

@cascade812
Copy link
Contributor Author

what if the number of tokens cannot divide the tp size?

When sequence parallelism is enabled, we always pad num_tokens to be a multiple of tensor_parallel_size in gpu_model_runner.py

Signed-off-by: cascade812 <[email protected]>
@cascade812
Copy link
Contributor Author

I think it would be good to add some documentation for this behavior, and get some benchmarking results! And should we pad to multiples of tp size?

Added detailed comment. And it already pads to multiples of tp size for full graph.
For benchmark, sequence parallelism along doesn't yield much performance improvement, it mainly lays the groundwork for subsequent fusion passes, such as GEMM + ReduceScatter and AllGather + GEMM fusions. Benchmark for those fusions have been provided in asynctp PRs.

@ProExpertProg
Copy link
Collaborator

Yes but this improves the speedup of async TP because it now also works for shapes that weren't explicitly compiled. Could you do an end-to-end serving benchmark comparing async TP off (main), async TP on (main), and async TP on (this PR).

@zou3519
Copy link
Collaborator

zou3519 commented Jul 23, 2025

I'm deferring to @ProExpertProg and @youkaichao on this

@zou3519 zou3519 removed their request for review July 23, 2025 00:19
@cascade812
Copy link
Contributor Author

cascade812 commented Aug 10, 2025

@zou3519 I'm encountering an out-of-memory error for the KV cache when benchmarking the LLaMA 70B model on H100x4 with full_cuda_graph=True and enable_sequence_parallelism=True for this PR.
The activation memory usage is super high ~44G (sp pass enabled) vs ~5G (sp pass disabled) during profiling.

I dumped a memory usage snapshot for profiling (see below screenshot), and found it's due to the improper memory release of buf290. This buffer is supposed to be released after its only usage at the line buf294 = torch.ops.vllm.all_gather.default(buf290, 0, 2, 'tp:0') . However, it isn’t freed until the very end of execution.
This seems to be a bug of torch compile. Do you have any insight for this issue?

buf289 = torch.ops.vllm.reduce_scatter.default(buf288, 0, 2, 'tp:0')
 del buf288
 buf290 = buf289
 assert_size_stride(buf290, (s0 // 2, 4096), (4096, 1))
 del buf289
 # Topologically Sorted Source Nodes: [], Original ATen: []
 torch.ops._C.fused_add_rms_norm.default(input=buf290, residual=buf2, weight=arg77_1, epsilon=1e-05)
 del arg77_1
 # Topologically Sorted Source Nodes: [], Original ATen: []
 buf294 = torch.ops.vllm.all_gather.default(buf290, 0, 2, 'tp:0')
 ...until the very end...
 del buf290
snapshot

@zou3519
Copy link
Collaborator

zou3519 commented Aug 11, 2025

@cascade812 are you able to send me a tlparse of this? This will generate a html page with all of the torch.compile logs that we can stare at. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs).

Otherwise I'll try to repro this on a machine.

The thing I am curious about is: are we sure that buf290 isn't used again, and what is between buf294 = torch.ops.vllm.all_gather.default(buf290, 0, 2, 'tp:0') and the del buf290

@zou3519
Copy link
Collaborator

zou3519 commented Aug 13, 2025

btw, @BoyuanFeng @eellison any initial thoughts here? Inductor codegen seems to not be deleting a buffer after its final use. I can get a tlparse later.

@BoyuanFeng
Copy link
Contributor

is buf294 a slice of buf290? If buf294 (or some other view/slice) is used later, buf290 cannot be freed.

A tlparse or generated output code (via TORCH_LOGS=output_code) would be helpful to investigate

@cascade812
Copy link
Contributor Author

is buf294 a slice of buf290? If buf294 (or some other view/slice) is used later, buf290 cannot be freed.

A tlparse or generated output code (via TORCH_LOGS=output_code) would be helpful to investigate

No, buf294 is not a slice of buf290. The all_gather operation allocates a new memory space for the output.

@BoyuanFeng
Copy link
Contributor

what is the command to repro? I can check the memory issue. Thanks!

@cascade812
Copy link
Contributor Author

what is the command to repro? I can check the memory issue. Thanks!

@BoyuanFeng thanks! I just sent the repro step and compilation result to you over slack.

@ProExpertProg
Copy link
Collaborator

Btw #23261 would help this pass as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants