Skip to content

Commit eda8111

Browse files
Martin Yuanfacebook-github-bot
authored andcommitted
Enable SDPA without kv cache (#8950)
Summary: Sdpa custom op has been decoupled from kv cache by kimishpatel. Update the llama definition so that the sdpa op is applied both with and without kv cache. Reviewed By: kimishpatel Differential Revision: D70593177
1 parent 337d73d commit eda8111

File tree

2 files changed

+102
-34
lines changed

2 files changed

+102
-34
lines changed

examples/models/llama/attention.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -133,23 +133,26 @@ def __init__(
133133

134134
def forward(
135135
self,
136-
input_pos: torch.Tensor,
136+
input_pos: Optional[torch.Tensor],
137137
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
138138
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
139139
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
140140
bsz,
141141
seqlen,
142142
mask: torch.Tensor,
143143
) -> torch.Tensor:
144-
if self.enable_dynamic_shape:
145-
start_pos = input_pos[-1].item()
146-
torch._check_is_size(start_pos)
147-
torch._check(start_pos < self.max_context_len)
148-
seq_length = q.size(2)
149-
# pyre-ignore: Incompatible parameter type [6]
150-
attn_mask = mask.narrow(0, start_pos, seq_length)
144+
if input_pos is None: # No kv cache
145+
attn_mask = mask[:seqlen, :seqlen]
151146
else:
152-
attn_mask = mask[None, None, input_pos]
147+
if self.enable_dynamic_shape:
148+
start_pos = input_pos[-1].item()
149+
torch._check_is_size(start_pos)
150+
torch._check(start_pos < self.max_context_len)
151+
seq_length = q.size(2)
152+
# pyre-ignore: Incompatible parameter type [6]
153+
attn_mask = mask.narrow(0, start_pos, seq_length)
154+
else:
155+
attn_mask = mask[None, None, input_pos]
153156

154157
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
155158
# can natively support GQA now. But needs enable_gqa=True
@@ -218,13 +221,13 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
218221
self.head_dim,
219222
args.enable_dynamic_shape,
220223
)
221-
self.SDPA = SDPA(
222-
dim=self.n_local_heads * self.head_dim,
223-
head_dim=self.head_dim,
224-
n_rep=self.n_rep,
225-
max_context_len=self.max_context_len,
226-
enable_dynamic_shape=args.enable_dynamic_shape,
227-
)
224+
self.SDPA = SDPA(
225+
dim=self.n_local_heads * self.head_dim,
226+
head_dim=self.head_dim,
227+
n_rep=self.n_rep,
228+
max_context_len=self.max_context_len,
229+
enable_dynamic_shape=args.enable_dynamic_shape,
230+
)
228231

229232
def forward(
230233
self,
@@ -257,21 +260,5 @@ def forward(
257260
if self.use_kv_cache:
258261
assert input_pos is not None
259262
k, v = self.kv_cache.update(input_pos, k, v)
260-
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
261-
return self.wo(output), None
262-
263-
# grouped multiquery attention: expand out keys and values
264-
k = k.repeat_interleave(self.n_rep, dim=1)
265-
v = v.repeat_interleave(self.n_rep, dim=1)
266-
267-
assert hasattr(self, "mask")
268-
269-
mask = self.mask[:seqlen, :seqlen]
270-
271-
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
272-
273-
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
274-
275-
output = self.wo(output)
276-
277-
return output, None
263+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
264+
return self.wo(output), None
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import unittest
2+
import torch
3+
from executorch.examples.models.llama.attention import AttentionMHA, ModelArgs, Rope, KVCache, SDPA
4+
5+
6+
class TestAttentionMHA(unittest.TestCase):
7+
8+
def create_mock_args(self):
9+
return ModelArgs(
10+
use_kv_cache=True,
11+
n_heads=8,
12+
n_kv_heads=4,
13+
head_dim=64,
14+
max_batch_size=2,
15+
max_context_len=16,
16+
dim=512,
17+
attention_qkv_bias=False,
18+
enable_dynamic_shape=False,
19+
)
20+
21+
def test_attentionmha_init(self):
22+
args = self.create_mock_args()
23+
rope = Rope(args)
24+
attn = AttentionMHA(args, layer_id=0, rope=rope)
25+
26+
self.assertEqual(attn.n_heads, 8)
27+
self.assertEqual(attn.n_kv_heads, 4)
28+
self.assertEqual(attn.n_local_heads, 8)
29+
self.assertEqual(attn.n_local_kv_heads, 4)
30+
self.assertEqual(attn.head_dim, 64)
31+
self.assertEqual(attn.dim, 512)
32+
self.assertEqual(attn.mask.shape, (16, 16)) # Causal mask shape check
33+
self.assertTrue(attn.use_kv_cache)
34+
35+
if attn.use_kv_cache:
36+
self.assertIsInstance(attn.kv_cache, KVCache)
37+
self.assertIsInstance(attn.SDPA, SDPA)
38+
39+
def test_attentionmha_forward(self):
40+
args = self.create_mock_args()
41+
rope = Rope(args)
42+
attn = AttentionMHA(args, layer_id=0, rope=rope)
43+
44+
bsz, seqlen, dim = 2, 4, args.dim
45+
x = torch.randn(bsz, seqlen, dim)
46+
freqs_cos = torch.randn(seqlen, args.head_dim // 2)
47+
freqs_sin = torch.randn(seqlen, args.head_dim // 2)
48+
input_pos = torch.tensor([0, 1, 2, 3])
49+
50+
output, _ = attn.forward(x, freqs_cos, freqs_sin, input_pos=input_pos)
51+
52+
self.assertEqual(output.shape, (bsz, seqlen, dim))
53+
54+
def test_attentionmha_forward_no_kv_cache(self):
55+
args = self.create_mock_args()
56+
args.use_kv_cache = False # Disable KV cache for this test
57+
rope = Rope(args)
58+
attn = AttentionMHA(args, layer_id=0, rope=rope)
59+
60+
bsz, seqlen, dim = 2, 4, args.dim
61+
x = torch.randn(bsz, seqlen, dim)
62+
freqs_cos = torch.randn(seqlen, args.head_dim // 2)
63+
freqs_sin = torch.randn(seqlen, args.head_dim // 2)
64+
65+
output, _ = attn.forward(x, freqs_cos, freqs_sin)
66+
67+
self.assertEqual(output.shape, (bsz, seqlen, dim))
68+
69+
def test_attentionmha_invalid_kv_cache(self):
70+
args = self.create_mock_args()
71+
rope = Rope(args)
72+
attn = AttentionMHA(args, layer_id=0, rope=rope)
73+
74+
bsz, seqlen, dim = 2, 4, args.dim
75+
x = torch.randn(bsz, seqlen, dim)
76+
freqs_cos = torch.randn(seqlen, args.head_dim // 2)
77+
freqs_sin = torch.randn(seqlen, args.head_dim // 2)
78+
79+
# No input_pos provided, should raise assertion error
80+
with self.assertRaises(AssertionError):
81+
attn.forward(x, freqs_cos, freqs_sin)

0 commit comments

Comments
 (0)