diff --git a/gemma/configs.h b/gemma/configs.h index 6bbbc45d..209c99f5 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -148,6 +148,15 @@ struct LayerConfig { size_t conv1d_width = 0; // griffin only bool ff_biases = false; bool softmax_attn_output_biases = false; + /** + * Self-extend + * Jin, Hongye, et al. "Llm maybe longlm: Self-extend llm context window without tuning." arXiv preprint arXiv:2401.01325 (2024). + */ + bool self_extend = false; + // Self-extend neighbor size + size_t se_neighbor_size = std::numeric_limits::max(); + // Self-extend group window size + size_t se_group_size = 1; bool optimized_gating = true; PostNormType post_norm = PostNormType::None; LayerAttentionType type = LayerAttentionType::kGemma; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 51f29995..ad817325 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -300,6 +300,9 @@ class GemmaAttention { } } // !is_mha_ + // Self-extension + const hwy::Divisor div_grp_size( + static_cast(layer_config_.se_group_size)); // Apply positional encodings for K (and copy KV to cache if MHA). pool_.Run(0, kv_heads * num_interleaved, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { @@ -307,21 +310,29 @@ class GemmaAttention { const size_t interleaved_idx = task / kv_heads; const size_t query_idx = interleaved_idx % num_queries_; const size_t batch_idx = interleaved_idx / num_queries_; - const size_t pos = queries_pos_[query_idx] + batch_idx; + size_t pos = queries_pos_[query_idx] + batch_idx; const size_t cache_pos = div_seq_len_.Remainder(pos); const size_t kv_offset = cache_pos * cache_pos_size_ + layer_ * cache_layer_size_ + head * qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; + + const size_t se_neighbor_size = layer_config_.se_neighbor_size; + const bool enable_self_extend = layer_config_.self_extend; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; const float* HWY_RESTRICT mha_kv = activations_.q.Batch(interleaved_idx) + head * q_stride_ + qkv_dim; + // In self-extend, when embedding position, + // we will use grouped key position + if (enable_self_extend && pos > se_neighbor_size) { + pos = div_grp_size.Divide(pos); + } // Copy from `q` if MHA, or apply in-place. PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f, kv); - // If MHA, also copy V into KVCache. if (is_mha_) { hwy::CopyBytes(mha_kv + qkv_dim, kv + qkv_dim, @@ -405,12 +416,25 @@ class GemmaAttention { const size_t batch_idx = interleaved_idx / num_queries_; const size_t head_offset = (head / kHeadGroups) * layer_config_.qkv_dim * 2; + + const size_t se_group_size = layer_config_.se_group_size; + const size_t se_neighbor_size = layer_config_.se_neighbor_size; + const bool enable_self_extend = + layer_config_.self_extend; + KVCache& kv_cache = kv_caches_[query_idx]; float* HWY_RESTRICT q = activations_.q.Batch(interleaved_idx) + head * q_stride_; // Apply rope and scaling to Q. - const size_t pos = queries_pos_[query_idx] + batch_idx; + size_t pos = queries_pos_[query_idx] + batch_idx; + if (enable_self_extend && pos > se_neighbor_size) { + const size_t grp_pos = pos / se_group_size; + const size_t shift = + se_neighbor_size - se_neighbor_size / se_group_size; + const size_t shifted_grouped_pos = grp_pos + shift; + pos = shifted_grouped_pos; + } PositionalEncodingQK(q, pos, layer_, query_scale, q); const size_t start_pos = StartPos(pos, layer_); @@ -1408,7 +1432,7 @@ void GenerateBatchT(const ModelWeightsStorage& model, qbatch_size); QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], - qbatch_size); + qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); GenerateT(model, activations, runtime_config, qbatch_prompts, qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info); diff --git a/gemma/gemma.h b/gemma/gemma.h index 5b840531..5efe6fa0 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -198,6 +198,7 @@ class Gemma { ~Gemma(); const ModelConfig& GetModelConfig() const { return model_.Config(); } + ModelConfig& GetMutableModelConfig() { return model_.MutableConfig(); } const ModelInfo& Info() const { return info_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const ModelWeightsStorage& Weights() const { return model_; } diff --git a/gemma/run.cc b/gemma/run.cc index 87c7c9dd..e556cbcc 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -77,6 +77,26 @@ std::string GetPrompt(std::istream& input, int verbosity, return prompt_string; } +// Extract args from the loader and modify model config +void ApplySelfExtendIfGiven(Gemma& model, LoaderArgs loader) { + ModelConfig& config = model.GetMutableModelConfig(); + if (loader.self_extend != Tristate::kTrue) { + return; + } + + // Modify layer config in-place + auto& layer_configs = config.layer_configs; + std::transform(layer_configs.begin(), layer_configs.end(), layer_configs.begin(), + [&loader](LayerConfig& layer_config) { + layer_config.self_extend = + loader.self_extend == Tristate::kTrue; + layer_config.se_group_size = loader.se_group_size; + layer_config.se_neighbor_size = loader.se_neighbor_size; + + return layer_config; + }); +} + // The main Read-Eval-Print Loop. void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, const InferenceArgs& args, const AcceptFunc& accept_token, @@ -206,6 +226,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { Allocator::Init(pools.Topology()); Gemma model = CreateGemma(loader, pools); + ApplySelfExtendIfGiven(model, loader); KVCache kv_cache = KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size); diff --git a/gemma/weights.h b/gemma/weights.h index b9acf899..35aa0e7b 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -549,6 +549,7 @@ class ModelWeightsStorage { void CopyWithTranspose(hwy::ThreadPool& pool); void LogWeightStats(); const ModelConfig& Config() const { return config_; } + ModelConfig& MutableConfig() { return config_; } template ModelWeightsPtrs* GetWeightsOfType() const { diff --git a/util/app.h b/util/app.h index 5128a389..d17da9be 100644 --- a/util/app.h +++ b/util/app.h @@ -171,6 +171,11 @@ struct LoaderArgs : public ArgsBase { std::string model_type_str; std::string weight_type_str; + // Self-extend + Tristate self_extend; + size_t se_group_size; + size_t se_neighbor_size; + template void ForEach(const Visitor& visitor) { visitor(tokenizer, "tokenizer", Path(), @@ -189,6 +194,12 @@ struct LoaderArgs : public ArgsBase { visitor(weight_type_str, "weight_type", std::string("sfp"), "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n" " Required argument."); + visitor(self_extend, "self_extend", Tristate::kDefault, + "Apply self extend ? -1 = auto, 0 = no, 1 = yes.", 2); + visitor(se_group_size, "se_group_size", size_t{1}, "Group size for self extend"); + visitor(se_neighbor_size, "se_neighbor_size", + std::numeric_limits::max(), + "Neighbor window size for self extend"); } // Uninitialized before Validate, must call after that.