Skip to content

Commit 27b0406

Browse files
llama : use n_embd_head_v when reshaping kqv (#7327)
* llama : use n_embd_head_v instead of n_embd_head_k when reshaping kqv * llama : use n_embd_v_gqa and n_embd_head_v instead of n_embd_k_gqa and n_embd_head_k when making a view of cached value vectors. --------- Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent 29c60d8 commit 27b0406

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

llama.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6622,6 +6622,7 @@ static struct ggml_tensor * llm_build_kqv(
66226622
const int64_t n_embd_head_k = hparams.n_embd_head_k;
66236623
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
66246624
const int64_t n_embd_head_v = hparams.n_embd_head_v;
6625+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
66256626

66266627
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
66276628
cb(q, "q", il);
@@ -6644,8 +6645,8 @@ static struct ggml_tensor * llm_build_kqv(
66446645
struct ggml_tensor * v =
66456646
ggml_view_3d(ctx, kv.v_l[il],
66466647
n_embd_head_v, n_kv, n_head_kv,
6647-
ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
6648-
ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
6648+
ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
6649+
ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
66496650
0);
66506651
cb(v, "v", il);
66516652

@@ -6655,7 +6656,7 @@ static struct ggml_tensor * llm_build_kqv(
66556656
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
66566657
}
66576658

6658-
cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
6659+
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
66596660
} else {
66606661
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
66616662
cb(kq, "kq", il);
@@ -6700,7 +6701,7 @@ static struct ggml_tensor * llm_build_kqv(
67006701
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
67016702
cb(kqv_merged, "kqv_merged", il);
67026703

6703-
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
6704+
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
67046705
cb(cur, "kqv_merged_cont", il);
67056706
}
67066707

0 commit comments

Comments
 (0)