-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Feature] Support sequence parallelism #14908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Thanks for your contribution, @cascade812 ! Hi @robertgshaw2-redhat , @tlrmchlsmth , could you take a look at this? Based on my micro-benchmark, collective-matmul optimization can improve the performance of multi-chip inference on TPU greatly. To enable collective-matmul, we depend on vLLM to support megatron-style sequence parallelism. Then the TPU compiler can automatically convert the pattern of ag-matmul and matmul-rs to collective-matmul. |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
I’ll take a look! I’m excited about and in favor of sequence parallel support in general. @yaochengji could you explain why this helps the gemm-rs and ag-gemm rewrite? I don’t really see that as sequence parallel outside of the rms_norm |
Thanks, @tlrmchlsmth . It is because sequence parallelism, the model with tp looks like as below: with sp enabled, it becomes: matmul-rs and ag-matmul are clean patterns for optimizations to detect, including TPU compiler. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
our previous plan for sequence parallel, is to make it a compilation pass, without changing the linear/embedding layer. that thread somehow got lost.
cc @bnellnm
Thanks for your comment, @youkaichao ! If we don't change the linear/embedding layer, we have to match various patterns for Changing the linear or embedding layer offers a potentially more generalized approach. Furthermore, the number of layers requiring modification is relatively small. |
@cascade812 I know currently not all layers support sp. Could you print readable message to the vLLM users if a specific layer doesn't support sequence parallelism? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mega job here thanks!
Can we get some benchmark numbers to go along with this PR?
There were too many problems with pytorch, the limitations of the kernels and issues with piecewise graphs so the project was put on hold. |
I'm working on enabling collective matmul optimization on TPU. My change on vLLM will be based on this PR, will share the benchmark numbers later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some inline comments. Generally I also think this should be done as a pass in torch inductor or similar compiler layer. I'm pretty sure these changes are making assumptions about the model definition that may not be valid.
forward_context = try_get_forward_context() | ||
if (forward_context is not None | ||
and forward_context.enable_sequence_parallel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward context isn't available outside of model initialization, so you'll have to do self.enable_sequence_parallel = forward_context.enable_sequence_parallel
in __init__
, otherwise you won't actually be using sequence parallel while actually inferencing the model (unless you're using CUDA graphs)
I think this is a pretty tricky footgun, so we should address this - (cc @youkaichao)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward context is available, it's set before calling self.model().
enable_sequence_parallel = ( | ||
self.vllm_config.parallel_config.enable_sequence_parallel | ||
and num_tokens % | ||
self.vllm_config.parallel_config.tensor_parallel_size == 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decisions like this should go in vllm/config.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is placed outside of vllm.config because torch.distributed.reduce_scatter_tensor only works when the reduced scattered dimension (in this case, the token dimension) is divisible by the parallel size. Since num_tokens changes every iteration, it seems not reasonable to put it in the config.
vllm/v1/worker/gpu_model_runner.py
Outdated
with set_forward_context( | ||
attn_metadata, | ||
self.vllm_config, | ||
enable_sequence_parallel=enable_sequence_parallel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I see we're setting the forward_context here now. @youkaichao thoughts on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks quite intrusive to me 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@youkaichao Since we need num_tokens to make the decision before reduce-scatter (see the explanation above), do you have any suggestions for a better approach?
Thanks for all the comments and reviews! I'll address accordingly.
@yaochengji Do you mean adding a message to all unsupported layers? That would involve many layers. Do you have any suggestions on how to implement this efficiently? |
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Usually we have a base class and the sequence parallelism NVM, I'm fine that users need to be aware of the sequence parallelism support availability before it is fully supported. |
I think you only need to match |
There's no
We don't need to change the op but we still need to detect it. Because we will decompose |
hey @cascade812 - is it okay if I push a few changes to your branch? |
Absolutely! Your contributions are more than welcome. Thanks! |
@robertgshaw2-redhat @tlrmchlsmth @youkaichao what's your thought on this? If you guys think compilation pass is the better choice, I'm willing to help too! |
I didn’t see this message! It would be awesome for you to work on that! |
Great! I'll work on it and keep you posted on the progress. |
This pull request has merge conflicts that must be resolved before it can be |
Hey @cascade812, I have some thoughts/pointers on this (sorry, meant to send these along earlier): First, here is an earlier PR that attempted to do this via an inductor pass with pattern matching: #9886. Another problem is that the pattern matching is very brittle and would need to be extended to support different models. @yaochengji raised this issue, and from my past experience I think this is a very valid concern. In an effort to make it more flexible, a some of us have discussed adding sentinel no-op operations (e.g. a BTW are you on the vllm slack? If not, please join as it's easier to discuss there! |
I noticed that the layer_norm is a
|
Not sure about this. Often we make things custom ops when torch.compile has trouble with them. @youkaichao do you know? |
self.vllm_config.parallel_config.enable_sequence_parallel | ||
and num_tokens is not None and num_tokens % | ||
self.vllm_config.parallel_config.tensor_parallel_size == 0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be beneficial to log a warning the first time this occurs so that it's clear in logs if sequence parallel is not being enabled due to sequence length even if config argument for sequence parallel is True.
Closing in favor of #16155, which has been merged! |
support sequence parallel with TP on models like llama
In this PR, I modified RowParallelLinear, ColumnParallelLinear, LogitsProcessor, VocabParallelEmbedding to support SP.
Belows are TODOs