Skip to content

Commit cc0b4ed

Browse files
sxufacebook-github-bot
authored andcommitted
Update static attention IO manager to use "smart mask" style update
Differential Revision: D72322014
1 parent 1572381 commit cc0b4ed

File tree

1 file changed

+150
-69
lines changed

1 file changed

+150
-69
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 150 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,65 @@
1616

1717
namespace example {
1818

19+
enum class StaticAttentionUpdateStyle {
20+
/**
21+
* KV caches will have valid data at the end of the cache. New elements are
22+
* added at the end and the pointers will shift forward (an extra copy is
23+
* allocated) to maintain this invariant. This potentially allows shorter
24+
* caches to be passed into the model by adjusting the start pointer.
25+
*/
26+
SLIDING_CACHE,
27+
/**
28+
* I/O pointers do not change which can enable persistent memory mapping
29+
* between AP and NPU. Can also implemente circular cache by adjusting the
30+
* attention mask accordingly.
31+
*/
32+
SMART_MASK,
33+
};
34+
1935
template <typename T, typename AllocatorT = std::allocator<T>>
2036
class StaticKVCache {
2137
public:
2238
/**
23-
* Helper class to handle KV cache I/O. Assumes batch size 1, same context
24-
* length and head dimension for each cache. Supports hybrid operation mixing
25-
* prefill and decode. Create one instance for key caches and another one for
26-
* value caches.
39+
* Helper class to handle KV cache I/O. Assumes batch size 1, same length and
40+
* head dimension for each cache. Supports multi-turn operation mixing prefill
41+
* and decode by sharing the same cache between methods with different input
42+
* length. Create one instance for key caches and another one for value
43+
* caches.
2744
*/
2845
StaticKVCache(
2946
size_t n_caches,
3047
size_t cache_len,
3148
size_t head_dim,
3249
size_t max_input_len = 1,
33-
bool transpose = false)
50+
bool transpose = false,
51+
StaticAttentionUpdateStyle style =
52+
StaticAttentionUpdateStyle::SLIDING_CACHE)
3453
: n_caches_(n_caches),
3554
cache_len_(cache_len),
3655
max_input_len_(max_input_len),
3756
head_dim_(head_dim),
38-
transpose_(transpose) {
39-
// Updates are appeneded at the end. Need one extra segment to support the
40-
// sliding window.
41-
data_size_ = (n_caches_ + 1) * cache_len_ * head_dim_ + max_input_len_;
42-
data_ = allocator_.allocate(data_size_);
43-
ET_CHECK(data_ != nullptr);
44-
reset();
57+
transpose_(transpose),
58+
style_(style),
59+
input_ptrs_(n_caches_),
60+
output_ptrs_(n_caches_) {
61+
if (transpose_) {
62+
throw std::runtime_error("Not implemented.");
63+
}
64+
65+
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
66+
// Allocates on extra copy to accomodate caches sliding forward.
67+
cache_data_size_ = (n_caches_ + 1) * cache_len_ * head_dim_;
68+
} else {
69+
cache_data_size_ = n_caches_ * cache_len_ * head_dim_;
70+
}
71+
update_data_size_ = n_caches_ * max_input_len_ * head_dim_;
72+
73+
cache_data_ = allocator_.allocate(cache_data_size_);
74+
update_data_ = allocator_.allocate(update_data_size_);
75+
ET_CHECK(cache_data_ != nullptr);
76+
ET_CHECK(update_data_ != nullptr);
77+
init_ptrs();
4578
}
4679

4780
StaticKVCache(const StaticKVCache& other) = delete;
@@ -50,23 +83,24 @@ class StaticKVCache {
5083
StaticKVCache& operator=(StaticKVCache&& other) = delete;
5184

5285
~StaticKVCache() {
53-
allocator_.deallocate(data_, data_size_);
86+
allocator_.deallocate(cache_data_, cache_data_size_);
87+
allocator_.deallocate(update_data_, update_data_size_);
5488
}
5589

5690
/**
5791
* Set up data pointers for the KV cache related inputs and outputs based on
5892
* the current state of the cache. Call StaticKVCache<T>::update or
59-
* StaticKVCache<T>::reset first as needed before calling this function.
93+
* StaticKVCache<T>::reset as needed before calling this function.
6094
*/
6195
void prepare(
6296
torch::executor::Method& method,
6397
const std::vector<size_t>& inputIndices,
64-
const std::vector<size_t>& outputIndices) {
65-
ET_CHECK(inputIndices.size() == outputIndices.size());
98+
const std::vector<size_t>& output_indices) {
99+
ET_CHECK(inputIndices.size() == output_indices.size());
66100
auto methodMeta = method.method_meta();
67101
for (size_t i = 0; i < n_caches_; i++) {
68102
auto inIdx = inputIndices[i];
69-
auto outIdx = outputIndices[i];
103+
auto outIdx = output_indices[i];
70104
auto inMeta = methodMeta.input_tensor_meta(inIdx);
71105
auto outMeta = methodMeta.output_tensor_meta(outIdx);
72106
ET_CHECK(inMeta.ok());
@@ -106,74 +140,90 @@ class StaticKVCache {
106140
/**
107141
* Update the internal data pointers using the cache updates returned by the
108142
* model. This length of each individual update cannot exceed the max update
109-
* length specified during the creation, and the total length cannot exceed
110-
* the context length.
143+
* length specified during creation, and the total length cannot exceed the
144+
* cache length.
111145
*/
112146
void update(
113147
torch::executor::Method& method,
114-
const std::vector<size_t>& outputIndices,
148+
const std::vector<size_t>& output_indices,
115149
size_t update_len) {
116150
if (valid_len_ + update_len > cache_len_) {
117151
throw std::runtime_error("Cache capacity exceeded.");
118152
}
119153

120-
if (transpose_) {
121-
throw std::runtime_error("Not implemented.");
154+
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
155+
update_sliding_cache(method, output_indices, update_len);
122156
} else {
123-
updateSeqDim(method, outputIndices, update_len);
157+
update_smart_mask(method, output_indices, update_len);
124158
}
125-
valid_len_ += update_len;
126159
}
127160

128161
/**
129162
* Reset the cache. After this the cache contains no valid data and is ready
130-
* for number of tokens up to the context length.
163+
* for number of tokens up to the cache length.
131164
*/
132165
void reset() {
133166
valid_len_ = 0;
134-
if (transpose_) {
135-
throw std::runtime_error("Not implemented.");
136-
} else {
137-
initSeqDim();
167+
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
168+
init_ptrs();
138169
}
139170
}
140171

141172
private:
142-
void initSeqDim() {
143-
auto cacheSize = cache_len_ * head_dim_;
173+
void init_ptrs() {
144174
input_ptrs_.resize(n_caches_);
145175
output_ptrs_.resize(n_caches_);
146176
for (size_t i = 0; i < n_caches_; i++) {
147-
input_ptrs_[i] = data_ + i * cacheSize;
148-
output_ptrs_[i] = input_ptrs_[i] + cacheSize;
177+
input_ptrs_[i] = cache_data_ + i * cache_len_ * head_dim_;
178+
output_ptrs_[i] = update_data_ + i * max_input_len_ * head_dim_;
149179
}
150180
}
151181

152-
void updateSeqDim(
182+
void update_sliding_cache(
153183
torch::executor::Method& method,
154-
const std::vector<size_t>& outputIndices,
184+
const std::vector<size_t>& output_indices,
155185
size_t update_len) {
156-
ET_CHECK(n_caches_ == outputIndices.size());
186+
ET_CHECK(n_caches_ == output_indices.size());
157187
for (size_t i = 0; i < n_caches_; i++) {
158-
const auto& updateTensor = method.get_output(outputIndices[i]).toTensor();
159-
ET_CHECK(
160-
input_ptrs_[i] + cache_len_ * head_dim_ ==
161-
updateTensor.mutable_data_ptr<T>());
162-
188+
const auto& updateTensor =
189+
method.get_output(output_indices[i]).toTensor();
190+
ET_CHECK(output_ptrs_[i] == updateTensor.const_data_ptr<T>());
191+
std::copy(
192+
output_ptrs_[i],
193+
output_ptrs_[i] + update_len * head_dim_,
194+
input_ptrs_[i] + cache_len_ * head_dim_);
163195
input_ptrs_[i] += update_len * head_dim_;
164-
output_ptrs_[i] += update_len * head_dim_;
165196
}
197+
valid_len_ += update_len;
198+
}
199+
200+
void update_smart_mask(
201+
torch::executor::Method& method,
202+
const std::vector<size_t>& output_indices,
203+
size_t update_len) {
204+
for (size_t i = 0; i < n_caches_; i++) {
205+
const auto& updateTensor =
206+
method.get_output(output_indices[i]).toTensor();
207+
ET_CHECK(output_ptrs_[i] == updateTensor.mutable_data_ptr<T>());
208+
std::copy(
209+
output_ptrs_[i],
210+
output_ptrs_[i] + update_len * head_dim_,
211+
input_ptrs_[i] + valid_len_ * head_dim_);
212+
}
213+
valid_len_ += update_len;
166214
}
167215

168-
// std::vector<T> pool_;
169216
size_t n_caches_;
170217
size_t cache_len_;
171218
size_t max_input_len_;
172219
size_t head_dim_;
173220
bool transpose_;
221+
StaticAttentionUpdateStyle style_;
174222
AllocatorT allocator_;
175-
size_t data_size_;
176-
T* data_;
223+
size_t cache_data_size_;
224+
T* cache_data_;
225+
size_t update_data_size_;
226+
T* update_data_;
177227
std::vector<T*> input_ptrs_;
178228
std::vector<T*> output_ptrs_;
179229
size_t valid_len_ = 0;
@@ -183,28 +233,30 @@ template <typename T, typename AllocatorT = std::allocator<T>>
183233
class StaticAttentionMask {
184234
public:
185235
/**
186-
* Manages the attention mask in the same style of KV cache IO where valid
187-
* data is at the end of the cache. The mask has shape (1, maxSeqLen,
188-
* cache_len
189-
* + maxSeqLen) where maxSeqLen is 1 for decode or the prefill length. Accepts
190-
* zero_val and mask_val (which represents -inf) to support quantized mask.
236+
* Manages the attention mask for StaticKVCache. Create one mask for each
237+
* input length. Accepts zero_val and mask_val (which represents -inf) to
238+
* support quantized mask.
191239
*
192-
* This class manages the slice of the mask at [:, :, : (cache_len -
193-
* validCacheLen)]. User can update the rest of the mask to implement causal
194-
* masking for example.
240+
* The mask shape is (1, input_len, cache_len + input_len). This class manages
241+
* the slice of the mask at [:, :, :cache_len] to only allow valid cache
242+
* elements to participate in the attention. User can update the rest of the
243+
* mask (to implement causal mask for example).
195244
*/
196245
StaticAttentionMask(
197246
size_t cache_len,
198247
size_t input_len,
199248
size_t head_dim,
200249
T zero_val,
201-
T mask_val)
250+
T mask_val,
251+
StaticAttentionUpdateStyle style =
252+
StaticAttentionUpdateStyle::SLIDING_CACHE)
202253
: cache_len_(cache_len),
203254
input_len_(input_len),
204255
head_dim_(head_dim),
205-
cache_mask_len_(cache_len_),
256+
cache_valid_len_(0),
206257
zero_val_(zero_val),
207-
mask_val_(mask_val) {
258+
mask_val_(mask_val),
259+
style_(style) {
208260
data_size_ = input_len_ * (cache_len_ + input_len_);
209261
data_ = allocator_.allocate(data_size_);
210262
ET_CHECK(data_ != nullptr);
@@ -224,7 +276,7 @@ class StaticAttentionMask {
224276
* Reset the mask to the state where the cache contains no valid data.
225277
*/
226278
void reset() {
227-
cache_mask_len_ = cache_len_;
279+
cache_valid_len_ = 0;
228280
for (size_t i = 0; i < input_len_; i++) {
229281
auto* p = data_ + (cache_len_ + input_len_) * i;
230282
std::fill(p, p + cache_len_, mask_val_);
@@ -233,19 +285,29 @@ class StaticAttentionMask {
233285

234286
/**
235287
* Update the mask to indicate update_len elements have been added to the
236-
* cache. Note that update_len might be smaller than maxSeqLen when prefilling
237-
* with padded inputs.
288+
* cache. Note that update_len might be smaller than input_len_ when
289+
* prefilling with padded inputs.
238290
*/
239-
void updateCacheMask(size_t update_len) {
240-
for (size_t i = 0; i < input_len_; i++) {
241-
auto* p = data_ + (cache_len_ + input_len_) * i;
242-
std::fill(
243-
p + cache_mask_len_ - update_len, p + cache_mask_len_, zero_val_);
291+
void unmask(size_t update_len) {
292+
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
293+
for (size_t i = 0; i < input_len_; i++) {
294+
auto* p = data_ + (cache_len_ + input_len_) * i;
295+
std::fill(
296+
p + cache_len_ - cache_valid_len_ - update_len,
297+
p + cache_len_ - cache_valid_len_,
298+
zero_val_);
299+
}
300+
} else {
301+
for (size_t i = 0; i < input_len_; i++) {
302+
auto* p = data_ + (cache_len_ + input_len_) * i;
303+
std::fill(
304+
p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
305+
}
244306
}
245-
cache_mask_len_ -= update_len;
307+
cache_valid_len_ += update_len;
246308
}
247309

248-
void setCausalMask() {
310+
void set_causal_mask() {
249311
for (size_t i = 0; i < input_len_ - 1; i++) {
250312
auto* p = data_ + (cache_len_ + input_len_) * i;
251313
std::fill(p + cache_len_, p + cache_len_ + 1 + i, zero_val_);
@@ -261,9 +323,10 @@ class StaticAttentionMask {
261323
size_t cache_len_;
262324
size_t input_len_;
263325
size_t head_dim_;
264-
size_t cache_mask_len_;
326+
size_t cache_valid_len_;
265327
T zero_val_;
266328
T mask_val_;
329+
StaticAttentionUpdateStyle style_;
267330
AllocatorT allocator_;
268331
size_t data_size_ = 0;
269332
T* data_;
@@ -285,7 +348,9 @@ class StaticAttentionIOManager {
285348
size_t rope_freqs_cos_index,
286349
size_t rope_freqs_sin_index,
287350
RopeT* rope_freqs_cos,
288-
RopeT* rope_freqs_sin)
351+
RopeT* rope_freqs_sin,
352+
StaticAttentionUpdateStyle style =
353+
StaticAttentionUpdateStyle::SLIDING_CACHE)
289354
: cache_len_(cache_len),
290355
head_dim_(head_dim),
291356
kCaches_(n_caches, cache_len, head_dim, max_input_len),
@@ -295,6 +360,9 @@ class StaticAttentionIOManager {
295360
rope_freqs_cos_(rope_freqs_cos),
296361
rope_freqs_sin_(rope_freqs_sin) {}
297362

363+
/**
364+
* Create a new StaticAttentionMask that will be managed by this object.
365+
*/
298366
StaticAttentionMask<MaskT, MaskAllocatorT>&
299367
addMask(size_t input_len, MaskT zero_val, MaskT mask_val) {
300368
auto it = attentionMasks_.emplace(
@@ -305,10 +373,16 @@ class StaticAttentionIOManager {
305373
return it.first->second;
306374
}
307375

376+
/**
377+
* Retrieve a mask suitable for given input length.
378+
*/
308379
StaticAttentionMask<MaskT, MaskAllocatorT>& getMask(size_t input_len) {
309380
return attentionMasks_.at(input_len);
310381
}
311382

383+
/**
384+
* Set I/O pointers for KV cache and RoPE freqencies.
385+
*/
312386
void prepare(
313387
torch::executor::Method& method,
314388
const std::vector<size_t>& k_cache_input_indices,
@@ -327,6 +401,10 @@ class StaticAttentionIOManager {
327401
rope_freqs_sin_ + input_pos_ * head_dim_ / 2);
328402
}
329403

404+
/**
405+
* Update all caches and masks under management to reflect that model produced
406+
* update_len new elements.
407+
*/
330408
void update(
331409
torch::executor::Method& method,
332410
const std::vector<size_t>& k_cache_output_indices,
@@ -336,10 +414,13 @@ class StaticAttentionIOManager {
336414
kCaches_.update(method, k_cache_output_indices, update_len);
337415
vCaches_.update(method, v_cache_output_indices, update_len);
338416
for (auto& it : attentionMasks_) {
339-
it.second.updateCacheMask(update_len);
417+
it.second.unmask(update_len);
340418
}
341419
}
342420

421+
/**
422+
* Reset all caches and masks under management.
423+
*/
343424
void reset() {
344425
input_pos_ = 0;
345426
kCaches_.reset();

0 commit comments

Comments
 (0)