-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[torch.compile][TPU] Make @support_torch_compile work for XLA backend #15782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
176bc3e
14d33b2
c27943d
c14518f
cdec999
a8dbbf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| import vllm.envs as envs | ||
| from vllm.attention.backends.abstract import AttentionType | ||
| from vllm.attention.layer import Attention | ||
| from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher | ||
| from vllm.config import VllmConfig | ||
| from vllm.forward_context import set_forward_context | ||
| from vllm.logger import init_logger | ||
|
|
@@ -608,11 +609,10 @@ def execute_model( | |
| hidden_states = self.model( | ||
| input_ids=input_ids, | ||
| positions=self.position_ids, | ||
| kv_caches=self.kv_caches, | ||
| inputs_embeds=inputs_embeds, | ||
| ) | ||
| selected_token_ids = self.model.sample_from_hidden( | ||
| hidden_states, tpu_sampling_metadata) | ||
| selected_token_ids = self.sample_from_hidden(hidden_states, | ||
| tpu_sampling_metadata) | ||
| # Remove padding on cpu and keep dynamic op outside of xla graph. | ||
| selected_token_ids = selected_token_ids.cpu()[:num_reqs] | ||
|
|
||
|
|
@@ -713,17 +713,15 @@ def load_model(self) -> None: | |
| "get_tensor_model_parallel_rank", | ||
| return_value=xm_tp_rank): | ||
| model = get_model(vllm_config=self.vllm_config) | ||
| model = model.eval() | ||
| # Sync all pending XLA execution during model initialization and weight | ||
| # loading. | ||
| xm.mark_step() | ||
| xm.wait_device_ops() | ||
| model = ModelWrapperV1(model) | ||
| self.model = torch.compile(model, | ||
| backend="openxla", | ||
| fullgraph=True, | ||
| dynamic=False) | ||
| self.model = model | ||
| self.sampler = TPUSampler() | ||
|
|
||
| @torch.no_grad() | ||
| def _dummy_run(self, kv_caches, num_tokens: int) -> None: | ||
| def _dummy_run(self, num_tokens: int) -> None: | ||
| if self.is_multimodal_model: | ||
| input_ids = None | ||
| inputs_embeds = torch.zeros((num_tokens, self.hidden_size), | ||
|
|
@@ -774,7 +772,6 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: | |
| with set_forward_context(attn_metadata, self.vllm_config, 0): | ||
| out = self.model(input_ids=input_ids, | ||
| positions=position_ids, | ||
| kv_caches=kv_caches, | ||
| inputs_embeds=inputs_embeds) | ||
| self._hidden_states_dtype = out.dtype | ||
|
|
||
|
|
@@ -786,7 +783,7 @@ def capture_model(self) -> None: | |
| start = time.perf_counter() | ||
| for num_tokens in self.num_tokens_paddings: | ||
| logger.info(" -- num_tokens: %d", num_tokens) | ||
| self._dummy_run(self.kv_caches, num_tokens) | ||
| self._dummy_run(num_tokens) | ||
| xm.mark_step() | ||
| xm.wait_device_ops() | ||
| end = time.perf_counter() | ||
|
|
@@ -815,8 +812,7 @@ def capture_model(self) -> None: | |
| from_input_batch(self.input_batch, indices) | ||
| logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, | ||
| num_reqs_to_sample) | ||
| out = self.model.sample_from_hidden(dummy_hidden, | ||
| sampling_meta) | ||
| out = self.sample_from_hidden(dummy_hidden, sampling_meta) | ||
| out = out.cpu() | ||
| if num_reqs_to_sample >= self.max_num_reqs: | ||
| break | ||
|
|
@@ -874,79 +870,47 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | |
| self.vllm_config.compilation_config.static_forward_context, | ||
| self.kv_caches) | ||
|
|
||
|
|
||
| class ModelWrapperV1(nn.Module): | ||
|
|
||
| def __init__(self, model: nn.Module): | ||
| super().__init__() | ||
| self.model = model | ||
| self.sampler = TPUSampler() | ||
|
|
||
| def sample( | ||
| self, logits: torch.Tensor, | ||
| sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput: | ||
| sampler_out = self.sampler(logits, sampling_metadata) | ||
| return sampler_out | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| kv_caches: list[torch.Tensor], | ||
| inputs_embeds: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """Executes the forward pass of the model. | ||
|
|
||
| Args: | ||
| input_ids: The input token IDs of shape [num_tokens]. | ||
| positions: The input position IDs of shape [num_tokens]. | ||
| kv_caches: The key and value caches. They can be None during the | ||
| memory profiling at initialization. | ||
| inputs_embeds: The input embeddings of shape [num_tokens, | ||
| hidden_size]. It is used for multimodal models. | ||
| """ | ||
|
|
||
| hidden_states = self.model( | ||
| input_ids=input_ids, | ||
| positions=positions, | ||
| inputs_embeds=inputs_embeds, | ||
| ) | ||
|
|
||
| return hidden_states | ||
| def reset_dynamo_cache(self): | ||
| if self.is_multimodal_model: | ||
| compiled_model = self.model.language_model.model | ||
| else: | ||
| compiled_model = self.model.model | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK "language_model" is not a stable attribute to reference, it is based on the HF model definition. Maybe @ywang96 @DarkLight1337 would know a stable interface to access the language model backbone? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a potential solution is to add a query function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah we can add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something like #16007? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please leave an assert here for now as we will address later @lsy323 |
||
| if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher): | ||
| logger.info("Clear dynamo cache and cached dynamo bytecode.") | ||
| torch._dynamo.eval_frame.remove_from_cache( | ||
| compiled_model.original_code_object) | ||
| compiled_model.compiled_codes.clear() | ||
|
|
||
| def sample_from_hidden( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| sampling_metadata: TPUSupportedSamplingMetadata, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Sample with xla-friendly function. This function is to be traced | ||
| separately from `forward` for lighter compilation overhead. | ||
| """ | ||
| Sample with xla-friendly function. This function is to be traced | ||
| separately for lighter compilation overhead. | ||
| """ | ||
| # Tensor `sample_hidden_states` is of fixed pre-compiled size. | ||
| sample_hidden_states = \ | ||
| hidden_states[sampling_metadata.indices_do_sample] | ||
| logits = self.compute_logits(sample_hidden_states) | ||
| # SamplingMetadata here for pruning output in LogitsProcessor, disabled. | ||
| logits = self.model.compute_logits(sample_hidden_states, None) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we re-add the pruning comment here? Just in case it slips though in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, may I ask what is 'pruning comment'? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! Updated |
||
|
|
||
| def sample( | ||
| logits: torch.Tensor, | ||
| sampling_metadata: TPUSupportedSamplingMetadata | ||
| ) -> SamplerOutput: | ||
| sampler_out = self.sampler(logits, sampling_metadata) | ||
| return sampler_out | ||
|
|
||
| # Optimized greedy sampling branch, tracing both paths in a single pass | ||
| # NOTE all_greedy is a scalar, this is just an optimized if/else. | ||
| out_tokens = torch.where(sampling_metadata.all_greedy, | ||
| torch.argmax(logits, dim=-1, keepdim=True), | ||
| self.sample(logits, sampling_metadata)\ | ||
| .sampled_token_ids) | ||
| out_tokens = torch.where( | ||
| sampling_metadata.all_greedy, | ||
| torch.argmax(logits, dim=-1, keepdim=True), | ||
| sample(logits, sampling_metadata).sampled_token_ids) | ||
| return out_tokens | ||
|
|
||
| def compute_logits(self, | ||
| hidden_states: torch.Tensor) -> Optional[torch.Tensor]: | ||
| # SamplingMetadata here for pruning output in LogitsProcessor, disabled | ||
| logits = self.model.compute_logits(hidden_states, None) | ||
| return logits | ||
|
|
||
| def get_multimodal_embeddings(self, *args, **kwargs): | ||
| return self.model.get_multimodal_embeddings(*args, **kwargs) | ||
|
|
||
| def get_input_embeddings(self, *args, **kwargs): | ||
| return self.model.get_input_embeddings(*args, **kwargs) | ||
|
Comment on lines
-980
to
-984
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you tested that multimodal inference still works and these are called correctly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I only ran TPU CI to test the change, multimodal is not tested. Can you provide a script to test the multimodal? I can test it on this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am adding tests in another pr btw There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sg I'll rebase to your PR after :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @NickLucche Thank you for providing the testing cmd for multimodal, I made this PR working for However, the client script There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Main is working for me, let me try your PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still really slow but it works the same on this PR on my side, thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NickLucche Thank you so much for trying! I set up a new conda env on my end and tried as well. I also found it's really slow, was about to ask you lol. |
||
|
|
||
|
|
||
| def _get_padded_number(n: int, multiple: int) -> int: | ||
| return ((n + multiple - 1) // multiple) * multiple | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't mind the wrapper too much I think it was grouping a few related functions nicely. Still, if it benefits performance I am ok with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wrapper was added in V0 for
torch.compile. Now we have the@support_torch_compiledecorator that wraps the model withtorch.compilealready, therefore we don't need the wrapper anymore.