Skip to content

Commit 9c3e089

Browse files
The gemma.cpp Authorscopybara-github
The gemma.cpp Authors
authored andcommitted
Internal change.
PiperOrigin-RevId: 765218260
1 parent 1e8642f commit 9c3e089

File tree

2 files changed

+48
-50
lines changed

2 files changed

+48
-50
lines changed

gemma/configs.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,8 @@ ModelConfig::ModelConfig(const Model model, Type weight,
572572
static Model FindModel(const std::string& specifier) {
573573
Model found_model = Model::UNKNOWN;
574574
ForEachModel([&](Model model) {
575-
const char* prefix = ModelPrefix(model);
575+
// Some model names are prefixes of other model names
576+
const std::string prefix = std::string(ModelPrefix(model)) + "-";
576577
if (specifier.rfind(prefix, 0) == 0) { // Starts with prefix.
577578
// We only expect one match.
578579
HWY_ASSERT_M(found_model == Model::UNKNOWN, specifier.c_str());

gemma/gemma-inl.h

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
176176
Sigmoid(gate_x + head_offset, kHeadDim);
177177
Sigmoid(a + head_offset, kHeadDim);
178178
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
179-
HWY_ATTR { return hn::Mul(x, gate_x); };
179+
HWY_ATTR { return hn::Mul(x, gate_x); };
180180
hn::Transform1(D(), a + head_offset, kHeadDim,
181181
layer_weights->griffin.a.PackedScale1() + head_offset,
182182
fn_mul);
@@ -424,51 +424,49 @@ class GemmaAttention {
424424
const size_t kHeadGroups = layer_config_.heads / layer_config_.kv_heads;
425425

426426
// For each head (token, query), compute Q.K, softmax, and weighted V.
427-
pool_.Run(0, layer_config_.heads * num_interleaved,
428-
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
429-
const size_t head = task % layer_config_.heads;
430-
const size_t interleaved_idx = task / layer_config_.heads;
431-
const size_t query_idx = interleaved_idx % num_queries_;
432-
const size_t batch_idx = interleaved_idx / num_queries_;
433-
const size_t qkv_dim = layer_config_.qkv_dim;
434-
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
435-
436-
float* HWY_RESTRICT q =
437-
activations_.q.Row(interleaved_idx) + head * q_stride_;
438-
float* HWY_RESTRICT att =
439-
activations_.att.Row(interleaved_idx) +
440-
head * activations_.seq_len;
441-
float* HWY_RESTRICT att_out =
442-
activations_.att_out.Row(interleaved_idx) + head * qkv_dim;
443-
444-
// Make strided views into the kv cache entries for the current
445-
// query and head.
446-
KVCache& kv_cache = kv_caches_[query_idx];
447-
const size_t kv_head_offset =
448-
layer_ * cache_layer_size_ + head_offset;
449-
MatPtrT<float> k("k_view",
450-
Extents2D(kv_cache.seq_len, qkv_dim));
451-
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
452-
/*stride=*/cache_pos_size_);
453-
MatPtrT<float> v("v_view",
454-
Extents2D(kv_cache.seq_len, qkv_dim));
455-
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
456-
/*stride=*/cache_pos_size_);
457-
458-
// Find the token position in the query and calculate the range
459-
// of cache positions to attend to.
460-
const size_t pos = queries_pos_[query_idx] + batch_idx;
461-
const size_t start_pos = StartPos(pos, layer_);
462-
size_t last_pos = pos;
463-
const size_t prefix_end = queries_prefix_end_[query_idx];
464-
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
465-
// last_pos in QDotK and WeightedSumV is inclusive.
466-
last_pos = prefix_end - 1;
467-
}
468-
469-
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
470-
pos, start_pos, last_pos);
471-
});
427+
pool_.Run(
428+
0, layer_config_.heads * num_interleaved,
429+
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
430+
const size_t head = task % layer_config_.heads;
431+
const size_t interleaved_idx = task / layer_config_.heads;
432+
const size_t query_idx = interleaved_idx % num_queries_;
433+
const size_t batch_idx = interleaved_idx / num_queries_;
434+
const size_t qkv_dim = layer_config_.qkv_dim;
435+
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
436+
437+
float* HWY_RESTRICT q =
438+
activations_.q.Row(interleaved_idx) + head * q_stride_;
439+
float* HWY_RESTRICT att = activations_.att.Row(interleaved_idx) +
440+
head * activations_.seq_len;
441+
float* HWY_RESTRICT att_out =
442+
activations_.att_out.Row(interleaved_idx) + head * qkv_dim;
443+
444+
// Make strided views into the kv cache entries for the current
445+
// query and head.
446+
KVCache& kv_cache = kv_caches_[query_idx];
447+
const size_t kv_head_offset =
448+
layer_ * cache_layer_size_ + head_offset;
449+
MatPtrT<float> k("k_view", Extents2D(kv_cache.seq_len, qkv_dim));
450+
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
451+
/*stride=*/cache_pos_size_);
452+
MatPtrT<float> v("v_view", Extents2D(kv_cache.seq_len, qkv_dim));
453+
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
454+
/*stride=*/cache_pos_size_);
455+
456+
// Find the token position in the query and calculate the range
457+
// of cache positions to attend to.
458+
const size_t pos = queries_pos_[query_idx] + batch_idx;
459+
const size_t start_pos = StartPos(pos, layer_);
460+
size_t last_pos = pos;
461+
const size_t prefix_end = queries_prefix_end_[query_idx];
462+
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
463+
// last_pos in QDotK and WeightedSumV is inclusive.
464+
last_pos = prefix_end - 1;
465+
}
466+
467+
SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale, pos,
468+
start_pos, last_pos);
469+
});
472470
}
473471

474472
private:
@@ -1510,8 +1508,7 @@ void GenerateSingleT(const ModelStore& model,
15101508
}
15111509

15121510
template <typename T>
1513-
void GenerateBatchT(const ModelStore& model,
1514-
const ModelWeightsPtrs<T>& weights,
1511+
void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
15151512
const RuntimeConfig& runtime_config,
15161513
const QueriesPromptTokens& queries_prompt,
15171514
const QueriesPos& queries_pos,
@@ -1536,7 +1533,7 @@ void GenerateBatchT(const ModelStore& model,
15361533
qbatch_size);
15371534
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
15381535
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
1539-
qbatch_size);
1536+
qbatch_size);
15401537
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
15411538
GenerateT<T>(model, weights, activations, runtime_config, qbatch_prompts,
15421539
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,

0 commit comments

Comments
 (0)