diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index e4be9d1c9..010eaf9b8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -751,6 +751,8 @@ def __init__( yarn_beta_fast: float = 32.0, yarn_beta_slow: float = 1.0, yarn_orig_ctx: int = 0, + grp_attn_n: int = 1, + grp_attn_w: int = 512, mul_mat_q: bool = True, logits_all: bool = False, embedding: bool = False, @@ -820,6 +822,8 @@ def __init__( yarn_beta_fast: YaRN low correction dim yarn_beta_slow: YaRN high correction dim yarn_orig_ctx: YaRN original context size + grp_attn_n: group-attention factor + grp_attn_w: group-attention width logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs. embedding: Embedding mode only. offload_kqv: Offload K, Q, V to GPU. @@ -935,6 +939,8 @@ def __init__( yarn_beta_slow if yarn_beta_slow != 0.0 else 0 ) self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 + self.context_params.grp_attn_n = grp_attn_n + self.context_params.grp_attn_w = grp_attn_w self.context_params.mul_mat_q = mul_mat_q self.context_params.logits_all = logits_all self.context_params.embedding = embedding @@ -2197,6 +2203,8 @@ def __getstate__(self): yarn_beta_fast=self.context_params.yarn_beta_fast, yarn_beta_slow=self.context_params.yarn_beta_slow, yarn_orig_ctx=self.context_params.yarn_orig_ctx, + grp_attn_n = self.context_params.grp_attn_n, + grp_attn_w = self.context_params.grp_attn_w, mul_mat_q=self.context_params.mul_mat_q, logits_all=self.context_params.logits_all, embedding=self.context_params.embedding, @@ -2241,6 +2249,8 @@ def __setstate__(self, state): yarn_beta_fast=state["yarn_beta_fast"], yarn_beta_slow=state["yarn_beta_slow"], yarn_orig_ctx=state["yarn_orig_ctx"], + grp_attn_n=state["grp_attn_n"], + grp_attn_w=state["grp_attn_w"], mul_mat_q=state["mul_mat_q"], logits_all=state["logits_all"], embedding=state["embedding"], diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 9e8e3cec7..e7df70e54 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -396,7 +396,6 @@ class llama_model_kv_override(Structure): # // override key-value pairs of the model meta data # const struct llama_model_kv_override * kv_overrides; - # // Keep the booleans together to avoid misalignment during copy-by-value. # bool vocab_only; // only load the vocabulary, no weights # bool use_mmap; // use mmap if possible @@ -448,6 +447,10 @@ class llama_model_params(Structure): # float yarn_beta_slow; // YaRN high correction dim # uint32_t yarn_orig_ctx; // YaRN original context size +# // ref: https://github.com/ggerganov/llama.cpp/pull/4815 +# int32_t grp_attn_n = 1; // group-attention factor +# int32_t grp_attn_w = 512; // group-attention width + # enum ggml_type type_k; // data type for K cache # enum ggml_type type_v; // data type for V cache @@ -475,6 +478,8 @@ class llama_context_params(Structure): yarn_beta_fast (float): YaRN low correction dim yarn_beta_slow (float): YaRN high correction dim yarn_orig_ctx (int): YaRN original context size + grp_attn_n (int): group-attention factor + grp_attn_w (int): group-attention width type_k (int): data type for K cache type_v (int): data type for V cache mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true) @@ -497,6 +502,8 @@ class llama_context_params(Structure): ("yarn_beta_fast", c_float), ("yarn_beta_slow", c_float), ("yarn_orig_ctx", c_uint32), + ("grp_attn_n", c_int32), + ("grp_attn_w", c_int32), ("type_k", c_int), ("type_v", c_int), ("mul_mat_q", c_bool), diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index f9be3237d..6ab829322 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -113,6 +113,8 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: yarn_beta_fast=settings.yarn_beta_fast, yarn_beta_slow=settings.yarn_beta_slow, yarn_orig_ctx=settings.yarn_orig_ctx, + grp_attn_n=settings.grp_attn_n, + grp_attn_w=settings.grp_attn_w, mul_mat_q=settings.mul_mat_q, logits_all=settings.logits_all, embedding=settings.embedding, diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index a10390c75..a15101fb3 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -84,6 +84,8 @@ class ModelSettings(BaseSettings): yarn_beta_fast: float = Field(default=32.0) yarn_beta_slow: float = Field(default=1.0) yarn_orig_ctx: int = Field(default=0) + grp_attn_n: int = Field(default=1) + grp_attn_w: int = Field(default=512) mul_mat_q: bool = Field( default=True, description="if true, use experimental mul_mat_q kernels" )