16
16
17
17
namespace example {
18
18
19
+ enum class StaticAttentionUpdateStyle {
20
+ /* *
21
+ * KV caches will have valid data at the end of the cache. New elements are
22
+ * added at the end and the pointers will shift forward (an extra copy is
23
+ * allocated) to maintain this invariant. This potentially allows shorter
24
+ * caches to be passed into the model by adjusting the start pointer.
25
+ */
26
+ SLIDING_CACHE,
27
+ /* *
28
+ * I/O pointers do not change which can enable persistent memory mapping
29
+ * between AP and NPU. Can also implemente circular cache by adjusting the
30
+ * attention mask accordingly.
31
+ */
32
+ SMART_MASK,
33
+ };
34
+
19
35
template <typename T, typename AllocatorT = std::allocator<T>>
20
36
class StaticKVCache {
21
37
public:
22
38
/* *
23
- * Helper class to handle KV cache I/O. Assumes batch size 1, same context
24
- * length and head dimension for each cache. Supports hybrid operation mixing
25
- * prefill and decode. Create one instance for key caches and another one for
26
- * value caches.
39
+ * Helper class to handle KV cache I/O. Assumes batch size 1, same length and
40
+ * head dimension for each cache. Supports multi-turn operation mixing prefill
41
+ * and decode by sharing the same cache between methods with different input
42
+ * length. Create one instance for key caches and another one for value
43
+ * caches.
27
44
*/
28
45
StaticKVCache (
29
46
size_t n_caches,
30
47
size_t cache_len,
31
48
size_t head_dim,
32
49
size_t max_input_len = 1 ,
33
- bool transpose = false )
50
+ bool transpose = false ,
51
+ StaticAttentionUpdateStyle style =
52
+ StaticAttentionUpdateStyle::SLIDING_CACHE)
34
53
: n_caches_(n_caches),
35
54
cache_len_ (cache_len),
36
55
max_input_len_(max_input_len),
37
56
head_dim_(head_dim),
38
- transpose_(transpose) {
39
- // Updates are appeneded at the end. Need one extra segment to support the
40
- // sliding window.
41
- data_size_ = (n_caches_ + 1 ) * cache_len_ * head_dim_ + max_input_len_;
42
- data_ = allocator_.allocate (data_size_);
43
- ET_CHECK (data_ != nullptr );
44
- reset ();
57
+ transpose_(transpose),
58
+ style_(style),
59
+ input_ptrs_(n_caches_),
60
+ output_ptrs_(n_caches_) {
61
+ if (transpose_) {
62
+ throw std::runtime_error (" Not implemented." );
63
+ }
64
+
65
+ if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
66
+ // Allocates on extra copy to accomodate caches sliding forward.
67
+ cache_data_size_ = (n_caches_ + 1 ) * cache_len_ * head_dim_;
68
+ } else {
69
+ cache_data_size_ = n_caches_ * cache_len_ * head_dim_;
70
+ }
71
+ update_data_size_ = n_caches_ * max_input_len_ * head_dim_;
72
+
73
+ cache_data_ = allocator_.allocate (cache_data_size_);
74
+ update_data_ = allocator_.allocate (update_data_size_);
75
+ ET_CHECK (cache_data_ != nullptr );
76
+ ET_CHECK (update_data_ != nullptr );
77
+ init_ptrs ();
45
78
}
46
79
47
80
StaticKVCache (const StaticKVCache& other) = delete;
@@ -50,23 +83,24 @@ class StaticKVCache {
50
83
StaticKVCache& operator =(StaticKVCache&& other) = delete ;
51
84
52
85
~StaticKVCache () {
53
- allocator_.deallocate (data_, data_size_);
86
+ allocator_.deallocate (cache_data_, cache_data_size_);
87
+ allocator_.deallocate (update_data_, update_data_size_);
54
88
}
55
89
56
90
/* *
57
91
* Set up data pointers for the KV cache related inputs and outputs based on
58
92
* the current state of the cache. Call StaticKVCache<T>::update or
59
- * StaticKVCache<T>::reset first as needed before calling this function.
93
+ * StaticKVCache<T>::reset as needed before calling this function.
60
94
*/
61
95
void prepare (
62
96
torch::executor::Method& method,
63
97
const std::vector<size_t >& inputIndices,
64
- const std::vector<size_t >& outputIndices ) {
65
- ET_CHECK (inputIndices.size () == outputIndices .size ());
98
+ const std::vector<size_t >& output_indices ) {
99
+ ET_CHECK (inputIndices.size () == output_indices .size ());
66
100
auto methodMeta = method.method_meta ();
67
101
for (size_t i = 0 ; i < n_caches_; i++) {
68
102
auto inIdx = inputIndices[i];
69
- auto outIdx = outputIndices [i];
103
+ auto outIdx = output_indices [i];
70
104
auto inMeta = methodMeta.input_tensor_meta (inIdx);
71
105
auto outMeta = methodMeta.output_tensor_meta (outIdx);
72
106
ET_CHECK (inMeta.ok ());
@@ -106,74 +140,90 @@ class StaticKVCache {
106
140
/* *
107
141
* Update the internal data pointers using the cache updates returned by the
108
142
* model. This length of each individual update cannot exceed the max update
109
- * length specified during the creation, and the total length cannot exceed
110
- * the context length.
143
+ * length specified during creation, and the total length cannot exceed the
144
+ * cache length.
111
145
*/
112
146
void update (
113
147
torch::executor::Method& method,
114
- const std::vector<size_t >& outputIndices ,
148
+ const std::vector<size_t >& output_indices ,
115
149
size_t update_len) {
116
150
if (valid_len_ + update_len > cache_len_) {
117
151
throw std::runtime_error (" Cache capacity exceeded." );
118
152
}
119
153
120
- if (transpose_ ) {
121
- throw std::runtime_error ( " Not implemented. " );
154
+ if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE ) {
155
+ update_sliding_cache (method, output_indices, update_len );
122
156
} else {
123
- updateSeqDim (method, outputIndices , update_len);
157
+ update_smart_mask (method, output_indices , update_len);
124
158
}
125
- valid_len_ += update_len;
126
159
}
127
160
128
161
/* *
129
162
* Reset the cache. After this the cache contains no valid data and is ready
130
- * for number of tokens up to the context length.
163
+ * for number of tokens up to the cache length.
131
164
*/
132
165
void reset () {
133
166
valid_len_ = 0 ;
134
- if (transpose_) {
135
- throw std::runtime_error (" Not implemented." );
136
- } else {
137
- initSeqDim ();
167
+ if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
168
+ init_ptrs ();
138
169
}
139
170
}
140
171
141
172
private:
142
- void initSeqDim () {
143
- auto cacheSize = cache_len_ * head_dim_;
173
+ void init_ptrs () {
144
174
input_ptrs_.resize (n_caches_);
145
175
output_ptrs_.resize (n_caches_);
146
176
for (size_t i = 0 ; i < n_caches_; i++) {
147
- input_ptrs_[i] = data_ + i * cacheSize ;
148
- output_ptrs_[i] = input_ptrs_[i] + cacheSize ;
177
+ input_ptrs_[i] = cache_data_ + i * cache_len_ * head_dim_ ;
178
+ output_ptrs_[i] = update_data_ + i * max_input_len_ * head_dim_ ;
149
179
}
150
180
}
151
181
152
- void updateSeqDim (
182
+ void update_sliding_cache (
153
183
torch::executor::Method& method,
154
- const std::vector<size_t >& outputIndices ,
184
+ const std::vector<size_t >& output_indices ,
155
185
size_t update_len) {
156
- ET_CHECK (n_caches_ == outputIndices .size ());
186
+ ET_CHECK (n_caches_ == output_indices .size ());
157
187
for (size_t i = 0 ; i < n_caches_; i++) {
158
- const auto & updateTensor = method.get_output (outputIndices[i]).toTensor ();
159
- ET_CHECK (
160
- input_ptrs_[i] + cache_len_ * head_dim_ ==
161
- updateTensor.mutable_data_ptr <T>());
162
-
188
+ const auto & updateTensor =
189
+ method.get_output (output_indices[i]).toTensor ();
190
+ ET_CHECK (output_ptrs_[i] == updateTensor.const_data_ptr <T>());
191
+ std::copy (
192
+ output_ptrs_[i],
193
+ output_ptrs_[i] + update_len * head_dim_,
194
+ input_ptrs_[i] + cache_len_ * head_dim_);
163
195
input_ptrs_[i] += update_len * head_dim_;
164
- output_ptrs_[i] += update_len * head_dim_;
165
196
}
197
+ valid_len_ += update_len;
198
+ }
199
+
200
+ void update_smart_mask (
201
+ torch::executor::Method& method,
202
+ const std::vector<size_t >& output_indices,
203
+ size_t update_len) {
204
+ for (size_t i = 0 ; i < n_caches_; i++) {
205
+ const auto & updateTensor =
206
+ method.get_output (output_indices[i]).toTensor ();
207
+ ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
208
+ std::copy (
209
+ output_ptrs_[i],
210
+ output_ptrs_[i] + update_len * head_dim_,
211
+ input_ptrs_[i] + valid_len_ * head_dim_);
212
+ }
213
+ valid_len_ += update_len;
166
214
}
167
215
168
- // std::vector<T> pool_;
169
216
size_t n_caches_;
170
217
size_t cache_len_;
171
218
size_t max_input_len_;
172
219
size_t head_dim_;
173
220
bool transpose_;
221
+ StaticAttentionUpdateStyle style_;
174
222
AllocatorT allocator_;
175
- size_t data_size_;
176
- T* data_;
223
+ size_t cache_data_size_;
224
+ T* cache_data_;
225
+ size_t update_data_size_;
226
+ T* update_data_;
177
227
std::vector<T*> input_ptrs_;
178
228
std::vector<T*> output_ptrs_;
179
229
size_t valid_len_ = 0 ;
@@ -183,28 +233,30 @@ template <typename T, typename AllocatorT = std::allocator<T>>
183
233
class StaticAttentionMask {
184
234
public:
185
235
/* *
186
- * Manages the attention mask in the same style of KV cache IO where valid
187
- * data is at the end of the cache. The mask has shape (1, maxSeqLen,
188
- * cache_len
189
- * + maxSeqLen) where maxSeqLen is 1 for decode or the prefill length. Accepts
190
- * zero_val and mask_val (which represents -inf) to support quantized mask.
236
+ * Manages the attention mask for StaticKVCache. Create one mask for each
237
+ * input length. Accepts zero_val and mask_val (which represents -inf) to
238
+ * support quantized mask.
191
239
*
192
- * This class manages the slice of the mask at [:, :, : (cache_len -
193
- * validCacheLen)]. User can update the rest of the mask to implement causal
194
- * masking for example.
240
+ * The mask shape is (1, input_len, cache_len + input_len). This class manages
241
+ * the slice of the mask at [:, :, :cache_len] to only allow valid cache
242
+ * elements to participate in the attention. User can update the rest of the
243
+ * mask (to implement causal mask for example).
195
244
*/
196
245
StaticAttentionMask (
197
246
size_t cache_len,
198
247
size_t input_len,
199
248
size_t head_dim,
200
249
T zero_val,
201
- T mask_val)
250
+ T mask_val,
251
+ StaticAttentionUpdateStyle style =
252
+ StaticAttentionUpdateStyle::SLIDING_CACHE)
202
253
: cache_len_(cache_len),
203
254
input_len_ (input_len),
204
255
head_dim_(head_dim),
205
- cache_mask_len_(cache_len_ ),
256
+ cache_valid_len_( 0 ),
206
257
zero_val_(zero_val),
207
- mask_val_(mask_val) {
258
+ mask_val_(mask_val),
259
+ style_(style) {
208
260
data_size_ = input_len_ * (cache_len_ + input_len_);
209
261
data_ = allocator_.allocate (data_size_);
210
262
ET_CHECK (data_ != nullptr );
@@ -224,7 +276,7 @@ class StaticAttentionMask {
224
276
* Reset the mask to the state where the cache contains no valid data.
225
277
*/
226
278
void reset () {
227
- cache_mask_len_ = cache_len_ ;
279
+ cache_valid_len_ = 0 ;
228
280
for (size_t i = 0 ; i < input_len_; i++) {
229
281
auto * p = data_ + (cache_len_ + input_len_) * i;
230
282
std::fill (p, p + cache_len_, mask_val_);
@@ -233,19 +285,29 @@ class StaticAttentionMask {
233
285
234
286
/* *
235
287
* Update the mask to indicate update_len elements have been added to the
236
- * cache. Note that update_len might be smaller than maxSeqLen when prefilling
237
- * with padded inputs.
288
+ * cache. Note that update_len might be smaller than input_len_ when
289
+ * prefilling with padded inputs.
238
290
*/
239
- void updateCacheMask (size_t update_len) {
240
- for (size_t i = 0 ; i < input_len_; i++) {
241
- auto * p = data_ + (cache_len_ + input_len_) * i;
242
- std::fill (
243
- p + cache_mask_len_ - update_len, p + cache_mask_len_, zero_val_);
291
+ void unmask (size_t update_len) {
292
+ if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
293
+ for (size_t i = 0 ; i < input_len_; i++) {
294
+ auto * p = data_ + (cache_len_ + input_len_) * i;
295
+ std::fill (
296
+ p + cache_len_ - cache_valid_len_ - update_len,
297
+ p + cache_len_ - cache_valid_len_,
298
+ zero_val_);
299
+ }
300
+ } else {
301
+ for (size_t i = 0 ; i < input_len_; i++) {
302
+ auto * p = data_ + (cache_len_ + input_len_) * i;
303
+ std::fill (
304
+ p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
305
+ }
244
306
}
245
- cache_mask_len_ - = update_len;
307
+ cache_valid_len_ + = update_len;
246
308
}
247
309
248
- void setCausalMask () {
310
+ void set_causal_mask () {
249
311
for (size_t i = 0 ; i < input_len_ - 1 ; i++) {
250
312
auto * p = data_ + (cache_len_ + input_len_) * i;
251
313
std::fill (p + cache_len_, p + cache_len_ + 1 + i, zero_val_);
@@ -261,9 +323,10 @@ class StaticAttentionMask {
261
323
size_t cache_len_;
262
324
size_t input_len_;
263
325
size_t head_dim_;
264
- size_t cache_mask_len_ ;
326
+ size_t cache_valid_len_ ;
265
327
T zero_val_;
266
328
T mask_val_;
329
+ StaticAttentionUpdateStyle style_;
267
330
AllocatorT allocator_;
268
331
size_t data_size_ = 0 ;
269
332
T* data_;
@@ -285,7 +348,9 @@ class StaticAttentionIOManager {
285
348
size_t rope_freqs_cos_index,
286
349
size_t rope_freqs_sin_index,
287
350
RopeT* rope_freqs_cos,
288
- RopeT* rope_freqs_sin)
351
+ RopeT* rope_freqs_sin,
352
+ StaticAttentionUpdateStyle style =
353
+ StaticAttentionUpdateStyle::SLIDING_CACHE)
289
354
: cache_len_(cache_len),
290
355
head_dim_ (head_dim),
291
356
kCaches_(n_caches, cache_len, head_dim, max_input_len),
@@ -295,6 +360,9 @@ class StaticAttentionIOManager {
295
360
rope_freqs_cos_(rope_freqs_cos),
296
361
rope_freqs_sin_(rope_freqs_sin) {}
297
362
363
+ /* *
364
+ * Create a new StaticAttentionMask that will be managed by this object.
365
+ */
298
366
StaticAttentionMask<MaskT, MaskAllocatorT>&
299
367
addMask (size_t input_len, MaskT zero_val, MaskT mask_val) {
300
368
auto it = attentionMasks_.emplace (
@@ -305,10 +373,16 @@ class StaticAttentionIOManager {
305
373
return it.first ->second ;
306
374
}
307
375
376
+ /* *
377
+ * Retrieve a mask suitable for given input length.
378
+ */
308
379
StaticAttentionMask<MaskT, MaskAllocatorT>& getMask (size_t input_len) {
309
380
return attentionMasks_.at (input_len);
310
381
}
311
382
383
+ /* *
384
+ * Set I/O pointers for KV cache and RoPE freqencies.
385
+ */
312
386
void prepare (
313
387
torch::executor::Method& method,
314
388
const std::vector<size_t >& k_cache_input_indices,
@@ -327,6 +401,10 @@ class StaticAttentionIOManager {
327
401
rope_freqs_sin_ + input_pos_ * head_dim_ / 2 );
328
402
}
329
403
404
+ /* *
405
+ * Update all caches and masks under management to reflect that model produced
406
+ * update_len new elements.
407
+ */
330
408
void update (
331
409
torch::executor::Method& method,
332
410
const std::vector<size_t >& k_cache_output_indices,
@@ -336,10 +414,13 @@ class StaticAttentionIOManager {
336
414
kCaches_ .update (method, k_cache_output_indices, update_len);
337
415
vCaches_.update (method, v_cache_output_indices, update_len);
338
416
for (auto & it : attentionMasks_) {
339
- it.second .updateCacheMask (update_len);
417
+ it.second .unmask (update_len);
340
418
}
341
419
}
342
420
421
+ /* *
422
+ * Reset all caches and masks under management.
423
+ */
343
424
void reset () {
344
425
input_pos_ = 0 ;
345
426
kCaches_ .reset ();
0 commit comments