-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[torch.compile][TPU] Make @support_torch_compile work for XLA backend #15782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
176bc3e
14d33b2
c27943d
c14518f
cdec999
a8dbbf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |||||||
| import vllm.envs as envs | ||||||||
| from vllm.attention.backends.abstract import AttentionType | ||||||||
| from vllm.attention.layer import Attention | ||||||||
| from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher | ||||||||
| from vllm.config import VllmConfig | ||||||||
| from vllm.forward_context import set_forward_context | ||||||||
| from vllm.logger import init_logger | ||||||||
|
|
@@ -608,11 +609,10 @@ def execute_model( | |||||||
| hidden_states = self.model( | ||||||||
| input_ids=input_ids, | ||||||||
| positions=self.position_ids, | ||||||||
| kv_caches=self.kv_caches, | ||||||||
| inputs_embeds=inputs_embeds, | ||||||||
| ) | ||||||||
| selected_token_ids = self.model.sample_from_hidden( | ||||||||
| hidden_states, tpu_sampling_metadata) | ||||||||
| selected_token_ids = self.sample_from_hidden(hidden_states, | ||||||||
| tpu_sampling_metadata) | ||||||||
| # Remove padding on cpu and keep dynamic op outside of xla graph. | ||||||||
| selected_token_ids = selected_token_ids.cpu()[:num_reqs] | ||||||||
|
|
||||||||
|
|
@@ -713,17 +713,15 @@ def load_model(self) -> None: | |||||||
| "get_tensor_model_parallel_rank", | ||||||||
| return_value=xm_tp_rank): | ||||||||
| model = get_model(vllm_config=self.vllm_config) | ||||||||
| model = model.eval() | ||||||||
| # Sync all pending XLA execution during model initialization and weight | ||||||||
| # loading. | ||||||||
| xm.mark_step() | ||||||||
| xm.wait_device_ops() | ||||||||
| model = ModelWrapperV1(model) | ||||||||
| self.model = torch.compile(model, | ||||||||
| backend="openxla", | ||||||||
| fullgraph=True, | ||||||||
| dynamic=False) | ||||||||
| self.model = model | ||||||||
| self.sampler = TPUSampler() | ||||||||
|
|
||||||||
| @torch.no_grad() | ||||||||
| def _dummy_run(self, kv_caches, num_tokens: int) -> None: | ||||||||
| def _dummy_run(self, num_tokens: int) -> None: | ||||||||
| if self.is_multimodal_model: | ||||||||
| input_ids = None | ||||||||
| inputs_embeds = torch.zeros((num_tokens, self.hidden_size), | ||||||||
|
|
@@ -772,11 +770,9 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: | |||||||
| torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) | ||||||||
|
|
||||||||
| with set_forward_context(attn_metadata, self.vllm_config, 0): | ||||||||
| out = self.model(input_ids=input_ids, | ||||||||
| positions=position_ids, | ||||||||
| kv_caches=kv_caches, | ||||||||
| inputs_embeds=inputs_embeds) | ||||||||
| self._hidden_states_dtype = out.dtype | ||||||||
| self.model(input_ids=input_ids, | ||||||||
| positions=position_ids, | ||||||||
| inputs_embeds=inputs_embeds) | ||||||||
|
|
||||||||
| def capture_model(self) -> None: | ||||||||
| """Compile the model.""" | ||||||||
|
|
@@ -786,9 +782,9 @@ def capture_model(self) -> None: | |||||||
| start = time.perf_counter() | ||||||||
| for num_tokens in self.num_tokens_paddings: | ||||||||
| logger.info(" -- num_tokens: %d", num_tokens) | ||||||||
| self._dummy_run(self.kv_caches, num_tokens) | ||||||||
| self._dummy_run(num_tokens) | ||||||||
| xm.mark_step() | ||||||||
| xm.wait_device_ops() | ||||||||
| xm.wait_device_ops() | ||||||||
|
||||||||
| end = time.perf_counter() | ||||||||
| logger.info("Compilation finished in in %.2f [secs].", end - start) | ||||||||
|
|
||||||||
|
|
@@ -815,8 +811,7 @@ def capture_model(self) -> None: | |||||||
| from_input_batch(self.input_batch, indices) | ||||||||
| logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, | ||||||||
| num_reqs_to_sample) | ||||||||
| out = self.model.sample_from_hidden(dummy_hidden, | ||||||||
| sampling_meta) | ||||||||
| out = self.sample_from_hidden(dummy_hidden, sampling_meta) | ||||||||
| out = out.cpu() | ||||||||
| if num_reqs_to_sample >= self.max_num_reqs: | ||||||||
| break | ||||||||
|
|
@@ -874,79 +869,45 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | |||||||
| self.vllm_config.compilation_config.static_forward_context, | ||||||||
| self.kv_caches) | ||||||||
|
|
||||||||
|
|
||||||||
| class ModelWrapperV1(nn.Module): | ||||||||
|
|
||||||||
| def __init__(self, model: nn.Module): | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't mind the wrapper too much I think it was grouping a few related functions nicely. Still, if it benefits performance I am ok with that.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The wrapper was added in V0 for |
||||||||
| super().__init__() | ||||||||
| self.model = model | ||||||||
| self.sampler = TPUSampler() | ||||||||
|
|
||||||||
| def sample( | ||||||||
| self, logits: torch.Tensor, | ||||||||
| sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput: | ||||||||
| sampler_out = self.sampler(logits, sampling_metadata) | ||||||||
| return sampler_out | ||||||||
|
|
||||||||
| def forward( | ||||||||
| self, | ||||||||
| input_ids: torch.Tensor, | ||||||||
| positions: torch.Tensor, | ||||||||
| kv_caches: list[torch.Tensor], | ||||||||
| inputs_embeds: Optional[torch.Tensor] = None, | ||||||||
| ) -> torch.Tensor: | ||||||||
| """Executes the forward pass of the model. | ||||||||
|
|
||||||||
| Args: | ||||||||
| input_ids: The input token IDs of shape [num_tokens]. | ||||||||
| positions: The input position IDs of shape [num_tokens]. | ||||||||
| kv_caches: The key and value caches. They can be None during the | ||||||||
| memory profiling at initialization. | ||||||||
| inputs_embeds: The input embeddings of shape [num_tokens, | ||||||||
| hidden_size]. It is used for multimodal models. | ||||||||
| """ | ||||||||
|
|
||||||||
| hidden_states = self.model( | ||||||||
| input_ids=input_ids, | ||||||||
| positions=positions, | ||||||||
| inputs_embeds=inputs_embeds, | ||||||||
| ) | ||||||||
|
|
||||||||
| return hidden_states | ||||||||
| def reset_dynamo_cache(self): | ||||||||
| # TODO(lsy323): Support multimodal models, the backbone language model | ||||||||
|
||||||||
| # TODO(lsy323): Support multimodal models, the backbone language model | |
| compiled_model = self.model.language_model if self.is_multimodal_model else self.model.model | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tested multimodal yet, plan to do it in another PR. Right now I only tested the tests in TPU CI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tested multimodal yet, plan to do it in another PR
I think it's best if we double check here o/w we may inadvertently break MM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll rebase to the MM PR after it's merged to ensure this PR doesn't break MM
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wonder why we need to do model.model. Could you add a comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible self.model doesn't have a attr model? E.g. it's not annotated by support_torch_compile?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, model.model is torch compiled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we re-add the pruning comment here? Just in case it slips though in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, may I ask what is 'pruning comment'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
or smt along this line to indicate why the 2nd argument is None. That's because it could enable logits pruning which is bad for xla
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tested that multimodal inference still works and these are called correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only ran TPU CI to test the change, multimodal is not tested. Can you provide a script to test the multimodal? I can test it on this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VLLM_USE_V1=1 vllm serve llava-hf/llava-1.5-7b-hf --max-model-len 4096 --max-num-seqs 8 --max-num-batched-tokens 512 --chat-template examples/template_llava.jinja
then python examples/online_serving/openai_chat_completion_client_for_multimodal.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am adding tests in another pr btw
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg I'll rebase to your PR after :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @NickLucche Thank you for providing the testing cmd for multimodal, I made this PR working for llava-hf/llava-1.5-7b-hf. The server cmd runs fine.
However, the client script python examples/online_serving/openai_chat_completion_client_for_multimodal.py would fail at HEAD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main is working for me, let me try your PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still really slow but it works the same on this PR on my side, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NickLucche Thank you so much for trying! I set up a new conda env on my end and tried as well. I also found it's really slow, was about to ask you lol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please keep the _hidden_states_dtype assignment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep that's still needed
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO updating
_hidden_states_dtypeis not needed,_hidden_states_dtypeis initialized with the model dtype already in _hidden_states_dtype. Also seems we don't need_hidden_states_dtypeat all, since it should be the same as the modeldtypeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is more context https://github.com/vllm-project/vllm/pull/15714/files#r2019453889
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is out of scope, but yes in principle you'd only need
self.dtype. The dtype issue you linked has proven that nothing crashes at runtime should the output of an op not matchself.dtypethough.Hence if the same bug were to re-appear again, we would only notice the server recompiling and we'd have to debug again the way I did, painfully.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Updated