Skip to content

Commit aaf996b

Browse files
committed
[Executorch][llm] Add ring buffer based kv cache and mask calculation to MHA
Pull Request resolved: #10609 Leveraging previous work now we allow MHA to have ring buffer cache. If ring buffer cache is used then we query the mask from kv cache and use that for sdpa instead of using precalculated mask. In this process we had to adjsut ring buffer implementation to allow keeping the context of full sliding window. See code for comment. ghstack-source-id: 283404675 @exported-using-ghexport Differential Revision: [D73891425](https://our.internmc.facebook.com/intern/diff/D73891425/)
1 parent f8e7264 commit aaf996b

File tree

8 files changed

+433
-102
lines changed

8 files changed

+433
-102
lines changed

examples/models/llama/attention.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,12 @@ def __init__(
123123
head_dim: int,
124124
n_rep: int,
125125
max_context_len: int,
126-
enable_dynamic_shape: bool,
127126
):
128127
super().__init__()
129128
self.dim = dim
130129
self.head_dim = head_dim
131130
self.n_rep = n_rep
132131
self.max_context_len = max_context_len
133-
self.enable_dynamic_shape = enable_dynamic_shape
134132

135133
def forward(
136134
self,
@@ -142,21 +140,12 @@ def forward(
142140
seqlen,
143141
mask: torch.Tensor,
144142
) -> torch.Tensor:
145-
if self.enable_dynamic_shape:
146-
start_pos = input_pos[-1].item()
147-
torch._check_is_size(start_pos)
148-
torch._check(start_pos < self.max_context_len)
149-
seq_length = q.size(2)
150-
# pyre-ignore: Incompatible parameter type [6]
151-
attn_mask = mask.narrow(0, start_pos, seq_length)
152-
else:
153-
attn_mask = mask[None, None, input_pos]
154143

155144
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
156145
# can natively support GQA now. But needs enable_gqa=True
157146
k = k.repeat_interleave(self.n_rep, dim=1)
158147
v = v.repeat_interleave(self.n_rep, dim=1)
159-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
148+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
160149

161150
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
162151

@@ -236,21 +225,79 @@ def __init__(
236225
enable_dynamic_shape: bool,
237226
dtype=torch.float32,
238227
):
228+
self.window_size = max_context_length
229+
"""
230+
Reason why we want the kv cache size to be twice the context length:
231+
Sliding window attention without ringbuffer
232+
pos 0 1 2 3 4 5 6 7 8 9 10
233+
0 x 0 0 0 0 0 0 0 0 0 0
234+
1 x x 0 0 0 0 0 0 0 0 0
235+
2 x x x 0 0 0 0 0 0 0 0
236+
3 x x x x 0 0 0 0 0 0 0
237+
4 0 x x x x 0 0 0 0 0 0
238+
5 0 0 x x x x 0 0 0 0 0
239+
6 0 0 0 x x x x 0 0 0 0
240+
7 0 0 0 0 x x x x 0 0 0
241+
8 0 0 0 0 0 x x x x 0 0
242+
9 0 0 0 0 0 0 x x x x 0
243+
10 0 0 0 0 0 0 0 x x x x
244+
245+
So when doing attention for pos = 5 and seq_len = 4 our attention
246+
mask would be
247+
5 0 0 x x x x 0 0 0 0 0
248+
6 0 0 0 x x x x 0 0 0 0
249+
7 0 0 0 0 x x x x 0 0 0
250+
8 0 0 0 0 0 x x x x 0 0
251+
Thus tok at pos = 5 is able to attend to tokens at pos 2, 3 and 4.
252+
This is how training is done.
253+
254+
Now lets consider ring kv cache of size 4. When we are at pos = 5
255+
before updating the kv cache, state of the kv cache would be
256+
[4 1 2 3]. That is we evicted token at pos = 0 out. Now during
257+
attention calculation at pos = 5 seq len = 4, we will update cache and
258+
new pos in the cache would be [8 5 6 7]. So note that 5 can now only attend
259+
to itself. Not 2, 3 and 4 as you would have during training.
260+
So not having kept 2, 3 and 4 in cache means we will have divergent behavior.
261+
Worst case of this would have been when update it equal to the length of
262+
the cache. like in our case pos = 5 seq len = 4.
263+
Thus we need to have a cache that is larger. How much larger, as much as
264+
the sliding window size. So twice the max_context_length.
265+
How would that have helped. Lets see. At pos = 5 our cache would have
266+
[0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have
267+
[8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the
268+
current step still has access to [pos - sliding_window_size, pos] tokens.
269+
270+
To make sure we dont over attend, i.e. we dont have pos = 5
271+
to attend to pos = 1, mask calculaton has to account for the sliding window
272+
size.
273+
"""
239274
super().__init__(
240275
max_batch_size,
241-
max_context_length,
276+
max_context_length * 2,
242277
n_heads,
243278
head_dim,
244279
enable_dynamic_shape,
245280
dtype,
246281
)
247-
self.cache_positions_manager = CachePositionsManager(max_context_length)
282+
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
283+
self.is_ring_buffer = True
284+
285+
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
286+
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
287+
cache_positions = self.cache_positions_manager.cache_positions
288+
delta = pos_q - cache_positions
289+
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size)
290+
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
291+
return attn_mask
248292

249293
def update(
250294
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
251295
) -> Tuple[torch.Tensor, torch.Tensor]:
252296
# input_pos: [S], k_val: [B, H, S, D]
253297
seq_len = k_val.size(2)
298+
assert seq_len <= self.k_cache.size(
299+
2
300+
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
254301
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
255302
input_pos, seq_len
256303
)
@@ -286,6 +333,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
286333
self.attention_qkv_bias = args.attention_qkv_bias
287334
self.use_qk_norm = args.use_qk_norm
288335
self.qk_norm_before_rope = args.qk_norm_before_rope
336+
self.enable_dynamic_shape = args.enable_dynamic_shape
289337

290338
if self.use_qk_norm:
291339
q_norm_dim = self.head_dim
@@ -331,7 +379,6 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
331379
head_dim=self.head_dim,
332380
n_rep=self.n_rep,
333381
max_context_len=self.max_context_len,
334-
enable_dynamic_shape=args.enable_dynamic_shape,
335382
)
336383

337384
def forward(
@@ -368,8 +415,22 @@ def forward(
368415

369416
if self.use_kv_cache:
370417
assert input_pos is not None
418+
if self.enable_dynamic_shape:
419+
start_pos = input_pos[-1].item()
420+
torch._check_is_size(start_pos)
421+
torch._check(start_pos < self.max_context_len)
422+
seq_length = q.size(2)
423+
# pyre-ignore: Incompatible parameter type [6]
424+
attn_mask = self.mask.narrow(0, start_pos, seq_length)
425+
else:
426+
# mask is always 2D
427+
attn_mask = self.mask[input_pos]
371428
k, v = self.kv_cache.update(input_pos, k, v)
372-
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
429+
if getattr(self.kv_cache, "is_ring_buffer", False):
430+
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
431+
input_pos[0].item(), seqlen
432+
)
433+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
373434
return self.wo(output), None
374435

375436
# grouped multiquery attention: expand out keys and values

examples/models/llama/source_transformation/sdpa.py

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,11 @@ class SDPACustom(torch.nn.Module):
2222
def __init__(
2323
self,
2424
dim: int,
25-
max_context_len,
26-
enable_dynamic_shape,
2725
use_attention_mask: bool = False,
2826
):
2927
super().__init__()
3028
self.dim = dim
31-
self.max_context_len = max_context_len
3229
self.use_attention_mask = use_attention_mask
33-
self.enable_dynamic_shape = enable_dynamic_shape
3430

3531
def forward(
3632
self,
@@ -42,16 +38,6 @@ def forward(
4238
seqlen,
4339
mask,
4440
):
45-
if self.use_attention_mask:
46-
if self.enable_dynamic_shape:
47-
start_pos = input_pos[-1].item()
48-
torch._check_is_size(start_pos)
49-
torch._check(start_pos < self.max_context_len)
50-
seq_length = q.size(2)
51-
mask = mask.narrow(0, start_pos, seq_length)
52-
else:
53-
mask = mask[input_pos]
54-
5541
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
5642
k = k.transpose(1, 2)
5743
v = v.transpose(1, 2)
@@ -96,8 +82,6 @@ def _replace_sdpa_with_custom_op(
9682
name,
9783
SDPACustom(
9884
child.dim,
99-
child.max_context_len,
100-
child.enable_dynamic_shape,
10185
use_attention_mask=use_attention_mask,
10286
),
10387
)
@@ -133,12 +117,15 @@ class QuantizedSDPA(torch.nn.Module):
133117
zero points, we need to pass kv_cache to SDPA.
134118
"""
135119

136-
def __init__(self, dim: int, kv_cache: QuantizedKVCache):
120+
def __init__(
121+
self, dim: int, kv_cache: QuantizedKVCache, use_attention_mask: bool = False
122+
):
137123
super().__init__()
138124
self.dim = dim
139125
self.quantized_dtype = torch.int8
140126
self.float_dtype = torch.float32
141127
self.kv_cache = kv_cache
128+
self.use_attention_mask = use_attention_mask
142129

143130
def forward(
144131
self,
@@ -176,22 +163,40 @@ def forward(
176163
v_scale_fp32 = self.kv_cache.v_cache_scales
177164

178165
start_pos = input_pos[0].item()
179-
output = torch.ops.llama.custom_quantized_sdpa(
180-
q_quantized,
181-
k_quantized,
182-
v_quantized,
183-
start_pos,
184-
None,
185-
0,
186-
True,
187-
None,
188-
q_zero_point_int8,
189-
q_scale_fp32,
190-
k_zero_point_int8,
191-
k_scale_fp32,
192-
v_zero_point_int8,
193-
v_scale_fp32,
194-
)
166+
if self.use_attention_mask:
167+
output = torch.ops.llama.custom_quantized_sdpa(
168+
q_quantized,
169+
k_quantized,
170+
v_quantized,
171+
start_pos,
172+
mask,
173+
0,
174+
False,
175+
None,
176+
q_zero_point_int8,
177+
q_scale_fp32,
178+
k_zero_point_int8,
179+
k_scale_fp32,
180+
v_zero_point_int8,
181+
v_scale_fp32,
182+
)
183+
else:
184+
output = torch.ops.llama.custom_quantized_sdpa(
185+
q_quantized,
186+
k_quantized,
187+
v_quantized,
188+
start_pos,
189+
None,
190+
0,
191+
True,
192+
None,
193+
q_zero_point_int8,
194+
q_scale_fp32,
195+
k_zero_point_int8,
196+
k_scale_fp32,
197+
v_zero_point_int8,
198+
v_scale_fp32,
199+
)
195200

196201
return output.view(bsz, seqlen, self.dim)
197202

@@ -201,6 +206,7 @@ def _update_attention_module_with_quantized_sdpa(
201206
):
202207
sdpa = getattr(module, "SDPA", None)
203208
assert sdpa is not None
209+
# TODO: add support for SDPA with attention mask
204210
# pyre-ignore
205211
setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache)) # noqa: B010
206212

@@ -254,7 +260,8 @@ def forward(
254260
seqlen,
255261
mask,
256262
):
257-
attn_mask = mask[None, None, input_pos]
263+
# Input mask is slided however it is 2D
264+
attn_mask = mask[None, None]
258265

259266
k = k.repeat_interleave(self.n_rep, dim=1)
260267
v = v.repeat_interleave(self.n_rep, dim=1)
@@ -310,7 +317,8 @@ def forward(
310317
"""
311318
k = repeat_kv(k, self.n_rep)
312319
v = repeat_kv(v, self.n_rep)
313-
attn_mask = mask[input_pos]
320+
# Mask is already sliced as needed
321+
attn_mask = mask
314322

315323
scale_factor = 1 / math.sqrt(q.size(-1))
316324
attn_weight = q @ k.transpose(-2, -1) * scale_factor
@@ -391,7 +399,8 @@ def forward(
391399
seqlen,
392400
mask,
393401
):
394-
attn_mask = mask[None, None, input_pos]
402+
# Input mask is slided however it is 2D
403+
attn_mask = mask[None, None]
395404

396405
if self.n_rep > 1:
397406
k = k.repeat_interleave(self.n_rep, dim=1)

examples/models/llama/source_transformation/test_quantized_sdpa.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
self.dim = dim
3232
self.head_dim = head_dim
3333
self.n_rep = n_rep
34-
self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len, enable_dynamic_shape)
34+
self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len)
3535
self.kv_cache = None
3636

3737
def forward(self, x, freqs_cos, freqs_sin, **kwargs):
@@ -159,15 +159,9 @@ def test_forward_functionality(self):
159159
k_quantized, v_quantized = model.attention.kv_cache.update(input_pos, k, v)
160160

161161
# Run the forward pass with the quantized SDPA
162-
try:
163-
output = model.attention.SDPA(
164-
input_pos, q, k_quantized, v_quantized, bsz, seqlen, None
165-
)
162+
output = model.attention.SDPA(
163+
input_pos, q, k_quantized, v_quantized, bsz, seqlen, None
164+
)
166165

167-
# Verify the output shape
168-
self.assertEqual(output.shape, (bsz, seqlen, self.dim))
169-
except Exception:
170-
# If the forward pass fails, it might be due to missing custom ops
171-
self.skipTest(
172-
"Custom ops not available, skipping forward functionality test"
173-
)
166+
# Verify the output shape
167+
self.assertEqual(output.shape, (bsz, seqlen, self.dim))

examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def test_simple(self, is_dynamic_shape=False):
7171
self.seq_len = 3
7272
self._init_cache()
7373
q, k_val, v_val = self._init_kv()
74-
self.float_sdpa = SDPACustom(self.dim, self.max_context_len, True)
75-
self.quantized_sdpa = SDPACustom(self.dim, self.max_context_len, True)
74+
self.float_sdpa = SDPACustom(self.dim)
75+
self.quantized_sdpa = SDPACustom(self.dim)
7676
k, v = self.custom_kv_cache.update(input_pos, k_val, v_val)
7777
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
7878
k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val)

examples/models/llama/tests/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,14 @@ python_unittest(
4949
"//executorch/examples/models/llama:llama_transformer",
5050
],
5151
)
52+
53+
python_unittest(
54+
name = "test_ring_attention",
55+
srcs = [
56+
"test_ring_attention.py",
57+
],
58+
deps = [
59+
"//caffe2:torch",
60+
"//executorch/examples/models/llama:llama_transformer",
61+
],
62+
)

0 commit comments

Comments
 (0)