Skip to content

Commit a94e6ff

Browse files
authored
update: support Qwen2-57B-A14B (#7835)
* update: convert-hf-to-gguf.py to support Qwen2-57B-A14B * fix: QWEN2MOE support for expert_feed_forward_length previously, expert ff was taken from n_ff (intermediate size) but it is now properly taken from LLM_KV_EXPERT_FEED_FORWARD_LENGTH n_ff_exp and n_ff_shared_exp are now properly calculated * update: convert-hf-to-gguf.py cleanup for Qwen2MoeForCausalLM * fix: QWEN2MOE support for expert_feed_forward_length previously, expert ff was taken from n_ff (intermediate size) but it is now properly taken from LLM_KV_EXPERT_FEED_FORWARD_LENGTH n_ff_exp and n_ff_shexp are now properly calculated
1 parent 5b6da18 commit a94e6ff

File tree

4 files changed

+57
-34
lines changed

4 files changed

+57
-34
lines changed

convert-hf-to-gguf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,6 +1632,12 @@ def set_gguf_parameters(self):
16321632
super().set_gguf_parameters()
16331633
if (n_experts := self.hparams.get("num_experts")) is not None:
16341634
self.gguf_writer.add_expert_count(n_experts)
1635+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
1636+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
1637+
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
1638+
if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None:
1639+
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size)
1640+
logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}")
16351641

16361642
_experts: list[dict[str, Tensor]] | None = None
16371643

gguf-py/gguf/constants.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,22 @@ class General:
3333
FILE_TYPE = "general.file_type"
3434

3535
class LLM:
36-
VOCAB_SIZE = "{arch}.vocab_size"
37-
CONTEXT_LENGTH = "{arch}.context_length"
38-
EMBEDDING_LENGTH = "{arch}.embedding_length"
39-
BLOCK_COUNT = "{arch}.block_count"
40-
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
41-
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
42-
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
43-
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
44-
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
45-
EXPERT_COUNT = "{arch}.expert_count"
46-
EXPERT_USED_COUNT = "{arch}.expert_used_count"
47-
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
48-
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
49-
POOLING_TYPE = "{arch}.pooling_type"
50-
LOGIT_SCALE = "{arch}.logit_scale"
36+
VOCAB_SIZE = "{arch}.vocab_size"
37+
CONTEXT_LENGTH = "{arch}.context_length"
38+
EMBEDDING_LENGTH = "{arch}.embedding_length"
39+
BLOCK_COUNT = "{arch}.block_count"
40+
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
41+
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
42+
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
43+
EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
44+
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
45+
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
46+
EXPERT_COUNT = "{arch}.expert_count"
47+
EXPERT_USED_COUNT = "{arch}.expert_used_count"
48+
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
49+
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
50+
POOLING_TYPE = "{arch}.pooling_type"
51+
LOGIT_SCALE = "{arch}.logit_scale"
5152

5253
class Attention:
5354
HEAD_COUNT = "{arch}.attention.head_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,9 @@ def add_feed_forward_length(self, length: int) -> None:
394394
def add_expert_feed_forward_length(self, length: int) -> None:
395395
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
396396

397+
def add_expert_shared_feed_forward_length(self, length: int) -> None:
398+
self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
399+
397400
def add_parallel_residual(self, use: bool) -> None:
398401
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
399402

llama.cpp

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ enum llm_kv {
286286
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
287287
LLM_KV_FEED_FORWARD_LENGTH,
288288
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
289+
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
289290
LLM_KV_USE_PARALLEL_RESIDUAL,
290291
LLM_KV_TENSOR_DATA_LAYOUT,
291292
LLM_KV_EXPERT_COUNT,
@@ -364,21 +365,22 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
364365
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
365366
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
366367

367-
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
368-
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
369-
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
370-
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
371-
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
372-
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
373-
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
374-
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
375-
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
376-
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
377-
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
378-
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
379-
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
380-
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
381-
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
368+
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
369+
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
370+
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
371+
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
372+
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
373+
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
374+
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
375+
{ LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
376+
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
377+
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
378+
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
379+
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
380+
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
381+
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
382+
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
383+
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
382384

383385
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
384386
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -1970,6 +1972,7 @@ struct llama_hparams {
19701972
uint32_t n_lora_q = 0;
19711973
uint32_t n_lora_kv = 0;
19721974
uint32_t n_ff_exp = 0;
1975+
uint32_t n_ff_shexp = 0;
19731976
uint32_t n_expert_shared = 0;
19741977
float expert_weights_scale = 0.0;
19751978

@@ -2018,6 +2021,7 @@ struct llama_hparams {
20182021
if (this->n_lora_q != other.n_lora_q) return true;
20192022
if (this->n_lora_kv != other.n_lora_kv) return true;
20202023
if (this->n_ff_exp != other.n_ff_exp) return true;
2024+
if (this->n_ff_shexp != other.n_ff_shexp) return true;
20212025
if (this->n_expert_shared != other.n_expert_shared) return true;
20222026

20232027
if (this->rope_finetuned != other.rope_finetuned) return true;
@@ -4455,6 +4459,9 @@ static void llm_load_hparams(
44554459
} break;
44564460
case LLM_ARCH_QWEN2MOE:
44574461
{
4462+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
4463+
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
4464+
44584465
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
44594466
switch (hparams.n_layer) {
44604467
case 24: model.type = e_model::MODEL_A2_7B; break;
@@ -5240,6 +5247,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
52405247
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
52415248
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
52425249
}
5250+
5251+
if (model.arch == LLM_ARCH_QWEN2MOE) {
5252+
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
5253+
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
5254+
}
52435255
}
52445256

52455257
// Returns false if cancelled by progress_callback
@@ -6026,16 +6038,17 @@ static bool llm_load_tensors(
60266038
GGML_ASSERT(hparams.n_expert_used > 0);
60276039

60286040
// MoE branch
6029-
auto n_ff_exp = n_ff / hparams.n_expert_used;
6041+
auto n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / hparams.n_expert_used;
60306042
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
60316043
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
60326044
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
60336045

60346046
// Shared expert branch
6047+
auto n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
60356048
layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
6036-
layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff});
6037-
layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff, n_embd});
6038-
layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff});
6049+
layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp});
6050+
layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd});
6051+
layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp});
60396052
}
60406053
} break;
60416054
case LLM_ARCH_PHI2:

0 commit comments

Comments
 (0)