-
-
Couldn't load subscription status.
- Fork 10.9k
[torch.compile][TPU] Make @support_torch_compile work for XLA backend #15782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile][TPU] Make @support_torch_compile work for XLA backend #15782
Conversation
Signed-off-by: Siyuan Liu <[email protected]>
Signed-off-by: Siyuan Liu <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
|
Slightly improved the throughput |
vllm/v1/worker/tpu_model_runner.py
Outdated
| self._hidden_states_dtype = out.dtype | ||
| self.model(input_ids=input_ids, | ||
| positions=position_ids, | ||
| inputs_embeds=inputs_embeds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please keep the _hidden_states_dtype assignment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep that's still needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO updating _hidden_states_dtype is not needed, _hidden_states_dtype is initialized with the model dtype already in _hidden_states_dtype. Also seems we don't need _hidden_states_dtype at all, since it should be the same as the model dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is more context https://github.com/vllm-project/vllm/pull/15714/files#r2019453889
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is out of scope, but yes in principle you'd only need self.dtype. The dtype issue you linked has proven that nothing crashes at runtime should the output of an op not match self.dtype though.
Hence if the same bug were to re-appear again, we would only notice the server recompiling and we'd have to debug again the way I did, painfully.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Updated
vllm/v1/worker/tpu_model_runner.py
Outdated
| xm.wait_device_ops() | ||
| xm.wait_device_ops() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to pull the device sync inside of the loop? IIRC we pulled it out since it made parallel compilation slightly quicker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Updated.
| def get_multimodal_embeddings(self, *args, **kwargs): | ||
| return self.model.get_multimodal_embeddings(*args, **kwargs) | ||
|
|
||
| def get_input_embeddings(self, *args, **kwargs): | ||
| return self.model.get_input_embeddings(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tested that multimodal inference still works and these are called correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only ran TPU CI to test the change, multimodal is not tested. Can you provide a script to test the multimodal? I can test it on this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VLLM_USE_V1=1 vllm serve llava-hf/llava-1.5-7b-hf --max-model-len 4096 --max-num-seqs 8 --max-num-batched-tokens 512 --chat-template examples/template_llava.jinja
then python examples/online_serving/openai_chat_completion_client_for_multimodal.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am adding tests in another pr btw
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg I'll rebase to your PR after :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @NickLucche Thank you for providing the testing cmd for multimodal, I made this PR working for llava-hf/llava-1.5-7b-hf. The server cmd runs fine.
However, the client script python examples/online_serving/openai_chat_completion_client_for_multimodal.py would fail at HEAD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main is working for me, let me try your PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still really slow but it works the same on this PR on my side, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NickLucche Thank you so much for trying! I set up a new conda env on my end and tried as well. I also found it's really slow, was about to ask you lol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
Left a comment about MM, I think that's about the only thing to clarify on my side.
vllm/v1/worker/tpu_model_runner.py
Outdated
|
|
||
| return hidden_states | ||
| def reset_dynamo_cache(self): | ||
| # TODO(lsy323): Support multimodal models, the backbone language model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean
| # TODO(lsy323): Support multimodal models, the backbone language model | |
| compiled_model = self.model.language_model if self.is_multimodal_model else self.model.model | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tested multimodal yet, plan to do it in another PR. Right now I only tested the tests in TPU CI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tested multimodal yet, plan to do it in another PR
I think it's best if we double check here o/w we may inadvertently break MM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll rebase to the MM PR after it's merged to ensure this PR doesn't break MM
|
|
||
| class ModelWrapperV1(nn.Module): | ||
|
|
||
| def __init__(self, model: nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't mind the wrapper too much I think it was grouping a few related functions nicely. Still, if it benefits performance I am ok with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wrapper was added in V0 for torch.compile. Now we have the @support_torch_compile decorator that wraps the model with torch.compile already, therefore we don't need the wrapper anymore.
| sample_hidden_states = \ | ||
| hidden_states[sampling_metadata.indices_do_sample] | ||
| logits = self.compute_logits(sample_hidden_states) | ||
| logits = self.model.compute_logits(sample_hidden_states, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we re-add the pruning comment here? Just in case it slips though in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, may I ask what is 'pruning comment'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
or smt along this line to indicate why the 2nd argument is None. That's because it could enable logits pruning which is bad for xla
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Updated
vllm/v1/worker/tpu_model_runner.py
Outdated
| self._hidden_states_dtype = out.dtype | ||
| self.model(input_ids=input_ids, | ||
| positions=position_ids, | ||
| inputs_embeds=inputs_embeds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep that's still needed
Signed-off-by: Siyuan Liu <[email protected]>
vllm/v1/worker/tpu_model_runner.py
Outdated
| def reset_dynamo_cache(self): | ||
| # TODO(lsy323): Support multimodal models, the backbone language model | ||
| # is stored in a different member. | ||
| compiled_model = self.model.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wonder why we need to do model.model. Could you add a comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible self.model doesn't have a attr model? E.g. it's not annotated by support_torch_compile?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, model.model is torch compiled
Signed-off-by: Siyuan Liu <[email protected]>
| if self.is_multimodal_model: | ||
| compiled_model = self.model.language_model.model | ||
| else: | ||
| compiled_model = self.model.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK "language_model" is not a stable attribute to reference, it is based on the HF model definition. Maybe @ywang96 @DarkLight1337 would know a stable interface to access the language model backbone?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a potential solution is to add a query function get_backbone_lm for multimodal models, similar to the existing get_input_embeddings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we can add get_language_model to SupportsMultiModal interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like #16007?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please leave an assert here for now as we will address later @lsy323
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I think we should aim to land this just working with language_model since there are other blockers/developments there. Afterwards we can migrate to use the get_language_model interface proposed by Nicolo.
| if self.is_multimodal_model: | ||
| compiled_model = self.model.language_model.model | ||
| else: | ||
| compiled_model = self.model.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please leave an assert here for now as we will address later @lsy323
…vllm-project#15782) Signed-off-by: Siyuan Liu <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Yang Wang <[email protected]>
…vllm-project#15782) Signed-off-by: Siyuan Liu <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
…vllm-project#15782) Signed-off-by: Siyuan Liu <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
…vllm-project#15782) Signed-off-by: Siyuan Liu <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Mu Huai <[email protected]>
Make
@support_torch_compilework for XLA backend. With the custom dispatcher, overhead of dynamo guard evaluation is eliminated.For TPU backend, each models have 2 FX graphs/dynamo bytecodes:
This breaks the assumption in the current
@support_torch_compileimplementation - Each model has one FX graph/cached bytecode. Since the profiling graph won't be invoked after the profiling run, we clear the cached bytecode after profiling run so that the assumption remains true.Other changes:
ModelWrapperV1, which is used to wrap the model code andtorch.compilethe wrapped model. It's not needed anymore since we are reusing the compile decorator.ModelWrapperV1is removed, sampler logic is moved to a separate function.Credit to @WoosukKwon on the idea of clearing bytecode cache and @youkaichao for the many help on torch dynamo related questions!
cc @youkaichao @alexm-redhat @miladm @NickLucche @WoosukKwon @yaochengji @robertgshaw2-redhat