Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 66 additions & 44 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu

if is_cpu():
try:
from vllm.attention.ops.ipex_attn import PagedAttention
except ImportError:
from vllm.attention.ops.paged_attn import PagedAttention
else:

try:
from vllm._ipex_ops import ipex_ops
from vllm.attention.ops.ipex_attn import PagedAttention
use_ipex = True
except ImportError:
from vllm.attention.ops.paged_attn import PagedAttention
use_ipex = False


class TorchSDPABackend(AttentionBackend):
Expand Down Expand Up @@ -69,6 +68,8 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
is_prompt: bool
slot_mapping: torch.Tensor
seq_lens: Optional[List[int]]
max_seqlen: Optional[int] = None
seqlen_q: Optional[torch.tensor] = None

def __post_init__(self):
# Set during the execution of the first attention op.
Expand Down Expand Up @@ -179,42 +180,63 @@ def forward(
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)

if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.seq_lens)
attn_metadata.attn_bias = att_masks

query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)

start = 0
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for seq_len, mask in zip(attn_metadata.seq_lens,
attn_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[None, :, start:end, :],
key[None, :, start:end, :],
value[None, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
if use_ipex:
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype,
device=query.device)
# ipex-cpu provide varlen_attention API
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Nice addition.

Suggested change
# ipex-cpu provide varlen_attention API
# ipex-cpu provides varlen_attention API

# which could perform better than torch.sdpa
ipex_ops.varlen_attention(query,
key,
value,
output,
attn_metadata.seqlen_q,
attn_metadata.seqlen_q,
attn_metadata.max_seqlen,
attn_metadata.max_seqlen,
pdropout=0.0,
softmax_scale=self.scale,
zero_tensors=False,
is_causal=True,
return_softmax=False,
gen_=None)
else:
if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.seq_lens)
attn_metadata.attn_bias = att_masks

query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)

start = 0
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for seq_len, mask in zip(attn_metadata.seq_lens,
attn_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[None, :, start:end, :],
key[None, :, start:end, :],
value[None, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
else:
# prefix-enabled attention
raise RuntimeError(
Expand Down
5 changes: 5 additions & 0 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,16 @@ def _prepare_prompt(
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device) # type: ignore
max_seqlen = max(seq_lens)
seqlen = torch.tensor([0] + seq_lens)
seqlen_q = torch.cumsum(seqlen, dim=0)

attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
seq_lens=seq_lens,
seq_lens_tensor=None,
max_seqlen=max_seqlen,
seqlen_q=seqlen_q,
max_decode_seq_len=None,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
Expand Down