Skip to content

Conversation

cascade812
Copy link
Contributor

@cascade812 cascade812 commented Mar 16, 2025

support sequence parallel with TP on models like llama

In this PR, I modified RowParallelLinear, ColumnParallelLinear, LogitsProcessor, VocabParallelEmbedding to support SP.

Belows are TODOs

  • support combination with pp
  • support other layers

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@cascade812 cascade812 changed the title support sequence parallel [Feature] Support sequence parallelism Mar 16, 2025
@yaochengji
Copy link
Collaborator

Thanks for your contribution, @cascade812 !

Hi @robertgshaw2-redhat , @tlrmchlsmth , could you take a look at this?

Based on my micro-benchmark, collective-matmul optimization can improve the performance of multi-chip inference on TPU greatly. To enable collective-matmul, we depend on vLLM to support megatron-style sequence parallelism. Then the TPU compiler can automatically convert the pattern of ag-matmul and matmul-rs to collective-matmul.

cc @bvrockwell @yarongmu-google

Copy link

mergify bot commented Mar 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @cascade812.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 18, 2025
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
@tlrmchlsmth
Copy link
Member

I’ll take a look! I’m excited about and in favor of sequence parallel support in general.

@yaochengji could you explain why this helps the gemm-rs and ag-gemm rewrite? I don’t really see that as sequence parallel outside of the rms_norm

@yaochengji
Copy link
Collaborator

yaochengji commented Mar 20, 2025

Thanks, @tlrmchlsmth .

It is because sequence parallelism, the model with tp looks like as below:
matmul -> allreduce -> rms_norm(or other ops) -> matmul

with sp enabled, it becomes:
matmul -> rs -> rms_norm(or other ops) -> ag -> matmul

matmul-rs and ag-matmul are clean patterns for optimizations to detect, including TPU compiler.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our previous plan for sequence parallel, is to make it a compilation pass, without changing the linear/embedding layer. that thread somehow got lost.

cc @bnellnm

@yaochengji
Copy link
Collaborator

Thanks for your comment, @youkaichao !

If we don't change the linear/embedding layer, we have to match various patterns for matmul -> allreduce -> rms_norm(or other ops) -> matmul even when only rms_norm is considered. We might find different kinds of hardware have different rms_norm implementation, sometimes one hardware has various implementations of rms_norm, e.g. nvgpu.

Changing the linear or embedding layer offers a potentially more generalized approach. Furthermore, the number of layers requiring modification is relatively small.

@yaochengji
Copy link
Collaborator

@cascade812 I know currently not all layers support sp. Could you print readable message to the vLLM users if a specific layer doesn't support sequence parallelism?

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mega job here thanks!
Can we get some benchmark numbers to go along with this PR?

@bnellnm
Copy link
Contributor

bnellnm commented Mar 20, 2025

our previous plan for sequence parallel, is to make it a compilation pass, without changing the linear/embedding layer. that thread somehow got lost.

cc @bnellnm

There were too many problems with pytorch, the limitations of the kernels and issues with piecewise graphs so the project was put on hold.

@yaochengji
Copy link
Collaborator

Mega job here thanks! Can we get some benchmark numbers to go along with this PR?

I'm working on enabling collective matmul optimization on TPU. My change on vLLM will be based on this PR, will share the benchmark numbers later.

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some inline comments. Generally I also think this should be done as a pass in torch inductor or similar compiler layer. I'm pretty sure these changes are making assumptions about the model definition that may not be valid.

Comment on lines +1270 to +1272
forward_context = try_get_forward_context()
if (forward_context is not None
and forward_context.enable_sequence_parallel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward context isn't available outside of model initialization, so you'll have to do self.enable_sequence_parallel = forward_context.enable_sequence_parallel in __init__, otherwise you won't actually be using sequence parallel while actually inferencing the model (unless you're using CUDA graphs)

I think this is a pretty tricky footgun, so we should address this - (cc @youkaichao)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward context is available, it's set before calling self.model().

Comment on lines +1323 to +1326
enable_sequence_parallel = (
self.vllm_config.parallel_config.enable_sequence_parallel
and num_tokens %
self.vllm_config.parallel_config.tensor_parallel_size == 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decisions like this should go in vllm/config.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is placed outside of vllm.config because torch.distributed.reduce_scatter_tensor only works when the reduced scattered dimension (in this case, the token dimension) is divisible by the parallel size. Since num_tokens changes every iteration, it seems not reasonable to put it in the config.

Comment on lines 1056 to 1059
with set_forward_context(
attn_metadata,
self.vllm_config,
enable_sequence_parallel=enable_sequence_parallel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see we're setting the forward_context here now. @youkaichao thoughts on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks quite intrusive to me 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao Since we need num_tokens to make the decision before reduce-scatter (see the explanation above), do you have any suggestions for a better approach?

@cascade812
Copy link
Contributor Author

Thanks for all the comments and reviews! I'll address accordingly.

I know currently not all layers support sp. Could you print readable message to the vLLM users if a specific layer doesn't support sequence parallelism?

@yaochengji Do you mean adding a message to all unsupported layers? That would involve many layers. Do you have any suggestions on how to implement this efficiently?

@mergify mergify bot added the ci/build label Mar 22, 2025
@yaochengji
Copy link
Collaborator

Do you mean adding a message to all unsupported layers? That would involve many layers. Do you have any suggestions on how to implement this efficiently?

Usually we have a base class and the sequence parallelism not implemented warning can be put there. But I just took a look the vLLM code realize it's not easy to implement that way.

NVM, I'm fine that users need to be aware of the sequence parallelism support availability before it is fully supported.

@youkaichao
Copy link
Member

we have to match various patterns for matmul -> allreduce -> rms_norm(or other ops) -> matmul even when only rms_norm is considered.

I think you only need to match matmul -> allreduce and allreduce -> matmul ? rms_norm just takes an input and produce output with the same shape, you don't need to change the op.

@yaochengji
Copy link
Collaborator

yaochengji commented Mar 22, 2025

I think you only need to match matmul -> allreduce and allreduce -> matmul ?

There's no allreduce->matmul in 1D TP.

rms_norm just takes an input and produce output with the same shape, you don't need to change the op.

We don't need to change the op but we still need to detect it. Because we will decompose allreduce to reduce-scatter and allgather and move allgather after rms_norm but before the next matmul.

@robertgshaw2-redhat
Copy link
Collaborator

hey @cascade812 - is it okay if I push a few changes to your branch?

@cascade812
Copy link
Contributor Author

hey @cascade812 - is it okay if I push a few changes to your branch?

Absolutely! Your contributions are more than welcome. Thanks!

@cascade812
Copy link
Contributor Author

@robertgshaw2-redhat @tlrmchlsmth @youkaichao what's your thought on this? If you guys think compilation pass is the better choice, I'm willing to help too!

@mgoin mgoin self-requested a review March 27, 2025 16:29
@tlrmchlsmth
Copy link
Member

@robertgshaw2-redhat @tlrmchlsmth @youkaichao what's your thought on this? If you guys think compilation pass is the better choice, I'm willing to help too!

I didn’t see this message! It would be awesome for you to work on that!

@cascade812
Copy link
Contributor Author

cascade812 commented Mar 31, 2025

@robertgshaw2-redhat @tlrmchlsmth @youkaichao what's your thought on this? If you guys think compilation pass is the better choice, I'm willing to help too!

I didn’t see this message! It would be awesome for you to work on that!

Great! I'll work on it and keep you posted on the progress.

Copy link

mergify bot commented Apr 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @cascade812.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 1, 2025
@tlrmchlsmth
Copy link
Member

@robertgshaw2-redhat @tlrmchlsmth @youkaichao what's your thought on this? If you guys think compilation pass is the better choice, I'm willing to help too!

I didn’t see this message! It would be awesome for you to work on that!

Great! I'll work on it and keep you posted on the progress.

Hey @cascade812, I have some thoughts/pointers on this (sorry, meant to send these along earlier):

First, here is an earlier PR that attempted to do this via an inductor pass with pattern matching: #9886.
There are a couple of issues with this implementation. It has some deadlocking issues that aren't that well understood, and it also has some complexities when working with num_tokens % 4 != 0 (similar to things that you need to deal with in this PR).

Another problem is that the pattern matching is very brittle and would need to be extended to support different models. @yaochengji raised this issue, and from my past experience I think this is a very valid concern.

In an effort to make it more flexible, a some of us have discussed adding sentinel no-op operations (e.g. a begin_sp_region operation that does a clone at the end of RowParallelLinear's apply method and an end_sp_region at the beginning of MergedColumnParallelLinear). Then we can find regions fenced by these operations to do a rewrite. This would let us be more robust and selective, and would neatly handle PP as well

BTW are you on the vllm slack? If not, please join as it's easier to discuss there!

@yaochengji
Copy link
Collaborator

@tlrmchlsmth

I noticed that the layer_norm is a CustomOp class, do you think we can make it a composite rms_norm op in fx graph for all cases?

@CustomOp.register("rms_norm")

@tlrmchlsmth
Copy link
Member

@tlrmchlsmth

I noticed that the layer_norm is a CustomOp class, do you think we can make it a composite rms_norm op in fx graph for all cases?

@CustomOp.register("rms_norm")

Not sure about this. Often we make things custom ops when torch.compile has trouble with them. @youkaichao do you know?

self.vllm_config.parallel_config.enable_sequence_parallel
and num_tokens is not None and num_tokens %
self.vllm_config.parallel_config.tensor_parallel_size == 0)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be beneficial to log a warning the first time this occurs so that it's clear in logs if sequence parallel is not being enabled due to sequence length even if config argument for sequence parallel is True.

@tlrmchlsmth
Copy link
Member

Closing in favor of #16155, which has been merged!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants