7
7
8
8
from vllm ._ipex_ops import ipex_ops
9
9
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
10
+ AttentionLayer ,
10
11
AttentionMetadata , AttentionType )
11
12
from vllm .attention .backends .utils import CommonAttentionState
12
13
from vllm .attention .ops .paged_attn import (PagedAttention ,
@@ -171,13 +172,12 @@ def split_kv_cache(
171
172
172
173
def forward (
173
174
self ,
175
+ layer : AttentionLayer ,
174
176
query : torch .Tensor ,
175
177
key : torch .Tensor ,
176
178
value : torch .Tensor ,
177
179
kv_cache : torch .Tensor ,
178
180
attn_metadata : IpexAttnMetadata , # type: ignore
179
- k_scale : float = 1.0 ,
180
- v_scale : float = 1.0 ,
181
181
output : Optional [torch .Tensor ] = None ,
182
182
) -> torch .Tensor :
183
183
"""Forward pass with IPEX varlen_attention and PagedAttention.
@@ -193,7 +193,7 @@ def forward(
193
193
Returns:
194
194
shape = [num_tokens, num_heads * head_size]
195
195
"""
196
- assert k_scale == 1.0 and v_scale == 1.0
196
+ assert layer . _k_scale == 1.0 and layer . _v_scale == 1.0
197
197
num_tokens , hidden_size = query .shape
198
198
# Reshape the query, key, and value tensors.
199
199
query = query .view (- 1 , self .num_heads , self .head_size )
@@ -210,8 +210,8 @@ def forward(
210
210
value_cache ,
211
211
attn_metadata .slot_mapping .flatten (),
212
212
self .kv_cache_dtype ,
213
- k_scale ,
214
- v_scale ,
213
+ layer . _k_scale ,
214
+ layer . _v_scale ,
215
215
)
216
216
217
217
if attn_metadata .is_prompt :
@@ -296,8 +296,8 @@ def forward(
296
296
max_seq_len ,
297
297
self .alibi_slopes ,
298
298
self .kv_cache_dtype ,
299
- k_scale ,
300
- v_scale ,
299
+ layer . _k_scale ,
300
+ layer . _v_scale ,
301
301
)
302
302
else :
303
303
# Run PagedAttention V2.
@@ -329,8 +329,8 @@ def forward(
329
329
max_seq_len ,
330
330
self .alibi_slopes ,
331
331
self .kv_cache_dtype ,
332
- k_scale ,
333
- v_scale ,
332
+ layer . _k_scale ,
333
+ layer . _v_scale ,
334
334
)
335
335
336
336
# Reshape the output tensor.
0 commit comments