@@ -4814,92 +4814,34 @@ struct llm_build_context {
4814
4814
// self-attention
4815
4815
{
4816
4816
// compute Q and K and RoPE them
4817
- struct ggml_tensor * tmpq = ggml_mul_mat (ctx0, model.layers [il].wq , cur);
4818
- cb (tmpq , " tmpq " , il);
4817
+ struct ggml_tensor * Qcur = ggml_mul_mat (ctx0, model.layers [il].wq , cur);
4818
+ cb (Qcur , " Qcur " , il);
4819
4819
4820
- struct ggml_tensor * tmpk = ggml_mul_mat (ctx0, model.layers [il].wk , cur);
4821
- cb (tmpk , " tmpk " , il);
4820
+ struct ggml_tensor * Kcur = ggml_mul_mat (ctx0, model.layers [il].wk , cur);
4821
+ cb (Kcur , " Kcur " , il);
4822
4822
4823
4823
struct ggml_tensor * Vcur = ggml_mul_mat (ctx0, model.layers [il].wv , cur);
4824
4824
cb (Vcur, " Vcur" , il);
4825
4825
4826
- // RoPE the first n_rot of q/k, pass the other half, and concat.
4827
- struct ggml_tensor * qrot = ggml_cont (ctx0, ggml_view_3d (
4828
- ctx0, tmpq, hparams.n_rot , n_head, n_tokens,
4829
- ggml_element_size (tmpq) * n_embd_head,
4830
- ggml_element_size (tmpq) * n_embd_head * n_head,
4831
- 0
4832
- ));
4833
- cb (qrot, " qrot" , il);
4834
-
4835
- struct ggml_tensor * krot = ggml_cont (ctx0, ggml_view_3d (
4836
- ctx0, tmpk, hparams.n_rot , n_head, n_tokens,
4837
- ggml_element_size (tmpk) * n_embd_head,
4838
- ggml_element_size (tmpk) * n_embd_head * n_head_kv,
4839
- 0
4840
- ));
4841
- cb (krot, " krot" , il);
4842
-
4843
- // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
4844
- struct ggml_tensor * qpass = ggml_view_3d (
4845
- ctx0, tmpq, (n_embd_head - hparams.n_rot ), n_head, n_tokens,
4846
- ggml_element_size (tmpq) * n_embd_head,
4847
- ggml_element_size (tmpq) * n_embd_head * n_head,
4848
- ggml_element_size (tmpq) * hparams.n_rot
4849
- );
4850
- cb (qpass, " qpass" , il);
4851
-
4852
- struct ggml_tensor * kpass = ggml_view_3d (
4853
- ctx0, tmpk, (n_embd_head - hparams.n_rot ), n_head_kv, n_tokens,
4854
- ggml_element_size (tmpk) * (n_embd_head),
4855
- ggml_element_size (tmpk) * (n_embd_head) * n_head_kv,
4856
- ggml_element_size (tmpk) * hparams.n_rot
4857
- );
4858
- cb (kpass, " kpass" , il);
4859
-
4860
- struct ggml_tensor * qrotated = ggml_rope_custom (
4861
- ctx0, qrot, inp_pos, hparams.n_rot , 2 , 0 , n_orig_ctx,
4862
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
4863
- );
4864
- cb (qrotated, " qrotated" , il);
4865
-
4866
- struct ggml_tensor * krotated = ggml_rope_custom (
4867
- ctx0, krot, inp_pos, hparams.n_rot , 2 , 0 , n_orig_ctx,
4868
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
4826
+ Qcur = ggml_rope_custom (
4827
+ ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4828
+ hparams.n_rot , 2 , 0 , n_orig_ctx, freq_base, freq_scale,
4829
+ ext_factor, attn_factor, beta_fast, beta_slow
4869
4830
);
4870
- cb (krotated, " krotated" , il);
4871
-
4872
- // ggml currently only supports concatenation on dim=2
4873
- // so we need to permute qrot, qpass, concat, then permute back.
4874
- qrotated = ggml_cont (ctx0, ggml_permute (ctx0, qrotated, 2 , 1 , 0 , 3 ));
4875
- cb (qrotated, " qrotated" , il);
4876
-
4877
- krotated = ggml_cont (ctx0, ggml_permute (ctx0, krotated, 2 , 1 , 0 , 3 ));
4878
- cb (krotated, " krotated" , il);
4879
-
4880
- qpass = ggml_cont (ctx0, ggml_permute (ctx0, qpass, 2 , 1 , 0 , 3 ));
4881
- cb (qpass, " qpass" , il);
4882
-
4883
- kpass = ggml_cont (ctx0, ggml_permute (ctx0, kpass, 2 , 1 , 0 , 3 ));
4884
- cb (kpass, " kpass" , il);
4885
-
4886
- struct ggml_tensor * Qcur = ggml_concat (ctx0, qrotated, qpass);
4887
4831
cb (Qcur, " Qcur" , il);
4888
4832
4889
- struct ggml_tensor * Kcur = ggml_concat (ctx0, krotated, kpass);
4890
- cb (Kcur, " Kcur" , il);
4891
-
4892
- struct ggml_tensor * Q = ggml_cont (ctx0, ggml_permute (ctx0, Qcur, 2 , 1 , 0 , 3 ));
4893
- cb (Q, " Q" , il);
4894
-
4895
- Kcur = ggml_cont (ctx0, ggml_permute (ctx0, Kcur, 2 , 1 , 0 , 3 ));
4833
+ Kcur = ggml_rope_custom (
4834
+ ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4835
+ hparams.n_rot , 2 , 0 , n_orig_ctx, freq_base, freq_scale,
4836
+ ext_factor, attn_factor, beta_fast, beta_slow
4837
+ );
4896
4838
cb (Kcur, " Kcur" , il);
4897
4839
4898
4840
llm_build_kv_store (ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
4899
4841
4900
4842
cur = llm_build_kqv (ctx0, hparams, kv_self,
4901
4843
model.layers [il].wo , NULL ,
4902
- Q , KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1 .0f , cb, il);
4844
+ Qcur , KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1 .0f , cb, il);
4903
4845
cb (cur, " kqv_out" , il);
4904
4846
}
4905
4847
0 commit comments