From 69c2f6835d2f6ba38b200ebc9c821863c80d0014 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 3 Mar 2025 23:07:16 +0000 Subject: [PATCH 1/3] TPU multimodal model support for ragged attention Signed-off-by: Michael Goin --- vllm/v1/worker/tpu_model_runner.py | 216 +++++++++++++++++++++++++---- vllm/v1/worker/tpu_worker.py | 2 +- 2 files changed, 191 insertions(+), 27 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 104e5a3dcfc5..bbf230f1939e 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -15,14 +15,18 @@ from vllm.attention.layer import Attention from vllm.config import VllmConfig from vllm.forward_context import get_forward_context, set_forward_context +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, NUM_QUERIES_PER_BLOCK, PallasAttentionBackend, PallasMetadata) +from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput @@ -74,6 +78,8 @@ def __init__( self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs + self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs, + NUM_QUERIES_PER_BLOCK) # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( @@ -84,6 +90,28 @@ def __init__( self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() + # Multi-modal data support + self.input_registry = INPUT_REGISTRY + self.mm_registry = MULTIMODAL_REGISTRY + self.uses_mrope = model_config.uses_mrope + # TODO: Support M-RoPE (e.g, Qwen2-VL) + assert not self.uses_mrope, "TPU does not support M-RoPE yet." + + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=model_config, + scheduler_config=scheduler_config, + ) + self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size + + # Lazy initialization + # self.model: nn.Module # Set after load_model + self.kv_caches: list[torch.Tensor] = [] + # req_id -> (input_id -> encoder_output) + self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + + # Request states. + self.requests: dict[str, CachedRequestState] = {} # Persistent batch. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, @@ -91,18 +119,9 @@ def __init__( max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), + vocab_size=model_config.get_vocab_size(), ) - # Request states. - self.requests: dict[str, CachedRequestState] = {} - - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} - - # KV caches for forward pass - self.kv_caches: list[tuple[torch.Tensor, torch.Tensor]] = [] - # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. # Sometimes the numpy op is faster so we create both. @@ -164,6 +183,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and @@ -177,6 +197,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: if req_index is not None: removed_req_indices.append(req_index) + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests # or running requests that are not scheduled in this step. We remove @@ -426,6 +454,92 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices + def _execute_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + + # Batch the multi-modal inputs. + mm_inputs: list[MultiModalKwargs] = [] + req_input_ids: list[tuple[str, int]] = [] + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + for input_id in encoder_input_ids: + mm_inputs.append(req_state.mm_inputs[input_id]) + req_input_ids.append((req_id, input_id)) + + # Batch mm inputs as much as we can: if a request in the batch has + # multiple modalities or a different modality than the previous one, + # we process it separately to preserve item order. + # FIXME(ywang96): This is a hacky way to deal with multiple modalities + # in the same batch while still being able to benefit from batching + # multimodal inputs. The proper solution should be reordering the + # encoder outputs. + grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) + + encoder_outputs = [] + for grouped_mm_inputs in grouped_mm_inputs_list: + batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) + batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, + device=self.device) + + # Run the encoder. + # `curr_group_outputs` is either of the following: + # 1. A tensor of shape (num_items, feature_size, hidden_size) + # in case feature_size is fixed across all multimodal items. + # 2. A list or tuple (length: num_items) of tensors, each of shape + # (feature_size, hidden_size) in case the feature size is dynamic + # depending on the input multimodal items. + curr_group_outputs = self.model.get_multimodal_embeddings( + **batched_mm_inputs) + + for output in curr_group_outputs: + encoder_outputs.append(output) + + # Cache the encoder outputs. + for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): + if req_id not in self.encoder_cache: + self.encoder_cache[req_id] = {} + self.encoder_cache[req_id][input_id] = output + + def _gather_encoder_outputs( + self, + scheduler_output: "SchedulerOutput", + ) -> list[torch.Tensor]: + encoder_outputs: list[torch.Tensor] = [] + for req_id in self.input_batch.req_ids: + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + req_state = self.requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + mm_positions = req_state.mm_positions + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info["offset"] + num_encoder_tokens = pos_info["length"] + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, + # num_computed_tokens + num_scheduled_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_scheduled_tokens: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens) + assert start_idx < end_idx + assert req_id in self.encoder_cache + assert i in self.encoder_cache[req_id] + encoder_output = self.encoder_cache[req_id][i] + encoder_outputs.append(encoder_output[start_idx:end_idx]) + return encoder_outputs + @torch.no_grad() def execute_model( self, @@ -434,16 +548,42 @@ def execute_model( # Update cached state self._update_states(scheduler_output) + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_encoder(scheduler_output) + encoder_outputs = self._gather_encoder_outputs(scheduler_output) + else: + encoder_outputs = [] + # Prepare inputs attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + if encoder_outputs: + inputs_embeds = self.model.get_input_embeddings( + self.input_ids, encoder_outputs) + else: + inputs_embeds = self.model.get_input_embeddings(self.input_ids) + input_ids = None + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids + inputs_embeds = None + # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( - token_ids=self.input_ids, - position_ids=self.position_ids, + input_ids=input_ids, + positions=self.position_ids, kv_caches=self.kv_caches, + inputs_embeds=inputs_embeds, ) hidden_states = hidden_states[:total_num_scheduled_tokens] num_reqs = self.input_batch.num_reqs @@ -538,14 +678,21 @@ def load_model(self) -> None: fullgraph=True, dynamic=False) - def dummy_run( + def _dummy_run( self, kv_caches, num_tokens: int, ) -> None: - input_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device=self.device) + if self.is_multimodal_model: + input_ids = None + inputs_embeds = torch.zeros((num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) + else: + input_ids = torch.zeros((num_tokens), + dtype=torch.int32, + device=self.device) + inputs_embeds = None position_ids = torch.zeros(num_tokens, dtype=torch.int32, device=self.device) @@ -571,7 +718,10 @@ def dummy_run( num_seqs=num_tokens, ) - torch._dynamo.mark_dynamic(input_ids, 0) + if self.is_multimodal_model: + torch._dynamo.mark_dynamic(inputs_embeds, 0) + else: + torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) @@ -580,7 +730,12 @@ def dummy_run( with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None - self.model(input_ids, position_ids, kv_caches) + self.model( + input_ids=input_ids, + positions=position_ids, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + ) def capture_model(self) -> None: """Compile the model.""" @@ -590,7 +745,7 @@ def capture_model(self) -> None: start = time.perf_counter() num_tokens = 16 while True: - self.dummy_run(self.kv_caches, num_tokens) + self._dummy_run(self.kv_caches, num_tokens) logger.info(" -- num_tokens: %d", num_tokens) xm.mark_step() xm.wait_device_ops() @@ -647,17 +802,20 @@ def __init__(self, model: nn.Module): def forward( self, - token_ids: torch.Tensor, - position_ids: torch.Tensor, + input_ids: torch.Tensor, + positions: torch.Tensor, kv_caches: list[tuple[torch.Tensor, torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. Args: - token_ids: The input token IDs of shape [num_tokens]. - position_ids: The input position IDs of shape [num_tokens]. + input_ids: The input token IDs of shape [batch_size, seq_len]. + positions: The input position IDs of shape [batch_size, seq_len]. kv_caches: The key and value caches. They can be None during the memory profiling at initialization. + inputs_embeds: The input embeddings of shape [batch_size, seq_len, + hidden_size]. It is used for multimodal models. """ # Skip this in memory profiling at initialization. if kv_caches[0][0].numel() > 0: @@ -684,9 +842,9 @@ def forward( assert self.model is not None hidden_states = self.model( - token_ids, - position_ids, - kv_caches, + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, ) return hidden_states @@ -699,6 +857,12 @@ def compute_logits( logits = self.model.compute_logits(hidden_states, sampling_metadata) return logits + def get_multimodal_embeddings(self, *args, **kwargs): + return self.model.get_multimodal_embeddings(*args, **kwargs) + + def get_input_embeddings(self, *args, **kwargs): + return self.model.get_input_embeddings(*args, **kwargs) + def _get_padded_number(n: int, multiple: int) -> int: return ((n + multiple - 1) // multiple) * multiple diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index cbd2fe6edd81..41bcf425a184 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -124,7 +124,7 @@ def determine_available_memory(self) -> int: self.vllm_config.compilation_config.static_forward_context, runner_kv_caches) - self.model_runner.dummy_run( + self.model_runner._dummy_run( runner_kv_caches, num_tokens=self.scheduler_config.max_num_batched_tokens, ) From 71e4c8e8497aa2455ed8aa5755ac905340880545 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 4 Mar 2025 00:21:22 +0000 Subject: [PATCH 2/3] Fix bad comment Signed-off-by: Michael Goin --- vllm/v1/worker/tpu_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index bbf230f1939e..9a0bb39ebc8a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -810,11 +810,11 @@ def forward( """Executes the forward pass of the model and samples the next token. Args: - input_ids: The input token IDs of shape [batch_size, seq_len]. - positions: The input position IDs of shape [batch_size, seq_len]. + input_ids: The input token IDs of shape [num_tokens]. + positions: The input position IDs of shape [num_tokens]. kv_caches: The key and value caches. They can be None during the memory profiling at initialization. - inputs_embeds: The input embeddings of shape [batch_size, seq_len, + inputs_embeds: The input embeddings of shape [num_tokens, hidden_size]. It is used for multimodal models. """ # Skip this in memory profiling at initialization. From 66df269379fe2e1c7eb4239656a8c2a7b94c7a89 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 4 Mar 2025 19:20:06 +0000 Subject: [PATCH 3/3] Fix Signed-off-by: Michael Goin --- vllm/v1/worker/tpu_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 9a0bb39ebc8a..f9a3217fbef3 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -76,8 +76,8 @@ def __init__( self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens - self.max_num_reqs = scheduler_config.max_num_seqs + self.max_num_tokens = _get_padded_number( + scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK) self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs, NUM_QUERIES_PER_BLOCK) @@ -749,7 +749,7 @@ def capture_model(self) -> None: logger.info(" -- num_tokens: %d", num_tokens) xm.mark_step() xm.wait_device_ops() - if num_tokens >= self.scheduler_config.max_num_batched_tokens: + if num_tokens >= self.max_num_tokens: break num_tokens *= 2 end = time.perf_counter()