Skip to content

Internal change. #576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gemma/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,8 @@ ModelConfig::ModelConfig(const Model model, Type weight,
static Model FindModel(const std::string& specifier) {
Model found_model = Model::UNKNOWN;
ForEachModel([&](Model model) {
const char* prefix = ModelPrefix(model);
// Some model names are prefixes of other model names
const std::string prefix = std::string(ModelPrefix(model)) + "-";
if (specifier.rfind(prefix, 0) == 0) { // Starts with prefix.
// We only expect one match.
HWY_ASSERT_M(found_model == Model::UNKNOWN, specifier.c_str());
Expand Down
95 changes: 46 additions & 49 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin.a.PackedScale1() + head_offset,
fn_mul);
Expand Down Expand Up @@ -424,51 +424,49 @@ class GemmaAttention {
const size_t kHeadGroups = layer_config_.heads / layer_config_.kv_heads;

// For each head (token, query), compute Q.K, softmax, and weighted V.
pool_.Run(0, layer_config_.heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % layer_config_.heads;
const size_t interleaved_idx = task / layer_config_.heads;
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;

float* HWY_RESTRICT q =
activations_.q.Row(interleaved_idx) + head * q_stride_;
float* HWY_RESTRICT att =
activations_.att.Row(interleaved_idx) +
head * activations_.seq_len;
float* HWY_RESTRICT att_out =
activations_.att_out.Row(interleaved_idx) + head * qkv_dim;

// Make strided views into the kv cache entries for the current
// query and head.
KVCache& kv_cache = kv_caches_[query_idx];
const size_t kv_head_offset =
layer_ * cache_layer_size_ + head_offset;
MatPtrT<float> k("k_view",
Extents2D(kv_cache.seq_len, qkv_dim));
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
/*stride=*/cache_pos_size_);
MatPtrT<float> v("v_view",
Extents2D(kv_cache.seq_len, qkv_dim));
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
/*stride=*/cache_pos_size_);

// Find the token position in the query and calculate the range
// of cache positions to attend to.
const size_t pos = queries_pos_[query_idx] + batch_idx;
const size_t start_pos = StartPos(pos, layer_);
size_t last_pos = pos;
const size_t prefix_end = queries_prefix_end_[query_idx];
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
// last_pos in QDotK and WeightedSumV is inclusive.
last_pos = prefix_end - 1;
}

SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale,
pos, start_pos, last_pos);
});
pool_.Run(
0, layer_config_.heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % layer_config_.heads;
const size_t interleaved_idx = task / layer_config_.heads;
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;

float* HWY_RESTRICT q =
activations_.q.Row(interleaved_idx) + head * q_stride_;
float* HWY_RESTRICT att = activations_.att.Row(interleaved_idx) +
head * activations_.seq_len;
float* HWY_RESTRICT att_out =
activations_.att_out.Row(interleaved_idx) + head * qkv_dim;

// Make strided views into the kv cache entries for the current
// query and head.
KVCache& kv_cache = kv_caches_[query_idx];
const size_t kv_head_offset =
layer_ * cache_layer_size_ + head_offset;
MatPtrT<float> k("k_view", Extents2D(kv_cache.seq_len, qkv_dim));
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
/*stride=*/cache_pos_size_);
MatPtrT<float> v("v_view", Extents2D(kv_cache.seq_len, qkv_dim));
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
/*stride=*/cache_pos_size_);

// Find the token position in the query and calculate the range
// of cache positions to attend to.
const size_t pos = queries_pos_[query_idx] + batch_idx;
const size_t start_pos = StartPos(pos, layer_);
size_t last_pos = pos;
const size_t prefix_end = queries_prefix_end_[query_idx];
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
// last_pos in QDotK and WeightedSumV is inclusive.
last_pos = prefix_end - 1;
}

SingleDotSoftmaxWeightedSum(q, k, v, att, att_out, query_scale, pos,
start_pos, last_pos);
});
}

private:
Expand Down Expand Up @@ -1510,8 +1508,7 @@ void GenerateSingleT(const ModelStore& model,
}

template <typename T>
void GenerateBatchT(const ModelStore& model,
const ModelWeightsPtrs<T>& weights,
void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
Expand All @@ -1536,7 +1533,7 @@ void GenerateBatchT(const ModelStore& model,
qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
qbatch_size);
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT<T>(model, weights, activations, runtime_config, qbatch_prompts,
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
Expand Down
Loading