Skip to content

Conversation

angelayi
Copy link
Contributor

@angelayi angelayi commented Oct 13, 2025

Purpose

Addresses #25277
Porting over the changes from @cascade812 in #21031 which enabled SP without needing to specify compile_sizes in CompilationConfig, + adding some special handling so that async TP/SP can be enabled if just use_inductor_graph_partition=True

Test Result

Here are the updated numbers:

For meta-llama/Llama-3.1-70B-Instruct, updated the command:

vllm bench latency --model=meta-llama/Llama-3.1-70B-Instruct --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 8 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level": 3, "pass_config": {"enable_async_tp": true, "enable_sequence_parallelism": true}, "use_inductor_graph_partition": false, "splitting_ops":[], "cudagraph_mode": FULL}' --no-enable-prefix-caching --gpu_memory_utilization 0.6
Image

For RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8:

vllm bench latency --model=RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8 --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 8 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level": 3, "pass_config": {"enable_async_tp": true, "enable_sequence_parallelism": true}, "use_inductor_graph_partition": false, "splitting_ops":[], "custom_ops":["+quant_fp8"], "cudagraph_mode": FULL}' --no-enable-prefix-caching
Image

For nvidia/Llama-3.3-70B-Instruct-FP8:

Image

Copy link

mergify bot commented Oct 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @angelayi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 13, 2025
@angelayi angelayi changed the title Enable sequence parallelism for full cuda graph without specifying compile sizes [compile] Enable sequence parallelism for full cuda graph without specifying compile sizes Oct 13, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables sequence parallelism for full CUDA graph compilation without requiring compile_sizes. This is achieved by updating the applicability checks in SequenceParallelismPass and AsyncTPPass. The changes look good and align with the goal of the PR.

My main feedback is regarding code duplication in the is_applicable method, which is now identical in both SequenceParallelismPass and AsyncTPPass. I've suggested refactoring this into a shared helper method to improve maintainability. I've also included a minor simplification for the condition check in the method.

@angelayi angelayi force-pushed the angelayi/patch_21031 branch from ea0f81b to 0506835 Compare October 13, 2025 04:56
Copy link

💡 Codex Review

# When sequence parallelism is enabled, the residual tensor from RMSNorm
# needs to be split along the sequence dimension. However, this dimension
# is symbolic during piecewise compilation, and splitting symbolic shapes
# is not supported.
#
# This pass is therefore only applied when the sequence dimension is
# concrete:
# 1. In full-graph compilation mode (no splitting ops are used).
# For this case we always pad num_tokens to be a multiple of
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
# 2. For specific shape provided during compilation (e.g., from
# `compile_sizes`), which must be divisible by the tensor-parallel
# size.
def is_applicable(self, shape: Optional[int]) -> bool:
splitting_ops = self.compilation_config.splitting_ops
if (
splitting_ops is None
or splitting_ops == []
or self.compilation_config.use_inductor_graph_partition
):
return True

P1 Badge Update residual shard checks for compile-size-free SP

Allowing SequenceParallelismPass.is_applicable to return true whenever splitting_ops is empty or use_inductor_graph_partition is enabled means sequence parallelism can now run without any static compile_sizes. The runtime helper is_residual_scattered_for_sp still assumes SP only runs for compile-time sizes and returns False when compile_sizes is empty, so downstream code copies full residual tensors instead of the sharded slices. When SP is enabled in the new full-graph scenario this causes shape mismatches during sync_and_slice_intermediate_tensors. Either the helper must be updated to mirror this new applicability condition or this pass should continue to require compile sizes.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

@mergify mergify bot removed the needs-rebase label Oct 13, 2025
@angelayi angelayi force-pushed the angelayi/patch_21031 branch from 0506835 to 32d222a Compare October 13, 2025 04:57
Signed-off-by: angelayi <[email protected]>
@angelayi angelayi force-pushed the angelayi/patch_21031 branch from 32d222a to 3d563e2 Compare October 13, 2025 04:58
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Can you post the benchmarking numbers in this PR as well?

def is_applicable_for_shape(self, shape: int | None) -> bool:
# only do replace for specific shapes
def is_applicable(self, shape: int | None) -> bool:
# This pass is applied on top of the sequence parallelism pass.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for now but @cascade812 isn't this pass technically fine no matter the shape? Obviously it won't match anything if sequence parallelism didn't run, but still

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, removing this implementation would work as well. But it would trigger matching logic for this pass which could add some overhead.

@angelayi angelayi force-pushed the angelayi/patch_21031 branch 3 times, most recently from 3fe31bf to e8d80b0 Compare October 13, 2025 06:09
@angelayi angelayi force-pushed the angelayi/patch_21031 branch from e8d80b0 to 95a0ba8 Compare October 13, 2025 16:49
@ProExpertProg ProExpertProg enabled auto-merge (squash) October 13, 2025 21:29
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 13, 2025
@vllm-bot vllm-bot merged commit b59dd19 into vllm-project:main Oct 14, 2025
45 of 47 checks passed
@ProExpertProg
Copy link
Collaborator

As a follow up, it seems like we should track down the following performance issues:

  • On B200, async TP is slower than just sequence parallelism for quantized models. I wonder if this is because there are issues with the passes when quantization is present?
  • FULL_AND_PIECEWISE seems slower than FULL?

1994 pushed a commit to 1994/vllm that referenced this pull request Oct 14, 2025
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
…cifying compile sizes (vllm-project#26681)

Signed-off-by: angelayi <[email protected]>
Signed-off-by: Dhruvil Bhatt <[email protected]>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
@zou3519
Copy link
Collaborator

zou3519 commented Oct 16, 2025

Btw, this PR causes the PostGradPassManager object to hold onto the CompilationConfig.

This is not good, because Inductor configs need to be serializable and deepcopy-able, the PostGradPassManager is added to the Inductor config, and the CompilationConfig holds some nn.Modules. Internally, we saw deepcopy fail on the nn.Modules and OOMs due to this

@ProExpertProg
Copy link
Collaborator

Yeah we used to only save PassConfig to avoid a cycle but avoiding saving the CompilationConfig seems like a good idea as well. We can just save the relevant properties in __init__

@zou3519
Copy link
Collaborator

zou3519 commented Oct 16, 2025

I'll make an issue to track, @luccafong tracked this one down (and has a fix somewhere)

@ProExpertProg
Copy link
Collaborator

Fixed in #27041

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants