Skip to content

Feat: Support for falcon-mamba architecture #9074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)

(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))

Expand Down
32 changes: 10 additions & 22 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D,
)
)
or not name.endswith(".weight")
Expand Down Expand Up @@ -2711,7 +2712,7 @@ class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2


@Model.register("MambaForCausalLM", "MambaLMHeadModel")
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA

Expand Down Expand Up @@ -2742,20 +2743,24 @@ def set_gguf_parameters(self):
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5

use_dt_b_c_norm = False
# For falconmamba we do apply RMS norm on B / DT and C layers
if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",):
use_dt_b_c_norm = True
# Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model

self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers
self.gguf_writer.add_file_type(self.ftype)

_tok_embd = None
Expand All @@ -2782,23 +2787,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

return [(new_name, data_torch)]

def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
if bid is not None and new_name in (
self.format_tensor_name(
n, bid, ".weight" if name.endswith(".weight") else ""
)
for n in [
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SSM_X,
gguf.MODEL_TENSOR.SSM_DT,
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]
):
return gguf.GGMLQuantizationType.F32

return super().tensor_force_quant(name, new_name, bid, n_dims)


@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
Expand Down Expand Up @@ -3792,7 +3780,7 @@ class ExaoneModel(Model):
def set_gguf_parameters(self):
hparams = self.hparams

assert(hparams["activation_function"] == "silu")
assert (hparams["activation_function"] == "silu")

max_position_embeddings = hparams["max_position_embeddings"]
embed_dim = hparams["hidden_size"]
Expand Down Expand Up @@ -3855,8 +3843,8 @@ def prepare_tensors(self):

super().prepare_tensors()

###### CONVERSION LOGIC ######

###### CONVERSION LOGIC ######

# tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase):
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class SSM:
INNER_SIZE = "{arch}.ssm.inner_size"
STATE_SIZE = "{arch}.ssm.state_size"
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"

class Tokenizer:
MODEL = "tokenizer.ggml.model"
Expand Down Expand Up @@ -1372,6 +1373,7 @@ def get_type(val: Any) -> GGUFValueType:
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS

# tokenization
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,9 @@ def add_ssm_state_size(self, value: int) -> None:
def add_ssm_time_step_rank(self, value: int) -> None:
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)

def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)

def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)

Expand Down
22 changes: 20 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_DT_B_C_RMS,

LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_PRE,
Expand Down Expand Up @@ -426,6 +427,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },

{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
Expand Down Expand Up @@ -2237,6 +2239,7 @@ struct llama_hparams {
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
bool ssm_dt_b_c_rms = false;

float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
Expand Down Expand Up @@ -2286,6 +2289,7 @@ struct llama_hparams {
if (this->ssm_d_inner != other.ssm_d_inner) return true;
if (this->ssm_d_state != other.ssm_d_state) return true;
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;

if (this->dec_start_token_id != other.dec_start_token_id) return true;

Expand Down Expand Up @@ -5052,6 +5056,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false);

ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

Expand Down Expand Up @@ -5907,6 +5912,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
}

LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
Expand Down Expand Up @@ -12161,6 +12167,10 @@ struct llm_build_context {
GGML_ASSERT(2 * d_model == d_inner);
const int64_t d_state = hparams.ssm_d_state;
const int64_t dt_rank = hparams.ssm_dt_rank;
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
// Use the same RMS norm as the final layer norm
const float norm_rms_eps = hparams.f_norm_rms_eps;

struct ggml_tensor * cur;
struct ggml_tensor * inpL;
Expand Down Expand Up @@ -12241,6 +12251,13 @@ struct llm_build_context {
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));

// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
if (ssm_dt_b_c_rms) {
dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
B = ggml_rms_norm(ctx0, B, norm_rms_eps);
C = ggml_rms_norm(ctx0, C, norm_rms_eps);
}

// {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
Expand Down Expand Up @@ -16105,6 +16122,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
}
if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
new_type = GGML_TYPE_F16;
}
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
++qs.n_fallback;
}
Expand Down Expand Up @@ -16433,8 +16453,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// do not quantize Mamba's small yet 2D weights
// NOTE: can't use LLM_TN here because the layer number is not known
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
quantize &= name.find("ssm_x.weight") == std::string::npos;
quantize &= name.find("ssm_dt.weight") == std::string::npos;

// do not quantize relative position bias (T5)
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
Expand Down
Loading