@@ -4626,16 +4626,6 @@ static void llm_load_hparams(
4626
4626
4627
4627
// non-transformer models do not have attention heads
4628
4628
if (hparams.n_head() > 0) {
4629
- // sanity check for n_rot (optional)
4630
- hparams.n_rot = hparams.n_embd / hparams.n_head();
4631
-
4632
- ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4633
-
4634
- if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4635
- if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
4636
- throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
4637
- }
4638
- }
4639
4629
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
4640
4630
// gpt-j n_rot = rotary_dim
4641
4631
@@ -4644,6 +4634,17 @@ static void llm_load_hparams(
4644
4634
4645
4635
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
4646
4636
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
4637
+
4638
+ // sanity check for n_rot (optional)
4639
+ hparams.n_rot = hparams.n_embd_head_k;
4640
+
4641
+ ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4642
+
4643
+ if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4644
+ if (hparams.n_rot != hparams.n_embd_head_k) {
4645
+ throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
4646
+ }
4647
+ }
4647
4648
} else {
4648
4649
hparams.n_rot = 0;
4649
4650
hparams.n_embd_head_k = 0;
@@ -11491,7 +11492,7 @@ struct llm_build_context {
11491
11492
11492
11493
Qcur = ggml_rope_ext(
11493
11494
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11494
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11495
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11495
11496
ext_factor, attn_factor, beta_fast, beta_slow);
11496
11497
cb(Qcur, "Qcur", il);
11497
11498
@@ -11500,7 +11501,7 @@ struct llm_build_context {
11500
11501
11501
11502
Kcur = ggml_rope_ext(
11502
11503
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11503
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11504
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11504
11505
ext_factor, attn_factor, beta_fast, beta_slow);
11505
11506
cb(Kcur, "Kcur", il);
11506
11507
@@ -11604,7 +11605,7 @@ struct llm_build_context {
11604
11605
11605
11606
Qcur = ggml_rope_ext(
11606
11607
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11607
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11608
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11608
11609
ext_factor, attn_factor, beta_fast, beta_slow);
11609
11610
cb(Qcur, "Qcur", il);
11610
11611
@@ -11613,7 +11614,7 @@ struct llm_build_context {
11613
11614
11614
11615
Kcur = ggml_rope_ext(
11615
11616
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11616
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11617
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11617
11618
ext_factor, attn_factor, beta_fast, beta_slow);
11618
11619
cb(Kcur, "Kcur", il);
11619
11620
0 commit comments