30
30
31
31
#include < stddef.h>
32
32
33
+ #include < array>
34
+
33
35
// copybara:import_next_line:gemma_cpp
34
36
#include " compression/sfp.h"
35
37
#include " hwy/base.h" // hwy::bfloat16_t
@@ -45,34 +47,121 @@ namespace gcpp {
45
47
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
46
48
static constexpr size_t kTopK = GEMMA_TOPK;
47
49
50
+ enum class LayerAttentionType {
51
+ kGemma ,
52
+ kGriffinRecurrentBlock ,
53
+ };
54
+
55
+ template <size_t kNum >
56
+ constexpr std::array<LayerAttentionType, kNum > FixedLayerConfig (
57
+ LayerAttentionType type) {
58
+ std::array<LayerAttentionType, kNum > config = {};
59
+ for (LayerAttentionType& l : config) {
60
+ l = type;
61
+ }
62
+ return config;
63
+ }
64
+
48
65
struct ConfigGemma7B {
49
66
static constexpr int kSeqLen = gcpp::kSeqLen ;
50
67
static constexpr int kVocabSize = 256000 ;
51
- static constexpr int kLayers = 28 ;
68
+ static constexpr std::array<LayerAttentionType, 28 > kLayerConfig =
69
+ FixedLayerConfig<28 >(LayerAttentionType::kGemma );
70
+ static constexpr int kLayers = kLayerConfig .size();
52
71
static constexpr int kModelDim = 3072 ;
53
72
static constexpr int kFFHiddenDim = 16 * 3072 / 2 ; // = 24576
54
73
static constexpr int kHeads = 16 ;
55
74
static constexpr int kKVHeads = 16 ; // standard MHA
56
75
static constexpr int kQKVDim = 256 ; // query size == key size == value size
57
76
static constexpr int kTopK = gcpp::kTopK ;
77
+
78
+ // SSM config.
79
+ static constexpr int kConv1dWidth = 0 ;
80
+ static constexpr bool kFFBiases = false ;
81
+ static constexpr bool kSoftmaxAttnOutputBiases = false ;
82
+ static constexpr bool kUseHalfRope = false ;
83
+ static constexpr bool kUseLocalAttention = false ;
84
+ static constexpr bool kInterleaveQKV = true ;
58
85
static constexpr int kNumTensorScales = 0 ;
59
86
using WeightT = GEMMA_WEIGHT_T;
60
87
};
61
88
62
89
struct ConfigGemma2B {
63
90
static constexpr int kSeqLen = gcpp::kSeqLen ;
64
91
static constexpr int kVocabSize = 256000 ;
65
- static constexpr int kLayers = 18 ;
92
+ static constexpr std::array<LayerAttentionType, 18 > kLayerConfig =
93
+ FixedLayerConfig<18 >(LayerAttentionType::kGemma );
94
+ static constexpr int kLayers = kLayerConfig .size();
66
95
static constexpr int kModelDim = 2048 ;
67
96
static constexpr int kFFHiddenDim = 16 * 2048 / 2 ; // = 16384
68
97
static constexpr int kHeads = 8 ;
69
98
static constexpr int kKVHeads = 1 ;
70
99
static constexpr int kQKVDim = 256 ; // query size == key size == value size
71
100
static constexpr int kTopK = gcpp::kTopK ;
101
+
102
+ // SSM config.
103
+ static constexpr int kConv1dWidth = 0 ;
104
+ static constexpr bool kFFBiases = false ;
105
+ static constexpr bool kSoftmaxAttnOutputBiases = false ;
106
+ static constexpr bool kUseHalfRope = false ;
107
+ static constexpr bool kUseLocalAttention = false ;
108
+ static constexpr bool kInterleaveQKV = true ;
72
109
static constexpr int kNumTensorScales = 0 ;
73
110
using WeightT = GEMMA_WEIGHT_T;
74
111
};
75
112
113
+ struct ConfigGriffin2B {
114
+ // Griffin uses local attention, so kSeqLen is actually the local attention
115
+ // window.
116
+ static constexpr int kSeqLen = 2048 ;
117
+ static constexpr int kVocabSize = 256000 ;
118
+ static constexpr std::array<LayerAttentionType, 26 > kLayerConfig = {
119
+ LayerAttentionType::kGriffinRecurrentBlock ,
120
+ LayerAttentionType::kGriffinRecurrentBlock ,
121
+ LayerAttentionType::kGemma ,
122
+ LayerAttentionType::kGriffinRecurrentBlock ,
123
+ LayerAttentionType::kGriffinRecurrentBlock ,
124
+ LayerAttentionType::kGemma ,
125
+ LayerAttentionType::kGriffinRecurrentBlock ,
126
+ LayerAttentionType::kGriffinRecurrentBlock ,
127
+ LayerAttentionType::kGemma ,
128
+ LayerAttentionType::kGriffinRecurrentBlock ,
129
+ LayerAttentionType::kGriffinRecurrentBlock ,
130
+ LayerAttentionType::kGemma ,
131
+ LayerAttentionType::kGriffinRecurrentBlock ,
132
+ LayerAttentionType::kGriffinRecurrentBlock ,
133
+ LayerAttentionType::kGemma ,
134
+ LayerAttentionType::kGriffinRecurrentBlock ,
135
+ LayerAttentionType::kGriffinRecurrentBlock ,
136
+ LayerAttentionType::kGemma ,
137
+ LayerAttentionType::kGriffinRecurrentBlock ,
138
+ LayerAttentionType::kGriffinRecurrentBlock ,
139
+ LayerAttentionType::kGemma ,
140
+ LayerAttentionType::kGriffinRecurrentBlock ,
141
+ LayerAttentionType::kGriffinRecurrentBlock ,
142
+ LayerAttentionType::kGemma ,
143
+ LayerAttentionType::kGriffinRecurrentBlock ,
144
+ LayerAttentionType::kGriffinRecurrentBlock ,
145
+ };
146
+ static constexpr int kLayers = kLayerConfig .size();
147
+ static constexpr int kModelDim = 2560 ;
148
+ static constexpr int kFFHiddenDim = 7680 ;
149
+ static constexpr int kHeads = 10 ;
150
+ static constexpr int kKVHeads = 1 ;
151
+ static constexpr int kQKVDim = 256 ; // query size == key size == value size
152
+ static constexpr int kTopK = gcpp::kTopK ;
153
+
154
+ // SSM config.
155
+ static constexpr int kConv1dWidth = 4 ;
156
+ static constexpr bool kFFBiases = true ;
157
+ static constexpr bool kSoftmaxAttnOutputBiases = true ;
158
+ static constexpr bool kUseHalfRope = true ;
159
+ static constexpr bool kUseLocalAttention = true ;
160
+ static constexpr bool kInterleaveQKV = false ;
161
+ static constexpr int kNumTensorScales = 140 ;
162
+ using WeightT = GEMMA_WEIGHT_T;
163
+ };
164
+
76
165
} // namespace gcpp
77
166
78
167
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
0 commit comments