From c6c775270e594e20250a120cbaeb4dbead27e330 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 15 Jan 2024 20:09:16 -0500 Subject: [PATCH 1/4] Add self extend support --- llama_cpp/llama_cpp.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 9e8e3cec7..a4fbfe460 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -396,6 +396,8 @@ class llama_model_kv_override(Structure): # // override key-value pairs of the model meta data # const struct llama_model_kv_override * kv_overrides; +# Ref https://github.com/ggerganov/llama.cpp/pull/4815 + # # // Keep the booleans together to avoid misalignment during copy-by-value. # bool vocab_only; // only load the vocabulary, no weights @@ -448,6 +450,9 @@ class llama_model_params(Structure): # float yarn_beta_slow; // YaRN high correction dim # uint32_t yarn_orig_ctx; // YaRN original context size +# int32_t grp_attn_n = 1; // group-attention factor +# int32_t grp_attn_w = 512; // group-attention width + # enum ggml_type type_k; // data type for K cache # enum ggml_type type_v; // data type for V cache @@ -475,6 +480,8 @@ class llama_context_params(Structure): yarn_beta_fast (float): YaRN low correction dim yarn_beta_slow (float): YaRN high correction dim yarn_orig_ctx (int): YaRN original context size + grp_attn_n (int): group-attention factor + grp_attn_w (int): group-attention width type_k (int): data type for K cache type_v (int): data type for V cache mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true) @@ -497,6 +504,8 @@ class llama_context_params(Structure): ("yarn_beta_fast", c_float), ("yarn_beta_slow", c_float), ("yarn_orig_ctx", c_uint32), + ("grp_attn_n", c_int32), + ("grp_attn_w", c_int32), ("type_k", c_int), ("type_v", c_int), ("mul_mat_q", c_bool), From b0ef3839fe6748e197f80392fa5d2b6e82343066 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 15 Jan 2024 20:10:22 -0500 Subject: [PATCH 2/4] More consistent --- llama_cpp/llama_cpp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index a4fbfe460..e7df70e54 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -396,9 +396,6 @@ class llama_model_kv_override(Structure): # // override key-value pairs of the model meta data # const struct llama_model_kv_override * kv_overrides; -# Ref https://github.com/ggerganov/llama.cpp/pull/4815 - # - # // Keep the booleans together to avoid misalignment during copy-by-value. # bool vocab_only; // only load the vocabulary, no weights # bool use_mmap; // use mmap if possible @@ -450,6 +447,7 @@ class llama_model_params(Structure): # float yarn_beta_slow; // YaRN high correction dim # uint32_t yarn_orig_ctx; // YaRN original context size +# // ref: https://github.com/ggerganov/llama.cpp/pull/4815 # int32_t grp_attn_n = 1; // group-attention factor # int32_t grp_attn_w = 512; // group-attention width From ee070bf1850c52a7759f5755593fd66d05d45551 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 15 Jan 2024 20:29:04 -0500 Subject: [PATCH 3/4] Add params --- llama_cpp/llama.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index e4be9d1c9..010eaf9b8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -751,6 +751,8 @@ def __init__( yarn_beta_fast: float = 32.0, yarn_beta_slow: float = 1.0, yarn_orig_ctx: int = 0, + grp_attn_n: int = 1, + grp_attn_w: int = 512, mul_mat_q: bool = True, logits_all: bool = False, embedding: bool = False, @@ -820,6 +822,8 @@ def __init__( yarn_beta_fast: YaRN low correction dim yarn_beta_slow: YaRN high correction dim yarn_orig_ctx: YaRN original context size + grp_attn_n: group-attention factor + grp_attn_w: group-attention width logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs. embedding: Embedding mode only. offload_kqv: Offload K, Q, V to GPU. @@ -935,6 +939,8 @@ def __init__( yarn_beta_slow if yarn_beta_slow != 0.0 else 0 ) self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 + self.context_params.grp_attn_n = grp_attn_n + self.context_params.grp_attn_w = grp_attn_w self.context_params.mul_mat_q = mul_mat_q self.context_params.logits_all = logits_all self.context_params.embedding = embedding @@ -2197,6 +2203,8 @@ def __getstate__(self): yarn_beta_fast=self.context_params.yarn_beta_fast, yarn_beta_slow=self.context_params.yarn_beta_slow, yarn_orig_ctx=self.context_params.yarn_orig_ctx, + grp_attn_n = self.context_params.grp_attn_n, + grp_attn_w = self.context_params.grp_attn_w, mul_mat_q=self.context_params.mul_mat_q, logits_all=self.context_params.logits_all, embedding=self.context_params.embedding, @@ -2241,6 +2249,8 @@ def __setstate__(self, state): yarn_beta_fast=state["yarn_beta_fast"], yarn_beta_slow=state["yarn_beta_slow"], yarn_orig_ctx=state["yarn_orig_ctx"], + grp_attn_n=state["grp_attn_n"], + grp_attn_w=state["grp_attn_w"], mul_mat_q=state["mul_mat_q"], logits_all=state["logits_all"], embedding=state["embedding"], From 82dcd9f55636b0f6205e11069f26ee0dd11fbaab Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 15 Jan 2024 20:52:28 -0500 Subject: [PATCH 4/4] Add in other places --- llama_cpp/server/model.py | 2 ++ llama_cpp/server/settings.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index f9be3237d..6ab829322 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -113,6 +113,8 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: yarn_beta_fast=settings.yarn_beta_fast, yarn_beta_slow=settings.yarn_beta_slow, yarn_orig_ctx=settings.yarn_orig_ctx, + grp_attn_n=settings.grp_attn_n, + grp_attn_w=settings.grp_attn_w, mul_mat_q=settings.mul_mat_q, logits_all=settings.logits_all, embedding=settings.embedding, diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 902a43919..693cfdec2 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -84,6 +84,8 @@ class ModelSettings(BaseSettings): yarn_beta_fast: float = Field(default=32.0) yarn_beta_slow: float = Field(default=1.0) yarn_orig_ctx: int = Field(default=0) + grp_attn_n: int = Field(default=1) + grp_attn_w: int = Field(default=512) mul_mat_q: bool = Field( default=True, description="if true, use experimental mul_mat_q kernels" )