@@ -176,7 +176,7 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
176
176
Sigmoid (gate_x + head_offset, kHeadDim );
177
177
Sigmoid (a + head_offset, kHeadDim );
178
178
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
179
- HWY_ATTR { return hn::Mul (x, gate_x); };
179
+ HWY_ATTR { return hn::Mul (x, gate_x); };
180
180
hn::Transform1 (D (), a + head_offset, kHeadDim ,
181
181
layer_weights->griffin .a .PackedScale1 () + head_offset,
182
182
fn_mul);
@@ -424,51 +424,49 @@ class GemmaAttention {
424
424
const size_t kHeadGroups = layer_config_.heads / layer_config_.kv_heads ;
425
425
426
426
// For each head (token, query), compute Q.K, softmax, and weighted V.
427
- pool_.Run (0 , layer_config_.heads * num_interleaved,
428
- [&](uint64_t task, size_t /* thread*/ ) HWY_ATTR {
429
- const size_t head = task % layer_config_.heads ;
430
- const size_t interleaved_idx = task / layer_config_.heads ;
431
- const size_t query_idx = interleaved_idx % num_queries_;
432
- const size_t batch_idx = interleaved_idx / num_queries_;
433
- const size_t qkv_dim = layer_config_.qkv_dim ;
434
- const size_t head_offset = (head / kHeadGroups ) * qkv_dim * 2 ;
435
-
436
- float * HWY_RESTRICT q =
437
- activations_.q .Row (interleaved_idx) + head * q_stride_;
438
- float * HWY_RESTRICT att =
439
- activations_.att .Row (interleaved_idx) +
440
- head * activations_.seq_len ;
441
- float * HWY_RESTRICT att_out =
442
- activations_.att_out .Row (interleaved_idx) + head * qkv_dim;
443
-
444
- // Make strided views into the kv cache entries for the current
445
- // query and head.
446
- KVCache& kv_cache = kv_caches_[query_idx];
447
- const size_t kv_head_offset =
448
- layer_ * cache_layer_size_ + head_offset;
449
- MatPtrT<float > k (" k_view" ,
450
- Extents2D (kv_cache.seq_len , qkv_dim));
451
- k.SetPtr (kv_cache.kv_cache .get () + kv_head_offset,
452
- /* stride=*/ cache_pos_size_);
453
- MatPtrT<float > v (" v_view" ,
454
- Extents2D (kv_cache.seq_len , qkv_dim));
455
- v.SetPtr (kv_cache.kv_cache .get () + kv_head_offset + qkv_dim,
456
- /* stride=*/ cache_pos_size_);
457
-
458
- // Find the token position in the query and calculate the range
459
- // of cache positions to attend to.
460
- const size_t pos = queries_pos_[query_idx] + batch_idx;
461
- const size_t start_pos = StartPos (pos, layer_);
462
- size_t last_pos = pos;
463
- const size_t prefix_end = queries_prefix_end_[query_idx];
464
- if (prefix_end > 0 && prefix_end - 1 > last_pos) {
465
- // last_pos in QDotK and WeightedSumV is inclusive.
466
- last_pos = prefix_end - 1 ;
467
- }
468
-
469
- SingleDotSoftmaxWeightedSum (q, k, v, att, att_out, query_scale,
470
- pos, start_pos, last_pos);
471
- });
427
+ pool_.Run (
428
+ 0 , layer_config_.heads * num_interleaved,
429
+ [&](uint64_t task, size_t /* thread*/ ) HWY_ATTR {
430
+ const size_t head = task % layer_config_.heads ;
431
+ const size_t interleaved_idx = task / layer_config_.heads ;
432
+ const size_t query_idx = interleaved_idx % num_queries_;
433
+ const size_t batch_idx = interleaved_idx / num_queries_;
434
+ const size_t qkv_dim = layer_config_.qkv_dim ;
435
+ const size_t head_offset = (head / kHeadGroups ) * qkv_dim * 2 ;
436
+
437
+ float * HWY_RESTRICT q =
438
+ activations_.q .Row (interleaved_idx) + head * q_stride_;
439
+ float * HWY_RESTRICT att = activations_.att .Row (interleaved_idx) +
440
+ head * activations_.seq_len ;
441
+ float * HWY_RESTRICT att_out =
442
+ activations_.att_out .Row (interleaved_idx) + head * qkv_dim;
443
+
444
+ // Make strided views into the kv cache entries for the current
445
+ // query and head.
446
+ KVCache& kv_cache = kv_caches_[query_idx];
447
+ const size_t kv_head_offset =
448
+ layer_ * cache_layer_size_ + head_offset;
449
+ MatPtrT<float > k (" k_view" , Extents2D (kv_cache.seq_len , qkv_dim));
450
+ k.SetPtr (kv_cache.kv_cache .get () + kv_head_offset,
451
+ /* stride=*/ cache_pos_size_);
452
+ MatPtrT<float > v (" v_view" , Extents2D (kv_cache.seq_len , qkv_dim));
453
+ v.SetPtr (kv_cache.kv_cache .get () + kv_head_offset + qkv_dim,
454
+ /* stride=*/ cache_pos_size_);
455
+
456
+ // Find the token position in the query and calculate the range
457
+ // of cache positions to attend to.
458
+ const size_t pos = queries_pos_[query_idx] + batch_idx;
459
+ const size_t start_pos = StartPos (pos, layer_);
460
+ size_t last_pos = pos;
461
+ const size_t prefix_end = queries_prefix_end_[query_idx];
462
+ if (prefix_end > 0 && prefix_end - 1 > last_pos) {
463
+ // last_pos in QDotK and WeightedSumV is inclusive.
464
+ last_pos = prefix_end - 1 ;
465
+ }
466
+
467
+ SingleDotSoftmaxWeightedSum (q, k, v, att, att_out, query_scale, pos,
468
+ start_pos, last_pos);
469
+ });
472
470
}
473
471
474
472
private:
@@ -1510,8 +1508,7 @@ void GenerateSingleT(const ModelStore& model,
1510
1508
}
1511
1509
1512
1510
template <typename T>
1513
- void GenerateBatchT (const ModelStore& model,
1514
- const ModelWeightsPtrs<T>& weights,
1511
+ void GenerateBatchT (const ModelStore& model, const ModelWeightsPtrs<T>& weights,
1515
1512
const RuntimeConfig& runtime_config,
1516
1513
const QueriesPromptTokens& queries_prompt,
1517
1514
const QueriesPos& queries_pos,
@@ -1536,7 +1533,7 @@ void GenerateBatchT(const ModelStore& model,
1536
1533
qbatch_size);
1537
1534
QueriesPos qbatch_pos (&queries_pos[qbatch_start], qbatch_size);
1538
1535
const QueriesPos qbatch_prefix_end (&queries_prefix_end[qbatch_start],
1539
- qbatch_size);
1536
+ qbatch_size);
1540
1537
const KVCaches qbatch_kv (&kv_caches[qbatch_start], qbatch_size);
1541
1538
GenerateT<T>(model, weights, activations, runtime_config, qbatch_prompts,
1542
1539
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
0 commit comments