Skip to content

Commit 7766564

Browse files
younesbelkadacompilade
authored andcommitted
llama : support for falcon-mamba architecture (ggml-org#9074)
* feat: initial support for llama.cpp * fix: lint * refactor: better refactor * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * fix: address comments * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * fix: add more cleanup and harmonization * fix: lint * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <[email protected]> * fix: change name * Apply suggestions from code review Co-authored-by: compilade <[email protected]> * add in operator * fix: add `dt_b_c_rms` in `llm_load_print_meta` * fix: correct printf format for bool * fix: correct print format * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * llama : quantize more Mamba tensors * llama : use f16 as the fallback of fallback quant types --------- Co-authored-by: compilade <[email protected]>
1 parent 3fe9fae commit 7766564

File tree

5 files changed

+36
-24
lines changed

5 files changed

+36
-24
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ Typically finetunes of the base models below are supported as well.
140140
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
141141
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
142142
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
143+
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
143144

144145
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
145146

convert_hf_to_gguf.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def prepare_tensors(self):
295295
gguf.MODEL_TENSOR.FFN_GATE_INP,
296296
gguf.MODEL_TENSOR.POS_EMBD,
297297
gguf.MODEL_TENSOR.TOKEN_TYPES,
298+
gguf.MODEL_TENSOR.SSM_CONV1D,
298299
)
299300
)
300301
or not name.endswith(".weight")
@@ -2711,7 +2712,7 @@ class StarCoder2Model(Model):
27112712
model_arch = gguf.MODEL_ARCH.STARCODER2
27122713

27132714

2714-
@Model.register("MambaForCausalLM", "MambaLMHeadModel")
2715+
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
27152716
class MambaModel(Model):
27162717
model_arch = gguf.MODEL_ARCH.MAMBA
27172718

@@ -2742,20 +2743,24 @@ def set_gguf_parameters(self):
27422743
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
27432744
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
27442745
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
2745-
2746+
use_dt_b_c_norm = False
2747+
# For falconmamba we do apply RMS norm on B / DT and C layers
2748+
if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",):
2749+
use_dt_b_c_norm = True
27462750
# Fail early for models which don't have a block expansion factor of 2
27472751
assert d_inner == 2 * d_model
27482752

27492753
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
27502754
self.gguf_writer.add_embedding_length(d_model)
27512755
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
27522756
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
2753-
self.gguf_writer.add_block_count(self.hparams["n_layer"])
2757+
self.gguf_writer.add_block_count(self.block_count)
27542758
self.gguf_writer.add_ssm_conv_kernel(d_conv)
27552759
self.gguf_writer.add_ssm_inner_size(d_inner)
27562760
self.gguf_writer.add_ssm_state_size(d_state)
27572761
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
27582762
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
2763+
self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers
27592764
self.gguf_writer.add_file_type(self.ftype)
27602765

27612766
_tok_embd = None
@@ -2782,23 +2787,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27822787

27832788
return [(new_name, data_torch)]
27842789

2785-
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
2786-
if bid is not None and new_name in (
2787-
self.format_tensor_name(
2788-
n, bid, ".weight" if name.endswith(".weight") else ""
2789-
)
2790-
for n in [
2791-
gguf.MODEL_TENSOR.SSM_CONV1D,
2792-
gguf.MODEL_TENSOR.SSM_X,
2793-
gguf.MODEL_TENSOR.SSM_DT,
2794-
gguf.MODEL_TENSOR.SSM_A,
2795-
gguf.MODEL_TENSOR.SSM_D,
2796-
]
2797-
):
2798-
return gguf.GGMLQuantizationType.F32
2799-
2800-
return super().tensor_force_quant(name, new_name, bid, n_dims)
2801-
28022790

28032791
@Model.register("CohereForCausalLM")
28042792
class CommandR2Model(Model):
@@ -3792,7 +3780,7 @@ class ExaoneModel(Model):
37923780
def set_gguf_parameters(self):
37933781
hparams = self.hparams
37943782

3795-
assert(hparams["activation_function"] == "silu")
3783+
assert (hparams["activation_function"] == "silu")
37963784

37973785
max_position_embeddings = hparams["max_position_embeddings"]
37983786
embed_dim = hparams["hidden_size"]
@@ -3855,8 +3843,8 @@ def prepare_tensors(self):
38553843

38563844
super().prepare_tensors()
38573845

3858-
###### CONVERSION LOGIC ######
38593846

3847+
###### CONVERSION LOGIC ######
38603848

38613849
# tree of lazy tensors
38623850
class LazyTorchTensor(gguf.LazyBase):

gguf-py/gguf/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class SSM:
130130
INNER_SIZE = "{arch}.ssm.inner_size"
131131
STATE_SIZE = "{arch}.ssm.state_size"
132132
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
133+
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
133134

134135
class Tokenizer:
135136
MODEL = "tokenizer.ggml.model"
@@ -1372,6 +1373,7 @@ def get_type(val: Any) -> GGUFValueType:
13721373
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
13731374
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
13741375
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
1376+
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
13751377

13761378
# tokenization
13771379
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,9 @@ def add_ssm_state_size(self, value: int) -> None:
730730
def add_ssm_time_step_rank(self, value: int) -> None:
731731
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
732732

733+
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
734+
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
735+
733736
def add_tokenizer_model(self, model: str) -> None:
734737
self.add_string(Keys.Tokenizer.MODEL, model)
735738

src/llama.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ enum llm_kv {
328328
LLM_KV_SSM_CONV_KERNEL,
329329
LLM_KV_SSM_STATE_SIZE,
330330
LLM_KV_SSM_TIME_STEP_RANK,
331+
LLM_KV_SSM_DT_B_C_RMS,
331332

332333
LLM_KV_TOKENIZER_MODEL,
333334
LLM_KV_TOKENIZER_PRE,
@@ -426,6 +427,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
426427
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
427428
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
428429
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
430+
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
429431

430432
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
431433
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
@@ -2237,6 +2239,7 @@ struct llama_hparams {
22372239
uint32_t ssm_d_inner = 0;
22382240
uint32_t ssm_d_state = 0;
22392241
uint32_t ssm_dt_rank = 0;
2242+
bool ssm_dt_b_c_rms = false;
22402243

22412244
float f_clamp_kqv = 0.0f;
22422245
float f_max_alibi_bias = 0.0f;
@@ -2286,6 +2289,7 @@ struct llama_hparams {
22862289
if (this->ssm_d_inner != other.ssm_d_inner) return true;
22872290
if (this->ssm_d_state != other.ssm_d_state) return true;
22882291
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
2292+
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
22892293

22902294
if (this->dec_start_token_id != other.dec_start_token_id) return true;
22912295

@@ -5052,6 +5056,7 @@ static void llm_load_hparams(
50525056
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
50535057
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
50545058
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
5059+
ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false);
50555060

50565061
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
50575062

@@ -5907,6 +5912,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
59075912
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
59085913
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
59095914
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
5915+
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
59105916
}
59115917

59125918
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
@@ -12165,6 +12171,10 @@ struct llm_build_context {
1216512171
GGML_ASSERT(2 * d_model == d_inner);
1216612172
const int64_t d_state = hparams.ssm_d_state;
1216712173
const int64_t dt_rank = hparams.ssm_dt_rank;
12174+
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
12175+
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
12176+
// Use the same RMS norm as the final layer norm
12177+
const float norm_rms_eps = hparams.f_norm_rms_eps;
1216812178

1216912179
struct ggml_tensor * cur;
1217012180
struct ggml_tensor * inpL;
@@ -12245,6 +12255,13 @@ struct llm_build_context {
1224512255
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
1224612256
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
1224712257

12258+
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
12259+
if (ssm_dt_b_c_rms) {
12260+
dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
12261+
B = ggml_rms_norm(ctx0, B, norm_rms_eps);
12262+
C = ggml_rms_norm(ctx0, C, norm_rms_eps);
12263+
}
12264+
1224812265
// {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
1224912266
dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
1225012267
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
@@ -16109,6 +16126,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
1610916126
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
1611016127
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
1611116128
}
16129+
if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
16130+
new_type = GGML_TYPE_F16;
16131+
}
1611216132
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
1611316133
++qs.n_fallback;
1611416134
}
@@ -16437,8 +16457,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
1643716457
// do not quantize Mamba's small yet 2D weights
1643816458
// NOTE: can't use LLM_TN here because the layer number is not known
1643916459
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
16440-
quantize &= name.find("ssm_x.weight") == std::string::npos;
16441-
quantize &= name.find("ssm_dt.weight") == std::string::npos;
1644216460

1644316461
// do not quantize relative position bias (T5)
1644416462
quantize &= name.find("attn_rel_b.weight") == std::string::npos;

0 commit comments

Comments
 (0)