Skip to content

Commit 7bb4e0b

Browse files
committed
compile success: set default self extend values in noSSM and griffin
1 parent 5a2a7ee commit 7bb4e0b

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

gemma/configs.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ struct ConfigNoSSM {
151151
static constexpr PostQKType kPostQK = PostQKType::Rope;
152152
static constexpr ActivationType kActivation = ActivationType::Gelu;
153153
static constexpr ResidualType kResidual = ResidualType::Add;
154+
155+
// Self-extend parameters with defaul values
156+
static constexpr bool kSelfExtend = false;
157+
static constexpr size_t kSelfExtendNgbSize = 0;
158+
static constexpr size_t kSelfExtendGrpSize = 1;
154159
};
155160

156161
struct ConfigBaseGemmaV1 : ConfigNoSSM {
@@ -372,6 +377,11 @@ struct ConfigGriffin2B {
372377
static constexpr ActivationType kActivation = ActivationType::Gelu;
373378
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
374379
static constexpr ResidualType kResidual = ResidualType::Add;
380+
381+
// Self-extend parameters with defaul values
382+
static constexpr bool kSelfExtend = false;
383+
static constexpr size_t kSelfExtendNgbSize = 0;
384+
static constexpr size_t kSelfExtendGrpSize = 1;
375385
};
376386

377387
} // namespace gcpp

gemma/gemma-inl.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
232232
const size_t num_interleaved = num_tokens * num_queries;
233233

234234
// Self extend
235-
constexpr size_t ngb_size = TConfig::self_extend_ngb_size;
236-
constexpr size_t grp_size = TConfig::self_extend_grp_size;
235+
constexpr size_t ngb_size = TConfig::kSelfExtendNgbSize;
236+
constexpr size_t grp_size = TConfig::kSelfExtendGrpSize;
237237

238238
// For the computation of Q, K, and V, it is useful to remember that
239239
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
@@ -298,8 +298,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
298298
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
299299

300300
// When embedding position, we will use grouped key position
301-
if (pos > ngb_size && TConfig::kSelfExtend) {
302-
pos /= grp_size;
301+
if constexpr (TConfig::kSelfExtend) {
302+
if (pos > ngb_size) {
303+
pos /= grp_size;
304+
}
303305
}
304306
if constexpr (kIsMHA) {
305307
// For MHA, copy KV into the KV cache from scratch space (see above).
@@ -331,11 +333,13 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
331333

332334
// Apply rope and scaling to Q.
333335
size_t pos = batch_start + batch_idx;
334-
if (pos > ngb_size && TConfig::kSelfExtend) {
335-
const grp_pos = pos / grp_size;
336-
const shift = ngb_size - ngb_size / grp_size
337-
const shifted_grouped_pos = grp_pos + shift
338-
pos = shifted_grouped_pos;
336+
if constexpr (TConfig::kSelfExtend) {
337+
if (pos > ngb_size) {
338+
const size_t grp_pos = pos / grp_size;
339+
const size_t shift = ngb_size - ngb_size / grp_size;
340+
const size_t shifted_grouped_pos = grp_pos + shift;
341+
pos = shifted_grouped_pos;
342+
}
339343
}
340344
PostQK<TConfig>(q, pos, layer);
341345
MulByConst(kQueryScale, q, kQKVDim);

0 commit comments

Comments
 (0)