Skip to content

Commit 155ec5b

Browse files
ggerganovNeo Zhang
authored and
Neo Zhang
committed
llama : fix n_rot default (ggml-org#8348)
ggml-ci
1 parent 78706ed commit 155ec5b

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

src/llama.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4626,16 +4626,6 @@ static void llm_load_hparams(
46264626

46274627
// non-transformer models do not have attention heads
46284628
if (hparams.n_head() > 0) {
4629-
// sanity check for n_rot (optional)
4630-
hparams.n_rot = hparams.n_embd / hparams.n_head();
4631-
4632-
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4633-
4634-
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4635-
if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
4636-
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
4637-
}
4638-
}
46394629
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
46404630
// gpt-j n_rot = rotary_dim
46414631

@@ -4644,6 +4634,17 @@ static void llm_load_hparams(
46444634

46454635
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
46464636
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
4637+
4638+
// sanity check for n_rot (optional)
4639+
hparams.n_rot = hparams.n_embd_head_k;
4640+
4641+
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4642+
4643+
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4644+
if (hparams.n_rot != hparams.n_embd_head_k) {
4645+
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
4646+
}
4647+
}
46474648
} else {
46484649
hparams.n_rot = 0;
46494650
hparams.n_embd_head_k = 0;
@@ -11491,7 +11492,7 @@ struct llm_build_context {
1149111492

1149211493
Qcur = ggml_rope_ext(
1149311494
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11494-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11495+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1149511496
ext_factor, attn_factor, beta_fast, beta_slow);
1149611497
cb(Qcur, "Qcur", il);
1149711498

@@ -11500,7 +11501,7 @@ struct llm_build_context {
1150011501

1150111502
Kcur = ggml_rope_ext(
1150211503
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11503-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11504+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1150411505
ext_factor, attn_factor, beta_fast, beta_slow);
1150511506
cb(Kcur, "Kcur", il);
1150611507

@@ -11604,7 +11605,7 @@ struct llm_build_context {
1160411605

1160511606
Qcur = ggml_rope_ext(
1160611607
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11607-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11608+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1160811609
ext_factor, attn_factor, beta_fast, beta_slow);
1160911610
cb(Qcur, "Qcur", il);
1161011611

@@ -11613,7 +11614,7 @@ struct llm_build_context {
1161311614

1161411615
Kcur = ggml_rope_ext(
1161511616
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11616-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11617+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1161711618
ext_factor, attn_factor, beta_fast, beta_slow);
1161811619
cb(Kcur, "Kcur", il);
1161911620

0 commit comments

Comments
 (0)