Skip to content

Commit e54d9cb

Browse files
danielkeyserscopybara-github
authored andcommitted
Fix Griffin model:
- use HalfRope position encodings - zero-initialize the caches for each Generate at position 0 The lack of the latter made the tests in gemma_test dependent on each other. PiperOrigin-RevId: 694509054
1 parent d4050a2 commit e54d9cb

File tree

6 files changed

+39
-15
lines changed

6 files changed

+39
-15
lines changed

evals/gemma_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ TEST_F(GemmaTest, CrossEntropySmall) {
246246
EXPECT_NEAR(entropy, 2.8f, 0.2f);
247247
break;
248248
case gcpp::Model::GRIFFIN_2B:
249-
EXPECT_NEAR(entropy, 1.57f, 0.02f);
249+
EXPECT_NEAR(entropy, 2.61f, 0.02f);
250250
break;
251251
case gcpp::Model::GEMMA2_2B:
252252
EXPECT_NEAR(entropy, 1.14f, 0.02f);
@@ -277,7 +277,7 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) {
277277
EXPECT_NEAR(entropy, 1.07f, 0.05f);
278278
break;
279279
case gcpp::Model::GRIFFIN_2B:
280-
EXPECT_NEAR(entropy, 2.09f, 0.02f);
280+
EXPECT_NEAR(entropy, 1.62f, 0.02f);
281281
break;
282282
case gcpp::Model::GEMMA2_2B:
283283
EXPECT_NEAR(entropy, 0.49f, 0.02f);
@@ -308,7 +308,7 @@ TEST_F(GemmaTest, CrossEntropyGettysburg) {
308308
EXPECT_NEAR(entropy, 0.75f, 0.1f);
309309
break;
310310
case gcpp::Model::GRIFFIN_2B:
311-
EXPECT_NEAR(entropy, 0.86f, 0.02f);
311+
EXPECT_NEAR(entropy, 0.71f, 0.02f);
312312
break;
313313
case gcpp::Model::GEMMA2_2B:
314314
EXPECT_NEAR(entropy, 0.20f, 0.02f);

gemma/configs.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ static ModelConfig ConfigGriffin2B() {
183183
.softmax_attn_output_biases = true,
184184
.type = LayerAttentionType::kGriffinRecurrentBlock,
185185
.activation = ActivationType::Gelu,
186-
.post_qk = PostQKType::Rope,
186+
.post_qk = PostQKType::HalfRope,
187187
};
188188
config.layer_configs = {26, layer_config};
189189
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {

gemma/configs_test.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,11 @@ void AssertMatch(const ModelConfig& config) {
397397
ASSERT_EQ(TConfig::kPostNorm, config.layer_configs[i].post_norm);
398398
ASSERT_EQ(TConfig::kLayerConfig[i], config.layer_configs[i].type);
399399
ASSERT_EQ(TConfig::kActivation, config.layer_configs[i].activation);
400-
ASSERT_EQ(TConfig::kPostQK, config.layer_configs[i].post_qk);
400+
PostQKType post_qk = TConfig::kPostQK;
401+
if (TConfig::kUseHalfRope) {
402+
post_qk = PostQKType::HalfRope;
403+
}
404+
ASSERT_EQ(post_qk, config.layer_configs[i].post_qk);
401405
}
402406

403407
ASSERT_EQ(TConfig::kAttentionWindowSizes.size(),

gemma/gemma-inl.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,8 +1240,12 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
12401240
const QueriesPos& queries_prefix_end,
12411241
const size_t query_idx_start, const KVCaches& kv_caches,
12421242
TimingInfo& timing_info) {
1243-
const size_t vocab_size = model.Config().vocab_size;
1244-
const ModelWeightsPtrs<T>& weights = *model.GetWeightsOfType<T>();
1243+
// Griffin assumes that the recurrent block cache is zero-initialized.
1244+
for (size_t i = 0; i < kv_caches.size(); ++i) {
1245+
if (queries_pos_in[i] == 0) {
1246+
kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models.
1247+
}
1248+
}
12451249

12461250
// Copy so we can increment without requiring users to pass in a mutable span.
12471251
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
@@ -1268,7 +1272,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
12681272
HWY_ASSERT(queries_pos_in.size() == num_queries);
12691273
HWY_ASSERT(kv_caches.size() == num_queries);
12701274
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
1271-
1275+
const ModelWeightsPtrs<T>& weights = *model.GetWeightsOfType<T>();
12721276
size_t max_prompt_size = MaxQueryLength(queries_prompt);
12731277
size_t max_generated_tokens = runtime_config.max_generated_tokens;
12741278
RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size);
@@ -1314,6 +1318,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
13141318
0.0f);
13151319
}
13161320

1321+
const size_t vocab_size = model.Config().vocab_size;
13171322
const double gen_start = hwy::platform::Now();
13181323
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
13191324
// Decode generates one token per query and increments queries_mutable_pos.

gemma/kv_cache.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@
2323

2424
namespace gcpp {
2525

26+
void KVCache::ZeroGriffinCache() {
27+
if (conv1d_cache_size != 0) {
28+
hwy::ZeroBytes(conv1d_cache.get(),
29+
conv1d_cache_size * sizeof(conv1d_cache[0]));
30+
}
31+
if (rglru_cache_size != 0) {
32+
hwy::ZeroBytes(rglru_cache.get(),
33+
rglru_cache_size * sizeof(rglru_cache[0]));
34+
}
35+
}
36+
2637
// prefill_tbatch_size is the maximum number of tokens from one query to
2738
// prefill at a time.
2839
KVCache KVCache::Create(const ModelConfig& weights_config,
@@ -37,9 +48,9 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
3748
kv_cache.kv_cache =
3849
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
3950
}
40-
size_t num_griffin_layers = weights_config.NumLayersOfType(
41-
LayerAttentionType::kGriffinRecurrentBlock);
4251

52+
const size_t num_griffin_layers = weights_config.NumLayersOfType(
53+
LayerAttentionType::kGriffinRecurrentBlock);
4354
// TODO(patrickms): Add query batching support for Griffin.
4455
if (num_griffin_layers > 0) {
4556
size_t conv1d_width = 0;
@@ -49,20 +60,18 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
4960
const size_t conv1d_cache_size =
5061
num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) *
5162
weights_config.model_dim;
63+
kv_cache.conv1d_cache_size = conv1d_cache_size;
5264
if (conv1d_cache_size != 0) {
5365
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
54-
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
55-
conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0]));
5666
}
5767

5868
const size_t rglru_cache_size =
5969
num_griffin_layers * weights_config.model_dim;
70+
kv_cache.rglru_cache_size = rglru_cache_size;
6071
if (rglru_cache_size != 0) {
6172
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
62-
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
63-
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
6473
}
65-
} // kGriffinLayers
74+
} // num_griffin_layers
6675

6776
return kv_cache;
6877
}

gemma/kv_cache.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,15 @@ struct KVCache {
3131

3232
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
3333
hwy::AlignedFreeUniquePtr<float[]> conv1d_cache;
34+
size_t conv1d_cache_size = 0;
3435

3536
// kModelDim * kGriffinLayers
3637
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
38+
size_t rglru_cache_size = 0;
39+
40+
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
41+
// and rglru_cache.
42+
void ZeroGriffinCache();
3743

3844
static KVCache Create(const ModelConfig& weights_config,
3945
size_t prefill_tbatch_size);

0 commit comments

Comments
 (0)