Skip to content

Commit ff4bf58

Browse files
tdakhranpwilkin
authored andcommitted
llama : support LiquidAI LFM2-MoE hybrid model (ggml-org#16464)
* llama : support LiquidAI LFM2-MoE hybrid model Add support for [LiquidAI/LFM2-8B-A1B](https://huggingface.co/LiquidAI/LFM2-8B-A1B) model. For more information about models, please read [the blog post](https://www.liquid.ai/company/news). [HF PR](huggingface/transformers#41401) [GGUFs](https://huggingface.co/LiquidAI/LFM2-8B-A1B-GGUF) * Do not use defaultdict * Address PR feedback
1 parent 542bee8 commit ff4bf58

File tree

7 files changed

+192
-15
lines changed

7 files changed

+192
-15
lines changed

convert_hf_to_gguf.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8866,6 +8866,75 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
88668866
return [(self.map_tensor_name(name), data_torch)]
88678867

88688868

8869+
@ModelBase.register("Lfm2MoeForCausalLM")
8870+
class LFM2MoeModel(TextModel):
8871+
model_arch = gguf.MODEL_ARCH.LFM2MOE
8872+
8873+
def set_gguf_parameters(self):
8874+
# set num_key_value_heads only for attention layers
8875+
self.hparams["num_key_value_heads"] = [
8876+
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
8877+
for layer_type in self.hparams["layer_types"]
8878+
]
8879+
8880+
super().set_gguf_parameters()
8881+
8882+
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
8883+
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
8884+
self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"])
8885+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
8886+
8887+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
8888+
self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"])
8889+
8890+
# cache for experts weights for merging
8891+
_experts_cache: dict[int, dict[str, Tensor]] = {}
8892+
8893+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8894+
# conv op requires 2d tensor
8895+
if 'conv.conv' in name:
8896+
data_torch = data_torch.squeeze(1)
8897+
8898+
if name.endswith(".expert_bias"):
8899+
name = name.replace(".expert_bias", ".expert_bias.bias")
8900+
8901+
# merge expert weights
8902+
if 'experts' in name:
8903+
n_experts = self.hparams["num_experts"]
8904+
assert bid is not None
8905+
8906+
expert_cache = self._experts_cache.setdefault(bid, {})
8907+
expert_cache[name] = data_torch
8908+
expert_weights = ["w1", "w2", "w3"]
8909+
8910+
# not enough expert weights to merge
8911+
if len(expert_cache) < n_experts * len(expert_weights):
8912+
return []
8913+
8914+
tensors: list[tuple[str, Tensor]] = []
8915+
for w_name in expert_weights:
8916+
datas: list[Tensor] = []
8917+
8918+
for xid in range(n_experts):
8919+
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{w_name}.weight"
8920+
datas.append(expert_cache[ename])
8921+
del expert_cache[ename]
8922+
8923+
data_torch = torch.stack(datas, dim=0)
8924+
merged_name = f"layers.{bid}.feed_forward.experts.{w_name}.weight"
8925+
new_name = self.map_tensor_name(merged_name)
8926+
tensors.append((new_name, data_torch))
8927+
8928+
del self._experts_cache[bid]
8929+
return tensors
8930+
8931+
return [(self.map_tensor_name(name), data_torch)]
8932+
8933+
def prepare_tensors(self):
8934+
super().prepare_tensors()
8935+
assert not self._experts_cache
8936+
8937+
88698938
@ModelBase.register("Lfm2VlForConditionalGeneration")
88708939
class LFM2VLModel(MmprojModel):
88718940
def __init__(self, *args, **kwargs):

gguf-py/gguf/constants.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ class MODEL_ARCH(IntEnum):
408408
SMOLLM3 = auto()
409409
GPT_OSS = auto()
410410
LFM2 = auto()
411+
LFM2MOE = auto()
411412
DREAM = auto()
412413
SMALLTHINKER = auto()
413414
LLADA = auto()
@@ -753,6 +754,7 @@ class MODEL_TENSOR(IntEnum):
753754
MODEL_ARCH.SMOLLM3: "smollm3",
754755
MODEL_ARCH.GPT_OSS: "gpt-oss",
755756
MODEL_ARCH.LFM2: "lfm2",
757+
MODEL_ARCH.LFM2MOE: "lfm2moe",
756758
MODEL_ARCH.DREAM: "dream",
757759
MODEL_ARCH.SMALLTHINKER: "smallthinker",
758760
MODEL_ARCH.LLADA: "llada",
@@ -2733,6 +2735,29 @@ class MODEL_TENSOR(IntEnum):
27332735
MODEL_TENSOR.ATTN_OUT,
27342736
MODEL_TENSOR.OUTPUT,
27352737
],
2738+
MODEL_ARCH.LFM2MOE: [
2739+
MODEL_TENSOR.TOKEN_EMBD,
2740+
MODEL_TENSOR.TOKEN_EMBD_NORM,
2741+
MODEL_TENSOR.SHORTCONV_CONV,
2742+
MODEL_TENSOR.SHORTCONV_INPROJ,
2743+
MODEL_TENSOR.SHORTCONV_OUTPROJ,
2744+
MODEL_TENSOR.FFN_GATE,
2745+
MODEL_TENSOR.FFN_DOWN,
2746+
MODEL_TENSOR.FFN_UP,
2747+
MODEL_TENSOR.FFN_NORM,
2748+
MODEL_TENSOR.ATTN_NORM, # operator_norm
2749+
MODEL_TENSOR.ATTN_Q_NORM,
2750+
MODEL_TENSOR.ATTN_K_NORM,
2751+
MODEL_TENSOR.ATTN_Q,
2752+
MODEL_TENSOR.ATTN_K,
2753+
MODEL_TENSOR.ATTN_V,
2754+
MODEL_TENSOR.ATTN_OUT,
2755+
MODEL_TENSOR.FFN_GATE_INP,
2756+
MODEL_TENSOR.FFN_GATE_EXP,
2757+
MODEL_TENSOR.FFN_DOWN_EXP,
2758+
MODEL_TENSOR.FFN_UP_EXP,
2759+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2760+
],
27362761
MODEL_ARCH.SMALLTHINKER: [
27372762
MODEL_TENSOR.TOKEN_EMBD,
27382763
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ class TensorNameMap:
358358
"model.layers.{bid}.mlp.router", # openai-moe
359359
"model.layers.{bid}.mlp.gate.wg", # hunyuan
360360
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
361+
"model.layers.{bid}.feed_forward.gate", # lfm2moe
361362
),
362363

363364
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -367,6 +368,7 @@ class TensorNameMap:
367368
MODEL_TENSOR.FFN_EXP_PROBS_B: (
368369
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
369370
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
371+
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
370372
),
371373

372374
# Feed-forward up

src/llama-arch.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
9494
{ LLM_ARCH_SMOLLM3, "smollm3" },
9595
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
9696
{ LLM_ARCH_LFM2, "lfm2" },
97+
{ LLM_ARCH_LFM2MOE, "lfm2moe" },
9798
{ LLM_ARCH_DREAM, "dream" },
9899
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
99100
{ LLM_ARCH_LLADA, "llada" },
@@ -2137,6 +2138,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
21372138
{ LLM_TENSOR_OUTPUT, "output" },
21382139
}
21392140
},
2141+
{
2142+
LLM_ARCH_LFM2MOE,
2143+
{
2144+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2145+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2146+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2147+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2148+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2149+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
2150+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
2151+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2152+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2153+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2154+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2155+
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
2156+
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
2157+
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
2158+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2159+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
2160+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
2161+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2162+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2163+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2164+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
2165+
}
2166+
},
21402167
{
21412168
LLM_ARCH_SMALLTHINKER,
21422169
{
@@ -2527,6 +2554,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
25272554
case LLM_ARCH_PLAMO2:
25282555
case LLM_ARCH_GRANITE_HYBRID:
25292556
case LLM_ARCH_LFM2:
2557+
case LLM_ARCH_LFM2MOE:
25302558
case LLM_ARCH_NEMOTRON_H:
25312559
case LLM_ARCH_QWEN3NEXT:
25322560
return true;

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ enum llm_arch {
9898
LLM_ARCH_SMOLLM3,
9999
LLM_ARCH_OPENAI_MOE,
100100
LLM_ARCH_LFM2,
101+
LLM_ARCH_LFM2MOE,
101102
LLM_ARCH_DREAM,
102103
LLM_ARCH_SMALLTHINKER,
103104
LLM_ARCH_LLADA,

src/llama-model.cpp

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ const char * llm_type_name(llm_type type) {
202202
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
203203
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
204204
case LLM_TYPE_A13B: return "A13B";
205+
case LLM_TYPE_8B_A1B: return "8B.A1B";
205206
case LLM_TYPE_21B_A3B: return "21B.A3B";
206207
case LLM_TYPE_30B_A3B: return "30B.A3B";
207208
case LLM_TYPE_80B_A3B: return "80B.A3B";
@@ -2107,14 +2108,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
21072108
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
21082109
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
21092110
}
2111+
hparams.n_layer_dense_lead = hparams.n_layer;
21102112
switch (hparams.n_ff()) {
21112113
case 4608: type = LLM_TYPE_350M; break;
21122114
case 6912: type = LLM_TYPE_700M; break;
21132115
case 8192: type = LLM_TYPE_1_2B; break;
21142116
case 10752: type = LLM_TYPE_2_6B; break;
2115-
default: type = LLM_TYPE_UNKNOWN;
2117+
default: type = LLM_TYPE_UNKNOWN;
21162118
}
21172119
} break;
2120+
case LLM_ARCH_LFM2MOE:
2121+
{
2122+
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
2123+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
2124+
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
2125+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
2126+
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
2127+
2128+
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
2129+
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
2130+
}
2131+
2132+
type = LLM_TYPE_8B_A1B;
2133+
} break;
21182134
case LLM_ARCH_SMALLTHINKER:
21192135
{
21202136
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
@@ -5995,6 +6011,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
59956011
}
59966012
} break;
59976013
case LLM_ARCH_LFM2:
6014+
case LLM_ARCH_LFM2MOE:
59986015
{
59996016
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
60006017
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
@@ -6006,11 +6023,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
60066023

60076024
for (int i = 0; i < n_layer; ++i) {
60086025
auto & layer = layers[i];
6009-
// ffn is same for transformer and conv layers
6026+
6027+
const bool is_moe_layer = i >= static_cast<int>(hparams.n_layer_dense_lead);
6028+
6029+
// ffn/moe is same for transformer and conv layers
60106030
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
6011-
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
6012-
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
6013-
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
6031+
if (is_moe_layer) {
6032+
GGML_ASSERT(n_expert && n_expert_used);
6033+
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
6034+
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0);
6035+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0);
6036+
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0);
6037+
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
6038+
} else { // dense
6039+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
6040+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
6041+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
6042+
}
60146043

60156044
// for operator_norm
60166045
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
@@ -6492,7 +6521,7 @@ void llama_model::print_info() const {
64926521
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
64936522
}
64946523

6495-
if (arch == LLM_ARCH_SMALLTHINKER) {
6524+
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
64966525
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
64976526
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
64986527
}
@@ -18784,6 +18813,8 @@ struct llm_build_lfm2 : public llm_graph_context {
1878418813
ggml_tensor * inp_out_ids = build_inp_out_ids();
1878518814

1878618815
for (int il = 0; il < n_layer; ++il) {
18816+
const bool is_moe_layer = il >= static_cast<int>(hparams.n_layer_dense_lead);
18817+
1878718818
auto * prev_cur = cur;
1878818819
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
1878918820
cb(cur, "model.layers.{}.operator_norm", il);
@@ -18798,7 +18829,16 @@ struct llm_build_lfm2 : public llm_graph_context {
1879818829
}
1879918830

1880018831
cur = ggml_add(ctx0, prev_cur, cur);
18801-
cur = ggml_add(ctx0, cur, build_feed_forward(cur, il));
18832+
18833+
auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
18834+
cb(ffn_norm_out, "model.layers.{}.ffn_norm", il);
18835+
18836+
ggml_tensor * ffn_out = is_moe_layer ?
18837+
build_moe_feed_forward(ffn_norm_out, il) :
18838+
build_dense_feed_forward(ffn_norm_out, il);
18839+
cb(ffn_norm_out, "model.layers.{}.ffn_out", il);
18840+
18841+
cur = ggml_add(ctx0, cur, ffn_out);
1880218842
}
1880318843

1880418844
cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1);
@@ -18813,23 +18853,32 @@ struct llm_build_lfm2 : public llm_graph_context {
1881318853
ggml_build_forward_expand(gf, cur);
1881418854
}
1881518855

18816-
ggml_tensor * build_feed_forward(ggml_tensor * cur,
18817-
int il) const {
18818-
cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
18819-
cb(cur, "model.layers.{}.ffn_norm", il);
18856+
ggml_tensor * build_moe_feed_forward(ggml_tensor * cur,
18857+
int il) const {
18858+
return build_moe_ffn(cur,
18859+
model.layers[il].ffn_gate_inp,
18860+
model.layers[il].ffn_up_exps,
18861+
model.layers[il].ffn_gate_exps,
18862+
model.layers[il].ffn_down_exps,
18863+
model.layers[il].ffn_exp_probs_b,
18864+
n_expert, n_expert_used,
18865+
LLM_FFN_SILU, true,
18866+
false, 0.0,
18867+
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func),
18868+
il);
18869+
}
1882018870

18871+
ggml_tensor * build_dense_feed_forward(ggml_tensor * cur,
18872+
int il) const {
1882118873
GGML_ASSERT(!model.layers[il].ffn_up_b);
1882218874
GGML_ASSERT(!model.layers[il].ffn_gate_b);
1882318875
GGML_ASSERT(!model.layers[il].ffn_down_b);
18824-
cur = build_ffn(cur,
18876+
return build_ffn(cur,
1882518877
model.layers[il].ffn_up, NULL, NULL,
1882618878
model.layers[il].ffn_gate, NULL, NULL,
1882718879
model.layers[il].ffn_down, NULL, NULL,
1882818880
NULL,
1882918881
LLM_FFN_SILU, LLM_FFN_PAR, il);
18830-
cb(cur, "model.layers.{}.feed_forward.w2", il);
18831-
18832-
return cur;
1883318882
}
1883418883

1883518884
ggml_tensor * build_attn_block(ggml_tensor * cur,
@@ -19999,6 +20048,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1999920048
llm = std::make_unique<llm_build_falcon_h1>(*this, params);
2000020049
} break;
2000120050
case LLM_ARCH_LFM2:
20051+
case LLM_ARCH_LFM2MOE:
2000220052
{
2000320053
llm = std::make_unique<llm_build_lfm2>(*this, params);
2000420054
} break;
@@ -20226,6 +20276,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
2022620276
case LLM_ARCH_OPENAI_MOE:
2022720277
case LLM_ARCH_HUNYUAN_DENSE:
2022820278
case LLM_ARCH_LFM2:
20279+
case LLM_ARCH_LFM2MOE:
2022920280
case LLM_ARCH_SMALLTHINKER:
2023020281
case LLM_ARCH_GLM4_MOE:
2023120282
case LLM_ARCH_SEED_OSS:

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ enum llm_type {
107107
LLM_TYPE_17B_16E, // llama4 Scout
108108
LLM_TYPE_17B_128E, // llama4 Maverick
109109
LLM_TYPE_A13B,
110+
LLM_TYPE_8B_A1B, // lfm2moe
110111
LLM_TYPE_21B_A3B, // Ernie MoE small
111112
LLM_TYPE_30B_A3B,
112113
LLM_TYPE_80B_A3B, // Qwen3 Next

0 commit comments

Comments
 (0)