diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0668e7168b5f..2286724ce4bd 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -15,6 +15,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger @@ -647,11 +648,10 @@ def execute_model( hidden_states = self.model( input_ids=input_ids, positions=self.position_ids, - kv_caches=self.kv_caches, inputs_embeds=inputs_embeds, ) - selected_token_ids = self.model.sample_from_hidden( - hidden_states, tpu_sampling_metadata) + selected_token_ids = self.sample_from_hidden(hidden_states, + tpu_sampling_metadata) # Remove padding on cpu and keep dynamic op outside of xla graph. selected_token_ids = selected_token_ids.cpu()[:num_reqs] @@ -751,17 +751,15 @@ def load_model(self) -> None: "get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model(vllm_config=self.vllm_config) - model = model.eval() + # Sync all pending XLA execution during model initialization and weight + # loading. xm.mark_step() xm.wait_device_ops() - model = ModelWrapperV1(model) - self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=False) + self.model = model + self.sampler = TPUSampler() @torch.no_grad() - def _dummy_run(self, kv_caches, num_tokens: int) -> None: + def _dummy_run(self, num_tokens: int) -> None: if self.is_multimodal_model: input_ids = None inputs_embeds = torch.zeros((num_tokens, self.hidden_size), @@ -812,7 +810,6 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: with set_forward_context(attn_metadata, self.vllm_config, 0): out = self.model(input_ids=input_ids, positions=position_ids, - kv_caches=kv_caches, inputs_embeds=inputs_embeds) self._hidden_states_dtype = out.dtype @@ -824,7 +821,7 @@ def capture_model(self) -> None: start = time.perf_counter() for num_tokens in self.num_tokens_paddings: logger.info(" -- num_tokens: %d", num_tokens) - self._dummy_run(self.kv_caches, num_tokens) + self._dummy_run(num_tokens) xm.mark_step() xm.wait_device_ops() end = time.perf_counter() @@ -855,8 +852,7 @@ def capture_model(self) -> None: from_input_batch(self.input_batch, indices) logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs_to_sample) - out = self.model.sample_from_hidden(dummy_hidden, - sampling_meta) + out = self.sample_from_hidden(dummy_hidden, sampling_meta) out = out.cpu() # Requests can't be more than tokens. But do compile for the # next bigger value in case num_tokens uses bucketed padding. @@ -910,45 +906,17 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.vllm_config.compilation_config.static_forward_context, self.kv_caches) - -class ModelWrapperV1(nn.Module): - - def __init__(self, model: nn.Module): - super().__init__() - self.model = model - self.sampler = TPUSampler() - - def sample( - self, logits: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput: - sampler_out = self.sampler(logits, sampling_metadata) - return sampler_out - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: list[torch.Tensor], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Executes the forward pass of the model. - - Args: - input_ids: The input token IDs of shape [num_tokens]. - positions: The input position IDs of shape [num_tokens]. - kv_caches: The key and value caches. They can be None during the - memory profiling at initialization. - inputs_embeds: The input embeddings of shape [num_tokens, - hidden_size]. It is used for multimodal models. - """ - - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - ) - - return hidden_states + def reset_dynamo_cache(self): + if self.is_multimodal_model: + assert hasattr(self.model, "language_model") + compiled_model = self.model.language_model.model + else: + compiled_model = self.model.model + if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher): + logger.info("Clear dynamo cache and cached dynamo bytecode.") + torch._dynamo.eval_frame.remove_from_cache( + compiled_model.original_code_object) + compiled_model.compiled_codes.clear() def sample_from_hidden( self, @@ -956,33 +924,30 @@ def sample_from_hidden( sampling_metadata: TPUSupportedSamplingMetadata, ) -> torch.Tensor: """ - Sample with xla-friendly function. This function is to be traced - separately from `forward` for lighter compilation overhead. - """ + Sample with xla-friendly function. This function is to be traced + separately for lighter compilation overhead. + """ # Tensor `sample_hidden_states` is of fixed pre-compiled size. sample_hidden_states = \ hidden_states[sampling_metadata.indices_do_sample] - logits = self.compute_logits(sample_hidden_states) + # SamplingMetadata here for pruning output in LogitsProcessor, disabled. + logits = self.model.compute_logits(sample_hidden_states, None) + + def sample( + logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata + ) -> SamplerOutput: + sampler_out = self.sampler(logits, sampling_metadata) + return sampler_out + # Optimized greedy sampling branch, tracing both paths in a single pass # NOTE all_greedy is a scalar, this is just an optimized if/else. - out_tokens = torch.where(sampling_metadata.all_greedy, - torch.argmax(logits, dim=-1, keepdim=True), - self.sample(logits, sampling_metadata)\ - .sampled_token_ids) + out_tokens = torch.where( + sampling_metadata.all_greedy, + torch.argmax(logits, dim=-1, keepdim=True), + sample(logits, sampling_metadata).sampled_token_ids) return out_tokens - def compute_logits(self, - hidden_states: torch.Tensor) -> Optional[torch.Tensor]: - # SamplingMetadata here for pruning output in LogitsProcessor, disabled - logits = self.model.compute_logits(hidden_states, None) - return logits - - def get_multimodal_embeddings(self, *args, **kwargs): - return self.model.get_multimodal_embeddings(*args, **kwargs) - - def get_input_embeddings(self, *args, **kwargs): - return self.model.get_input_embeddings(*args, **kwargs) - def _get_padded_number(n: int, multiple: int) -> int: return ((n + multiple - 1) // multiple) * multiple diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 67902b41b284..73c43969b87b 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -157,13 +157,19 @@ def determine_available_memory(self) -> int: runner_kv_caches) self.model_runner._dummy_run( - runner_kv_caches, - num_tokens=self.scheduler_config.max_num_batched_tokens, - ) + self.scheduler_config.max_num_batched_tokens) # Synchronize before measuring the memory usage. xm.wait_device_ops() + # During the profiling run, the model runs without KV cache. After + # the profiling run, the model always runs with KV cache. Here we clear + # the dynamo cache and cached bytecode to ensure the model always has + # one compiled bytecode. Having one FX graph/cached bytecode per + # compiled model is required for `support_torch_compile` decorator to + # skip dynamo guard. + self.model_runner.reset_dynamo_cache() + # Get the maximum amount of memory used by the model weights and # intermediate activations. m = xm.get_memory_info(self.device)