Skip to content

Commit 83dd08a

Browse files
Merge pull request #136 from pculliton:griffin
PiperOrigin-RevId: 623054233
2 parents 4326249 + 9c3f969 commit 83dd08a

File tree

8 files changed

+640
-137
lines changed

8 files changed

+640
-137
lines changed

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,24 @@ Example invocation for the following configuration:
241241
--model 2b-it
242242
```
243243

244+
### RecurrentGemma
245+
246+
This repository includes a version of Gemma based on Griffin
247+
([paper](https://arxiv.org/abs/2402.19427),
248+
[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture
249+
includes both recurrent layers and local attention, thus it is more efficient
250+
for longer sequences and has a smaller memory footprint than standard Gemma. We
251+
here provide a C++ implementation of this model based on the paper.
252+
253+
To use the recurrent version of Gemma included in this repository, build the
254+
gemma binary as noted above in Step 3. Download the compressed weights and
255+
tokenizer from
256+
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
257+
Step 1, and run the binary as follows:
258+
259+
`./gemma --tokenizer tokenizer.spm --model gr2b-it --compressed_weights 2b-it-sfp.sbs`
260+
261+
244262
### Troubleshooting and FAQs
245263

246264
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
@@ -478,4 +496,9 @@ gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.
478496
and [Jan Wassenberg](mailto:[email protected]), and subsequently released February 2024
479497
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
480498

499+
Griffin support was implemented in April 2024 thanks to contributions by Andrey
500+
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
501+
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
502+
Fischbacher and Zoltan Szabadka.
503+
481504
This is not an officially supported Google product.

benchmark.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
#include "nlohmann/json.hpp"
1111
// copybara:import_next_line:gemma_cpp
1212
#include "gemma.h"
13-
// copybara:import_next_line:gemma_cpp
14-
#include "util/app.h"
15-
// copybara:import_next_line:gemma_cpp
16-
#include "util/args.h"
1713
#include "hwy/base.h"
1814
#include "hwy/contrib/thread_pool/thread_pool.h"
1915
#include "hwy/highway.h"
2016
#include "hwy/timer.h"
17+
// copybara:import_next_line:gemma_cpp
18+
#include "util/app.h"
19+
// copybara:import_next_line:gemma_cpp
20+
#include "util/args.h"
2121

2222
using json = nlohmann::json;
2323

@@ -259,6 +259,13 @@ int main(int argc, char** argv) {
259259
gcpp::AppArgs app(argc, argv);
260260
BenchmarkArgs benchmark_args(argc, argv);
261261

262+
if (const char* error = loader.Validate()) {
263+
HWY_ABORT("\nInvalid loader args: %s", error);
264+
}
265+
if (const char* error = args.Validate()) {
266+
HWY_ABORT("\nInvalid inference args: %s", error);
267+
}
268+
262269
hwy::ThreadPool inner_pool(0);
263270
hwy::ThreadPool pool(app.num_threads);
264271
// For many-core, pinning threads to cores helps.
@@ -275,7 +282,7 @@ int main(int argc, char** argv) {
275282

276283
if (!benchmark_args.goldens.path.empty()) {
277284
const std::string golden_path =
278-
benchmark_args.goldens.path + "/" + loader.model_type + ".txt";
285+
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
279286
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
280287
golden_path);
281288
} else if (!benchmark_args.summarize_text.path.empty()) {

compress_weights.cc

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,35 +44,14 @@ struct Args : public ArgsBase<Args> {
4444
ChooseNumThreads();
4545
}
4646

47-
static std::string ToLower(const std::string& text) {
48-
std::string result = text;
49-
std::transform(begin(result), end(result), begin(result),
50-
[](unsigned char c) { return std::tolower(c); });
51-
return result;
52-
}
53-
54-
gcpp::Model ModelType() const {
55-
const std::string model_type_lc = ToLower(model_type);
56-
if (model_type_lc.substr(0, 2) == "2b") {
57-
return gcpp::Model::GEMMA_2B;
58-
} else if (model_type_lc.substr(0, 2) == "7b") {
59-
return gcpp::Model::GEMMA_7B;
60-
} else {
61-
HWY_ABORT("Unknown model type %s", model_type_lc.c_str());
62-
}
63-
}
47+
gcpp::Model ModelType() const { return model_type; }
6448

6549
// Returns error string or nullptr if OK.
66-
const char* Validate() const {
67-
const std::string model_type_lc = ToLower(model_type);
68-
if (model_type.empty()) {
69-
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
70-
"2b-it, 7b-it.";
71-
}
72-
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
73-
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
74-
return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it.";
75-
}
50+
const char* Validate() {
51+
ModelTraining model_training;
52+
const char* parse_result =
53+
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
54+
if (parse_result) return parse_result;
7655
if (weights.path.empty()) {
7756
return "Missing --weights flag, a file for the uncompressed model.";
7857
}
@@ -88,18 +67,21 @@ struct Args : public ArgsBase<Args> {
8867

8968
Path weights; // uncompressed weights file location
9069
Path compressed_weights; // compressed weights file location
91-
std::string model_type;
70+
std::string model_type_str;
71+
Model model_type;
9272
size_t num_threads;
9373

9474
template <class Visitor>
9575
void ForEach(const Visitor& visitor) {
9676
visitor(weights, "weights", Path(),
9777
"Path name of model weights (.sbs) file.\n"
9878
" Required argument.");
99-
visitor(model_type, "model", std::string(),
79+
visitor(model_type_str, "model", std::string(),
10080
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
10181
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
10282
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
83+
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
84+
"gr2b-pt = griffin 2B parameters, pretrained\n "
10385
" Required argument.");
10486
visitor(compressed_weights, "compressed_weights", Path(),
10587
"Path name where compressed weights file will be written.\n"
@@ -115,7 +97,7 @@ struct Args : public ArgsBase<Args> {
11597
void ShowHelp(gcpp::Args& args) {
11698
std::cerr
11799
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
118-
" --model <model type> --compressed_weights <output path>\n";
100+
" --model <model type> --compressed_weights <output path>\n";
119101
std::cerr << "\n*Arguments*\n\n";
120102
args.Help();
121103
std::cerr << "\n";

configs.h

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
#include <stddef.h>
3232

33+
#include <array>
34+
3335
// copybara:import_next_line:gemma_cpp
3436
#include "compression/sfp.h"
3537
#include "hwy/base.h" // hwy::bfloat16_t
@@ -45,34 +47,121 @@ namespace gcpp {
4547
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
4648
static constexpr size_t kTopK = GEMMA_TOPK;
4749

50+
enum class LayerAttentionType {
51+
kGemma,
52+
kGriffinRecurrentBlock,
53+
};
54+
55+
template <size_t kNum>
56+
constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
57+
LayerAttentionType type) {
58+
std::array<LayerAttentionType, kNum> config = {};
59+
for (LayerAttentionType& l : config) {
60+
l = type;
61+
}
62+
return config;
63+
}
64+
4865
struct ConfigGemma7B {
4966
static constexpr int kSeqLen = gcpp::kSeqLen;
5067
static constexpr int kVocabSize = 256000;
51-
static constexpr int kLayers = 28;
68+
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
69+
FixedLayerConfig<28>(LayerAttentionType::kGemma);
70+
static constexpr int kLayers = kLayerConfig.size();
5271
static constexpr int kModelDim = 3072;
5372
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
5473
static constexpr int kHeads = 16;
5574
static constexpr int kKVHeads = 16; // standard MHA
5675
static constexpr int kQKVDim = 256; // query size == key size == value size
5776
static constexpr int kTopK = gcpp::kTopK;
77+
78+
// SSM config.
79+
static constexpr int kConv1dWidth = 0;
80+
static constexpr bool kFFBiases = false;
81+
static constexpr bool kSoftmaxAttnOutputBiases = false;
82+
static constexpr bool kUseHalfRope = false;
83+
static constexpr bool kUseLocalAttention = false;
84+
static constexpr bool kInterleaveQKV = true;
5885
static constexpr int kNumTensorScales = 0;
5986
using WeightT = GEMMA_WEIGHT_T;
6087
};
6188

6289
struct ConfigGemma2B {
6390
static constexpr int kSeqLen = gcpp::kSeqLen;
6491
static constexpr int kVocabSize = 256000;
65-
static constexpr int kLayers = 18;
92+
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
93+
FixedLayerConfig<18>(LayerAttentionType::kGemma);
94+
static constexpr int kLayers = kLayerConfig.size();
6695
static constexpr int kModelDim = 2048;
6796
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
6897
static constexpr int kHeads = 8;
6998
static constexpr int kKVHeads = 1;
7099
static constexpr int kQKVDim = 256; // query size == key size == value size
71100
static constexpr int kTopK = gcpp::kTopK;
101+
102+
// SSM config.
103+
static constexpr int kConv1dWidth = 0;
104+
static constexpr bool kFFBiases = false;
105+
static constexpr bool kSoftmaxAttnOutputBiases = false;
106+
static constexpr bool kUseHalfRope = false;
107+
static constexpr bool kUseLocalAttention = false;
108+
static constexpr bool kInterleaveQKV = true;
72109
static constexpr int kNumTensorScales = 0;
73110
using WeightT = GEMMA_WEIGHT_T;
74111
};
75112

113+
struct ConfigGriffin2B {
114+
// Griffin uses local attention, so kSeqLen is actually the local attention
115+
// window.
116+
static constexpr int kSeqLen = 2048;
117+
static constexpr int kVocabSize = 256000;
118+
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
119+
LayerAttentionType::kGriffinRecurrentBlock,
120+
LayerAttentionType::kGriffinRecurrentBlock,
121+
LayerAttentionType::kGemma,
122+
LayerAttentionType::kGriffinRecurrentBlock,
123+
LayerAttentionType::kGriffinRecurrentBlock,
124+
LayerAttentionType::kGemma,
125+
LayerAttentionType::kGriffinRecurrentBlock,
126+
LayerAttentionType::kGriffinRecurrentBlock,
127+
LayerAttentionType::kGemma,
128+
LayerAttentionType::kGriffinRecurrentBlock,
129+
LayerAttentionType::kGriffinRecurrentBlock,
130+
LayerAttentionType::kGemma,
131+
LayerAttentionType::kGriffinRecurrentBlock,
132+
LayerAttentionType::kGriffinRecurrentBlock,
133+
LayerAttentionType::kGemma,
134+
LayerAttentionType::kGriffinRecurrentBlock,
135+
LayerAttentionType::kGriffinRecurrentBlock,
136+
LayerAttentionType::kGemma,
137+
LayerAttentionType::kGriffinRecurrentBlock,
138+
LayerAttentionType::kGriffinRecurrentBlock,
139+
LayerAttentionType::kGemma,
140+
LayerAttentionType::kGriffinRecurrentBlock,
141+
LayerAttentionType::kGriffinRecurrentBlock,
142+
LayerAttentionType::kGemma,
143+
LayerAttentionType::kGriffinRecurrentBlock,
144+
LayerAttentionType::kGriffinRecurrentBlock,
145+
};
146+
static constexpr int kLayers = kLayerConfig.size();
147+
static constexpr int kModelDim = 2560;
148+
static constexpr int kFFHiddenDim = 7680;
149+
static constexpr int kHeads = 10;
150+
static constexpr int kKVHeads = 1;
151+
static constexpr int kQKVDim = 256; // query size == key size == value size
152+
static constexpr int kTopK = gcpp::kTopK;
153+
154+
// SSM config.
155+
static constexpr int kConv1dWidth = 4;
156+
static constexpr bool kFFBiases = true;
157+
static constexpr bool kSoftmaxAttnOutputBiases = true;
158+
static constexpr bool kUseHalfRope = true;
159+
static constexpr bool kUseLocalAttention = true;
160+
static constexpr bool kInterleaveQKV = false;
161+
static constexpr int kNumTensorScales = 140;
162+
using WeightT = GEMMA_WEIGHT_T;
163+
};
164+
76165
} // namespace gcpp
77166

78167
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_

0 commit comments

Comments
 (0)