Skip to content

llama : run all KQV ops on the CPU with no KV offload #5049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,24 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
ggml_tallocr_t src_allocr = node_allocr(src);
GGML_ASSERT(src_allocr != NULL); // all inputs should be assigned by now
if (src_allocr != node_allocr) {
// create a copy of the input in the split's backend
size_t id = hash_id(src);
if (sched->node_copies[id][cur_backend_id] == NULL) {
ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);

sched->node_copies[id][cur_backend_id] = tensor_copy;
node_allocr(tensor_copy) = cur_allocr;
SET_CAUSE(tensor_copy, "4.cpy");

int n_inputs = sched->splits[cur_split].n_inputs++;
GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
sched->splits[cur_split].inputs[n_inputs] = src;
}
node->src[j] = sched->node_copies[id][cur_backend_id];

#if 0
// check if the input is already in the split
bool found = false;
for (int k = 0; k < sched->splits[cur_split].n_inputs; k++) {
Expand All @@ -1206,19 +1224,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
sched->splits[cur_split].inputs[n_inputs] = src;
}

// create a copy of the input in the split's backend
size_t id = hash_id(src);
if (sched->node_copies[id][cur_backend_id] == NULL) {
ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);

sched->node_copies[id][cur_backend_id] = tensor_copy;
node_allocr(tensor_copy) = cur_allocr;
SET_CAUSE(tensor_copy, "4.cpy");
}
node->src[j] = sched->node_copies[id][cur_backend_id];
#endif
}
}
}
Expand Down Expand Up @@ -1333,7 +1339,7 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
uint64_t compute_start_us = ggml_time_us();
if (!sched->callback_eval) {
ggml_backend_graph_compute(split_backend, &split->graph);
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
} else {
// similar to ggml_backend_compare_graph_backend
for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
Expand Down
145 changes: 79 additions & 66 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4315,6 +4315,7 @@ static struct ggml_tensor * llm_build_kqv(
const llama_model & model,
const llama_hparams & hparams,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
struct ggml_tensor * wo,
struct ggml_tensor * wo_b,
struct ggml_tensor * q_cur,
Expand Down Expand Up @@ -4393,6 +4394,8 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il);

ggml_build_forward_expand(graph, cur);

cur = ggml_mul_mat(ctx, wo, cur);
if (wo_b) {
cb(cur, "kqv_wo", il);
Expand All @@ -4405,6 +4408,44 @@ static struct ggml_tensor * llm_build_kqv(
return cur;
}

static struct ggml_tensor * llm_build_kv(
struct ggml_context * ctx,
const llama_model & model,
const llama_hparams & hparams,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
struct ggml_tensor * wo,
struct ggml_tensor * wo_b,
struct ggml_tensor * k_cur,
struct ggml_tensor * v_cur,
struct ggml_tensor * q_cur,
struct ggml_tensor * kq_mask,
int64_t n_ctx,
int32_t n_tokens,
int32_t kv_head,
int32_t n_kv,
float max_alibi_bias,
float kq_scale,
const llm_build_cb & cb,
int il) {

// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(graph, k_cur);
ggml_build_forward_expand(graph, v_cur);
ggml_build_forward_expand(graph, q_cur);

llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);

struct ggml_tensor * cur;
cur = llm_build_kqv(ctx, model, hparams, kv, graph,
wo, wo_b,
q_cur, kq_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, kq_scale, cb, il);
cb(cur, "kqv_out", il);

return cur;
}

struct llm_build_context {
const llama_model & model;
const llama_hparams & hparams;
Expand Down Expand Up @@ -4562,12 +4603,6 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);
}

// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, Qcur);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);

Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
Expand All @@ -4582,11 +4617,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -4763,14 +4796,13 @@ struct llm_build_context {
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

// apply ALiBi for 13B model
const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -4892,11 +4924,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -4993,11 +5023,9 @@ struct llm_build_context {

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -5200,12 +5228,9 @@ struct llm_build_context {
);
cb(Vcur, "Vcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

// TODO: not tested, could be broken
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Q, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Q, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -5292,11 +5317,9 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -5390,11 +5413,9 @@ struct llm_build_context {

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -5485,11 +5506,9 @@ struct llm_build_context {

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -5597,11 +5616,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -5714,11 +5731,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -5837,11 +5852,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -5966,11 +5979,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f, cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -6071,11 +6082,9 @@ struct llm_build_context {
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
struct ggml_tensor * sa_out = cur;
Expand Down Expand Up @@ -6172,11 +6181,9 @@ struct llm_build_context {

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -6283,11 +6290,9 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);

llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);

cur = llm_build_kqv(ctx0, model, hparams, kv_self,
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}

Expand Down Expand Up @@ -6355,6 +6360,14 @@ static struct ggml_cgraph * llama_build_graph(
ggml_set_name(cur, name);
}


if (!lctx.cparams.offload_kqv) {
if (strcmp(name, "kqv_merged_cont") == 0) {
// all nodes between the KV store and the attention output are run on the CPU
ggml_backend_sched_set_node_backend(lctx.sched, cur, lctx.backend_cpu);
}
}

//
// allocate input tensors and set input data
//
Expand Down