@@ -232,8 +232,8 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
232
232
const size_t num_interleaved = num_tokens * num_queries;
233
233
234
234
// Self extend
235
- constexpr size_t ngb_size = TConfig::self_extend_ngb_size ;
236
- constexpr size_t grp_size = TConfig::self_extend_grp_size ;
235
+ constexpr size_t ngb_size = TConfig::kSelfExtendNgbSize ;
236
+ constexpr size_t grp_size = TConfig::kSelfExtendGrpSize ;
237
237
238
238
// For the computation of Q, K, and V, it is useful to remember that
239
239
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
@@ -298,8 +298,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
298
298
float * HWY_RESTRICT kv = kv_cache.kv_cache .get () + kv_offset;
299
299
300
300
// When embedding position, we will use grouped key position
301
- if (pos > ngb_size && TConfig::kSelfExtend ) {
302
- pos /= grp_size;
301
+ if constexpr (TConfig::kSelfExtend ) {
302
+ if (pos > ngb_size) {
303
+ pos /= grp_size;
304
+ }
303
305
}
304
306
if constexpr (kIsMHA ) {
305
307
// For MHA, copy KV into the KV cache from scratch space (see above).
@@ -331,11 +333,13 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
331
333
332
334
// Apply rope and scaling to Q.
333
335
size_t pos = batch_start + batch_idx;
334
- if (pos > ngb_size && TConfig::kSelfExtend ) {
335
- const grp_pos = pos / grp_size;
336
- const shift = ngb_size - ngb_size / grp_size
337
- const shifted_grouped_pos = grp_pos + shift
338
- pos = shifted_grouped_pos;
336
+ if constexpr (TConfig::kSelfExtend ) {
337
+ if (pos > ngb_size) {
338
+ const size_t grp_pos = pos / grp_size;
339
+ const size_t shift = ngb_size - ngb_size / grp_size;
340
+ const size_t shifted_grouped_pos = grp_pos + shift;
341
+ pos = shifted_grouped_pos;
342
+ }
339
343
}
340
344
PostQK<TConfig>(q, pos, layer);
341
345
MulByConst (kQueryScale , q, kQKVDim );
0 commit comments