diff --git a/docs/source/contributing/model/basic.md b/docs/source/contributing/model/basic.md index 180fdd59e9a6..ad31995f76be 100644 --- a/docs/source/contributing/model/basic.md +++ b/docs/source/contributing/model/basic.md @@ -74,8 +74,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> torch.Tensor: ... ``` diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index 14a59953ef48..990eac82d516 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -16,8 +16,6 @@ Further update the model as follows: self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, + pixel_values: torch.Tensor, ) -> SamplerOutput: ``` diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 0d11e8652ce6..0a93f7ce9450 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -644,11 +644,7 @@ def _run_encoder_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward( - reshaped_query, packed_qkv.key, packed_qkv.value, - torch.tensor([], - dtype=torch.float32, - device=packed_qkv.query.device), attn_metadata) + return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value) def _run_decoder_self_attention_test( @@ -682,7 +678,6 @@ def _run_decoder_self_attention_test( & attn_metadata ''' attn = test_rsrcs.attn - kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata, vllm_config): @@ -695,8 +690,7 @@ def _run_decoder_self_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value, - kv_cache, attn_metadata) + return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value) def _run_encoder_decoder_cross_attention_test( @@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test( assert decoder_test_params.packed_qkvo.packed_qkv is not None attn = test_rsrcs.attn - kv_cache = test_rsrcs.kv_cache if cross_test_params is None: key = None value = None @@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, key, value, kv_cache, - attn_metadata) + return attn.forward(reshaped_query, key, value) @pytest.fixture(autouse=True) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index e4df7ffc5885..bd7783cc3981 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -7,7 +7,7 @@ import torch.nn.functional as F import vllm.envs as envs -from vllm.attention import AttentionMetadata, AttentionType +from vllm.attention import AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context @@ -153,15 +153,10 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: - # NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments - # directly, use `self.kv_cache` and - # `get_forward_context().attn_metadata` instead. if self.calculate_kv_scales: - ctx_attn_metadata = get_forward_context().attn_metadata - if ctx_attn_metadata.enable_kv_scales_calculation: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.enable_kv_scales_calculation: self.calc_kv_scales(key, value) if self.use_output: output = torch.empty_like(query) @@ -177,14 +172,14 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() - ctx_attn_metadata = forward_context.attn_metadata + attn_metadata = forward_context.attn_metadata self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, query, key, value, self_kv_cache, - ctx_attn_metadata, + attn_metadata, output=output) else: torch.ops.vllm.unified_attention_with_output( @@ -193,10 +188,10 @@ def forward( else: if self.use_direct_call: forward_context = get_forward_context() - ctx_attn_metadata = forward_context.attn_metadata + attn_metadata = forward_context.attn_metadata self_kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(self, query, key, value, - self_kv_cache, ctx_attn_metadata) + self_kv_cache, attn_metadata) else: return torch.ops.vllm.unified_attention( query, key, value, self.layer_name) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 93c3cc91bb09..156e8752e96c 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -7,6 +7,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -130,14 +131,14 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): ) if use_rms_norm else None def forward_native(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor): pass def forward_cuda(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams): + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) hidden_states, gate = projected_states.chunk(2, dim=-2) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 2bcf50e70713..b53a540ed662 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -14,6 +14,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.forward_context import get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -376,17 +377,16 @@ def __init__(self, eps=rms_norm_eps) def forward_native(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor): pass def forward_cuda( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, sequence_idx: Optional[torch.Tensor] = None, ): + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata seq_len, _ = hidden_states.shape groups_time_state_size = self.n_groups * self.ssm_state_size diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 3e1daa773fc8..23d72d8e60f6 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T: return cls # Lazy import - from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.pooler import PoolingType @@ -201,13 +200,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: list[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = super().forward(input_ids, positions, kv_caches, - attn_metadata, + hidden_states = super().forward(input_ids, positions, intermediate_tensors, inputs_embeds) logits, _ = self.score(hidden_states) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 27df448e63f7..a700e739df77 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -5,7 +5,7 @@ import torch from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -282,13 +282,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -335,16 +333,12 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual_input = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual_input + hidden_states @@ -399,8 +393,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -412,11 +404,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) @@ -457,13 +446,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index bff4100a1dee..656e9b037d96 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -9,7 +9,6 @@ from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.processing_aria import AriaProcessor -from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, QuantizationConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.activation import get_act_fn @@ -626,8 +625,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -643,8 +640,6 @@ def forward( hidden_states = self.language_model( input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 2e51b9c9c0c7..4fb68e7b48da 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -20,13 +20,13 @@ # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -182,14 +182,12 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) if self.postion_embedding != "ALIBI": q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -232,8 +230,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -246,8 +242,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -301,8 +295,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -316,13 +308,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -379,13 +368,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 22ae1775c3d9..69da05884ded 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -1,17 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 """Inference-only Bamba model.""" # Added by the IBM Team, 2024 -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers import BambaConfig -from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -107,7 +107,6 @@ def __init__(self, def forward( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, sequence_idx: Optional[torch.Tensor] = None, @@ -120,8 +119,8 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, attn_metadata, - mamba_cache_params, sequence_idx) + hidden_states = self.mamba(hidden_states, mamba_cache_params, + sequence_idx) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -215,15 +214,13 @@ def self_attention( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -231,8 +228,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], **kwargs, ): @@ -246,8 +241,6 @@ def forward( hidden_states = self.self_attention( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( @@ -312,8 +305,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -323,6 +314,7 @@ def forward( # proper continuous batching computation including # chunked prefill seq_idx = None + attn_metadata = get_forward_context().attn_metadata if attn_metadata.num_prefills > 0: seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) for i, (srt, end) in enumerate( @@ -348,9 +340,7 @@ def forward( num_attn = 0 for i in range(len(self.layers)): layer = self.layers[i] - kv_cache = None if isinstance(layer, BambaAttentionDecoderLayer): - kv_cache = kv_caches[num_attn] num_attn += 1 layer_mamba_cache_params = None @@ -361,8 +351,6 @@ def forward( hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, residual=residual, mamba_cache_params=layer_mamba_cache_params, sequence_idx=seq_idx, @@ -440,8 +428,6 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): @@ -454,8 +440,7 @@ def forward(self, self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, *self._get_mamba_cache_shape()) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params, + hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 204c48d0d896..5d2a8cdcb97d 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -19,14 +19,14 @@ # limitations under the License. """PyTorch BART model.""" import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import BartConfig from transformers.utils import logging -from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention import Attention, AttentionType from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -181,14 +181,13 @@ def __init__( prefix=f"{prefix}.attn", attn_type=AttentionType.ENCODER) - def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Input shape: Batch x Time x Channel""" qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -261,14 +260,13 @@ def __init__( prefix=f"{prefix}.attn", attn_type=AttentionType.DECODER) - def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Input shape: Batch x Time x Channel""" qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -344,8 +342,6 @@ def __init__( def forward( self, decoder_hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" @@ -363,7 +359,7 @@ def forward( _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -411,23 +407,16 @@ def __init__( self.final_layer_norm = nn.LayerNorm(self.embed_dim) - def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: r""" Args: hidden_states torch.Tensor of *encoder* input embeddings. - kv_cache: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Encoder layer output torch.Tensor """ residual = hidden_states - hidden_states = self.self_attn(hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -509,18 +498,12 @@ def __init__( def forward( self, decoder_hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Args: decoder_hidden_states torch.Tensor of *decoder* input embeddings. - kv_cache: - KV cache tensor - attn_metadata: - vLLM Attention metadata structure encoder_hidden_states torch.Tensor of *encoder* input embeddings. Returns: @@ -529,9 +512,7 @@ def forward( residual = decoder_hidden_states # Self Attention - hidden_states = self.self_attn(hidden_states=decoder_hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) + hidden_states = self.self_attn(hidden_states=decoder_hidden_states) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -542,8 +523,6 @@ def forward( hidden_states = self.encoder_attn( decoder_hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, encoder_hidden_states=encoder_hidden_states, ) @@ -609,9 +588,8 @@ def __init__(self, self.layernorm_embedding = nn.LayerNorm(embed_dim) - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> torch.Tensor: + def forward(self, input_ids: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: r""" Args: input_ids @@ -620,10 +598,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, provide it. positions Positions of *encoder* input sequence tokens. - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Decoder output torch.Tensor """ @@ -636,12 +610,8 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) - for idx, encoder_layer in enumerate(self.layers): - hidden_states = encoder_layer( - hidden_states=hidden_states, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, - ) + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states=hidden_states) return hidden_states @@ -693,9 +663,7 @@ def __init__( def forward(self, decoder_input_ids: torch.Tensor, decoder_positions: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> torch.Tensor: + encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor: r""" Args: decoder_input_ids @@ -706,10 +674,6 @@ def forward(self, decoder_input_ids: torch.Tensor, Positions of *decoder* input sequence tokens. encoder_hidden_states: Tensor of encoder output embeddings - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Decoder output torch.Tensor """ @@ -725,11 +689,9 @@ def forward(self, decoder_input_ids: torch.Tensor, # decoder layers - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: hidden_states = decoder_layer( decoder_hidden_states=hidden_states, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, encoder_hidden_states=encoder_hidden_states, ) @@ -768,8 +730,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> torch.Tensor: + encoder_positions: torch.Tensor) -> torch.Tensor: r""" Args: input_ids @@ -782,10 +743,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Indices of *encoder* input sequence tokens in the vocabulary. encoder_positions: Positions of *encoder* input sequence tokens. - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Model output torch.Tensor """ @@ -796,18 +753,14 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata) + positions=encoder_positions) # decoder outputs consists of # (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( decoder_input_ids=input_ids, decoder_positions=positions, - encoder_hidden_states=encoder_hidden_states, - kv_caches=kv_caches, - attn_metadata=attn_metadata) + encoder_hidden_states=encoder_hidden_states) return decoder_outputs @@ -845,8 +798,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, *, encoder_input_ids: torch.Tensor, @@ -863,15 +814,11 @@ def forward( torch.Tensor of *encoder* input token ids. encoder_positions torch.Tensor of *encoder* position indices - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Output torch.Tensor """ return self.model(input_ids, positions, encoder_input_ids, - encoder_positions, kv_caches, attn_metadata) + encoder_positions) def compute_logits( self, diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 4d0f5ac8ea5d..4ff69527653d 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers import BertConfig -from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -113,12 +114,9 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): def forward( self, hidden_states: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> torch.Tensor: - for i in range(len(self.layer)): - layer = self.layer[i] - hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + for layer in self.layer: + hidden_states = layer(hidden_states) return hidden_states @@ -152,13 +150,8 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.output") - def forward( - self, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata, - ): - attn_output = self.attention(hidden_states, kv_cache, attn_metadata) + def forward(self, hidden_states: torch.Tensor): + attn_output = self.attention(hidden_states) intermediate_output = self.intermediate(attn_output) output = self.output(intermediate_output, attn_output) return output @@ -191,10 +184,8 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: - self_output = self.self(hidden_states, kv_cache, attn_metadata) + self_output = self.self(hidden_states) return self.output(self_output, hidden_states) @@ -246,12 +237,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - output = self.attn(q, k, v, kv_cache, attn_metadata) + output = self.attn(q, k, v) return output @@ -343,8 +332,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, @@ -352,13 +339,14 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: + attn_metadata = get_forward_context().attn_metadata assert hasattr(attn_metadata, "seq_lens_tensor") hidden_states = self.embeddings( input_ids=input_ids, seq_lens=attn_metadata.seq_lens_tensor, position_ids=position_ids, token_type_ids=token_type_ids) - return self.encoder(hidden_states, kv_caches, attn_metadata) + return self.encoder(hidden_states) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -420,17 +408,13 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(input_ids=input_ids, position_ids=positions, - kv_caches=kv_caches, inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors, - attn_metadata=attn_metadata) + intermediate_tensors=intermediate_tensors) def pooler( self, @@ -519,16 +503,12 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.bert(input_ids=input_ids, position_ids=positions, - kv_caches=kv_caches, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, - attn_metadata=attn_metadata, token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 0463a0b97d40..23bb3cd07f1d 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, +from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -9,7 +9,6 @@ from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, apply_chunking_to_forward) -from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig @@ -658,8 +657,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -708,8 +705,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 229677ae7d98..84b79613abc4 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -18,13 +18,13 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import BloomConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -126,13 +126,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: del position_ids # Unused. qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.dense(attn_output) return output @@ -193,8 +191,6 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) @@ -209,8 +205,6 @@ def forward( attention_output = self.self_attention( position_ids=position_ids, hidden_states=layernorm_output, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) attention_output = attention_output + residual layernorm_output = self.post_attention_layernorm(attention_output) @@ -266,8 +260,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -279,14 +271,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer( - position_ids, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) @@ -322,14 +308,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 2d4dfab60730..e91399b2674d 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import cached_property -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, +from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -10,7 +10,7 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, ChameleonVQVAEConfig) -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -310,15 +310,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -372,8 +370,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -386,8 +382,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -447,8 +441,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -456,8 +448,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = self.input_layernorm(hidden_states) @@ -906,8 +896,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -921,13 +909,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -1028,8 +1013,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, @@ -1048,8 +1031,6 @@ def forward( hidden_states = self.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index ecf417655452..6eca25212ee6 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -2,13 +2,13 @@ # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from torch.nn import LayerNorm -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul @@ -108,19 +108,11 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(position_ids, q, k) - context_layer = self.attn( - q, - k, - v, - kv_cache, - attn_metadata, - ) + context_layer = self.attn(q, k, v) attn_output, _ = self.dense(context_layer) return attn_output @@ -215,8 +207,6 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # hidden_states: [num_tokens, h] # Layer norm at the beginning of the transformer layer. @@ -225,8 +215,6 @@ def forward( attention_output = self.self_attention( hidden_states=layernorm_output, position_ids=position_ids, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Residual connection. @@ -289,17 +277,10 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> Union[torch.Tensor, IntermediateTensors]: - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - hidden_states=hidden_states, - position_ids=position_ids, - kv_cache=kv_caches[i - self.start_layer], - attn_metadata=attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states=hidden_states, + position_ids=position_ids) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -350,8 +331,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -369,8 +348,6 @@ def forward( hidden_states = self.encoder( hidden_states=hidden_states, position_ids=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, ) return hidden_states @@ -494,12 +471,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 0ceefc3e93aa..b0cb4a62333a 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -21,14 +21,14 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers import CohereConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -218,8 +218,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -227,7 +225,7 @@ def forward( q, k = self._apply_qk_norm(q, k) if self.v1 or self.sliding_window: q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -255,8 +253,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -265,8 +261,6 @@ def forward( hidden_states_attention = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states_mlp = self.mlp(hidden_states) # Add everything together @@ -311,8 +305,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -326,13 +318,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -389,13 +378,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index bb3f4f40dd21..7830dd4ce2ec 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch import torch.nn as nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -230,15 +230,13 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.Wqkv(hidden_states) if self.clip_qkv is not None: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(position_ids, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) hidden_states, _ = self.out_proj(attn_output) return hidden_states @@ -265,16 +263,12 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.norm_1(hidden_states) x = self.attn( position_ids=position_ids, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + x residual = hidden_states @@ -303,14 +297,10 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states, residual = self.norm_attn_norm( position_ids=position_ids, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = self.ffn(hidden_states) hidden_states = hidden_states + residual @@ -353,8 +343,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -366,14 +354,8 @@ def forward( else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - block = self.blocks[i] - hidden_states = block( - position_ids, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for block in self.blocks[self.start_layer:self.end_layer]: + hidden_states = block(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm_f(hidden_states) @@ -415,14 +397,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 9599e1df6a3c..c04e7a02bae2 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -248,13 +248,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -309,8 +307,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -323,8 +319,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -370,8 +364,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -384,11 +376,8 @@ def forward( else: hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -425,13 +414,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 1a051992a306..cac1b2b3b11c 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -69,8 +68,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, spec_step_index: int = 0, @@ -88,8 +85,6 @@ def forward( hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, residual=None) hidden_states = residual + hidden_states return self.shared_head(hidden_states) @@ -122,8 +117,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, @@ -131,8 +124,6 @@ def forward( return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( input_ids, positions, - kv_caches[spec_step_idx], - attn_metadata, previous_hidden_states, inputs_embeds, spec_step_idx, @@ -165,16 +156,14 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, previous_hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model(input_ids, positions, + previous_hidden_states, inputs_embeds, + spec_step_idx) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a4d52c613b3e..65f3bccbd77d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, @@ -283,8 +283,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] @@ -317,7 +315,7 @@ def forward( v = torch.nn.functional.pad( v, [0, self.qk_head_dim - self.v_head_dim], value=0).view(-1, self.num_local_heads * self.qk_head_dim) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output = attn_output.view( -1, self.num_local_heads, self.qk_head_dim)[..., :self.v_head_dim].reshape( @@ -455,8 +453,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: if self.q_lora_rank is not None: ckq = self.q_a_proj(hidden_states)[0] @@ -466,8 +462,7 @@ def forward( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, - attn_metadata) + return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe) class DeepseekV2DecoderLayer(nn.Module): @@ -536,8 +531,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -550,8 +543,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -612,8 +603,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -628,11 +617,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -669,13 +655,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 5f684fa295ad..4e2dda33bcab 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -13,7 +13,6 @@ from einops import rearrange, repeat from transformers import BatchFeature -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -595,8 +594,6 @@ def get_input_embeddings( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object): @@ -614,8 +611,6 @@ def forward(self, hidden_states = self.language_model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index ab3f0dc07f4d..f2a2935e6c69 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -121,8 +120,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -140,8 +137,6 @@ def forward( input_ids=None, inputs_embeds=inputs_embeds, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, ) return hidden_states diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index e795c7e288c4..79939f6f40e4 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -24,12 +24,12 @@ # limitations under the License. """Inference-only Exaone model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -179,13 +179,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -225,14 +223,10 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: return self.attention( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) @@ -288,8 +282,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -301,8 +293,6 @@ def forward( hidden_states = self.attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -365,8 +355,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -381,13 +369,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] + for layer in self.h[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) @@ -471,14 +456,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + model_output = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 01b66a1c2a5f..7154ac2e6a5a 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -20,14 +20,14 @@ """PyTorch Falcon model.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from torch.nn import LayerNorm from transformers import FalconConfig as HF_FalconConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -190,8 +190,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, bias = self.query_key_value(hidden_states) if bias is not None: @@ -199,7 +197,7 @@ def forward( q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.use_rotary: q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output, bias = self.dense(attn_output) return attn_output, bias @@ -291,8 +289,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states @@ -306,8 +302,6 @@ def forward( attention_output, attention_bias = self.self_attention( positions=positions, hidden_states=attention_layernorm_out, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) if self.reduce_row_parallel_results and attention_bias is not None: attention_output += attention_bias @@ -384,8 +378,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -396,14 +388,8 @@ def forward( hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) @@ -450,14 +436,11 @@ def forward( self, input_ids: torch.LongTensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 4a1ad5f4ee0c..06912bcfdc8a 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch import torch.nn as nn -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -50,8 +49,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> torch.Tensor: + encoder_positions: torch.Tensor) -> torch.Tensor: r""" Args: input_ids @@ -64,10 +62,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Indices of *encoder* input sequence tokens in the vocabulary. encoder_positions: Positions of *encoder* input sequence tokens. - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Model output torch.Tensor """ @@ -78,18 +72,14 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata) + positions=encoder_positions) # decoder outputs consists of # (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( decoder_input_ids=input_ids, decoder_positions=positions, - encoder_hidden_states=encoder_hidden_states, - kv_caches=kv_caches, - attn_metadata=attn_metadata) + encoder_hidden_states=encoder_hidden_states) return decoder_outputs @@ -122,8 +112,6 @@ def forward( positions: torch.Tensor, encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: r""" @@ -136,15 +124,11 @@ def forward( torch.Tensor of *encoder* input token ids. encoder_positions torch.Tensor of *encoder* position indices - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Output torch.Tensor """ return self.model(input_ids, positions, encoder_input_ids, - encoder_positions, kv_caches, attn_metadata) + encoder_positions) def compute_logits( self, @@ -213,8 +197,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, *, encoder_input_ids: torch.Tensor, @@ -231,15 +213,11 @@ def forward( torch.Tensor of *encoder* input token ids. encoder_positions torch.Tensor of *encoder* position indices - kv_caches: - Layer-wise list of KV cache tensors - attn_metadata: - vLLM Attention metadata structure Returns: Output torch.Tensor """ return self.language_model(input_ids, positions, encoder_input_ids, - encoder_positions, kv_caches, attn_metadata) + encoder_positions) def compute_logits( self, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 42a6aa979427..4f5519f325e0 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -25,7 +25,6 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, FuyuProcessor) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.sampler import SamplerOutput @@ -351,8 +350,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -371,8 +368,6 @@ def forward( hidden_states = self.language_model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index d0589e60a72b..da17646c540f 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -16,13 +16,13 @@ # limitations under the License. """Inference-only Gemma model compatible with HuggingFace weights.""" from functools import cache -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GemmaConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -183,13 +183,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -233,8 +231,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -247,8 +243,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -298,8 +292,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -313,13 +305,10 @@ def forward( else: hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -370,13 +359,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 6ee257d65c50..cf744fc2b9d1 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -15,13 +15,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import Gemma2Config -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -164,13 +164,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -220,8 +218,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: if residual is None: @@ -233,8 +229,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -284,8 +278,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -300,13 +292,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -415,13 +404,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 8fc5a797f824..48543c5642ea 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -4,7 +4,7 @@ # https://github.com/THUDM/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" from argparse import Namespace -from typing import List, Literal, Mapping, Optional, TypedDict, Union +from typing import Literal, Mapping, Optional, TypedDict, Union import torch from torch import nn @@ -15,7 +15,6 @@ from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput -from vllm.attention import AttentionMetadata from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -628,8 +627,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -645,8 +642,7 @@ def forward( vision_embeddings) input_ids = None - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 7ad9a24dcbbc..776c03f652bd 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -18,13 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GPT2Config -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed.parallel_state import ( @@ -92,12 +92,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -164,16 +162,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output = self.attn( - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) + attn_output = self.attn(hidden_states=hidden_states) # residual connection hidden_states = attn_output + residual @@ -222,8 +214,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor], ) -> Union[torch.Tensor, IntermediateTensors]: @@ -236,11 +226,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer(hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -279,14 +266,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 799edff46ea3..43f3d4f6dc9c 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -19,13 +19,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GPTBigCodeConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -101,8 +101,6 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.split( @@ -112,7 +110,7 @@ def forward( ], dim=-1, ) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -173,16 +171,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output = self.attn( - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) + attn_output = self.attn(hidden_states=hidden_states, ) # residual connection hidden_states = attn_output + residual @@ -234,8 +226,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -246,11 +236,8 @@ def forward( else: hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer(hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -302,14 +289,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 815aba145d30..752aec0b223d 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -17,13 +17,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GPTJConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -104,13 +104,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output, _ = self.out_proj(attn_output) return attn_output @@ -167,16 +165,12 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_output = self.attn( position_ids=position_ids, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) mlp_output = self.mlp(hidden_states) hidden_states = attn_output + mlp_output + residual @@ -217,8 +211,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -229,14 +221,8 @@ def forward( hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer( - position_ids, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) @@ -273,14 +259,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 550ca3f7ca9e..4b30c7bb3035 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -17,13 +17,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GPTNeoXConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -104,13 +104,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.dense(attn_output) return output @@ -167,15 +165,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: attn_input = self.input_layernorm(hidden_states) attn_output = self.attention( position_ids=position_ids, hidden_states=attn_input, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) if self.use_parallel_residual: @@ -230,8 +224,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -242,14 +234,8 @@ def forward( hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - position_ids, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layer_norm(hidden_states) @@ -285,14 +271,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.gpt_neox(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 2aeb179ee932..201e15d3a30f 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM Granite model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import GraniteConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -166,13 +166,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -233,8 +231,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -242,8 +238,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states * self.residual_multiplier # Fully Connected @@ -300,8 +294,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -318,14 +310,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -405,13 +391,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 40df9c72c561..9b56874a8add 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GraniteMoe model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers.models.granitemoe import GraniteMoeConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -173,13 +173,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -226,8 +224,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -235,8 +231,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states * self.residual_multiplier residual = hidden_states @@ -287,8 +281,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -303,11 +295,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -377,13 +366,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 0f3a2ffe9a13..a20328289f92 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -1,15 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 from array import array -from typing import List, Optional, Union +from typing import Optional, Union import torch import torch.nn as nn from xformers.ops.fmha.attn_bias import BlockDiagonalMask -from vllm.attention import AttentionMetadata from vllm.attention.backends.xformers import XFormersImpl from vllm.config import ModelConfig, VllmConfig +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.pooler import PoolerHead from vllm.model_executor.models.llama import LlamaForCausalLM @@ -217,13 +217,12 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: # Change attention to non-causal for pooling tasks. if self.runner_type == "pooling": + attn_metadata = get_forward_context().attn_metadata assert attn_metadata.prefill_metadata.attn_bias is None attn_metadata.prefill_metadata.attn_bias = [ BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) @@ -232,8 +231,6 @@ def forward( return super().forward( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, **kwargs, ) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 3a7e2a9a6a57..0a8763cf910c 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -25,7 +25,6 @@ from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, Idefics3Processor) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear @@ -563,8 +562,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -572,8 +569,6 @@ def forward( hidden_states = self.text_model( input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds, ) @@ -645,8 +640,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -664,8 +657,6 @@ def forward( hidden_states = self.model.text_model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index c5f7be135d71..22c9287509ed 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union, - overload, runtime_checkable) +from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload, + runtime_checkable) import torch import torch.nn as nn @@ -11,7 +11,6 @@ from vllm.utils import supports_kw if TYPE_CHECKING: - from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import PoolerOutput from vllm.model_executor.layers.sampler import SamplerOutput @@ -46,8 +45,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: "AttentionMetadata", ) -> T_co: ... @@ -62,7 +59,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: if not callable(model_forward): return False - vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata") + vllm_kws = ("input_ids", "positions") missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw)) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index b21933dd5da7..41ca399b9efb 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -175,13 +175,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.wqkv(hidden_states) q, k, v = self.split_qkv(qkv) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.wo(attn_output) return output @@ -227,8 +225,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -241,8 +237,6 @@ def forward( hidden_states = self.attention( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -290,8 +284,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -305,15 +297,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - residual, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -363,13 +348,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states @@ -466,13 +448,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) logits, _ = self.v_head(hidden_states) return logits diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index 106c3b6b78cc..69b0caab8f8e 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm @@ -65,8 +64,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], visual_token_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -80,8 +77,6 @@ def forward( hidden_states = self.attention( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -113,8 +108,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, visual_token_mask: Optional[torch.Tensor] = None, @@ -129,13 +122,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, visual_token_mask=visual_token_mask, ) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 4a6007876776..52ddb279cca3 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -17,7 +17,6 @@ from PIL import Image from transformers import BatchFeature, PretrainedConfig, TensorType -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig @@ -929,8 +928,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -951,8 +948,6 @@ def forward( forward_kwargs = { "input_ids": input_ids, "positions": positions, - "kv_caches": kv_caches, - "attn_metadata": attn_metadata, "intermediate_tensors": intermediate_tensors, "inputs_embeds": inputs_embeds, } diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 72bcef5e2282..78fe6588eddc 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -21,12 +21,12 @@ """Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -123,12 +123,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -200,16 +198,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output = self.attn( - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) + attn_output = self.attn(hidden_states=hidden_states, ) # residual connection hidden_states = attn_output + residual @@ -266,8 +258,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: @@ -285,11 +275,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] - hidden_states = layer(hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -332,14 +319,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5530e3ca708c..14e56df6cadf 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 """Inference-only Jamba model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers import JambaConfig -from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -138,7 +137,6 @@ def __init__(self, def forward( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, **kwargs, @@ -150,8 +148,7 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, attn_metadata, - mamba_cache_params) + hidden_states = self.mamba(hidden_states, mamba_cache_params) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -223,13 +220,11 @@ def self_attention( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -237,8 +232,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], **kwargs, ): @@ -252,8 +245,6 @@ def forward( hidden_states = self.self_attention( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( @@ -320,8 +311,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -339,12 +328,9 @@ def forward( kv_cache_index = 0 mamba_cache_index = 0 - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - kv_cache = None + for layer in self.layers[self.start_layer:self.end_layer]: layer_mamba_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): - kv_cache = kv_caches[kv_cache_index] kv_cache_index += 1 if isinstance(layer, JambaMambaDecoderLayer): current_state_layer = mamba_cache_index @@ -355,8 +341,6 @@ def forward( hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, residual=residual, mamba_cache_params=layer_mamba_cache_params) if not get_pp_group().is_last_rank: @@ -429,8 +413,6 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): @@ -443,8 +425,7 @@ def forward(self, mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params, + hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 011d0a7aafaa..a0aff9e609d9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union import torch from torch import nn from transformers import LlamaConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -197,13 +197,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -268,8 +266,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -280,9 +276,7 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) + hidden_states=hidden_states) # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -347,8 +341,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -363,11 +355,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -535,13 +524,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 19752ba703f4..72b1591306f2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -15,7 +15,6 @@ from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.activation import get_act_fn @@ -658,8 +657,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -712,8 +709,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c39daec709fc..6a050d7798a2 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -12,7 +12,6 @@ get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -508,8 +507,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -571,8 +568,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 2af3cc05080a..807d6977ed40 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -10,7 +10,6 @@ from transformers import (BatchFeature, LlavaNextVideoConfig, LlavaNextVideoProcessor) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -443,8 +442,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -468,8 +465,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 8eb8071e6577..e57eea4286e9 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -13,7 +13,6 @@ get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -922,8 +921,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -955,8 +952,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ba88950ee898..9f1cd8c29a5a 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers import MambaConfig -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group @@ -64,7 +63,6 @@ def __init__(self, def forward( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, **kwargs, @@ -75,8 +73,7 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, attn_metadata, - mamba_cache_params) + hidden_states = self.mixer(hidden_states, mamba_cache_params) return hidden_states, residual @@ -125,7 +122,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -146,7 +142,6 @@ def forward( hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, - attn_metadata=attn_metadata, residual=residual, mamba_cache_params=mamba_cache_params.at_layer_idx( i - self.start_layer)) @@ -208,8 +203,6 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): @@ -222,9 +215,8 @@ def forward(self, mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - hidden_states = self.backbone(input_ids, positions, attn_metadata, - mamba_cache_params, intermediate_tensors, - inputs_embeds) + hidden_states = self.backbone(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 6366fc023682..266cdc243ac4 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """PyTorch MAMBA2 model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn @@ -10,6 +10,7 @@ from vllm.config import VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import ( @@ -63,7 +64,6 @@ def __init__(self, def forward( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, sequence_idx: Optional[torch.Tensor], @@ -75,8 +75,8 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, attn_metadata, - mamba_cache_params, sequence_idx) + hidden_states = self.mixer(hidden_states, mamba_cache_params, + sequence_idx) return hidden_states, residual @@ -122,7 +122,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -142,6 +141,7 @@ def forward( # proper continuous batching computation including # chunked prefill seq_idx = None + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata if attn_metadata.num_prefills > 0: seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) for i, (srt, end) in enumerate( @@ -158,7 +158,6 @@ def forward( hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, - attn_metadata=attn_metadata, residual=residual, mamba_cache_params=mamba_cache_params.at_layer_idx( i - self.start_layer), @@ -224,8 +223,6 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): @@ -238,9 +235,8 @@ def forward(self, mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - hidden_states = self.backbone(input_ids, positions, attn_metadata, - mamba_cache_params, intermediate_tensors, - inputs_embeds) + hidden_states = self.backbone(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 52ab89488785..bae756910d1a 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -23,13 +23,13 @@ # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -256,8 +256,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -265,7 +263,7 @@ def forward( q, k = q.float(), k.float() q, k = self.rotary_emb(positions, q, k) q, k = q.to(orig_dtype), k.to(orig_dtype) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -330,8 +328,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -340,8 +336,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states * \ (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) @@ -408,8 +402,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -423,13 +415,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -578,13 +567,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index b85306c40880..1b24c38cef1b 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -29,7 +29,7 @@ from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm @@ -129,8 +129,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: q, _ = self.q_a_proj(hidden_states) q = self.q_a_layernorm(q) @@ -170,7 +168,7 @@ def forward( v, [0, self.qk_head_dim - self.v_head_dim], value=0).view(-1, self.num_local_heads * self.qk_head_dim) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) attn_output = attn_output.view( -1, self.num_local_heads, self.qk_head_dim)[..., :self.v_head_dim].reshape( diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index aa8c193ed6a5..e354e5323327 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -33,7 +33,6 @@ from transformers.models.whisper.modeling_whisper import ( ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import MultiModalFieldConfig @@ -792,8 +791,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: Any, ) -> torch.Tensor: @@ -818,8 +815,6 @@ def forward( output = self.llm.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=vlm_embeddings, ) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 1f278b65740c..de67ac1af983 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -37,7 +37,6 @@ from transformers import BatchFeature, PretrainedConfig from typing_extensions import TypeVar -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, @@ -1029,8 +1028,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: Any, ) -> torch.Tensor: @@ -1050,8 +1047,6 @@ def forward( output = self.llm.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=vlm_embeddings, ) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index b83b69fd2c2d..c8dea557e571 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import MixtralConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -175,13 +175,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -224,8 +222,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -238,8 +234,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -291,8 +285,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -306,11 +298,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -377,13 +366,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index fdc438917542..21b52d9f54c7 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import numpy as np import torch @@ -30,7 +30,7 @@ from torch import nn from transformers import MixtralConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -229,13 +229,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -274,8 +272,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -288,8 +284,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -333,8 +327,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -348,11 +340,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -390,13 +379,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 1f8f5b2eb136..459928fe3fb0 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -38,7 +38,8 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.selector import _Backend from vllm.config import VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tp_group +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -416,11 +417,11 @@ def __init__(self, prefix: str = ""): super().__init__() - model_parallel_size = get_tensor_model_parallel_world_size() + tensor_parallel_size = get_tp_group().world_size self.embed_dim = config.hidden_size self.num_heads = config.attention_heads self.head_dim = config.hidden_size // config.attention_heads - self.num_local_heads = self.num_heads // model_parallel_size + self.num_local_heads = self.num_heads // tensor_parallel_size self.q_size = self.num_local_heads * self.head_dim self.kv_size = self.num_local_heads * self.head_dim @@ -771,12 +772,13 @@ def __init__( ): super().__init__() self.config = config - self.model_parallel_size = get_tensor_model_parallel_world_size() + self.pipeline_parallel_rank = get_pp_group().rank_in_group + self.tensor_parallel_size = get_tp_group().world_size self.num_heads = self.config.num_attention_heads - self.num_local_heads = self.num_heads // self.model_parallel_size + self.num_local_heads = self.num_heads // self.tensor_parallel_size self.num_key_value_heads = self.config.num_key_value_heads self.num_local_key_value_heads = \ - self.num_key_value_heads // self.model_parallel_size + self.num_key_value_heads // self.tensor_parallel_size self.dropout = config.dropout self.hidden_size = config.hidden_size self.head_dim = config.hidden_size // self.num_heads @@ -824,8 +826,6 @@ def forward( attention_mask: Optional[torch.Tensor], kv_range_for_decode: Optional[List[Tuple[int, int]]], cross_attention_states: Optional[torch.Tensor], - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv_dec, _ = self.qkv_proj(hidden_states) q, _, _ = qkv_dec.split( @@ -846,14 +846,11 @@ def forward( q = self.q_norm(q) if attention_mask is not None: - output = self._attention_with_mask(q, k, v, kv_cache, - attention_mask, - kv_range_for_decode, - attn_metadata) + output = self._attention_with_mask(q, k, v, attention_mask, + kv_range_for_decode) else: output = self.attn( - q.view(-1, self.num_local_heads * self.head_dim), k, v, - kv_cache, attn_metadata) + q.view(-1, self.num_local_heads * self.head_dim), k, v) out, _ = self.o_proj(output) return out @@ -862,11 +859,11 @@ def _attention_with_mask( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - kv_cache: torch.Tensor, attention_mask: torch.Tensor, kv_range_for_decode: List[Tuple[int, int]], - attn_metadata: AttentionMetadata, ) -> torch.Tensor: + kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank] + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: i = torch.ones(1, dtype=torch.float32) @@ -978,8 +975,6 @@ def forward( cross_attention_mask: torch.Tensor, kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: torch.Tensor, - kv_cache: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -989,8 +984,6 @@ def forward( attention_mask=cross_attention_mask, kv_range_for_decode=kv_range_for_decode, cross_attention_states=cross_attention_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_attn_gate.tanh( @@ -1054,14 +1047,12 @@ def forward( kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, skip_cross_attention: bool, ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): if not skip_cross_attention: hidden_states = decoder_layer( @@ -1071,15 +1062,11 @@ def forward( kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask= full_text_row_masked_out_mask, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, ) elif isinstance(decoder_layer, LlamaDecoderLayer): hidden_states, residual = decoder_layer( positions=positions, hidden_states=hidden_states, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, residual=None, ) hidden_states = hidden_states + residual @@ -1124,8 +1111,6 @@ def forward( kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, skip_cross_attention: bool, ) -> torch.Tensor: hidden_states = self.model( @@ -1135,8 +1120,6 @@ def forward( cross_attention_mask=cross_attention_mask, kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask=full_text_row_masked_out_mask, - kv_caches=kv_caches, - attn_metadata=attn_metadata, skip_cross_attention=skip_cross_attention, ) return hidden_states @@ -1353,10 +1336,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, **kwargs: object, ) -> Union[Tuple, CausalLMOutputWithPast]: + attn_metadata = get_forward_context().attn_metadata if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: raise ValueError("Chunk prefill not supported") @@ -1410,8 +1392,6 @@ def forward( cross_attention_mask=cross_attention_mask, kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask=full_text_row_masked_out_mask, - kv_caches=kv_caches, - attn_metadata=attn_metadata, skip_cross_attention=skip_cross_attention, ) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 6ce9fbda182f..cc4d38d8740b 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -16,7 +16,7 @@ from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.attention.layer import MultiHeadAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -460,15 +460,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.q_norm is not None and self.k_norm is not None: q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -580,8 +578,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Self Attention @@ -594,8 +590,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states, residual = self.post_attention_layernorm( @@ -610,8 +604,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Self Attention @@ -619,8 +611,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = self.input_layernorm(hidden_states) @@ -841,8 +831,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -858,13 +846,10 @@ def forward( residual = intermediate_tensors["residual"] # Apply blocks one-by-one. - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -1643,8 +1628,6 @@ def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -1663,8 +1646,6 @@ def forward( hidden_states = self.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 676c960623ed..d716818f31c0 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -2,12 +2,12 @@ # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch import torch.nn as nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -125,8 +125,6 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: del position_ids # unused. qkv, _ = self.Wqkv(hidden_states) @@ -136,7 +134,7 @@ def forward( if self.qk_ln: q = self.q_ln(q) k = self.k_ln(k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -196,15 +194,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: x = self.norm_1(hidden_states) x = self.attn( position_ids=position_ids, hidden_states=x, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = hidden_states + x x = self.norm_2(hidden_states) @@ -253,8 +247,6 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -267,14 +259,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - block = self.blocks[i] - hidden_states = block( - position_ids, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for block in self.blocks[self.start_layer:self.end_layer]: + hidden_states = block(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm_f(hidden_states) @@ -306,14 +292,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index a42734edb39a..3b86b91465ca 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -27,7 +27,7 @@ import torch from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -204,13 +204,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -269,8 +267,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -283,8 +279,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -343,8 +337,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -359,15 +351,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - residual, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -444,13 +429,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 3b470dfdd05b..4a341c97d6cd 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import OlmoConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -119,15 +119,13 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) if self.clip_qkv is not None: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -212,14 +210,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Attention block. residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn(positions, hidden_states, kv_cache, - attn_metadata) + hidden_states = self.self_attn(positions, hidden_states) hidden_states = hidden_states + residual # MLP block. @@ -263,8 +258,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -281,14 +274,9 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] # Apply blocks one-by-one. - for i in range(self.start_layer, self.end_layer): + for layer in self.layers[self.start_layer:self.end_layer]: # shape: (batch_size, seq_len, d_model) - hidden_states = self.layers[i]( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -332,16 +320,12 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index d06f894123ac..54cc851de934 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -24,12 +24,12 @@ """Inference-only OLMo2 model compatible with HuggingFace weights.""" from functools import partial -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple, Union import torch from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.communication_op import tensor_model_parallel_all_gather @@ -153,14 +153,12 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -239,13 +237,10 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Attention block. residual = hidden_states - hidden_states = self.self_attn(positions, hidden_states, kv_cache, - attn_metadata) + hidden_states = self.self_attn(positions, hidden_states) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = hidden_states + residual @@ -287,8 +282,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], ) -> Union[torch.Tensor, IntermediateTensors]: """ @@ -307,14 +300,9 @@ def forward( assert isinstance(hidden_states, torch.Tensor) # Apply blocks one-by-one. - for i in range(self.start_layer, self.end_layer): + for layer in self.layers[self.start_layer:self.end_layer]: # shape: (batch_size, seq_len, d_model) - hidden_states = self.layers[i]( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -357,15 +345,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, ) return hidden_states diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index d6e24c6d67f3..e27ff5deace2 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -168,14 +168,12 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -222,8 +220,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -237,8 +233,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -283,8 +277,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -299,13 +291,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) @@ -347,13 +336,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index ad1d66902435..e4775478a54d 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -18,13 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import OPTConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -107,12 +107,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output @@ -164,17 +162,13 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Self Attention residual = hidden_states # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention if self.do_layer_norm_before: hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: @@ -261,8 +255,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -277,11 +269,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer(hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -317,15 +306,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: return self.decoder(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) @@ -362,13 +347,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index f4f5cdff6437..6668ede91eec 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -5,13 +5,13 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -136,13 +136,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -189,8 +187,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -198,8 +194,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -247,8 +241,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -260,14 +252,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -303,13 +289,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 955a59953eb4..02d1861b8027 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, +from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch from torch import nn from transformers import PaliGemmaConfig -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) @@ -288,8 +287,6 @@ def get_input_embeddings( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: @@ -306,8 +303,6 @@ def forward(self, hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 6a80bea348ea..db8d170a8c91 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -21,13 +21,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only persimmon model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PersimmonConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -142,8 +142,6 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # [seq_length, 3 x hidden_size] qkv, _ = self.query_key_value(hidden_states) @@ -161,7 +159,7 @@ def forward( k = self._merge_heads(k) q, k = self.rotary_emb(position_ids, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.dense(attn_output) return output @@ -189,8 +187,6 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states @@ -200,8 +196,6 @@ def forward( hidden_states = self.self_attn( position_ids=position_ids, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -248,8 +242,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -261,13 +253,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - hidden_states = self.layers[i]( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layernorm(hidden_states) @@ -298,16 +285,12 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ): hidden_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 1ca8cad22ad9..6ee80210c2b4 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -36,13 +36,13 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PhiConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -126,13 +126,11 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.dense(attn_output) return output @@ -186,16 +184,12 @@ def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_outputs = self.self_attn( position_ids=position_ids, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) feed_forward_hidden_states = self.mlp(hidden_states) hidden_states = attn_outputs + feed_forward_hidden_states + residual @@ -234,8 +228,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -247,14 +239,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -304,13 +290,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 873e9d37771d..33984f54ae27 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers.configuration_utils import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -231,8 +231,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: qkv, _ = self.query_key_value(hidden_states) @@ -248,7 +246,7 @@ def forward( v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.dense(attn_output) return output @@ -282,8 +280,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -291,8 +287,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -338,8 +332,6 @@ def forward( self, input_ids: torch.LongTensor, positions: Optional[torch.LongTensor], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor], ) -> Union[torch.Tensor, IntermediateTensors]: @@ -354,14 +346,8 @@ def forward( else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layernorm(hidden_states) @@ -438,16 +424,12 @@ def forward( self, input_ids: torch.LongTensor, positions: Optional[torch.LongTensor], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: output_hidden_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 207204df2055..61d63e104de4 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -23,7 +23,6 @@ from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, ProcessorMixin) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig @@ -672,8 +671,6 @@ def get_input_embeddings( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object): @@ -691,8 +688,6 @@ def forward(self, hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 17369cb58e36..c35c7e9fcce7 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -22,13 +22,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only PhiMoE model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers.configuration_utils import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -357,13 +357,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -410,8 +408,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: residual = hidden_states @@ -422,8 +418,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = hidden_states + residual @@ -478,8 +472,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -494,13 +486,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) @@ -571,13 +560,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 273dc3b1cf75..87b1d50749a2 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -16,7 +16,6 @@ from transformers.models.pixtral.modeling_pixtral import ( PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, @@ -270,8 +269,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -291,8 +288,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 9383cbae11bc..0d0c367e677e 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -15,13 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM/NASA Prithvi Geospatial model.""" -from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union +from typing import Iterable, Mapping, Optional, Set, Tuple, Union import torch import torch.nn as nn from transformers import BatchFeature -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (IsAttentionFree, @@ -181,8 +180,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 7c4627036203..96abfb9d1096 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -6,13 +6,13 @@ # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -124,13 +124,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.c_proj(attn_output) return output @@ -168,8 +166,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -181,8 +177,6 @@ def forward( hidden_states = self.attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -225,8 +219,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -241,13 +233,10 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.h[i] + for layer in self.h[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -373,12 +362,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7da6e558ff33..fe615c41aeaa 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -23,13 +23,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import Qwen2Config -from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -170,13 +170,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -233,8 +231,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -247,8 +243,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -328,8 +322,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -343,13 +335,10 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) if not get_pp_group().is_last_rank: @@ -468,13 +457,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states @@ -553,12 +539,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - return self.model(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors) + return self.model(input_ids, positions, intermediate_tensors) def pooler( self, diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ef31f18445fd..858cf28d2b87 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -37,7 +37,6 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils @@ -992,8 +991,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -1047,8 +1044,6 @@ def forward( hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 3df5dd2bdd41..f0dc8573ee14 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -22,8 +22,8 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from functools import cached_property -from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict, + Union) import torch import torch.nn as nn @@ -33,7 +33,6 @@ Qwen2AudioProcessor) from transformers.models.whisper import WhisperFeatureExtractor -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -380,8 +379,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -400,8 +397,6 @@ def forward( hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 35d9854a55d6..41536b34b2f2 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -23,14 +23,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, @@ -232,13 +232,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -296,8 +294,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -310,8 +306,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -358,8 +352,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -373,11 +365,8 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -416,13 +405,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index c6588a47d881..21cc9e8ed1c6 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -5,12 +5,11 @@ # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. """Inference-only Qwen2-RM model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -80,13 +79,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) logits, _ = self.score(hidden_states) return logits diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 31701abd3339..849ef7293bb7 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -24,8 +24,8 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import cached_property, partial -from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, - Set, Tuple, Type, TypedDict, Union) +from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set, + Tuple, Type, TypedDict, Union) import torch import torch.nn as nn @@ -38,7 +38,6 @@ Qwen2VLConfig, Qwen2VLVisionConfig) from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils @@ -1302,8 +1301,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -1354,8 +1351,6 @@ def forward( hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 56faa390fc5d..e0d8bf2fa3d2 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -22,7 +22,6 @@ from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -766,8 +765,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, @@ -783,7 +780,6 @@ def forward( vision_embeddings) input_ids = None - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 742e63a065b1..f86fa268072d 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import itertools -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import RobertaConfig -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import CrossEncodingPooler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -243,16 +242,12 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.roberta(input_ids=input_ids, position_ids=positions, - kv_caches=kv_caches, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, - attn_metadata=attn_metadata, token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index ad98f3b07034..0f9e517aeb55 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -23,13 +23,13 @@ # limitations under the License. """Inference-only Solar model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -172,13 +172,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -238,8 +236,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -252,8 +248,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Fully Connected @@ -315,8 +309,6 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -357,8 +349,6 @@ def forward( hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual, ) @@ -438,13 +428,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index a5d4432669f4..a15faec547b9 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -20,13 +20,13 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import StableLmConfig -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul @@ -147,13 +147,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -183,8 +181,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -192,8 +188,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -241,8 +235,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -254,14 +246,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - ) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) @@ -296,13 +282,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 01ea43666482..90098af9dde0 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -19,13 +19,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Starcoder2 model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import Starcoder2Config -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -118,13 +118,11 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -184,8 +182,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -193,8 +189,6 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -246,8 +240,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -259,11 +251,8 @@ def forward( else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) @@ -306,13 +295,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index b431abb76b69..1c3c443b2941 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -22,7 +22,7 @@ from transformers import AutoModel, PreTrainedModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.utils import divide @@ -59,7 +59,6 @@ def vllm_flash_attention_forward( # Transformers kwargs scaling: Optional[float] = None, # vLLM kwargs - attn_metadata: Optional[AttentionMetadata] = None, attention_instances: Optional[list[Attention]] = None, **kwargs): self_attn = attention_instances[module.layer_idx] @@ -68,12 +67,7 @@ def vllm_flash_attention_forward( hidden = query.shape[-2] query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) - return self_attn.forward( - query, - key, - value, - kv_cache=None, # argument not used - attn_metadata=attn_metadata), None + return self_attn.forward(query, key, value), None ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward @@ -251,8 +245,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: list[torch.Tensor], # argument not used - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -260,7 +252,6 @@ def forward( input_ids[None, ...], use_cache=False, position_ids=positions[None, ...], - attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, attention_instances=self.attention_instances, return_dict=False)[0][0, ...] # we remove batch dimension for now diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b99094e5d4ca..1dbba3c50b19 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -4,8 +4,8 @@ """PyTorch Ultravox model.""" import math from functools import cached_property -from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) import torch import torch.utils.checkpoint @@ -16,8 +16,8 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder from vllm import envs -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -495,13 +495,13 @@ def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[NestedTensors] = None, - attn_metadata: Optional[AttentionMetadata] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: # TODO(ywang96): remove this block after v0 is deprecated. if not envs.VLLM_USE_V1: + attn_metadata = get_forward_context().attn_metadata merge_multimodal_embeddings_from_map( inputs_embeds, multimodal_embeddings, attn_metadata.multi_modal_placeholder_index_maps["audio"]) @@ -514,8 +514,6 @@ def get_input_embeddings( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, IntermediateTensors]: @@ -540,17 +538,12 @@ def forward(self, elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - # TODO(ywang96): remove attn_metadata from get_input_embeddings - # after v0 is deprecated inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings, - attn_metadata) + multimodal_embeddings) input_ids = None hidden_states = self.language_model.model(input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 2ad1731144ef..e5f77e08c403 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -10,7 +10,7 @@ WhisperProcessor) from transformers.models.whisper.modeling_whisper import sinusoids -from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention import Attention, AttentionType from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -134,13 +134,11 @@ def _init_qkv( def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) @@ -196,8 +194,6 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor], - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ): q, _ = self.q_proj(hidden_states) @@ -209,13 +205,7 @@ def forward( else: k = v = None - attn_output = self.attn( - q, - k, - v, - kv_cache, - attn_metadata, - ) + attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) @@ -285,16 +275,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn( - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -348,14 +332,10 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor], - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) + hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states @@ -363,8 +343,6 @@ def forward( hidden_states = self.encoder_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -411,12 +389,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_positions.weight.copy_( sinusoids(*self.embed_positions.weight.shape)) - def forward( - self, - input_features: Union[torch.Tensor, List[torch.Tensor]], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ): + def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]): hidden_states = [] for features in input_features: embeds = nn.functional.gelu(self.conv1(features)) @@ -426,12 +399,8 @@ def forward( hidden_states.append(embeds) hidden_states = torch.cat(hidden_states) - for idx, encoder_layer in enumerate(self.layers): - hidden_states = encoder_layer( - hidden_states, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, - ) + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) hidden_states = self.layer_norm(hidden_states) return hidden_states @@ -466,19 +435,15 @@ def forward( input_ids, positions: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ): inputs_embeds = self.get_input_embeddings(input_ids) positions = self.embed_positions(positions) hidden_states = inputs_embeds + positions - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: hidden_states = decoder_layer( hidden_states, encoder_hidden_states=encoder_hidden_states, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, ) hidden_states = self.layer_norm(hidden_states) @@ -505,36 +470,22 @@ def forward( input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_ids: Optional[torch.Tensor], positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> torch.Tensor: - encoder_outputs = self.get_encoder_outputs( - input_features, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - ) + encoder_outputs = self.get_encoder_outputs(input_features) decoder_outputs = self.decoder( input_ids=input_ids, positions=positions, encoder_hidden_states=encoder_outputs, - kv_caches=kv_caches, - attn_metadata=attn_metadata, ) return decoder_outputs def get_encoder_outputs( self, input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, ) -> Optional[torch.Tensor]: if input_features is None: return None - return self.encoder( - input_features, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - ) + return self.encoder(input_features) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -733,8 +684,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: audio_input = self._parse_and_validate_audio_input(**kwargs) @@ -742,31 +691,19 @@ def forward( input_features=audio_input["input_features"], input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, ) return decoder_outputs - def get_multimodal_embeddings( - self, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - **kwargs, - ) -> Optional[NestedTensors]: + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: # TODO: This method does not obey the interface for SupportsMultiModal. # Refactor this once encoder/decoder support is implemented in V1. audio_input = self._parse_and_validate_audio_input(**kwargs) - return self.model.get_encoder_outputs( - audio_input["input_features"], - kv_caches=kv_caches, - attn_metadata=attn_metadata, - ) + return self.model.get_encoder_outputs(audio_input["input_features"]) def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[NestedTensors] = None, - attn_metadata: Optional[AttentionMetadata] = None, ) -> torch.Tensor: # TODO: This method just returns the decoder sequence embeddings since # Whisper does not have encoder text tokens. Refactor this once diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 7353d3c53ae9..40ecc3481e6b 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -288,8 +288,6 @@ def execute_model( hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a7b9d4781183..1fbce3098a34 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -939,8 +939,6 @@ def execute_model( hidden_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=self.kv_caches, - attn_metadata=None, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) @@ -1137,11 +1135,8 @@ def _get_prompt_logprobs_dict( def _dummy_run( self, num_tokens: int, - kv_caches: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: model = self.model - if kv_caches is None: - kv_caches = self.kv_caches if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -1172,26 +1167,12 @@ def _dummy_run( hidden_states = model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=None, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states def profile_run(self) -> None: - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value `None`. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - dummy_kv_caches = [ - torch.tensor((), dtype=torch.float32, device=self.device) - for _ in range(self.num_attn_layers) - ] - # Profile with multimodal encoder & encoder cache. # TODO: handle encoder-decoder models once we support them. if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 @@ -1302,8 +1283,7 @@ def profile_run(self) -> None: with self.maybe_profile_with_lora(self.lora_config, num_scheduled_tokens): # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.max_num_tokens, - dummy_kv_caches) + hidden_states = self._dummy_run(self.max_num_tokens) if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e60268f04527..f7d72d26e045 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -13,11 +13,10 @@ import torch_xla.core.xla_model as xm import torch_xla.runtime as xr -from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingType @@ -623,7 +622,6 @@ def execute_model( assert self.model is not None selected_token_ids = self.model(prompt_data.input_tokens, prompt_data.input_positions, - prompt_data.attn_metadata, self.kv_caches) # In parallel to TPU execution, prepare the next iteration @@ -662,7 +660,6 @@ def execute_model( assert self.model is not None selected_token_ids = self.model(decode_data.input_tokens, decode_data.input_positions, - decode_data.attn_metadata, self.kv_caches) # Transfer sampled tokens from TPU to CPU @@ -839,7 +836,7 @@ def dummy_run( with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None - self.model(token_ids, position_ids, attn_metadata, kv_caches) + self.model(token_ids, position_ids, kv_caches) def capture_model(self) -> None: """Compile the model.""" @@ -963,7 +960,6 @@ def forward( self, token_ids: torch.Tensor, position_ids: torch.Tensor, - attn_metadata: AttentionMetadata, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -971,7 +967,6 @@ def forward( Args: token_ids: The input token IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len]. - attn_metadata: The Pallas attention metadata. input_lens: The actual input lengths of shape [batch_size]. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. @@ -980,7 +975,8 @@ def forward( memory profiling at initialization. """ # Skip this in memory profiling at initialization. - if attn_metadata is not None and kv_caches[0][0].numel() > 0: + if kv_caches[0][0].numel() > 0: + attn_metadata = get_forward_context().attn_metadata # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape # [num_kv_heads, num_blocks, block_size, head_size]. To make it @@ -1001,12 +997,7 @@ def forward( attn_metadata.slot_mapping = slot_mapping assert self.model is not None - hidden_states = self.model( - token_ids, - position_ids, - kv_caches, - attn_metadata, - ) + hidden_states = self.model(token_ids, position_ids) hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, None) diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index 71e32c5f7aca..ac7c93e48395 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -297,10 +297,6 @@ def execute_model( model_input.encoder_input_tokens, "encoder_positions": model_input.encoder_input_positions, - "kv_caches": - kv_caches, - "attn_metadata": - model_input.attn_metadata, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), "intermediate_tensors": diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 9400893105d7..8407f073040e 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -654,8 +654,6 @@ def execute_model( hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **execute_model_kwargs, **multimodal_kwargs, diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py index c0744d63b8d0..1ceb2557c6b3 100644 --- a/vllm/worker/cpu_pooling_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -41,16 +41,6 @@ def execute_model( raise ValueError( "CPU worker does not support multi-step execution.") - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(num_layers) - ] - model_executable = self.model cross_enc_kwargs = {} if model_input.token_type_ids is not None: @@ -60,10 +50,6 @@ def execute_model( model_input.input_tokens, "positions": model_input.input_positions, - "kv_caches": - kv_caches, - "attn_metadata": - model_input.attn_metadata, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), **cross_enc_kwargs, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index e2d338f75761..5f39f2fa4947 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -184,8 +184,6 @@ def execute_model( positions=model_input.input_positions, encoder_input_ids=model_input.encoder_input_tokens, encoder_positions=model_input.encoder_input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), @@ -324,21 +322,11 @@ def profile_run(self) -> None: or encoder_dummy_data.multi_modal_placeholders) seqs.append(seq) - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(num_layers) - ] finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) intermediate_tensors = None - self.execute_model(model_input, kv_caches, intermediate_tensors) + self.execute_model(model_input, None, intermediate_tensors) torch.cuda.synchronize() return diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index f22526cfad70..d6eaf84e40f6 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -384,11 +384,12 @@ def forward(self, *args, **kwargs): if 'virtual_engine' in kwargs: virtual_engine = kwargs.pop('virtual_engine') input_ids = kwargs['input_ids'] - kwargs['attn_metadata'] = self._update_metadata( - kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), - input_ids.device, self.dtype) + attn_metadata = self._update_metadata(kwargs.pop('attn_metadata'), + input_ids.size(0), + input_ids.size(1), + input_ids.device, self.dtype) LoraMask.setLoraMask(kwargs.pop('lora_mask')) - with set_forward_context(kwargs['attn_metadata'], self.vllm_config, + with set_forward_context(attn_metadata, self.vllm_config, virtual_engine): hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -1346,15 +1347,13 @@ def profile_run(self) -> None: max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] max_batch_size = min(self.max_num_batched_tokens // max_seq_len, self.scheduler_config.max_num_seqs) - self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, - False, True) + self.warmup_scenario(max_batch_size, max_seq_len, True, False, True) return def warmup_scenario(self, batch_size, seq_len, is_prompt, - kv_caches, is_pt_profiler_run=False, is_lora_profile_run=False) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) @@ -1418,7 +1417,7 @@ def warmup_scenario(self, profiler.start() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches, warmup_mode=True) + self.execute_model(inputs, None, warmup_mode=True) torch.hpu.synchronize() if profiler: profiler.step() @@ -1470,17 +1469,16 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len): f"free_mem:{free_mem}") logger.info(msg) - def warmup_all_buckets(self, buckets, is_prompt, kv_caches): + def warmup_all_buckets(self, buckets, is_prompt): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + self.warmup_scenario(batch_size, seq_len, is_prompt) def warmup_graphs(self, strategy, buckets, is_prompt, - kv_caches, available_mem, starting_mem=0, total_batch_seq=0.001): @@ -1512,7 +1510,7 @@ def warmup_graphs(self, self.graphed_buckets.add(graphed_bucket) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + self.warmup_scenario(batch_size, seq_len, is_prompt) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) available_mem -= used_mem @@ -1542,8 +1540,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: graphs = graph == 't' if graphs: self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) - self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, - True) + self.warmup_scenario(int(bs), int(seq_len), is_prompt, True) raise AssertionError("Finished profiling") if self.skip_warmup: logger.info("Skipping warmup...") @@ -1608,9 +1605,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, - True, kv_caches) + True) self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, - False, kv_caches) + False) if not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, \ @@ -1641,11 +1638,11 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( prompt_strategy, self.bucketing_global_state.prompt_buckets, - True, kv_caches, prompt_available_memory) + True, prompt_available_memory) mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( decode_strategy, self.bucketing_global_state.decode_buckets, - False, kv_caches, decode_available_memory) + False, decode_available_memory) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space @@ -1656,7 +1653,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.warmup_graphs( prompt_strategy, self.bucketing_global_state.prompt_buckets, True, - kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_prompt, prompt_batch_seq)) @@ -1669,7 +1665,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: mem_post_decode, _, _ = self.warmup_graphs( decode_strategy, self.bucketing_global_state.decode_buckets, False, - kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_decode, decode_batch_seq) @@ -1982,7 +1977,6 @@ def execute_model( execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, - "kv_caches": kv_caches, "attn_metadata": self.trim_attn_metadata(attn_metadata), "intermediate_tensors": intermediate_tensors, "lora_mask": lora_mask, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1a78498ad124..86dcde234f86 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -26,7 +26,7 @@ from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, graph_capture) -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -1727,8 +1727,6 @@ def execute_model( hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), @@ -1913,8 +1911,6 @@ def capture( self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_inputs, **kwargs, ) @@ -1927,8 +1923,6 @@ def capture( output_hidden_or_intermediate_states = self.model( input_ids=input_ids, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, intermediate_tensors=intermediate_inputs, **kwargs, ) @@ -1976,13 +1970,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], **kwargs, ) -> torch.Tensor: - # KV caches are fixed tensors, so we don't need to copy them. - del kv_caches + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 90771e8ac75d..7ddf382079c6 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -476,7 +476,7 @@ def execute_model( # path for warm up runs if not model_input.is_multi_step: return self._base_model_runner.execute_model( - frozen_model_input, kv_caches, intermediate_tensors, num_steps) + frozen_model_input, None, intermediate_tensors, num_steps) # make sure we skip the sampler on the lask rank and only pythonize # if CPU is ahead. @@ -538,7 +538,7 @@ def execute_model( # Execute the model output = self._base_model_runner.execute_model(frozen_model_input, - kv_caches, + None, intermediate_tensors, num_steps=1) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index f7a5ab9de9fa..5035ea20294c 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -346,10 +346,6 @@ def execute_model( input_tokens, "positions": input_positions, - "kv_caches": - kv_caches, - "attn_metadata": - attn_metadata, **MultiModalKwargs.as_kwargs(multi_modal_kwargs or {}, device=self.device), } diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index 4cbe5db44534..cbd5e2060cad 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -91,16 +91,6 @@ def execute_model( else: model_executable = self.model - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(num_layers) - ] - multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, @@ -121,8 +111,6 @@ def execute_model( hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index ecdf7aa88896..53541a2579ed 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -15,7 +15,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import VllmConfig -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model @@ -275,8 +275,8 @@ def _dummy_run( torch._dynamo.mark_dynamic(p, 0) # Dummy run. with set_forward_context(attn_metadata, self.vllm_config, 0): - self.model(token_ids, position_ids, attn_metadata, input_lens, t, - p, num_samples, kv_caches) + self.model(token_ids, position_ids, input_lens, t, p, num_samples, + kv_caches) def warmup_model( self, @@ -679,8 +679,8 @@ def execute_model( self.vllm_config, model_input.virtual_engine): output_token_ids = self.model(token_ids, position_ids, - attn_metadata, input_lens, t, - p, model_input.num_samples, + input_lens, t, p, + model_input.num_samples, kv_caches) next_token_ids.append(output_token_ids[0]) start_idx = end_idx @@ -730,8 +730,8 @@ def execute_model( self.vllm_config, model_input.virtual_engine): output_token_ids = self.model(token_ids, position_ids, - attn_metadata, input_lens, t, - p, model_input.num_samples, + input_lens, t, p, + model_input.num_samples, kv_caches) self.cached_step_outputs.append(output_token_ids) @@ -777,7 +777,6 @@ def forward( self, token_ids: torch.Tensor, position_ids: torch.Tensor, - attn_metadata: AttentionMetadata, input_lens: torch.Tensor, t: torch.Tensor, p: torch.Tensor, @@ -789,7 +788,6 @@ def forward( Args: token_ids: The input token IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len]. - attn_metadata: The Pallas attention metadata. input_lens: The actual input lengths of shape [batch_size]. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. @@ -802,6 +800,7 @@ def forward( start_indicies = torch.arange( batch_size, dtype=torch.int32, device=input_lens.device) * seq_len logits_indices = start_indicies + input_lens - 1 + attn_metadata = get_forward_context().attn_metadata # FIXME(woosuk): This is a temporary hack to avoid using the existing # sampler and sampling metadata. @@ -833,12 +832,7 @@ def forward( slot_mapping = slot_mapping.flatten() attn_metadata.slot_mapping = slot_mapping - hidden_states = self.model( - token_ids, - position_ids, - kv_caches, - attn_metadata, - ) + hidden_states = self.model(token_ids, position_ids) hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, sampling_metadata) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 9c726e1a107e..39957e661c47 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -484,15 +484,6 @@ def profile_run(self) -> None: multi_modal_placeholders=dummy_data.multi_modal_placeholders) seqs.append(seq) - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - ] * num_layers finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) @@ -502,7 +493,7 @@ def profile_run(self) -> None: batch_size=batch_size, dtype=self.model_config.dtype, device=self.device) - self.execute_model(model_input, kv_caches, intermediate_tensors) + self.execute_model(model_input, None, intermediate_tensors) torch.xpu.synchronize() return @@ -581,8 +572,6 @@ def execute_model( hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},