Skip to content

Commit 0090e5b

Browse files
ggerganovNexesenex
authored andcommitted
llama : fix n_rot default (ggml-org#8348)
ggml-ci
1 parent 87d76b3 commit 0090e5b

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

llama.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4745,16 +4745,6 @@ static void llm_load_hparams(
47454745

47464746
// non-transformer models do not have attention heads
47474747
if (hparams.n_head() > 0) {
4748-
// sanity check for n_rot (optional)
4749-
hparams.n_rot = hparams.n_embd / hparams.n_head();
4750-
4751-
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4752-
4753-
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4754-
if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
4755-
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
4756-
}
4757-
}
47584748
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
47594749
// gpt-j n_rot = rotary_dim
47604750

@@ -4763,6 +4753,17 @@ static void llm_load_hparams(
47634753

47644754
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
47654755
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
4756+
4757+
// sanity check for n_rot (optional)
4758+
hparams.n_rot = hparams.n_embd_head_k;
4759+
4760+
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4761+
4762+
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4763+
if (hparams.n_rot != hparams.n_embd_head_k) {
4764+
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
4765+
}
4766+
}
47664767
} else {
47674768
hparams.n_rot = 0;
47684769
hparams.n_embd_head_k = 0;
@@ -11633,7 +11634,7 @@ struct llm_build_context {
1163311634

1163411635
Qcur = ggml_rope_ext(
1163511636
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11636-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11637+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1163711638
ext_factor, attn_factor, beta_fast, beta_slow);
1163811639
cb(Qcur, "Qcur", il);
1163911640

@@ -11642,7 +11643,7 @@ struct llm_build_context {
1164211643

1164311644
Kcur = ggml_rope_ext(
1164411645
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11645-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11646+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1164611647
ext_factor, attn_factor, beta_fast, beta_slow);
1164711648
cb(Kcur, "Kcur", il);
1164811649

@@ -11746,7 +11747,7 @@ struct llm_build_context {
1174611747

1174711748
Qcur = ggml_rope_ext(
1174811749
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11749-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11750+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1175011751
ext_factor, attn_factor, beta_fast, beta_slow);
1175111752
cb(Qcur, "Qcur", il);
1175211753

@@ -11755,7 +11756,7 @@ struct llm_build_context {
1175511756

1175611757
Kcur = ggml_rope_ext(
1175711758
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11758-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11759+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1175911760
ext_factor, attn_factor, beta_fast, beta_slow);
1176011761
cb(Kcur, "Kcur", il);
1176111762

0 commit comments

Comments
 (0)