Skip to content

Commit 319e47e

Browse files
committed
stablelm : simplify + speedup generation
1 parent dfc7cd4 commit 319e47e

File tree

1 file changed

+14
-72
lines changed

1 file changed

+14
-72
lines changed

llama.cpp

Lines changed: 14 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4814,92 +4814,34 @@ struct llm_build_context {
48144814
// self-attention
48154815
{
48164816
// compute Q and K and RoPE them
4817-
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
4818-
cb(tmpq, "tmpq", il);
4817+
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
4818+
cb(Qcur, "Qcur", il);
48194819

4820-
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
4821-
cb(tmpk, "tmpk", il);
4820+
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
4821+
cb(Kcur, "Kcur", il);
48224822

48234823
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
48244824
cb(Vcur, "Vcur", il);
48254825

4826-
// RoPE the first n_rot of q/k, pass the other half, and concat.
4827-
struct ggml_tensor * qrot = ggml_cont(ctx0, ggml_view_3d(
4828-
ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
4829-
ggml_element_size(tmpq) * n_embd_head,
4830-
ggml_element_size(tmpq) * n_embd_head * n_head,
4831-
0
4832-
));
4833-
cb(qrot, "qrot", il);
4834-
4835-
struct ggml_tensor * krot = ggml_cont(ctx0, ggml_view_3d(
4836-
ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
4837-
ggml_element_size(tmpk) * n_embd_head,
4838-
ggml_element_size(tmpk) * n_embd_head * n_head_kv,
4839-
0
4840-
));
4841-
cb(krot, "krot", il);
4842-
4843-
// get the second half of tmpq, e.g tmpq[n_rot:, :, :]
4844-
struct ggml_tensor * qpass = ggml_view_3d(
4845-
ctx0, tmpq, (n_embd_head - hparams.n_rot), n_head, n_tokens,
4846-
ggml_element_size(tmpq) * n_embd_head,
4847-
ggml_element_size(tmpq) * n_embd_head * n_head,
4848-
ggml_element_size(tmpq) * hparams.n_rot
4849-
);
4850-
cb(qpass, "qpass", il);
4851-
4852-
struct ggml_tensor * kpass = ggml_view_3d(
4853-
ctx0, tmpk, (n_embd_head - hparams.n_rot), n_head_kv, n_tokens,
4854-
ggml_element_size(tmpk) * (n_embd_head),
4855-
ggml_element_size(tmpk) * (n_embd_head) * n_head_kv,
4856-
ggml_element_size(tmpk) * hparams.n_rot
4857-
);
4858-
cb(kpass, "kpass", il);
4859-
4860-
struct ggml_tensor * qrotated = ggml_rope_custom(
4861-
ctx0, qrot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
4862-
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
4863-
);
4864-
cb(qrotated, "qrotated", il);
4865-
4866-
struct ggml_tensor * krotated = ggml_rope_custom(
4867-
ctx0, krot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
4868-
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
4826+
Qcur = ggml_rope_custom(
4827+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4828+
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
4829+
ext_factor, attn_factor, beta_fast, beta_slow
48694830
);
4870-
cb(krotated, "krotated", il);
4871-
4872-
// ggml currently only supports concatenation on dim=2
4873-
// so we need to permute qrot, qpass, concat, then permute back.
4874-
qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
4875-
cb(qrotated, "qrotated", il);
4876-
4877-
krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
4878-
cb(krotated, "krotated", il);
4879-
4880-
qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
4881-
cb(qpass, "qpass", il);
4882-
4883-
kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
4884-
cb(kpass, "kpass", il);
4885-
4886-
struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
48874831
cb(Qcur, "Qcur", il);
48884832

4889-
struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
4890-
cb(Kcur, "Kcur", il);
4891-
4892-
struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3));
4893-
cb(Q, "Q", il);
4894-
4895-
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
4833+
Kcur = ggml_rope_custom(
4834+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4835+
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
4836+
ext_factor, attn_factor, beta_fast, beta_slow
4837+
);
48964838
cb(Kcur, "Kcur", il);
48974839

48984840
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
48994841

49004842
cur = llm_build_kqv(ctx0, hparams, kv_self,
49014843
model.layers[il].wo, NULL,
4902-
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
4844+
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
49034845
cb(cur, "kqv_out", il);
49044846
}
49054847

0 commit comments

Comments
 (0)