Skip to content

Add self extend support #1090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,8 @@ def __init__(
yarn_beta_fast: float = 32.0,
yarn_beta_slow: float = 1.0,
yarn_orig_ctx: int = 0,
grp_attn_n: int = 1,
grp_attn_w: int = 512,
mul_mat_q: bool = True,
logits_all: bool = False,
embedding: bool = False,
Expand Down Expand Up @@ -820,6 +822,8 @@ def __init__(
yarn_beta_fast: YaRN low correction dim
yarn_beta_slow: YaRN high correction dim
yarn_orig_ctx: YaRN original context size
grp_attn_n: group-attention factor
grp_attn_w: group-attention width
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
embedding: Embedding mode only.
offload_kqv: Offload K, Q, V to GPU.
Expand Down Expand Up @@ -935,6 +939,8 @@ def __init__(
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
)
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
self.context_params.grp_attn_n = grp_attn_n
self.context_params.grp_attn_w = grp_attn_w
self.context_params.mul_mat_q = mul_mat_q
self.context_params.logits_all = logits_all
self.context_params.embedding = embedding
Expand Down Expand Up @@ -2197,6 +2203,8 @@ def __getstate__(self):
yarn_beta_fast=self.context_params.yarn_beta_fast,
yarn_beta_slow=self.context_params.yarn_beta_slow,
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
grp_attn_n = self.context_params.grp_attn_n,
grp_attn_w = self.context_params.grp_attn_w,
mul_mat_q=self.context_params.mul_mat_q,
logits_all=self.context_params.logits_all,
embedding=self.context_params.embedding,
Expand Down Expand Up @@ -2241,6 +2249,8 @@ def __setstate__(self, state):
yarn_beta_fast=state["yarn_beta_fast"],
yarn_beta_slow=state["yarn_beta_slow"],
yarn_orig_ctx=state["yarn_orig_ctx"],
grp_attn_n=state["grp_attn_n"],
grp_attn_w=state["grp_attn_w"],
mul_mat_q=state["mul_mat_q"],
logits_all=state["logits_all"],
embedding=state["embedding"],
Expand Down
9 changes: 8 additions & 1 deletion llama_cpp/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ class llama_model_kv_override(Structure):
# // override key-value pairs of the model meta data
# const struct llama_model_kv_override * kv_overrides;


# // Keep the booleans together to avoid misalignment during copy-by-value.
# bool vocab_only; // only load the vocabulary, no weights
# bool use_mmap; // use mmap if possible
Expand Down Expand Up @@ -448,6 +447,10 @@ class llama_model_params(Structure):
# float yarn_beta_slow; // YaRN high correction dim
# uint32_t yarn_orig_ctx; // YaRN original context size

# // ref: https://github.com/ggerganov/llama.cpp/pull/4815
# int32_t grp_attn_n = 1; // group-attention factor
# int32_t grp_attn_w = 512; // group-attention width

# enum ggml_type type_k; // data type for K cache
# enum ggml_type type_v; // data type for V cache

Expand Down Expand Up @@ -475,6 +478,8 @@ class llama_context_params(Structure):
yarn_beta_fast (float): YaRN low correction dim
yarn_beta_slow (float): YaRN high correction dim
yarn_orig_ctx (int): YaRN original context size
grp_attn_n (int): group-attention factor
grp_attn_w (int): group-attention width
type_k (int): data type for K cache
type_v (int): data type for V cache
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
Expand All @@ -497,6 +502,8 @@ class llama_context_params(Structure):
("yarn_beta_fast", c_float),
("yarn_beta_slow", c_float),
("yarn_orig_ctx", c_uint32),
("grp_attn_n", c_int32),
("grp_attn_w", c_int32),
("type_k", c_int),
("type_v", c_int),
("mul_mat_q", c_bool),
Expand Down
2 changes: 2 additions & 0 deletions llama_cpp/server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
yarn_beta_fast=settings.yarn_beta_fast,
yarn_beta_slow=settings.yarn_beta_slow,
yarn_orig_ctx=settings.yarn_orig_ctx,
grp_attn_n=settings.grp_attn_n,
grp_attn_w=settings.grp_attn_w,
mul_mat_q=settings.mul_mat_q,
logits_all=settings.logits_all,
embedding=settings.embedding,
Expand Down
2 changes: 2 additions & 0 deletions llama_cpp/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class ModelSettings(BaseSettings):
yarn_beta_fast: float = Field(default=32.0)
yarn_beta_slow: float = Field(default=1.0)
yarn_orig_ctx: int = Field(default=0)
grp_attn_n: int = Field(default=1)
grp_attn_w: int = Field(default=512)
mul_mat_q: bool = Field(
default=True, description="if true, use experimental mul_mat_q kernels"
)
Expand Down