-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[torch.compile][TPU] Make @support_torch_compile work for XLA backend #15782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
176bc3e
14d33b2
c27943d
c14518f
cdec999
a8dbbf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| import vllm.envs as envs | ||
| from vllm.attention.backends.abstract import AttentionType | ||
| from vllm.attention.layer import Attention | ||
| from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher | ||
| from vllm.config import VllmConfig | ||
| from vllm.forward_context import set_forward_context | ||
| from vllm.logger import init_logger | ||
|
|
@@ -647,11 +648,10 @@ def execute_model( | |
| hidden_states = self.model( | ||
| input_ids=input_ids, | ||
| positions=self.position_ids, | ||
| kv_caches=self.kv_caches, | ||
| inputs_embeds=inputs_embeds, | ||
| ) | ||
| selected_token_ids = self.model.sample_from_hidden( | ||
| hidden_states, tpu_sampling_metadata) | ||
| selected_token_ids = self.sample_from_hidden(hidden_states, | ||
| tpu_sampling_metadata) | ||
| # Remove padding on cpu and keep dynamic op outside of xla graph. | ||
| selected_token_ids = selected_token_ids.cpu()[:num_reqs] | ||
|
|
||
|
|
@@ -751,17 +751,15 @@ def load_model(self) -> None: | |
| "get_tensor_model_parallel_rank", | ||
| return_value=xm_tp_rank): | ||
| model = get_model(vllm_config=self.vllm_config) | ||
| model = model.eval() | ||
| # Sync all pending XLA execution during model initialization and weight | ||
| # loading. | ||
| xm.mark_step() | ||
| xm.wait_device_ops() | ||
| model = ModelWrapperV1(model) | ||
| self.model = torch.compile(model, | ||
| backend="openxla", | ||
| fullgraph=True, | ||
| dynamic=False) | ||
| self.model = model | ||
| self.sampler = TPUSampler() | ||
|
|
||
| @torch.no_grad() | ||
| def _dummy_run(self, kv_caches, num_tokens: int) -> None: | ||
| def _dummy_run(self, num_tokens: int) -> None: | ||
| if self.is_multimodal_model: | ||
| input_ids = None | ||
| inputs_embeds = torch.zeros((num_tokens, self.hidden_size), | ||
|
|
@@ -812,7 +810,6 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: | |
| with set_forward_context(attn_metadata, self.vllm_config, 0): | ||
| out = self.model(input_ids=input_ids, | ||
| positions=position_ids, | ||
| kv_caches=kv_caches, | ||
| inputs_embeds=inputs_embeds) | ||
| self._hidden_states_dtype = out.dtype | ||
|
|
||
|
|
@@ -824,7 +821,7 @@ def capture_model(self) -> None: | |
| start = time.perf_counter() | ||
| for num_tokens in self.num_tokens_paddings: | ||
| logger.info(" -- num_tokens: %d", num_tokens) | ||
| self._dummy_run(self.kv_caches, num_tokens) | ||
| self._dummy_run(num_tokens) | ||
| xm.mark_step() | ||
| xm.wait_device_ops() | ||
| end = time.perf_counter() | ||
|
|
@@ -855,8 +852,7 @@ def capture_model(self) -> None: | |
| from_input_batch(self.input_batch, indices) | ||
| logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, | ||
| num_reqs_to_sample) | ||
| out = self.model.sample_from_hidden(dummy_hidden, | ||
| sampling_meta) | ||
| out = self.sample_from_hidden(dummy_hidden, sampling_meta) | ||
| out = out.cpu() | ||
| # Requests can't be more than tokens. But do compile for the | ||
| # next bigger value in case num_tokens uses bucketed padding. | ||
|
|
@@ -910,79 +906,48 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | |
| self.vllm_config.compilation_config.static_forward_context, | ||
| self.kv_caches) | ||
|
|
||
|
|
||
| class ModelWrapperV1(nn.Module): | ||
|
|
||
| def __init__(self, model: nn.Module): | ||
| super().__init__() | ||
| self.model = model | ||
| self.sampler = TPUSampler() | ||
|
|
||
| def sample( | ||
| self, logits: torch.Tensor, | ||
| sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput: | ||
| sampler_out = self.sampler(logits, sampling_metadata) | ||
| return sampler_out | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| kv_caches: list[torch.Tensor], | ||
| inputs_embeds: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """Executes the forward pass of the model. | ||
|
|
||
| Args: | ||
| input_ids: The input token IDs of shape [num_tokens]. | ||
| positions: The input position IDs of shape [num_tokens]. | ||
| kv_caches: The key and value caches. They can be None during the | ||
| memory profiling at initialization. | ||
| inputs_embeds: The input embeddings of shape [num_tokens, | ||
| hidden_size]. It is used for multimodal models. | ||
| """ | ||
|
|
||
| hidden_states = self.model( | ||
| input_ids=input_ids, | ||
| positions=positions, | ||
| inputs_embeds=inputs_embeds, | ||
| ) | ||
|
|
||
| return hidden_states | ||
| def reset_dynamo_cache(self): | ||
| if self.is_multimodal_model: | ||
| assert hasattr(self.model, "language_model") | ||
| compiled_model = self.model.language_model.model | ||
| else: | ||
| compiled_model = self.model.model | ||
| if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher): | ||
| logger.info("Clear dynamo cache and cached dynamo bytecode.") | ||
| torch._dynamo.eval_frame.remove_from_cache( | ||
| compiled_model.original_code_object) | ||
| compiled_model.compiled_codes.clear() | ||
|
|
||
| def sample_from_hidden( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| sampling_metadata: TPUSupportedSamplingMetadata, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Sample with xla-friendly function. This function is to be traced | ||
| separately from `forward` for lighter compilation overhead. | ||
| """ | ||
| Sample with xla-friendly function. This function is to be traced | ||
| separately for lighter compilation overhead. | ||
| """ | ||
| # Tensor `sample_hidden_states` is of fixed pre-compiled size. | ||
| sample_hidden_states = \ | ||
| hidden_states[sampling_metadata.indices_do_sample] | ||
| logits = self.compute_logits(sample_hidden_states) | ||
| # SamplingMetadata here for pruning output in LogitsProcessor, disabled. | ||
| logits = self.model.compute_logits(sample_hidden_states, None) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we re-add the pruning comment here? Just in case it slips though in the future.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, may I ask what is 'pruning comment'?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! Updated |
||
|
|
||
| def sample( | ||
| logits: torch.Tensor, | ||
| sampling_metadata: TPUSupportedSamplingMetadata | ||
| ) -> SamplerOutput: | ||
| sampler_out = self.sampler(logits, sampling_metadata) | ||
| return sampler_out | ||
|
|
||
| # Optimized greedy sampling branch, tracing both paths in a single pass | ||
| # NOTE all_greedy is a scalar, this is just an optimized if/else. | ||
| out_tokens = torch.where(sampling_metadata.all_greedy, | ||
| torch.argmax(logits, dim=-1, keepdim=True), | ||
| self.sample(logits, sampling_metadata)\ | ||
| .sampled_token_ids) | ||
| out_tokens = torch.where( | ||
| sampling_metadata.all_greedy, | ||
| torch.argmax(logits, dim=-1, keepdim=True), | ||
| sample(logits, sampling_metadata).sampled_token_ids) | ||
| return out_tokens | ||
|
|
||
| def compute_logits(self, | ||
| hidden_states: torch.Tensor) -> Optional[torch.Tensor]: | ||
| # SamplingMetadata here for pruning output in LogitsProcessor, disabled | ||
| logits = self.model.compute_logits(hidden_states, None) | ||
| return logits | ||
|
|
||
| def get_multimodal_embeddings(self, *args, **kwargs): | ||
| return self.model.get_multimodal_embeddings(*args, **kwargs) | ||
|
|
||
| def get_input_embeddings(self, *args, **kwargs): | ||
| return self.model.get_input_embeddings(*args, **kwargs) | ||
|
Comment on lines
-980
to
-984
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you tested that multimodal inference still works and these are called correctly?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I only ran TPU CI to test the change, multimodal is not tested. Can you provide a script to test the multimodal? I can test it on this PR.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
then
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am adding tests in another pr btw
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sg I'll rebase to your PR after :)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @NickLucche Thank you for providing the testing cmd for multimodal, I made this PR working for However, the client script
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Main is working for me, let me try your PR
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still really slow but it works the same on this PR on my side, thanks!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NickLucche Thank you so much for trying! I set up a new conda env on my end and tried as well. I also found it's really slow, was about to ask you lol. |
||
|
|
||
|
|
||
| def _get_padded_number(n: int, multiple: int) -> int: | ||
| return ((n + multiple - 1) // multiple) * multiple | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't mind the wrapper too much I think it was grouping a few related functions nicely. Still, if it benefits performance I am ok with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wrapper was added in V0 for
torch.compile. Now we have the@support_torch_compiledecorator that wraps the model withtorch.compilealready, therefore we don't need the wrapper anymore.