Skip to content

Commit 1982a6b

Browse files
pcullitoncopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 657831926
1 parent a24eda8 commit 1982a6b

File tree

8 files changed

+110
-6
lines changed

8 files changed

+110
-6
lines changed

BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ cc_library(
170170
"gemma/instantiations/gr2b_bf16.cc",
171171
"gemma/instantiations/gr2b_f32.cc",
172172
"gemma/instantiations/gr2b_sfp.cc",
173+
"gemma/instantiations/gemma2_2b_bf16.cc",
174+
"gemma/instantiations/gemma2_2b_f32.cc",
175+
"gemma/instantiations/gemma2_2b_sfp.cc",
173176
],
174177
hdrs = [
175178
"gemma/activations.h",

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ set(SOURCES
9191
gemma/instantiations/tiny_bf16.cc
9292
gemma/instantiations/tiny_f32.cc
9393
gemma/instantiations/tiny_sfp.cc
94+
gemma/instantiations/gemma2_2b_bf16.cc
95+
gemma/instantiations/gemma2_2b_f32.cc
96+
gemma/instantiations/gemma2_2b_sfp.cc
9497
gemma/kv_cache.cc
9598
gemma/kv_cache.h
9699
gemma/tokenizer.cc

gemma/common.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@
2929
namespace gcpp {
3030

3131
constexpr const char* kModelFlags[] = {
32-
"2b-pt", "2b-it", // Gemma 2B
33-
"7b-pt", "7b-it", // Gemma 7B
34-
"9b-pt", "9b-it", // Gemma 9B
35-
"27b-pt", "27b-it", // Gemma 27B
36-
"gr2b-pt", "gr2b-it", // RecurrentGemma
37-
"tiny", // Gemma Tiny (mostly for debugging)
32+
"2b-pt", "2b-it", // Gemma 2B
33+
"7b-pt", "7b-it", // Gemma 7B
34+
"9b-pt", "9b-it", // Gemma 9B
35+
"27b-pt", "27b-it", // Gemma 27B
36+
"gr2b-pt", "gr2b-it", // RecurrentGemma
37+
"tiny", // Gemma Tiny (mostly for debugging)
38+
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
3839
};
3940
constexpr Model kModelTypes[] = {
4041
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
@@ -43,6 +44,7 @@ constexpr Model kModelTypes[] = {
4344
Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B
4445
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
4546
Model::GEMMA_TINY, // Gemma Tiny
47+
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
4648
};
4749
constexpr ModelTraining kModelTraining[] = {
4850
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
@@ -51,6 +53,7 @@ constexpr ModelTraining kModelTraining[] = {
5153
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B
5254
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
5355
ModelTraining::GEMMA_IT, // Gemma Tiny
56+
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B2
5457
};
5558

5659
constexpr size_t kNumModelFlags =

gemma/common.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ enum class Model {
4545
GEMMA_27B,
4646
GRIFFIN_2B,
4747
GEMMA_TINY,
48+
GEMMA2_2B,
4849
};
4950

5051
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
@@ -99,6 +100,9 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) {
99100
return FuncT<ConfigGemma27B<TWeight>>()(std::forward<TArgs>(args)...);
100101
case Model::GRIFFIN_2B:
101102
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
103+
case Model::GEMMA2_2B:
104+
return FuncT<ConfigGemma2_2B<TWeight>>()(std::forward<TArgs>(args)...);
105+
102106
default:
103107
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
104108
}
@@ -142,6 +146,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
142146
GEMMA_FOREACH_WEIGHT(X, ConfigGemma9B) \
143147
GEMMA_FOREACH_WEIGHT(X, ConfigGemma27B) \
144148
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
149+
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \
145150
static_assert(true, "Allow trailing ;")
146151

147152
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
@@ -178,6 +183,11 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
178183
ARGS; \
179184
break; \
180185
} \
186+
case Model::GEMMA2_2B: { \
187+
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2_2B<TWEIGHT>>) \
188+
ARGS; \
189+
break; \
190+
} \
181191
default: \
182192
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
183193
}

gemma/configs.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,28 @@ struct ConfigGemma2B : public ConfigBaseGemmaV1 {
253253
static constexpr bool kAbsolutePE = false;
254254
};
255255

256+
template <typename TWeight>
257+
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
258+
using Weight = TWeight; // make accessible where we only have a TConfig
259+
260+
static constexpr int kSeqLen = 8192;
261+
static constexpr int kVocabSize = 256000;
262+
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
263+
FixedLayerConfig<26>(LayerAttentionType::kGemma);
264+
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
265+
RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
266+
static constexpr int kLayers = kLayerConfig.size();
267+
static constexpr int kGemmaLayers = kLayers;
268+
static constexpr int kModelDim = 2304;
269+
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
270+
static constexpr int kHeads = 8;
271+
static constexpr int kKVHeads = 4;
272+
static constexpr int kQKVDim = 256; // query size == key size == value size
273+
static constexpr int kTopK = gcpp::kTopK;
274+
static constexpr bool kAbsolutePE = false;
275+
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
276+
};
277+
256278
template <typename TWeight>
257279
struct ConfigGemmaTiny : public ConfigNoSSM {
258280
using Weight = TWeight; // make accessible where we only have a TConfig
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Copyright 2024 Google LLC
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#undef HWY_TARGET_INCLUDE
17+
#define HWY_TARGET_INCLUDE \
18+
"gemma/instantiations/gemma2_2b_bf16.cc"
19+
#include "hwy/foreach_target.h" // IWYU pragma: keep
20+
#define GEMMA_CONFIG ConfigGemma2_2B<hwy::bfloat16_t>
21+
#include "gemma/gemma-inl.h"

gemma/instantiations/gemma2_2b_f32.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Copyright 2024 Google LLC
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#undef HWY_TARGET_INCLUDE
17+
#define HWY_TARGET_INCLUDE \
18+
"gemma/instantiations/gemma2_2b_f32.cc"
19+
#include "hwy/foreach_target.h" // IWYU pragma: keep
20+
#define GEMMA_CONFIG ConfigGemma2_2B<float>
21+
#include "gemma/gemma-inl.h"

gemma/instantiations/gemma2_2b_sfp.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Copyright 2024 Google LLC
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#undef HWY_TARGET_INCLUDE
17+
#define HWY_TARGET_INCLUDE \
18+
"gemma/instantiations/gemma2_2b_sfp.cc"
19+
#include "hwy/foreach_target.h" // IWYU pragma: keep
20+
#define GEMMA_CONFIG ConfigGemma2_2B<SfpStream>
21+
#include "gemma/gemma-inl.h"

0 commit comments

Comments
 (0)