Skip to content

Granite Four #13550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 79 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
1f0fea7
llama : initial Mamba-2 support
compilade Aug 1, 2024
dceff23
ggml : SIMD ggml_ssm_scan for Mamba-2
compilade Aug 19, 2024
2bfe9de
llama : support running Mamba-Codestral-7B-v0.1
compilade Aug 19, 2024
aff9692
llama : fix Mamba-2 conv state saving
compilade Aug 21, 2024
e04910d
llama : remove unused variable
compilade Aug 22, 2024
fa358e7
llama : add missing break
compilade Aug 22, 2024
38913dc
convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present
compilade Aug 22, 2024
0e601ca
Merge branch 'master' into compilade/mamba2
compilade Sep 18, 2024
273e7a4
llama : avoid redundant state copy for Mamba 1 and 2
compilade Sep 30, 2024
7d6cb36
Merge branch 'master' into compilade/mamba2
compilade Oct 1, 2024
2c77d79
metal : attempt to adapt SSM_SCAN for Mamba-2
compilade Oct 2, 2024
87b97d0
metal : fix SSM_SCAN pipeline scope
compilade Oct 2, 2024
03d0e6e
metal : use log and exp instead of log1pf and expf in SSM_SCAN
compilade Oct 2, 2024
7a351ab
metal : remove unused arguments for SSM_SCAN
compilade Oct 2, 2024
8b15bc6
metal : add back n_seqs to SSM_SCAN args
compilade Oct 2, 2024
5b8ec2b
metal : fix SSM_SCAN state head offset
compilade Oct 2, 2024
62b09b3
metal : fix wrong number of tokens per sequence in SSM_SCAN
compilade Oct 3, 2024
038d958
Merge branch 'master' into compilade/mamba2
compilade Oct 12, 2024
805512a
ggml : remove unused fast broadcast path in GGML_MUL
compilade Oct 12, 2024
7d16e1b
Merge branch 'master' into compilade/mamba2
compilade Nov 1, 2024
3bc7103
ggml : avoid multiply by D in GGML_OP_SSM_SCAN
compilade Nov 4, 2024
8d8f065
Merge branch 'master' into compilade/mamba2
compilade Nov 4, 2024
b4e9c59
convert : fix flake8 lint
compilade Nov 4, 2024
1ee6c48
Merge branch 'master' into compilade/mamba2
compilade Nov 25, 2024
c9ecf62
Merge branch 'master' into compilade/mamba2
compilade Feb 26, 2025
35d06fa
Merge branch 'master' into compilade/mamba2
compilade May 1, 2025
cf4f0a4
metal : fix confusion between ; and ,
compilade May 1, 2025
6def5cd
metal : add missing args for nb references in ssm_scan_f32_group
compilade May 1, 2025
791998b
metal : single-user mamba2 inference works
compilade May 2, 2025
94c3d53
kv-cache : remove const_cast when setting inputs for s_copy
compilade May 2, 2025
929fe85
Merge branch 'master' into compilade/mamba2
compilade May 2, 2025
d55b0d0
convert : avoid AutoConfig for Mamba and Mamba2 hparams
compilade May 2, 2025
e94f393
kv-cache : allow context shift for recurrent models
compilade May 2, 2025
582792b
kv-cache : simplify the "struct llama_kv_cache" interface
ggerganov May 25, 2025
99653c3
kv-cache : revert the (n_swa + n_ubatch) change (for next PR)
ggerganov May 25, 2025
052f3f3
kv-cache : some comments
ggerganov May 25, 2025
5693eb6
context : fix graph reserve for multiple sequences
ggerganov May 25, 2025
cb2175f
kv-cache : fix typo [no ci]
ggerganov May 25, 2025
3c6b330
kv-cache : fix find_slot() logic for free slots
ggerganov May 25, 2025
f98b8d0
llama : add TODO for deprecating the defrag API in the future
ggerganov May 26, 2025
7e6d403
kv-cache : improve find_slot() using min/max seq pos info
ggerganov May 27, 2025
47e570c
llama : handle aborts and compute errors
ggerganov May 28, 2025
2b984f4
memory : extract state into llama_memory_state
ggerganov May 28, 2025
f23e4cc
kv-cache : add comments
ggerganov May 30, 2025
3fd6dd5
feat: Add llama_model_is_hybrid API call
gabe-l-hart May 9, 2025
dbad513
feat: Add c++ side constants for attention layer indices hparam
gabe-l-hart May 9, 2025
453d253
feat: Add support for distinguishing recurrent vs non-recurrent layer…
gabe-l-hart May 9, 2025
26e51f4
feat: Auto-fill hparams.recurrent_layer_arr based on whether the mode…
gabe-l-hart May 9, 2025
33a41f5
refactor: rename *_is_hybrid -> *_is_hybrid_recurrent
gabe-l-hart May 28, 2025
162639c
feat: Move layer_filter_cb up to llama_kv_cache
gabe-l-hart May 20, 2025
a886cc1
feat: Add layer filter to recurrent cache
gabe-l-hart May 20, 2025
5c149d2
fix: Fix indexing into k_l for recurrent cache with filter
gabe-l-hart May 20, 2025
4470221
fix: Use per-layer sizing everywhere in kv caches
gabe-l-hart May 14, 2025
ec7695f
feat: First pass at llama_kv_cache_hybrid_recurrent
gabe-l-hart May 30, 2025
728f514
feat: Construct hybrid recurrent cache for hybrid recurrent models
gabe-l-hart May 28, 2025
b58351e
fix: Fix wrong bool condition for split equal in hybrid cache
gabe-l-hart May 28, 2025
4a2709f
feat: Support hybrid recurrent cache in llm_graph_context
gabe-l-hart May 30, 2025
f5fbd1c
Merge branch 'compilade/mamba2' into GraniteFour
gabe-l-hart May 30, 2025
af11d84
feat: Add conversion for Bamba models
gabe-l-hart May 13, 2025
cb37d37
feat: Add Granite 4 conversion
gabe-l-hart May 9, 2025
143b239
feat: Plumb bamba through llama-arch
gabe-l-hart May 9, 2025
c2f3612
feat: Add bamba to llama_arch_is_hybrid_recurrent
gabe-l-hart May 20, 2025
4b83c93
feat: Add optional mamba ssm_in bias tensor
gabe-l-hart May 13, 2025
2b7dee8
feat: Add template specialization for get_arr to load a vector<uint32…
gabe-l-hart May 13, 2025
07e0aa5
feat: Use an explicit bool to determine mamaba vs mamba2
gabe-l-hart May 14, 2025
f1a071d
feat: Isolate mamba(2) and granite attention layer building in static…
gabe-l-hart May 29, 2025
6870ad0
fix: Use per-layer sizes in mamba layer builders
gabe-l-hart May 29, 2025
e5a7c6a
fix: Use per-layer sizes in granite build_attention_layer
gabe-l-hart May 14, 2025
7de4289
feat: First (broken) pass at end-to-end Bamba implementation
gabe-l-hart May 14, 2025
9c73244
fix: Only do Granite multipliers if set
gabe-l-hart May 14, 2025
a7d868b
refactor: Pull granite ffn portion into a static function and reuse i…
gabe-l-hart May 14, 2025
13f116a
feat(py): Allow gguf duplicate keys if they match by value and type
gabe-l-hart May 14, 2025
7856a31
refactor(py): Simplify granitemoehybrid conversion to use parents better
gabe-l-hart May 14, 2025
300f31e
feat: Add GRANITE_MOE_HYBRID through llama-arch
gabe-l-hart May 14, 2025
9ee4a88
feat: Support GRANITE_MOE_HYBRID in llama-model
gabe-l-hart May 14, 2025
418db4e
style: Fix flake8 errors
gabe-l-hart May 14, 2025
662d32a
fix: Fix recurrent cache get after rebase
gabe-l-hart May 28, 2025
8e2221f
fix: Use @compilade's suggested fix for seq_id indexing with equal sp…
gabe-l-hart May 28, 2025
2b20a09
fix: Fix hybrid granite implementation for signature changes in build…
gabe-l-hart May 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 250 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4566,6 +4566,14 @@ def set_gguf_parameters(self):
class MambaModel(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA

def __init__(self, dir_model: Path, *args, **kwargs):
# Avoid using AutoConfig for hparams
hparams = kwargs.pop("hparams", None)
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)

def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 8
Expand Down Expand Up @@ -4640,6 +4648,206 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(new_name, data_torch)]


@ModelBase.register("Mamba2ForCausalLM")
class Mamba2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA2

def __init__(self, dir_model: Path, *args, **kwargs):
# Avoid using AutoConfig for hparams
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
hparams = kwargs.pop("hparams", None)
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
self.n_group = self.hparams.get("n_groups", 1)

def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 16
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
# pad using ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size

if (self.dir_model / "tokenizer.model").is_file():
self._set_vocab_sentencepiece()
elif (self.dir_model / "tokenizer.model.v3").is_file():
# mamba-codestral
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
elif (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2()
else:
# Use the GPT-NeoX tokenizer when no tokenizer files are present
self._set_vocab_builtin("gpt-neox", vocab_size)

def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
n_group = self.find_hparam(["n_groups"], optional=True) or 1

rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5

# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0

self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
self.gguf_writer.add_ssm_group_count(n_group)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_file_type(self.ftype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:

if name.startswith("model.backbone") or name.startswith("model.lm_head"):
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
name = name.removeprefix("model.")

if name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"

new_name = self.map_tensor_name(name)

if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
data_torch = data_torch.squeeze()
elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]):
# unsqueeze A to use similar shape semantics as Mamba-1
# (D is also unsqueezed, but for more straightforward broadcast internally)
data_torch = data_torch.reshape((*data_torch.shape, 1))
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
data_torch = data_torch.reshape((self.n_group, self.d_inner // self.n_group))

if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)

yield (new_name, data_torch)


@ModelBase.register("BambaForCausalLM")
class BambaModel(Mamba2Model):
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
model_arch = gguf.MODEL_ARCH.BAMBA
undo_permute = True

def __init__(self, *args, **kwargs):

# Hybrid mamba models use a prefix for the mamba-specific params.
# TODO: Extend this if the prefix(es) need to be configurable
self.hparam_prefixes = ["mamba"]

super().__init__(*args, **kwargs)

# Use Llama conversion for attention
self._transformer_model_class: type[TextModel] = LlamaModel

# Lists of which layers use ssm vs attention
self._attn_layers = self.get_attn_layres()
self._ssm_layers = [
i for i in range(self.block_count)
if i not in self._attn_layers
]

# n_group and d_inner are used during reshape_tensors for mamaba2
self.d_model = self.find_hparam(["hidden_size", "d_model"])
self.n_group = self.find_hparam(["n_groups"])
self.d_inner = self.find_hparam(["expand"]) * self.d_model

def get_attn_layres(self) -> list[int]:
attn_layers = self.hparams.get("attn_layer_indices", [])
if not attn_layers:
attn_period = self.hparams.get("attn_layer_period")
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
attn_offset = self.hparams.get("attn_layer_offset")
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
attn_layers = [
i for i in range(self.block_count)
if i % attn_period == attn_offset
]
return attn_layers

def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
prefixed = []
for pfx in self.hparam_prefixes:
prefixed.extend(
"_".join([pfx, k])
for k in keys
)
keys = list(keys) + prefixed
return super().find_hparam(keys, *args, **kwargs)

def set_gguf_parameters(self):

## General Params ##
self.gguf_writer.add_embedding_length(self.d_model)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])

## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
# in llama.cpp
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))

## Attention params ##
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
if rope_dim := self.hparams.get("attn_rotary_emb"):
self.gguf_writer.add_rope_dimension_count(rope_dim)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))

## Feed Forward Params ##
self.gguf_writer.add_layer_norm_rms_eps(
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
)

## Validation ##
d_head = self.find_hparam(["d_head"], optional=True) or 64
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"

def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:

# Determine whether this is a mamaba layer or an attention layer
if bid in self._ssm_layers:
for mamba_new_name, data_torch in super().modify_tensors(
data_torch, name, bid
):
yield mamba_new_name, data_torch
elif bid in self._attn_layers:
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
self, data_torch, name, bid
):
yield llama_new_name, data_torch
else:
yield self.map_tensor_name(name), data_torch


@ModelBase.register("CohereForCausalLM")
class CommandR2Model(TextModel):
model_arch = gguf.MODEL_ARCH.COMMAND_R
Expand Down Expand Up @@ -6021,6 +6229,39 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("GraniteMoeHybridForCausalLM")
class GraniteMoeHybridModel(BambaModel, GraniteMoeModel):
"""GraniteMoeHybrid is a hybrid SSM + MoE Attention model that uses Mamba2
SSM layers"""
model_arch = gguf.MODEL_ARCH.GRANITE_MOE_HYBRID

def get_attn_layres(self):
if layer_types := self.hparams.get("layer_types"):
return [
i for i, typ in enumerate(layer_types)
if typ == "attention"
]
return super().get_attn_layres()

def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:
if (
name.endswith("block_sparse_moe.input_linear.weight")
or name.endswith("shared_mlp.input_linear.weight")
):
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
return super().modify_tensors(data_torch, name, bid)

def set_gguf_parameters(self):
GraniteMoeModel.set_gguf_parameters(self)
BambaModel.set_gguf_parameters(self)

def set_vocab(self):
self.hparams["pad_vocab_size_multiple"] = 8
super().set_vocab()


@ModelBase.register("BailingMoeForCausalLM")
class BailingMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
Expand Down Expand Up @@ -6406,12 +6647,20 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
# maybe we should fallback to text model's arch in that case, since not many models have both
text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0]
arch = None
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
arch = arches[0]
elif "ssm_cfg" in hparams:
# For non-hf Mamba and Mamba2 models
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"

# if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0]
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0]
if arch is None:
raise ValueError("Failed to detect model architecture")
return arch


Expand Down
3 changes: 2 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1878,7 +1878,8 @@ extern "C" {
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C);
struct ggml_tensor * C,
struct ggml_tensor * ids);

// partition into non-overlapping windows with padding if needed
// example:
Expand Down
Loading
Loading