diff --git a/README.md b/README.md index 7c9b3949..f2cb89bb 100644 --- a/README.md +++ b/README.md @@ -241,6 +241,24 @@ Example invocation for the following configuration: --model 2b-it ``` +### RecurrentGemma + +This repository includes a version of Gemma based on Griffin +([paper](https://arxiv.org/abs/2402.19427), +[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture +includes both recurrent layers and local attention, thus it is more efficient +for longer sequences and has a smaller memory footprint than standard Gemma. We +here provide a C++ implementation of this model based on the paper. + +To use the recurrent version of Gemma included in this repository, build the +gemma binary as noted above in Step 3. Download the compressed weights and +tokenizer from +[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in +Step 1, and run the binary as follows: + +`./gemma --tokenizer tokenizer.spm --model gr2b-it --compressed_weights 2b-it-sfp.sbs` + + ### Troubleshooting and FAQs **Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."** @@ -478,4 +496,9 @@ gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google. and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024 thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng. +Griffin support was implemented in April 2024 thanks to contributions by Andrey +Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode +Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas +Fischbacher and Zoltan Szabadka. + This is not an officially supported Google product. diff --git a/benchmark.cc b/benchmark.cc index d71bbc60..995b0b8b 100644 --- a/benchmark.cc +++ b/benchmark.cc @@ -10,14 +10,14 @@ #include "nlohmann/json.hpp" // copybara:import_next_line:gemma_cpp #include "gemma.h" -// copybara:import_next_line:gemma_cpp -#include "util/app.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" #include "hwy/timer.h" +// copybara:import_next_line:gemma_cpp +#include "util/app.h" +// copybara:import_next_line:gemma_cpp +#include "util/args.h" using json = nlohmann::json; @@ -259,6 +259,13 @@ int main(int argc, char** argv) { gcpp::AppArgs app(argc, argv); BenchmarkArgs benchmark_args(argc, argv); + if (const char* error = loader.Validate()) { + HWY_ABORT("\nInvalid loader args: %s", error); + } + if (const char* error = args.Validate()) { + HWY_ABORT("\nInvalid inference args: %s", error); + } + hwy::ThreadPool inner_pool(0); hwy::ThreadPool pool(app.num_threads); // For many-core, pinning threads to cores helps. @@ -275,7 +282,7 @@ int main(int argc, char** argv) { if (!benchmark_args.goldens.path.empty()) { const std::string golden_path = - benchmark_args.goldens.path + "/" + loader.model_type + ".txt"; + benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt"; return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool, golden_path); } else if (!benchmark_args.summarize_text.path.empty()) { diff --git a/compress_weights.cc b/compress_weights.cc index ce4f6428..ae8b0887 100644 --- a/compress_weights.cc +++ b/compress_weights.cc @@ -44,35 +44,14 @@ struct Args : public ArgsBase { ChooseNumThreads(); } - static std::string ToLower(const std::string& text) { - std::string result = text; - std::transform(begin(result), end(result), begin(result), - [](unsigned char c) { return std::tolower(c); }); - return result; - } - - gcpp::Model ModelType() const { - const std::string model_type_lc = ToLower(model_type); - if (model_type_lc.substr(0, 2) == "2b") { - return gcpp::Model::GEMMA_2B; - } else if (model_type_lc.substr(0, 2) == "7b") { - return gcpp::Model::GEMMA_7B; - } else { - HWY_ABORT("Unknown model type %s", model_type_lc.c_str()); - } - } + gcpp::Model ModelType() const { return model_type; } // Returns error string or nullptr if OK. - const char* Validate() const { - const std::string model_type_lc = ToLower(model_type); - if (model_type.empty()) { - return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " - "2b-it, 7b-it."; - } - if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && - model_type_lc != "2b-it" && model_type_lc != "7b-it") { - return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it."; - } + const char* Validate() { + ModelTraining model_training; + const char* parse_result = + ParseModelTypeAndTraining(model_type_str, model_type, model_training); + if (parse_result) return parse_result; if (weights.path.empty()) { return "Missing --weights flag, a file for the uncompressed model."; } @@ -88,7 +67,8 @@ struct Args : public ArgsBase { Path weights; // uncompressed weights file location Path compressed_weights; // compressed weights file location - std::string model_type; + std::string model_type_str; + Model model_type; size_t num_threads; template @@ -96,10 +76,12 @@ struct Args : public ArgsBase { visitor(weights, "weights", Path(), "Path name of model weights (.sbs) file.\n" " Required argument."); - visitor(model_type, "model", std::string(), + visitor(model_type_str, "model", std::string(), "Model type\n 2b-it = 2B parameters, instruction-tuned\n " "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n " + "gr2b-it = griffin 2B parameters, instruction-tuned\n " + "gr2b-pt = griffin 2B parameters, pretrained\n " " Required argument."); visitor(compressed_weights, "compressed_weights", Path(), "Path name where compressed weights file will be written.\n" @@ -115,7 +97,7 @@ struct Args : public ArgsBase { void ShowHelp(gcpp::Args& args) { std::cerr << "Usage:\n./compress_weights --weights " - " --model --compressed_weights \n"; + " --model --compressed_weights \n"; std::cerr << "\n*Arguments*\n\n"; args.Help(); std::cerr << "\n"; diff --git a/configs.h b/configs.h index f1d7f9d4..98bcf12b 100644 --- a/configs.h +++ b/configs.h @@ -30,6 +30,8 @@ #include +#include + // copybara:import_next_line:gemma_cpp #include "compression/sfp.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -45,16 +47,41 @@ namespace gcpp { static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; static constexpr size_t kTopK = GEMMA_TOPK; +enum class LayerAttentionType { + kGemma, + kGriffinRecurrentBlock, +}; + +template +constexpr std::array FixedLayerConfig( + LayerAttentionType type) { + std::array config = {}; + for (LayerAttentionType& l : config) { + l = type; + } + return config; +} + struct ConfigGemma7B { static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kVocabSize = 256000; - static constexpr int kLayers = 28; + static constexpr std::array kLayerConfig = + FixedLayerConfig<28>(LayerAttentionType::kGemma); + static constexpr int kLayers = kLayerConfig.size(); static constexpr int kModelDim = 3072; static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kHeads = 16; static constexpr int kKVHeads = 16; // standard MHA static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; + + // SSM config. + static constexpr int kConv1dWidth = 0; + static constexpr bool kFFBiases = false; + static constexpr bool kSoftmaxAttnOutputBiases = false; + static constexpr bool kUseHalfRope = false; + static constexpr bool kUseLocalAttention = false; + static constexpr bool kInterleaveQKV = true; static constexpr int kNumTensorScales = 0; using WeightT = GEMMA_WEIGHT_T; }; @@ -62,17 +89,79 @@ struct ConfigGemma7B { struct ConfigGemma2B { static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kVocabSize = 256000; - static constexpr int kLayers = 18; + static constexpr std::array kLayerConfig = + FixedLayerConfig<18>(LayerAttentionType::kGemma); + static constexpr int kLayers = kLayerConfig.size(); static constexpr int kModelDim = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kHeads = 8; static constexpr int kKVHeads = 1; static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; + + // SSM config. + static constexpr int kConv1dWidth = 0; + static constexpr bool kFFBiases = false; + static constexpr bool kSoftmaxAttnOutputBiases = false; + static constexpr bool kUseHalfRope = false; + static constexpr bool kUseLocalAttention = false; + static constexpr bool kInterleaveQKV = true; static constexpr int kNumTensorScales = 0; using WeightT = GEMMA_WEIGHT_T; }; +struct ConfigGriffin2B { + // Griffin uses local attention, so kSeqLen is actually the local attention + // window. + static constexpr int kSeqLen = 2048; + static constexpr int kVocabSize = 256000; + static constexpr std::array kLayerConfig = { + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + }; + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kModelDim = 2560; + static constexpr int kFFHiddenDim = 7680; + static constexpr int kHeads = 10; + static constexpr int kKVHeads = 1; + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + + // SSM config. + static constexpr int kConv1dWidth = 4; + static constexpr bool kFFBiases = true; + static constexpr bool kSoftmaxAttnOutputBiases = true; + static constexpr bool kUseHalfRope = true; + static constexpr bool kUseLocalAttention = true; + static constexpr bool kInterleaveQKV = false; + static constexpr int kNumTensorScales = 140; + using WeightT = GEMMA_WEIGHT_T; +}; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ diff --git a/gemma.cc b/gemma.cc index 33e58c61..80f7c176 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,12 +25,12 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" #include "hwy/timer.h" +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // Path // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" // copybara:end @@ -43,11 +43,12 @@ #include // sqrtf #include #include +#include +#include #include #include #include -#include #include // NOLINT #include #include @@ -74,8 +75,35 @@ constexpr bool kDryRunFread = false; // Setting this to false will load and use uncompressed weights. constexpr bool kWeightsAreCompressed = true; +// Set this to true to debug tokenizer tokens. +constexpr bool kShowTokenization = false; + namespace gcpp { +template +constexpr size_t NumLayersOfTypeBefore( + const std::array& layers, + LayerAttentionType type, size_t num) { + size_t count = 0; + for (size_t i = 0; i < num; i++) { + if (layers[i] == type) count++; + } + return count; +} + +template +constexpr size_t NumGemmaLayers() { + return NumLayersOfTypeBefore(TConfig::kLayerConfig, + LayerAttentionType::kGemma, TConfig::kLayers); +} + +template +constexpr size_t NumGriffinLayers() { + return NumLayersOfTypeBefore(TConfig::kLayerConfig, + LayerAttentionType::kGriffinRecurrentBlock, + TConfig::kLayers); +} + template struct Layer { Layer() = default; @@ -96,6 +124,25 @@ struct Layer { std::array linear_w; std::array pre_attention_norm_scale; std::array pre_ffw_norm_scale; + // These fields are only used by Griffin, and do not affect loading of the + // model as it is done per-member. + // TODO(veluca): pull weights that are never used at the same time into a + // union or otherwise reduce the memory usage. + std::array ffw_gating_biases; + std::array ffw_output_biases; + std::array attention_output_biases; + + std::array griffin_linear_y_w; + std::array griffin_linear_y_biases; + std::array griffin_linear_x_w; + std::array griffin_linear_x_biases; + std::array griffin_linear_out_w; + std::array griffin_linear_out_biases; + std::array griffin_conv_biases; + std::array griffin_gate_w; + std::array griffin_gate_biases; + std::array griffin_a; + std::array griffin_conv_w; }; float ScaleWeights(float* data, size_t len) { @@ -196,6 +243,7 @@ hwy::AlignedFreeUniquePtr LoadWeights( do_fread(&(weights->final_norm_scale), -1, "final_norm_scale", sizeof(weights->final_norm_scale)); for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + auto type = TConfig::kLayerConfig[layer]; Layer* layer_view = weights->GetLayer(layer); #define READ_WEIGHTS(name) \ @@ -212,16 +260,42 @@ hwy::AlignedFreeUniquePtr LoadWeights( } while (0) // Make sure we don't have uninitialized memory. hwy::ZeroBytes(layer_view, sizeof(*layer_view)); - READ_WEIGHTS(attn_vec_einsum_w); - READ_WEIGHTS(qkv_einsum_w); - SCALE_WEIGHTS(attn_vec_einsum_w); - SCALE_WEIGHTS(qkv_einsum_w); + if (type == LayerAttentionType::kGemma) { + READ_WEIGHTS(attn_vec_einsum_w); + READ_WEIGHTS(qkv_einsum_w); + SCALE_WEIGHTS(attn_vec_einsum_w); + SCALE_WEIGHTS(qkv_einsum_w); + } else { + READ_WEIGHTS(griffin_linear_x_w); + READ_WEIGHTS(griffin_linear_x_biases); + READ_WEIGHTS(griffin_linear_y_w); + READ_WEIGHTS(griffin_linear_y_biases); + READ_WEIGHTS(griffin_linear_out_w); + READ_WEIGHTS(griffin_linear_out_biases); + READ_WEIGHTS(griffin_conv_w); + READ_WEIGHTS(griffin_conv_biases); + READ_WEIGHTS(griffin_gate_w); + READ_WEIGHTS(griffin_gate_biases); + READ_WEIGHTS(griffin_a); + SCALE_WEIGHTS(griffin_linear_x_w); + SCALE_WEIGHTS(griffin_linear_y_w); + SCALE_WEIGHTS(griffin_linear_out_w); + SCALE_WEIGHTS(griffin_gate_w); + } READ_WEIGHTS(gating_einsum_w); READ_WEIGHTS(linear_w); SCALE_WEIGHTS(gating_einsum_w); SCALE_WEIGHTS(linear_w); READ_WEIGHTS(pre_attention_norm_scale); READ_WEIGHTS(pre_ffw_norm_scale); + if (TConfig::kFFBiases) { + READ_WEIGHTS(ffw_gating_biases); + READ_WEIGHTS(ffw_output_biases); + } + if (TConfig::kSoftmaxAttnOutputBiases && + type == LayerAttentionType::kGemma) { + READ_WEIGHTS(attention_output_biases); + } #undef READ_WEIGHTS } if (!ok) { @@ -253,14 +327,30 @@ struct CompressedLayer { // We don't yet have an RMSNorm that accepts all WeightT. CompressedArray pre_attention_norm_scale; CompressedArray pre_ffw_norm_scale; + CompressedArray ffw_gating_biases; + CompressedArray ffw_output_biases; + CompressedArray attention_output_biases; CompressedArray gating_einsum_w; CompressedArray linear_w; CompressedArray qkv_einsum_w; CompressedArray attn_vec_einsum_w; + + CompressedArray griffin_linear_y_w; + CompressedArray griffin_linear_x_w; + CompressedArray griffin_linear_out_w; + CompressedArray + griffin_gate_w; + CompressedArray griffin_a; + CompressedArray griffin_linear_y_biases; + CompressedArray griffin_linear_x_biases; + CompressedArray griffin_linear_out_biases; + CompressedArray griffin_conv_biases; + CompressedArray griffin_gate_biases; + CompressedArray griffin_conv_w; }; -// Array instead of single large allocation for parallel mem init. Split out of -// CompressedWeights so that only these pointers are initialized, not the +// Array instead of single large allocation for parallel mem init. Split out +// of CompressedWeights so that only these pointers are initialized, not the // CompressedArray. template struct CompressedLayerPointers { @@ -307,7 +397,8 @@ struct Activations { static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads; - static constexpr size_t kCachePosSize = TConfig::kLayers * kKVHeads * kQKVDim; + static constexpr size_t kCachePosSize = + NumGemmaLayers() * kKVHeads * kQKVDim; static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim; std::array x; // input @@ -327,6 +418,12 @@ struct Activations { // bf_ffw_hidden; std::array ffw_out; std::array logits; + + // Griffin layer internal activations + std::array griffin_x; + std::array griffin_y; + std::array griffin_gate_x; + std::array griffin_multiplier; }; // GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we @@ -353,8 +450,13 @@ struct GemmaInterface { template KVCache CreateKVCache() { - return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, - Config::kSeqLen); + constexpr size_t kConv1dWidth = Config::kConv1dWidth; + return CreateKVCache( + NumGemmaLayers() * Config::kKVHeads * Config::kQKVDim, + Config::kSeqLen, + NumGriffinLayers() * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) * + Config::kModelDim, + NumGriffinLayers() * Config::kModelDim); } KVCache CreateKVCache(Model type) { @@ -363,6 +465,8 @@ KVCache CreateKVCache(Model type) { return CreateKVCache(); case Model::GEMMA_7B: return CreateKVCache(); + case Model::GRIFFIN_2B: + return CreateKVCache(); default: HWY_ABORT("Model type %d unknown.", static_cast(type)); } @@ -379,7 +483,15 @@ class GemmaTokenizerImpl : public GemmaTokenizer { } bool Encode(const std::string& input, std::vector* pieces) const override { - return impl_->Encode(input, pieces).ok(); + if constexpr (kShowTokenization) { + bool is_ok = impl_->Encode(input, pieces).ok(); + for (int i = 0; i < static_cast(pieces->size()); i++) { + fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]); + } + return is_ok; + } else { + return impl_->Encode(input, pieces).ok(); + } } // Given a sequence of ids, decodes it into a detokenized output. bool Decode(const std::vector& ids, @@ -442,6 +554,119 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +template +HWY_NOINLINE void GriffinRecurrent( + size_t batch_start, size_t batch_idx, size_t layer, + Activations& activations, const LayerT* layer_weights, + KVCache& kv_cache, hwy::ThreadPool& pool) { + PROFILER_ZONE("Gen.Griffin"); + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + HWY_DASSERT(batch_idx < kBatchSize); + static constexpr size_t kModelDim = + gcpp::Activations::kModelDim; + static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; + static constexpr size_t kHeads = TConfig::kHeads; + const size_t batch_offset = batch_idx * kModelDim; + const size_t pos = batch_start + batch_idx; + + // X / Y linear layers. + float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; + float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; + TwoMatVecAdd( + layer_weights->griffin_linear_x_w, layer_weights->griffin_linear_y_w, 0, + activations.pre_att_rms_out.data() + batch_offset, + /*add0=*/layer_weights->griffin_linear_x_biases.data(), + /*add1=*/layer_weights->griffin_linear_y_biases.data(), /*out0=*/x, + /*out1=*/y, pool); + Gelu(y, kModelDim); + + // Conv1D. + { + HWY_FULL(float) df; + HWY_DASSERT(kModelDim % Lanes(df) == 0); + const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1); + + // cache[i] = input at time t-i. + float* HWY_RESTRICT cache[HWY_MAX(kConv1dWidth, 1)]; + cache[0] = x; + for (size_t i = 1; i < kConv1dWidth; i++) { + cache[i] = + kv_cache.conv1d_cache.get() + layer_offset + + ((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim; + } + for (size_t i = 0; i < kModelDim; i += Lanes(df)) { + auto xv = hn::Load(df, x + i); + auto accum0 = hn::Load(df, layer_weights->griffin_conv_biases.data() + i); + auto accum1 = hn::Zero(df); + static_assert(kConv1dWidth % 2 == 0, "Conv width must be even"); + for (size_t l = 0; 2 * l < kConv1dWidth; l++) { + auto wv0 = hn::Load(df, layer_weights->griffin_conv_w.data() + + (kConv1dWidth - 1 - 2 * l) * kModelDim + i); + auto wv1 = hn::Load(df, layer_weights->griffin_conv_w.data() + + (kConv1dWidth - 2 - 2 * l) * kModelDim + i); + accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); + accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); + } + hn::Store(hn::Add(accum0, accum1), df, x + i); + hn::Store(xv, df, cache[HWY_MAX(kConv1dWidth, 1) - 1] + i); + } + } + + // RGLRU + float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset; + float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset; + float* HWY_RESTRICT rnn_state = + kv_cache.rglru_cache.get() + layer * kModelDim; + + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + constexpr size_t kHeadDim = kModelDim / kHeads; + constexpr size_t kMatrixSize = kHeadDim * kHeadDim; + size_t head_offset = head * kHeadDim; + TwoOfsMatVecAddLoop( + layer_weights->griffin_gate_w, kMatrixSize * head, + kMatrixSize * (kHeads + head), x + head_offset, + /*add0=*/layer_weights->griffin_gate_biases.data() + head_offset, + /*add1=*/layer_weights->griffin_gate_biases.data() + kModelDim + + head_offset, + /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); + Sigmoid(gate_x + head_offset, kHeadDim); + Sigmoid(a + head_offset, kHeadDim); + const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) + HWY_ATTR { return hn::Mul(x, gate_x); }; + hn::Transform1(D(), a + head_offset, kHeadDim, + layer_weights->griffin_a.data() + head_offset, fn_mul); + hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, + fn_mul); + // RNN scan + HWY_FULL(float) df; + HWY_DASSERT(kHeadDim % Lanes(df) == 0); + for (size_t i = 0; i < kHeadDim; i += Lanes(df)) { + auto log_a = hn::Load(df, a + head_offset + i); + auto gated_x = hn::Load(df, x + head_offset + i); + auto rnn = hn::Load(df, rnn_state + head_offset + i); + auto a = hn::Exp(df, log_a); + auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0))); + if (pos == 0) { + x_multiplier = hn::Set(df, 1.0); + } + auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); + hn::Store(new_x, df, rnn_state + head_offset + i); + + // Join branches. + auto yv = hn::Load(df, y + head_offset + i); + auto pre_out = hn::Mul(yv, new_x); + hn::Store(pre_out, df, x + head_offset + i); + } + }); + + // Final linear layer. + float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; + MatVecAdd( + layer_weights->griffin_linear_out_w, 0, x, + layer_weights->griffin_linear_out_biases.data(), out_ptr, pool); +} + template HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Activations& activations, @@ -462,6 +687,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, static const float kQueryScale = static_cast(1.0 / sqrt(static_cast(kQKVDim))); + size_t cache_pos = pos; + size_t cache_num = pos + 1; + if constexpr (TConfig::kUseLocalAttention) { + cache_pos %= TConfig::kSeqLen; + cache_num = std::min(cache_num, static_cast(TConfig::kSeqLen)); + } + float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR { @@ -480,7 +712,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, TwoOfsMatVecLoop(layer_weights->qkv_einsum_w, k_offset, v_offset, x, k, v); - Rope(k, kQKVDim, pos); + Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); }; auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR { @@ -491,24 +723,24 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, head * TConfig::kSeqLen + batch_idx * kHeads * kQKVDim; - Rope(q, kQKVDim, pos); + Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); // Compute Q dot K scores - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { const size_t cache_offset = pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; } - Softmax(head_att, pos + 1); + Softmax(head_att, cache_num); // Weighted summation float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { const size_t cache_offset = pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; @@ -520,22 +752,34 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, head == 0 ? activations.att_post2.data() + batch_idx * kModelDim : activations.att_post1.data() + head * kBatchSize * kModelDim; - MatVecLoop(layer_weights->attn_vec_einsum_w, - head * kModelDim * kQKVDim, att_out, - head_out); + if (head == 0) { + MatVecAddLoop( + layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out, + layer_weights->attention_output_biases.data(), head_out); + } else { + MatVecLoop(layer_weights->attn_vec_einsum_w, + head * kModelDim * kQKVDim, att_out, + head_out); + } }; if constexpr (kHeads == kKVHeads) { // Multi-Head Attention pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - const size_t head_offset = head * 3 * kQKVDim * kModelDim; - - ProjQ(head, head_offset); + // linear projections to QKV + const size_t head_offset = TConfig::kInterleaveQKV + ? 3 * kQKVDim * kModelDim + : kQKVDim * kModelDim; + const size_t mat_offset = + TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim; + const size_t q_offset = head * head_offset + 0 * mat_offset; + const size_t k_offset = head * head_offset + 1 * mat_offset; + const size_t v_offset = head * head_offset + 2 * mat_offset; + + ProjQ(head, q_offset); - const size_t k_offset = head_offset + 1 * kQKVDim * kModelDim; - const size_t v_offset = head_offset + 2 * kQKVDim * kModelDim; const size_t kv_offset = - pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; ProjKV(k_offset, v_offset, kv_offset); @@ -546,7 +790,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; - const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; + const size_t kv_offset = + cache_pos * kCachePosSize + layer * kCacheLayerSize; + ProjKV(k_offset, v_offset, kv_offset); pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { @@ -581,13 +827,13 @@ HWY_NOINLINE void FFW(Activations& activations, // Same matrix, first and second half of rows. Could fuse into one MatVec, // but separating them could help on NUMA e.g. multiple sockets. - MatVec(layer_weights->gating_einsum_w, - kFFHiddenDim * kModelDim, vec, out_mul, - pool); - + MatVecAdd( + layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, + layer_weights->ffw_gating_biases.data() + kFFHiddenDim, out_mul, pool); // Gate, will go through the nonlinearity. - MatVec(layer_weights->gating_einsum_w, 0, vec, out, - pool); + MatVecAdd( + layer_weights->gating_einsum_w, 0, vec, + layer_weights->ffw_gating_biases.data(), out, pool); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; @@ -598,8 +844,9 @@ HWY_NOINLINE void FFW(Activations& activations, } PROFILER_ZONE("Gen.FFW\\GatedGELU"); - MatVec( + MatVecAdd( layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset, + layer_weights->ffw_output_biases.data(), activations.ffw_out.data() + batch_idx * kModelDim, pool); } @@ -639,15 +886,23 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, }); for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + auto type = TConfig::kLayerConfig[layer]; const auto* layer_weights = weights.GetLayer(layer); + size_t layer_of_type = + NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { RMSNorm(activations.x.data() + token_idx * kModelDim, layer_weights->pre_attention_norm_scale.data(), activations.pre_att_rms_out.data() + token_idx * kModelDim, kModelDim); - Attention(pos, token_idx, layer, activations, layer_weights, - kv_cache, pool); + if (type == LayerAttentionType::kGemma) { + Attention(pos, token_idx, layer_of_type, activations, + layer_weights, kv_cache, pool); + } else { + GriffinRecurrent(pos, token_idx, layer_of_type, activations, + layer_weights, kv_cache, pool); + } } // TODO: sink the loop into these functions, i.e. make them matmuls. @@ -678,9 +933,7 @@ template void Transformer(int token, size_t pos, const WeightArrayT& weights, Activations& activations, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) { - static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kModelDim = TConfig::kModelDim; - Decompress(weights.embedder_input_embedding, token * kModelDim, activations.x.data(), kModelDim); @@ -688,12 +941,21 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights, EmbeddingScaling(); MulByConst(kEmbScaling, activations.x.data(), kModelDim); - for (size_t layer = 0; layer < kLayers; ++layer) { + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + auto type = TConfig::kLayerConfig[layer]; const auto* layer_weights = weights.GetLayer(layer); + size_t layer_of_type = + NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); RMSNorm(activations.x.data(), layer_weights->pre_attention_norm_scale.data(), activations.pre_att_rms_out.data(), kModelDim); - Attention<1>(pos, 0, layer, activations, layer_weights, kv_cache, pool); + if (type == LayerAttentionType::kGemma) { + Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache, + pool); + } else { + GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights, + kv_cache, pool); + } AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim); RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(), activations.bf_pre_ffw_rms_out.data(), kModelDim); @@ -707,10 +969,12 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights, template void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, size_t& prompt_size) { - if (max_tokens > TConfig::kSeqLen) { - fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n", - max_tokens, TConfig::kSeqLen); - max_tokens = static_cast(TConfig::kSeqLen); + if (!TConfig::kUseLocalAttention) { + if (max_tokens > TConfig::kSeqLen) { + fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n", + max_tokens, TConfig::kSeqLen); + max_tokens = static_cast(TConfig::kSeqLen); + } } if (max_generated_tokens > max_tokens) { @@ -720,12 +984,14 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, max_generated_tokens = max_tokens - 1; } - if (prompt_size + max_generated_tokens > max_tokens) { - fprintf(stderr, - "WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen " - "%d, truncating.\n", - prompt_size, max_generated_tokens, TConfig::kSeqLen); - prompt_size = max_tokens - max_generated_tokens; + if (!TConfig::kUseLocalAttention) { + if (prompt_size + max_generated_tokens > max_tokens) { + fprintf(stderr, + "WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen " + "%d, truncating.\n", + prompt_size, max_generated_tokens, TConfig::kSeqLen); + prompt_size = max_tokens - max_generated_tokens; + } } } @@ -935,6 +1201,19 @@ void Generate7B(GemmaImpl& gemma, size_t max_tokens, accept_token, gen, verbosity); } +void GenerateGriffin2B(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { + GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, + start_pos, kv_cache, pool, inner_pool, stream_token, + accept_token, gen, verbosity); +} + float ComputeCrossEntropy2B(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, @@ -951,6 +1230,15 @@ float ComputeCrossEntropy7B(GemmaImpl& gemma, size_t max_tokens, inner_pool, verbosity); } +float ComputeCrossEntropyGriffin2B(GemmaImpl& gemma, + size_t max_tokens, + const std::vector& prompt, + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, int verbosity) { + return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, + inner_pool, verbosity); +} + // Calls func(name, float*, CompressedArray&) for each tensor. float* is null // if weights = null, which happens during the first call where we attempt to // load from cache. @@ -967,6 +1255,7 @@ void ForEachTensor(const Weights* weights, char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { + auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); const Layer* layer = weights ? weights->GetLayer(idx) : nullptr; CompressedLayer* layer_weights = c_weights.GetLayer(idx); @@ -978,9 +1267,33 @@ void ForEachTensor(const Weights* weights, CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale); CALL_FUNC("gating_ein", gating_einsum_w); CALL_FUNC("linear_w", linear_w); - CALL_FUNC("qkv_ein", qkv_einsum_w); - CALL_FUNC("att_ein", attn_vec_einsum_w); + if (type == LayerAttentionType::kGemma) { + CALL_FUNC("qkv_ein", qkv_einsum_w); + CALL_FUNC("att_ein", attn_vec_einsum_w); + } else { + CALL_FUNC("gr_lin_x_w", griffin_linear_x_w); + CALL_FUNC("gr_lin_x_b", griffin_linear_x_biases); + CALL_FUNC("gr_lin_y_w", griffin_linear_y_w); + CALL_FUNC("gr_lin_y_b", griffin_linear_y_biases); + CALL_FUNC("gr_lin_out_w", griffin_linear_out_w); + CALL_FUNC("gr_lin_out_b", griffin_linear_out_biases); + CALL_FUNC("gr_conv_w", griffin_conv_w); + CALL_FUNC("gr_conv_b", griffin_conv_biases); + CALL_FUNC("gr_gate_w", griffin_gate_w); + CALL_FUNC("gr_gate_b", griffin_gate_biases); + CALL_FUNC("gr_a", griffin_a); + } CALL_FUNC("pre_att_ns", pre_attention_norm_scale); + + if (TConfig::kFFBiases) { + CALL_FUNC("ffw_gat_b", ffw_gating_biases); + CALL_FUNC("ffw_out_b", ffw_output_biases); + } + + if (TConfig::kSoftmaxAttnOutputBiases && + type == LayerAttentionType::kGemma) { + CALL_FUNC("attn_ob", attention_output_biases); + } #undef CALL_FUNC } } @@ -1011,10 +1324,18 @@ hwy::AlignedFreeUniquePtr LoadCompressedWeights( if (TConfig::kNumTensorScales > 0) { size_t scale_pos = 0; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { + auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); CompressedLayer* layer_weights = c_weights->GetLayer(idx); - layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]); - layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]); + if (type == LayerAttentionType::kGemma) { + layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]); + layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]); + } else { + layer_weights->griffin_linear_x_w.set_scale(scales[scale_pos++]); + layer_weights->griffin_linear_y_w.set_scale(scales[scale_pos++]); + layer_weights->griffin_linear_out_w.set_scale(scales[scale_pos++]); + layer_weights->griffin_gate_w.set_scale(scales[scale_pos++]); + } layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]); layer_weights->linear_w.set_scale(scales[scale_pos++]); } @@ -1031,6 +1352,8 @@ hwy::AlignedFreeUniquePtr LoadCompressedWeightsT( return LoadCompressedWeights(weights, pool); case Model::GEMMA_7B: return LoadCompressedWeights(weights, pool); + case Model::GRIFFIN_2B: + return LoadCompressedWeights(weights, pool); default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } @@ -1044,6 +1367,8 @@ hwy::AlignedFreeUniquePtr LoadWeightsT(gcpp::Model model, return LoadWeights(weights, pool); case Model::GEMMA_7B: return LoadWeights(weights, pool); + case Model::GRIFFIN_2B: + return LoadWeights(weights, pool); default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } @@ -1089,6 +1414,9 @@ void CompressWeightsT(gcpp::Model model, const Path& weights, case Model::GEMMA_7B: CompressWeights(weights, compressed_weights, pool); break; + case Model::GRIFFIN_2B: + CompressWeights(weights, compressed_weights, pool); + break; default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } @@ -1106,13 +1434,29 @@ HWY_EXPORT(LoadWeightsT); HWY_EXPORT(CompressWeightsT); HWY_EXPORT(Generate2B); HWY_EXPORT(Generate7B); +HWY_EXPORT(GenerateGriffin2B); HWY_EXPORT(ComputeCrossEntropy2B); HWY_EXPORT(ComputeCrossEntropy7B); +HWY_EXPORT(ComputeCrossEntropyGriffin2B); -KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, + size_t conv_cache_size, size_t rglru_cache_size) { KVCache kv_cache = {}; - kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + if (size_cache_pos != 0) { + kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + kv_cache.value_cache = + hwy::AllocateAligned(seq_len * size_cache_pos); + } + if (conv_cache_size != 0) { + kv_cache.conv1d_cache = hwy::AllocateAligned(conv_cache_size); + hwy::ZeroBytes(kv_cache.conv1d_cache.get(), + conv_cache_size * sizeof(kv_cache.conv1d_cache[0])); + } + if (rglru_cache_size != 0) { + kv_cache.rglru_cache = hwy::AllocateAligned(rglru_cache_size); + hwy::ZeroBytes(kv_cache.rglru_cache.get(), + rglru_cache_size * sizeof(kv_cache.rglru_cache[0])); + } return kv_cache; } @@ -1136,6 +1480,7 @@ void GemmaImpl::Generate( (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } + template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, @@ -1148,6 +1493,18 @@ void GemmaImpl::Generate( kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } +template <> +void GemmaImpl::Generate( + size_t max_tokens, size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { + HWY_DYNAMIC_DISPATCH(GenerateGriffin2B) + (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); +} + template <> float GemmaImpl::ComputeCrossEntropy( size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, @@ -1164,6 +1521,14 @@ float GemmaImpl::ComputeCrossEntropy( *this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); } +template <> +float GemmaImpl::ComputeCrossEntropy( + size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) { + return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)( + *this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); +} + Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, hwy::ThreadPool& pool) { std::unique_ptr tokenizer; @@ -1190,6 +1555,9 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, case Model::GEMMA_7B: impl_.reset(new GemmaImpl(tokenizer, weights_u8, pool)); break; + case Model::GRIFFIN_2B: + impl_.reset(new GemmaImpl(tokenizer, weights_u8, pool)); + break; default: HWY_ABORT("Model type %d unknown.", static_cast(model_type)); } @@ -1240,5 +1608,42 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, return result; } +namespace { +constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", + "2b-it", "7b-it", "gr2b-it"}; +constexpr Model kModelTypes[] = {Model::GEMMA_2B, Model::GEMMA_7B, + Model::GRIFFIN_2B, Model::GEMMA_2B, + Model::GEMMA_7B, Model::GRIFFIN_2B}; +constexpr ModelTraining kModelTraining[] = { + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, + ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT}; +} // namespace + +const char* ParseModelTypeAndTraining(const std::string& model_flag, + Model& model, ModelTraining& training) { + constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags); + static char kErrorMessageBuffer[kNum * 8 + 1024]; + kErrorMessageBuffer[0] = 0; + strcat(kErrorMessageBuffer, + "Invalid or missing model flag, need to specify one of "); + for (size_t i = 0; i + 1 < kNum; i++) { + strcat(kErrorMessageBuffer, kModelFlags[i]); + strcat(kErrorMessageBuffer, ", "); + } + strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); + strcat(kErrorMessageBuffer, "."); + std::string model_type_lc = model_flag; + std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc), + [](unsigned char c) { return std::tolower(c); }); + for (size_t i = 0; i < kNum; i++) { + if (kModelFlags[i] == model_type_lc) { + model = kModelTypes[i]; + training = kModelTraining[i]; + return nullptr; + } + } + return kErrorMessageBuffer; +} + } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma.h b/gemma.h index 00141c99..8ae35778 100644 --- a/gemma.h +++ b/gemma.h @@ -42,15 +42,24 @@ constexpr bool kSystemPrompt = false; struct KVCache { hwy::AlignedFreeUniquePtr - key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim + key_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim hwy::AlignedFreeUniquePtr - value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim + value_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim + hwy::AlignedFreeUniquePtr + conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kNumGriffinLayers + hwy::AlignedFreeUniquePtr + rglru_cache; // kModelDim * kNumGriffinLayers }; // Model variants: see configs.h for details. -enum class Model { GEMMA_2B, GEMMA_7B }; +enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; +// Returns error string or nullptr if OK. +// Thread-hostile. +const char* ParseModelTypeAndTraining(const std::string& model_flag, + Model& model, ModelTraining& training); + struct RuntimeConfig { size_t max_tokens; size_t max_generated_tokens; @@ -79,7 +88,8 @@ struct Gemma { }; KVCache CreateKVCache(Model type); // convenient workaround for now -KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, + size_t conv1d_cache_size, size_t rglru_cache_size); // StreamFunc is called with (token, probability). For prompt tokens, // probability is 0.0f. diff --git a/run.cc b/run.cc index e16910de..1a3fa0a1 100644 --- a/run.cc +++ b/run.cc @@ -27,8 +27,6 @@ #include "compression/compress.h" // copybara:import_next_line:gemma_cpp #include "gemma.h" // Gemma -// copybara:import_next_line:gemma_cpp -#include "util/app.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -36,8 +34,12 @@ #include "hwy/profiler.h" #include "hwy/timer.h" // copybara:import_next_line:gemma_cpp +#include "util/app.h" +// copybara:import_next_line:gemma_cpp #include "util/args.h" // HasHelp +static constexpr bool kVerboseLogTokens = false; + namespace gcpp { static constexpr std::string_view kAsciiArtBanner = R""( @@ -203,6 +205,12 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, std::cerr << "\n" << "[ Reading prompt ] " << std::flush; + if constexpr (kVerboseLogTokens) { + for (int i = 0; i < static_cast(prompt.size()); ++i) { + fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); + } + } + const double time_start = hwy::platform::Now(); GenerateGemma(model, args.max_tokens, args.max_generated_tokens, args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, diff --git a/util/app.h b/util/app.h index f67849bc..af267123 100644 --- a/util/app.h +++ b/util/app.h @@ -125,46 +125,21 @@ class AppArgs : public ArgsBase { struct LoaderArgs : public ArgsBase { LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - static std::string ToLower(const std::string& text) { - std::string result = text; - std::transform(begin(result), end(result), begin(result), - [](unsigned char c) { return std::tolower(c); }); - return result; - } + gcpp::Model ModelType() const { return model_type; } - gcpp::Model ModelType() const { - const std::string model_type_lc = ToLower(model_type); - if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") { - return gcpp::Model::GEMMA_2B; - } else { - return gcpp::Model::GEMMA_7B; - } - } - - gcpp::ModelTraining ModelTraining() const { - const std::string model_type_lc = ToLower(model_type); - if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") { - return gcpp::ModelTraining::GEMMA_PT; - } else { - return gcpp::ModelTraining::GEMMA_IT; - } - } + gcpp::ModelTraining ModelTraining() const { return model_training; } // Returns error string or nullptr if OK. const char* Validate() { - const std::string model_type_lc = ToLower(model_type); - if (model_type.empty()) { - return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " - "2b-it, or 7b-it."; - } - if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && - model_type_lc != "2b-it" && model_type_lc != "7b-it") { - return "Model type must be 2b-pt, 7b-pt, 2b-it, or " - "7b-it."; - } + const char* parse_result = + ParseModelTypeAndTraining(model_type_str, model_type, model_training); + if (parse_result) return parse_result; if (tokenizer.path.empty()) { return "Missing --tokenizer flag, a file for the tokenizer is required."; } + if (!tokenizer.exists()) { + return "Can't open file specified with --tokenizer flag."; + } if (!compressed_weights.path.empty()) { if (weights.path.empty()) { weights = compressed_weights; @@ -186,7 +161,9 @@ struct LoaderArgs : public ArgsBase { Path tokenizer; Path weights; // weights file location Path compressed_weights; - std::string model_type; + std::string model_type_str; + Model model_type; + enum ModelTraining model_training; template void ForEach(const Visitor& visitor) { @@ -196,10 +173,12 @@ struct LoaderArgs : public ArgsBase { "Path name of model weights (.sbs) file.\n Required argument."); visitor(compressed_weights, "compressed_weights", Path(), "Alias for --weights."); - visitor(model_type, "model", std::string(), + visitor(model_type_str, "model", std::string(), "Model type\n 2b-it = 2B parameters, instruction-tuned\n " "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " - "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n" + "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n " + "gr2b-it = griffin 2B parameters, instruction-tuned\n " + "gr2b-pt = griffin 2B parameters, pretrained\n " " Required argument."); } };