Skip to content

Commit b6c9bcd

Browse files
authored
[Executorch][llm] Add ring buffer based kv cache and mask calculation to MHA
Differential Revision: D73891425 Pull Request resolved: #10609
1 parent e86ea6d commit b6c9bcd

File tree

8 files changed

+433
-102
lines changed

8 files changed

+433
-102
lines changed

examples/models/llama/attention.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,12 @@ def __init__(
123123
head_dim: int,
124124
n_rep: int,
125125
max_context_len: int,
126-
enable_dynamic_shape: bool,
127126
):
128127
super().__init__()
129128
self.dim = dim
130129
self.head_dim = head_dim
131130
self.n_rep = n_rep
132131
self.max_context_len = max_context_len
133-
self.enable_dynamic_shape = enable_dynamic_shape
134132

135133
def forward(
136134
self,
@@ -142,21 +140,12 @@ def forward(
142140
seqlen,
143141
mask: torch.Tensor,
144142
) -> torch.Tensor:
145-
if self.enable_dynamic_shape:
146-
start_pos = input_pos[-1].item()
147-
torch._check_is_size(start_pos)
148-
torch._check(start_pos < self.max_context_len)
149-
seq_length = q.size(2)
150-
# pyre-ignore: Incompatible parameter type [6]
151-
attn_mask = mask.narrow(0, start_pos, seq_length)
152-
else:
153-
attn_mask = mask[None, None, input_pos]
154143

155144
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
156145
# can natively support GQA now. But needs enable_gqa=True
157146
k = k.repeat_interleave(self.n_rep, dim=1)
158147
v = v.repeat_interleave(self.n_rep, dim=1)
159-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
148+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
160149

161150
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
162151

@@ -236,21 +225,79 @@ def __init__(
236225
enable_dynamic_shape: bool,
237226
dtype=torch.float32,
238227
):
228+
self.window_size = max_context_length
229+
"""
230+
Reason why we want the kv cache size to be twice the context length:
231+
Sliding window attention without ringbuffer
232+
pos 0 1 2 3 4 5 6 7 8 9 10
233+
0 x 0 0 0 0 0 0 0 0 0 0
234+
1 x x 0 0 0 0 0 0 0 0 0
235+
2 x x x 0 0 0 0 0 0 0 0
236+
3 x x x x 0 0 0 0 0 0 0
237+
4 0 x x x x 0 0 0 0 0 0
238+
5 0 0 x x x x 0 0 0 0 0
239+
6 0 0 0 x x x x 0 0 0 0
240+
7 0 0 0 0 x x x x 0 0 0
241+
8 0 0 0 0 0 x x x x 0 0
242+
9 0 0 0 0 0 0 x x x x 0
243+
10 0 0 0 0 0 0 0 x x x x
244+
245+
So when doing attention for pos = 5 and seq_len = 4 our attention
246+
mask would be
247+
5 0 0 x x x x 0 0 0 0 0
248+
6 0 0 0 x x x x 0 0 0 0
249+
7 0 0 0 0 x x x x 0 0 0
250+
8 0 0 0 0 0 x x x x 0 0
251+
Thus tok at pos = 5 is able to attend to tokens at pos 2, 3 and 4.
252+
This is how training is done.
253+
254+
Now lets consider ring kv cache of size 4. When we are at pos = 5
255+
before updating the kv cache, state of the kv cache would be
256+
[4 1 2 3]. That is we evicted token at pos = 0 out. Now during
257+
attention calculation at pos = 5 seq len = 4, we will update cache and
258+
new pos in the cache would be [8 5 6 7]. So note that 5 can now only attend
259+
to itself. Not 2, 3 and 4 as you would have during training.
260+
So not having kept 2, 3 and 4 in cache means we will have divergent behavior.
261+
Worst case of this would have been when update it equal to the length of
262+
the cache. like in our case pos = 5 seq len = 4.
263+
Thus we need to have a cache that is larger. How much larger, as much as
264+
the sliding window size. So twice the max_context_length.
265+
How would that have helped. Lets see. At pos = 5 our cache would have
266+
[0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have
267+
[8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the
268+
current step still has access to [pos - sliding_window_size, pos] tokens.
269+
270+
To make sure we dont over attend, i.e. we dont have pos = 5
271+
to attend to pos = 1, mask calculaton has to account for the sliding window
272+
size.
273+
"""
239274
super().__init__(
240275
max_batch_size,
241-
max_context_length,
276+
max_context_length * 2,
242277
n_heads,
243278
head_dim,
244279
enable_dynamic_shape,
245280
dtype,
246281
)
247-
self.cache_positions_manager = CachePositionsManager(max_context_length)
282+
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
283+
self.is_ring_buffer = True
284+
285+
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
286+
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
287+
cache_positions = self.cache_positions_manager.cache_positions
288+
delta = pos_q - cache_positions
289+
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size)
290+
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
291+
return attn_mask
248292

249293
def update(
250294
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
251295
) -> Tuple[torch.Tensor, torch.Tensor]:
252296
# input_pos: [S], k_val: [B, H, S, D]
253297
seq_len = k_val.size(2)
298+
assert seq_len <= self.k_cache.size(
299+
2
300+
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
254301
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
255302
input_pos, seq_len
256303
)
@@ -286,6 +333,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
286333
self.attention_qkv_bias = args.attention_qkv_bias
287334
self.use_qk_norm = args.use_qk_norm
288335
self.qk_norm_before_rope = args.qk_norm_before_rope
336+
self.enable_dynamic_shape = args.enable_dynamic_shape
289337

290338
if self.use_qk_norm:
291339
q_norm_dim = self.head_dim
@@ -331,7 +379,6 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
331379
head_dim=self.head_dim,
332380
n_rep=self.n_rep,
333381
max_context_len=self.max_context_len,
334-
enable_dynamic_shape=args.enable_dynamic_shape,
335382
)
336383

337384
def forward(
@@ -368,8 +415,22 @@ def forward(
368415

369416
if self.use_kv_cache:
370417
assert input_pos is not None
418+
if self.enable_dynamic_shape:
419+
start_pos = input_pos[-1].item()
420+
torch._check_is_size(start_pos)
421+
torch._check(start_pos < self.max_context_len)
422+
seq_length = q.size(2)
423+
# pyre-ignore: Incompatible parameter type [6]
424+
attn_mask = self.mask.narrow(0, start_pos, seq_length)
425+
else:
426+
# mask is always 2D
427+
attn_mask = self.mask[input_pos]
371428
k, v = self.kv_cache.update(input_pos, k, v)
372-
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
429+
if getattr(self.kv_cache, "is_ring_buffer", False):
430+
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
431+
input_pos[0].item(), seqlen
432+
)
433+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
373434
return self.wo(output), None
374435

375436
# grouped multiquery attention: expand out keys and values

examples/models/llama/source_transformation/sdpa.py

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,11 @@ class SDPACustom(torch.nn.Module):
2222
def __init__(
2323
self,
2424
dim: int,
25-
max_context_len,
26-
enable_dynamic_shape,
2725
use_attention_mask: bool = False,
2826
):
2927
super().__init__()
3028
self.dim = dim
31-
self.max_context_len = max_context_len
3229
self.use_attention_mask = use_attention_mask
33-
self.enable_dynamic_shape = enable_dynamic_shape
3430

3531
def forward(
3632
self,
@@ -42,16 +38,6 @@ def forward(
4238
seqlen,
4339
mask,
4440
):
45-
if self.use_attention_mask:
46-
if self.enable_dynamic_shape:
47-
start_pos = input_pos[-1].item()
48-
torch._check_is_size(start_pos)
49-
torch._check(start_pos < self.max_context_len)
50-
seq_length = q.size(2)
51-
mask = mask.narrow(0, start_pos, seq_length)
52-
else:
53-
mask = mask[input_pos]
54-
5541
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
5642
k = k.transpose(1, 2)
5743
v = v.transpose(1, 2)
@@ -96,8 +82,6 @@ def _replace_sdpa_with_custom_op(
9682
name,
9783
SDPACustom(
9884
child.dim,
99-
child.max_context_len,
100-
child.enable_dynamic_shape,
10185
use_attention_mask=use_attention_mask,
10286
),
10387
)
@@ -133,12 +117,15 @@ class QuantizedSDPA(torch.nn.Module):
133117
zero points, we need to pass kv_cache to SDPA.
134118
"""
135119

136-
def __init__(self, dim: int, kv_cache: QuantizedKVCache):
120+
def __init__(
121+
self, dim: int, kv_cache: QuantizedKVCache, use_attention_mask: bool = False
122+
):
137123
super().__init__()
138124
self.dim = dim
139125
self.quantized_dtype = torch.int8
140126
self.float_dtype = torch.float32
141127
self.kv_cache = kv_cache
128+
self.use_attention_mask = use_attention_mask
142129

143130
def forward(
144131
self,
@@ -176,22 +163,40 @@ def forward(
176163
v_scale_fp32 = self.kv_cache.v_cache_scales
177164

178165
start_pos = input_pos[0].item()
179-
output = torch.ops.llama.custom_quantized_sdpa(
180-
q_quantized,
181-
k_quantized,
182-
v_quantized,
183-
start_pos,
184-
None,
185-
0,
186-
True,
187-
None,
188-
q_zero_point_int8,
189-
q_scale_fp32,
190-
k_zero_point_int8,
191-
k_scale_fp32,
192-
v_zero_point_int8,
193-
v_scale_fp32,
194-
)
166+
if self.use_attention_mask:
167+
output = torch.ops.llama.custom_quantized_sdpa(
168+
q_quantized,
169+
k_quantized,
170+
v_quantized,
171+
start_pos,
172+
mask,
173+
0,
174+
False,
175+
None,
176+
q_zero_point_int8,
177+
q_scale_fp32,
178+
k_zero_point_int8,
179+
k_scale_fp32,
180+
v_zero_point_int8,
181+
v_scale_fp32,
182+
)
183+
else:
184+
output = torch.ops.llama.custom_quantized_sdpa(
185+
q_quantized,
186+
k_quantized,
187+
v_quantized,
188+
start_pos,
189+
None,
190+
0,
191+
True,
192+
None,
193+
q_zero_point_int8,
194+
q_scale_fp32,
195+
k_zero_point_int8,
196+
k_scale_fp32,
197+
v_zero_point_int8,
198+
v_scale_fp32,
199+
)
195200

196201
return output.view(bsz, seqlen, self.dim)
197202

@@ -201,6 +206,7 @@ def _update_attention_module_with_quantized_sdpa(
201206
):
202207
sdpa = getattr(module, "SDPA", None)
203208
assert sdpa is not None
209+
# TODO: add support for SDPA with attention mask
204210
# pyre-ignore
205211
setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache)) # noqa: B010
206212

@@ -254,7 +260,8 @@ def forward(
254260
seqlen,
255261
mask,
256262
):
257-
attn_mask = mask[None, None, input_pos]
263+
# Input mask is slided however it is 2D
264+
attn_mask = mask[None, None]
258265

259266
k = k.repeat_interleave(self.n_rep, dim=1)
260267
v = v.repeat_interleave(self.n_rep, dim=1)
@@ -310,7 +317,8 @@ def forward(
310317
"""
311318
k = repeat_kv(k, self.n_rep)
312319
v = repeat_kv(v, self.n_rep)
313-
attn_mask = mask[input_pos]
320+
# Mask is already sliced as needed
321+
attn_mask = mask
314322

315323
scale_factor = 1 / math.sqrt(q.size(-1))
316324
attn_weight = q @ k.transpose(-2, -1) * scale_factor
@@ -391,7 +399,8 @@ def forward(
391399
seqlen,
392400
mask,
393401
):
394-
attn_mask = mask[None, None, input_pos]
402+
# Input mask is slided however it is 2D
403+
attn_mask = mask[None, None]
395404

396405
if self.n_rep > 1:
397406
k = k.repeat_interleave(self.n_rep, dim=1)

examples/models/llama/source_transformation/test_quantized_sdpa.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
self.dim = dim
3232
self.head_dim = head_dim
3333
self.n_rep = n_rep
34-
self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len, enable_dynamic_shape)
34+
self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len)
3535
self.kv_cache = None
3636

3737
def forward(self, x, freqs_cos, freqs_sin, **kwargs):
@@ -159,15 +159,9 @@ def test_forward_functionality(self):
159159
k_quantized, v_quantized = model.attention.kv_cache.update(input_pos, k, v)
160160

161161
# Run the forward pass with the quantized SDPA
162-
try:
163-
output = model.attention.SDPA(
164-
input_pos, q, k_quantized, v_quantized, bsz, seqlen, None
165-
)
162+
output = model.attention.SDPA(
163+
input_pos, q, k_quantized, v_quantized, bsz, seqlen, None
164+
)
166165

167-
# Verify the output shape
168-
self.assertEqual(output.shape, (bsz, seqlen, self.dim))
169-
except Exception:
170-
# If the forward pass fails, it might be due to missing custom ops
171-
self.skipTest(
172-
"Custom ops not available, skipping forward functionality test"
173-
)
166+
# Verify the output shape
167+
self.assertEqual(output.shape, (bsz, seqlen, self.dim))

examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def test_simple(self, is_dynamic_shape=False):
7171
self.seq_len = 3
7272
self._init_cache()
7373
q, k_val, v_val = self._init_kv()
74-
self.float_sdpa = SDPACustom(self.dim, self.max_context_len, True)
75-
self.quantized_sdpa = SDPACustom(self.dim, self.max_context_len, True)
74+
self.float_sdpa = SDPACustom(self.dim)
75+
self.quantized_sdpa = SDPACustom(self.dim)
7676
k, v = self.custom_kv_cache.update(input_pos, k_val, v_val)
7777
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
7878
k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val)

examples/models/llama/tests/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,14 @@ python_unittest(
4949
"//executorch/examples/models/llama:llama_transformer",
5050
],
5151
)
52+
53+
python_unittest(
54+
name = "test_ring_attention",
55+
srcs = [
56+
"test_ring_attention.py",
57+
],
58+
deps = [
59+
"//caffe2:torch",
60+
"//executorch/examples/models/llama:llama_transformer",
61+
],
62+
)

0 commit comments

Comments
 (0)