Skip to content

Commit d7201ab

Browse files
authored
[Executorch][llm] Add support for ring kv cache and ring attention (#10832)
Pull Request resolved: #10608 Introduced CachePositionManager to keep track of what is the position for each slot in ring kv cache. This is used to generate mask. ghstack-source-id: 283404678 @exported-using-ghexport Differential Revision: [D73891427](https://our.internmc.facebook.com/intern/diff/D73891427/)
1 parent 756f86a commit d7201ab

File tree

3 files changed

+595
-0
lines changed

3 files changed

+595
-0
lines changed

examples/models/llama/attention.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from enum import Enum
23
from typing import Any, Dict, Optional, Tuple, Type, TypedDict
34

45
import torch
@@ -160,6 +161,112 @@ def forward(
160161
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
161162

162163

164+
class CacheUpdateStrategy(Enum):
165+
RING_BUFFER = "RingBuffer"
166+
INVALID = "Invalid"
167+
168+
169+
class CachePositionsManager(nn.Module):
170+
def __init__(
171+
self,
172+
max_context_length: int,
173+
cache_update_strategy: CacheUpdateStrategy = CacheUpdateStrategy.RING_BUFFER,
174+
):
175+
super().__init__()
176+
assert (
177+
cache_update_strategy == CacheUpdateStrategy.RING_BUFFER
178+
), "Only RingBuffer is supported"
179+
self.max_context_length = max_context_length
180+
self.register_buffer(
181+
"cache_positions",
182+
torch.zeros((self.max_context_length), dtype=torch.long, device="cpu"),
183+
)
184+
185+
def calculate_positions_and_update_indices(self, input_pos: torch.Tensor, seq_len):
186+
"""
187+
Calculate indices, into k_cache, v_cache, where to put k_val tensor.
188+
Given the input_pos and length of k_val at sequence dim, the input pos may
189+
have to wrap around if it is smaller than the cache capacity.
190+
If it is larger than the cache capacity then just pick the last
191+
self.max_context_length entries.
192+
193+
Additionally:
194+
Update the cache positions buffer with the new indices.
195+
Given the cache positions in sequence dim, indicated by indices,
196+
we can just update cache_positions buffer using orig_indices.
197+
For example
198+
Given cache capacity of 4 and update of length 3 with start_pos = 2
199+
will have following values
200+
indices = [2, 3, 0]
201+
orig_indices = [2, 3, 4]
202+
So cache_positions after the update will be [4, 1, 2, 3]
203+
Note cache_positions[1] = 1 that is from previous write to the cache.
204+
The corner case here is cache positions before cache rolls over.
205+
For example when start_pos = 0 and update is of length 2, then we have
206+
filled positions 0 and 1 in the buffer, while the rest are invalid. In this case
207+
we have
208+
indices = [0, 1]
209+
orig_indices = [0, 1]
210+
But if we have cache_positins = [0, 1, 0, 0] that is not valid. Hence we have
211+
to make sure that invalid positions have a sentinel value of - 1.
212+
"""
213+
start_pos = input_pos[0].item()
214+
torch._check_is_size(start_pos)
215+
orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos
216+
indices = orig_indices % self.max_context_length
217+
218+
full_t = torch.full((self.max_context_length,), -1, dtype=torch.long)
219+
arange_tensor = torch.arange(self.max_context_length, dtype=torch.long)
220+
cache_positions = torch.where(
221+
arange_tensor < start_pos, self.cache_positions, full_t
222+
)
223+
self.cache_positions.copy_(cache_positions)
224+
self.cache_positions.index_copy_(0, indices, orig_indices)
225+
226+
return indices
227+
228+
229+
class RingKVCache(KVCache):
230+
def __init__(
231+
self,
232+
max_batch_size: int,
233+
max_context_length: int,
234+
n_heads: int,
235+
head_dim: int,
236+
enable_dynamic_shape: bool,
237+
dtype=torch.float32,
238+
):
239+
super().__init__(
240+
max_batch_size,
241+
max_context_length,
242+
n_heads,
243+
head_dim,
244+
enable_dynamic_shape,
245+
dtype,
246+
)
247+
self.cache_positions_manager = CachePositionsManager(max_context_length)
248+
249+
def update(
250+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
251+
) -> Tuple[torch.Tensor, torch.Tensor]:
252+
# input_pos: [S], k_val: [B, H, S, D]
253+
seq_len = k_val.size(2)
254+
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
255+
input_pos, seq_len
256+
)
257+
if self.enable_dynamic_shape:
258+
start_pos = input_pos[0].item()
259+
torch._check_is_size(start_pos)
260+
261+
self.k_cache.index_copy_(2, indices, k_val)
262+
self.v_cache.index_copy_(2, indices, v_val)
263+
else:
264+
self.k_cache[:, :, indices] = k_val
265+
self.v_cache[:, :, indices] = v_val
266+
267+
return self.k_cache, self.v_cache
268+
269+
163270
@register_attention("mha")
164271
class AttentionMHA(Attention):
165272
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):

examples/models/llama/tests/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,14 @@ python_unittest(
3838
"//executorch/examples/models/llama:static_attention",
3939
],
4040
)
41+
42+
python_unittest(
43+
name = "test_ring_kv_cache",
44+
srcs = [
45+
"test_ring_kv_cache.py",
46+
],
47+
deps = [
48+
"//caffe2:torch",
49+
"//executorch/examples/models/llama:llama_transformer",
50+
],
51+
)

0 commit comments

Comments
 (0)