@@ -4745,16 +4745,6 @@ static void llm_load_hparams(
4745
4745
4746
4746
// non-transformer models do not have attention heads
4747
4747
if (hparams.n_head() > 0) {
4748
- // sanity check for n_rot (optional)
4749
- hparams.n_rot = hparams.n_embd / hparams.n_head();
4750
-
4751
- ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4752
-
4753
- if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4754
- if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
4755
- throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
4756
- }
4757
- }
4758
4748
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
4759
4749
// gpt-j n_rot = rotary_dim
4760
4750
@@ -4763,6 +4753,17 @@ static void llm_load_hparams(
4763
4753
4764
4754
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
4765
4755
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
4756
+
4757
+ // sanity check for n_rot (optional)
4758
+ hparams.n_rot = hparams.n_embd_head_k;
4759
+
4760
+ ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4761
+
4762
+ if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4763
+ if (hparams.n_rot != hparams.n_embd_head_k) {
4764
+ throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
4765
+ }
4766
+ }
4766
4767
} else {
4767
4768
hparams.n_rot = 0;
4768
4769
hparams.n_embd_head_k = 0;
@@ -11633,7 +11634,7 @@ struct llm_build_context {
11633
11634
11634
11635
Qcur = ggml_rope_ext(
11635
11636
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11636
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11637
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11637
11638
ext_factor, attn_factor, beta_fast, beta_slow);
11638
11639
cb(Qcur, "Qcur", il);
11639
11640
@@ -11642,7 +11643,7 @@ struct llm_build_context {
11642
11643
11643
11644
Kcur = ggml_rope_ext(
11644
11645
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11645
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11646
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11646
11647
ext_factor, attn_factor, beta_fast, beta_slow);
11647
11648
cb(Kcur, "Kcur", il);
11648
11649
@@ -11746,7 +11747,7 @@ struct llm_build_context {
11746
11747
11747
11748
Qcur = ggml_rope_ext(
11748
11749
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11749
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11750
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11750
11751
ext_factor, attn_factor, beta_fast, beta_slow);
11751
11752
cb(Qcur, "Qcur", il);
11752
11753
@@ -11755,7 +11756,7 @@ struct llm_build_context {
11755
11756
11756
11757
Kcur = ggml_rope_ext(
11757
11758
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11758
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11759
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11759
11760
ext_factor, attn_factor, beta_fast, beta_slow);
11760
11761
cb(Kcur, "Kcur", il);
11761
11762
0 commit comments