Skip to content

Commit 3241b3d

Browse files
committed
Reworked to directly update KV cache params using info from name
1 parent 5289a6a commit 3241b3d

File tree

3 files changed

+35
-51
lines changed

3 files changed

+35
-51
lines changed

ggml/include/ggml.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -552,13 +552,6 @@ extern "C" {
552552
GGML_TENSOR_FLAG_PARAM = 4,
553553
};
554554

555-
// Flag (used on GGML_OP_CPY nodes) on whether node is associated with K or V cache
556-
enum ggml_kv_cache_flag {
557-
GGML_KV_CACHE_FLAG_NONE = 0,
558-
GGML_KV_CACHE_FLAG_K = 1,
559-
GGML_KV_CACHE_FLAG_V = 2
560-
};
561-
562555
// ggml object
563556
struct ggml_object {
564557
size_t offs;
@@ -593,8 +586,6 @@ extern "C" {
593586
// op params - allocated as int32_t for alignment
594587
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
595588

596-
enum ggml_kv_cache_flag kv_cache_flag;
597-
598589
int32_t flags;
599590

600591
struct ggml_tensor * grad;
@@ -610,7 +601,7 @@ extern "C" {
610601

611602
void * extra; // extra things e.g. for ggml-cuda.cu
612603

613-
char padding[1];
604+
// char padding[4];
614605
};
615606

616607
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);

ggml/src/ggml.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3638,7 +3638,6 @@ static struct ggml_tensor * ggml_new_tensor_impl(
36383638
/*.nb =*/ { 0, 0, 0, 0 },
36393639
/*.op =*/ GGML_OP_NONE,
36403640
/*.op_params =*/ { 0 },
3641-
/*.kv_cache_flag=*/ GGML_KV_CACHE_FLAG_NONE,
36423641
/*.flags =*/ 0,
36433642
/*.grad =*/ NULL,
36443643
/*.src =*/ { NULL },
@@ -3647,7 +3646,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
36473646
/*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
36483647
/*.name =*/ { 0 },
36493648
/*.extra =*/ NULL,
3650-
/*.padding =*/ { 0 },
3649+
///*.padding =*/ { 0 },
36513650
};
36523651

36533652
#ifdef __clang__

src/llama.cpp

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7794,9 +7794,7 @@ static void llm_build_kv_store(
77947794
cb(k_cache_view, "k_cache_view", il);
77957795

77967796
// note: storing RoPE-ed version of K in the KV cache
7797-
ggml_tensor * tmp = ggml_cpy(ctx, k_cur, k_cache_view);
7798-
tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_K;
7799-
ggml_build_forward_expand(graph, tmp);
7797+
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
78007798

78017799
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
78027800

@@ -7814,9 +7812,7 @@ static void llm_build_kv_store(
78147812
v_cur = ggml_transpose(ctx, v_cur);
78157813
}
78167814
cb(v_cache_view, "v_cache_view", il);
7817-
tmp=ggml_cpy(ctx, v_cur, v_cache_view);
7818-
tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_V;
7819-
ggml_build_forward_expand(graph, tmp);
7815+
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
78207816
}
78217817

78227818
static struct ggml_tensor * llm_build_norm(
@@ -14607,43 +14603,41 @@ static int llama_decode_internal(
1460714603
}
1460814604
lctx.cached_graph.gf = gf;
1460914605

14610-
if(ggml_use_cached_graph(lctx.sched)) {
14611-
14612-
// Temporarily store KV cache parameters that will need updated in cached graph.
14606+
// Update K and V cache parameters in cached graph.
14607+
if(gf != nullptr && gf->nodes != nullptr && ggml_use_cached_graph(lctx.sched)) {
14608+
1461314609
const struct llama_hparams & hparams = model.hparams;
14614-
const int64_t n_layer = hparams.n_layer;
1461514610
const int64_t kv_head = kv_self.head;
14616-
std::vector<void *> k_cache_ptrs;
14617-
std::vector<void *> v_cache_ptrs;
14618-
for (int il = 0; il < n_layer; ++il) {
14619-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14620-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14621-
ggml_tensor * tmp_tensor = kv_self.k_l[il];
14622-
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14623-
k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14624-
tmp_tensor = kv_self.v_l[il];
14625-
if (cparams.flash_attn) {
14626-
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14627-
} else {
14628-
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14629-
}
14630-
v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14631-
}
14632-
14633-
// Update KV cache parameters in cached graph.
14634-
int k_count = 0;
14635-
int v_count = 0;
14636-
if(gf != nullptr && gf->nodes != nullptr){
14637-
for (int i = 0; i < gf->n_nodes; i++) {
14638-
ggml_tensor * node = gf->nodes[i];
14639-
if (node->op == GGML_OP_CPY) {
14640-
if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_K) {
14641-
node->src[1]->data = k_cache_ptrs[k_count++];
14642-
}
14643-
if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_V) {
14644-
node->src[1]->data = v_cache_ptrs[v_count++];
14611+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14612+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14613+
14614+
for (int i = 0; i < gf->n_nodes; i++) {
14615+
ggml_tensor * node = gf->nodes[i];
14616+
if (node->op == GGML_OP_CPY) {
14617+
14618+
// K cache
14619+
const char* k_prefix = "k_cache_view-";
14620+
if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) {
14621+
int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name
14622+
ggml_tensor * tmp_tensor = kv_self.k_l[il];
14623+
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14624+
node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
14625+
}
14626+
14627+
// V cache
14628+
const char* v_prefix = "v_cache_view-";
14629+
if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) {
14630+
int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name
14631+
ggml_tensor * tmp_tensor = kv_self.v_l[il];
14632+
size_t tmp_offset;
14633+
if (cparams.flash_attn) {
14634+
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14635+
} else {
14636+
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
1464514637
}
14638+
node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
1464614639
}
14640+
1464714641
}
1464814642
}
1464914643

0 commit comments

Comments
 (0)