Skip to content

Commit 6dba3b6

Browse files
faaanyCopilot
andauthored
add XPU path in apply_rotary_pos_emb_flashattn for Keye-VL models (vllm-project#7)
* add xpu path Signed-off-by: Lin, Fanli <[email protected]> * use partial to create a function wrapper Co-authored-by: Copilot <[email protected]> --------- Signed-off-by: Lin, Fanli <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent c763367 commit 6dba3b6

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

vllm/model_executor/models/keye.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,10 @@ def apply_rotary_pos_emb_flashatt(
345345
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
346346
elif current_platform.is_rocm():
347347
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
348+
else:
349+
# For XPU and other platforms, use PyTorch fallback
350+
from vllm.model_executor.layers.rotary_embedding.common import apply_rotary_emb_torch
351+
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
348352

349353
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
350354
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)

0 commit comments

Comments
 (0)