Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
28c7f27
Remove `kv_cache` and `attn_metadata` from `Attention`
hmellor Feb 19, 2025
1fe2b0d
Remove `attn_metadata` from `MambaMixer` 1 & 2
hmellor Feb 19, 2025
153d253
Remove `kv_caches` and `attn_metadata` from `forward` call
hmellor Feb 19, 2025
eb30940
Remove `kv_caches` and `attn_metadata` from new model docs
hmellor Feb 19, 2025
7a75753
Remove `kv_caches` and `attn_metadata` from model interface
hmellor Feb 19, 2025
7ddfd1f
Remove args from a batch of models
hmellor Feb 19, 2025
f8794e9
Remove args from another batch of models
hmellor Feb 19, 2025
f81cad0
Remove `attn_metadata` from a couple more places
hmellor Feb 19, 2025
6beb1b1
Attempt fix HPU model runner
hmellor Feb 19, 2025
c784070
Update CPU model runners
hmellor Feb 19, 2025
72450ae
Update V1 GPU model runner
hmellor Feb 19, 2025
fdda9c6
Update draft model runner
hmellor Feb 19, 2025
f9a1ee8
Update enc dec model runner
hmellor Feb 19, 2025
b91538a
Update remaining non-device model runners
hmellor Feb 19, 2025
59f01be
Allow `kv_caches` to be passed to `execute_model`
hmellor Feb 19, 2025
778910f
Update XPU model runner
hmellor Feb 19, 2025
c7cd852
Update V1 GPU model runner
hmellor Feb 19, 2025
334d2b3
Update OpenVINO model runner
hmellor Feb 19, 2025
0735ed9
Update Neuron model runner
hmellor Feb 19, 2025
5a8a73d
Add unused `kv_caches` arg to runners to limit scope of PR
hmellor Feb 19, 2025
3b9a35b
Update TPU V0 and V1
hmellor Feb 19, 2025
bb094d2
Update HPU model runner
hmellor Feb 19, 2025
46d8fab
Make `kv_caches` optional in `HPUModelRunner.execute_model`
hmellor Feb 19, 2025
39ad6d4
Make linter happy
hmellor Feb 19, 2025
164ee32
Fix whisper test
hmellor Feb 20, 2025
f6c8e2a
Add `kv_caches` back to remaining `*ModelRunner.execute_model()`
hmellor Feb 20, 2025
c917880
Fix kernel tests
hmellor Feb 20, 2025
6a29698
Kick CI
hmellor Feb 20, 2025
cd1e845
Merge branch 'main' into remove-unused-attn-args
hmellor Feb 20, 2025
f8b4d36
Fix missing import
hmellor Feb 20, 2025
39742a3
Fix call to `execute_model` in encoder decoder model runner
hmellor Feb 20, 2025
cc087b0
Fix call to `execute_model` in XPU model runner
hmellor Feb 20, 2025
6f703ba
Fix call to `execute_model` in multi-step model runner
hmellor Feb 20, 2025
d0ee431
Fix V1 TPU model runner
hmellor Feb 20, 2025
29cff77
Fix multi-step model runner
hmellor Feb 20, 2025
7e0c808
Merge branch 'main' into remove-unused-attn-args
hmellor Feb 21, 2025
5d84b99
Deprecate args in `Attention.forward` instead
hmellor Feb 21, 2025
8925e30
Revert "Deprecate args in `Attention.forward` instead"
hmellor Feb 22, 2025
b7ec2d9
Merge branch 'main' into remove-unused-attn-args
hmellor Feb 24, 2025
a775d1c
Fix `mllama` KV cache access
hmellor Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/contributing/model/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...
```
Expand Down
2 changes: 0 additions & 2 deletions docs/source/contributing/model/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ Further update the model as follows:
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
```
Expand Down
14 changes: 3 additions & 11 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,11 +644,7 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(
reshaped_query, packed_qkv.key, packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device), attn_metadata)
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)


def _run_decoder_self_attention_test(
Expand Down Expand Up @@ -682,7 +678,6 @@ def _run_decoder_self_attention_test(
& attn_metadata
'''
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
with set_forward_context(attn_metadata, vllm_config):
Expand All @@ -695,8 +690,7 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
kv_cache, attn_metadata)
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)


def _run_encoder_decoder_cross_attention_test(
Expand Down Expand Up @@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test(
assert decoder_test_params.packed_qkvo.packed_qkv is not None

attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
if cross_test_params is None:
key = None
value = None
Expand All @@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query, key, value, kv_cache,
attn_metadata)
return attn.forward(reshaped_query, key, value)


@pytest.fixture(autouse=True)
Expand Down
11 changes: 8 additions & 3 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils import deprecate_args, direct_register_custom_op


class Attention(nn.Module):
Expand Down Expand Up @@ -148,13 +148,18 @@ def __init__(
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

@deprecate_args(
4,
additional_message=
"In Attention, kv_cache is accessed via self.kv_cache and "
"attn_metadata is accessed via forward context.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
# directly, use `self.kv_cache` and
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -130,14 +131,14 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
) if use_rms_norm else None

def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, ssm_state: torch.Tensor):
pass

def forward_cuda(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams):

attn_metadata: AttentionMetadata = get_forward_context().attn_metadata

# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
Expand Down Expand Up @@ -365,17 +366,16 @@ def __init__(self,
eps=rms_norm_eps)

def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, ssm_state: torch.Tensor):
pass

def forward_cuda(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None,
):
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata

seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size
Expand Down
6 changes: 1 addition & 5 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T:
return cls

# Lazy import
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolingType
Expand Down Expand Up @@ -201,13 +200,10 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = super().forward(input_ids, positions, kv_caches,
attn_metadata,
hidden_states = super().forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
logits, _ = self.score(hidden_states)
Expand Down
24 changes: 5 additions & 19 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -282,13 +282,11 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output

Expand Down Expand Up @@ -335,16 +333,12 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual_input = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual_input + hidden_states

Expand Down Expand Up @@ -399,8 +393,6 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
Expand All @@ -412,11 +404,8 @@ def forward(
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states)
Expand Down Expand Up @@ -457,13 +446,10 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states

Expand Down
5 changes: 0 additions & 5 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -626,8 +625,6 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
Expand All @@ -643,8 +640,6 @@ def forward(
hidden_states = self.language_model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
Expand Down
24 changes: 5 additions & 19 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union

import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -182,14 +182,12 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output

Expand Down Expand Up @@ -232,8 +230,6 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
Expand All @@ -246,8 +242,6 @@ def forward(
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)

# Fully Connected
Expand Down Expand Up @@ -301,8 +295,6 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
Expand All @@ -316,13 +308,10 @@ def forward(
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
Expand Down Expand Up @@ -387,13 +376,10 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states

Expand Down
Loading