Skip to content

move grammar to llama request level #594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 6 additions & 50 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ hypr-db-user = { path = "crates/db-user", package = "db-user" }
hypr-detect = { path = "crates/detect", package = "detect" }
hypr-diart = { path = "crates/diart", package = "diart" }
hypr-file = { path = "crates/file", package = "file" }
hypr-gbnf = { path = "crates/gbnf", package = "gbnf" }
hypr-gguf = { path = "crates/gguf", package = "gguf" }
hypr-host = { path = "crates/host", package = "host" }
hypr-llama = { path = "crates/llama", package = "llama" }
Expand Down Expand Up @@ -129,6 +130,7 @@ tonic-build = "0.12.3"

async-openai = { git = "https://github.com/fastrepl/async-openai", rev = "6404d307f3f706e818ad91544dc82fac5c545aee", default-features = false }
async-stripe = { version = "0.39.1", default-features = false }
gbnf-validator = { git = "https://github.com/fastrepl/gbnf-validator", rev = "3dec055" }
graph-rs-sdk = "2.0.3"

sentry = "0.36.0"
Expand Down
13 changes: 13 additions & 0 deletions crates/gbnf/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "gbnf"
version = "0.1.0"
edition = "2021"

[dependencies]
minijinja = { workspace = true }

[dev-dependencies]
gbnf-validator = { workspace = true }

colored = "3"
indoc = "2"
26 changes: 14 additions & 12 deletions crates/llama/src/grammar/mod.rs → crates/gbnf/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use include_url_macro::include_url;
pub const ENHANCE_AUTO: &str = include_str!("../assets/enhance-auto.gbnf");

pub const MARKDOWN_GRAMMAR: &str = include_str!("./markdown.gbnf");

#[allow(dead_code)]
pub const JSON_ARR_GRAMMAR: &str = include_url!(
"https://raw.githubusercontent.com/ggml-org/llama.cpp/7a84777/grammars/json_arr.gbnf"
);
pub enum GBNF {
Enhance(Option<Vec<String>>),
}

#[allow(dead_code)]
pub const JSON_GRAMMAR: &str =
include_url!("https://raw.githubusercontent.com/ggml-org/llama.cpp/7a84777/grammars/json.gbnf");
impl GBNF {
pub fn build(&self) -> String {
match self {
GBNF::Enhance(Some(_)) => ENHANCE_AUTO.to_string(),
GBNF::Enhance(None) => ENHANCE_AUTO.to_string(),
}
}
}

#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -55,7 +57,7 @@ mod tests {
"};

assert_eq!(input_1, input_2);
assert!(gbnf.validate(MARKDOWN_GRAMMAR, input_1).unwrap());
assert!(gbnf.validate(ENHANCE_AUTO, input_1).unwrap());
}

#[test]
Expand Down Expand Up @@ -96,7 +98,7 @@ mod tests {
"};

assert!(gbnf.validate(MARKDOWN_GRAMMAR, input).unwrap());
assert!(gbnf.validate(ENHANCE_AUTO, input).unwrap());
}

#[allow(dead_code)]
Expand Down
30 changes: 13 additions & 17 deletions crates/llama/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,11 @@ name = "llama"
version = "0.1.0"
edition = "2021"

[dev-dependencies]
hypr-buffer = { workspace = true }
hypr-data = { workspace = true }
hypr-listener-interface = { workspace = true }
hypr-template = { workspace = true }
hypr-timeline = { workspace = true }

colored = "3.0.0"
dirs = { workspace = true }
gbnf-validator = { git = "https://github.com/fastrepl/gbnf-validator", rev = "3dec055" }
indoc = "2.0.6"
minijinja = { workspace = true }
rand = "0.9.0"
serde_json = { workspace = true }

[dependencies]
hypr-gguf = { workspace = true }
include_url_macro = { workspace = true }

encoding_rs = "0.8.35"
gbnf = "0.1.7"
gbnf-validator = { workspace = true }

async-openai = { workspace = true }
futures-util = { workspace = true }
Expand All @@ -38,3 +22,15 @@ llama-cpp-2 = { git = "https://github.com/utilityai/llama-cpp-rs", default-featu

[target.'cfg(target_os = "macos")'.dependencies]
llama-cpp-2 = { git = "https://github.com/utilityai/llama-cpp-rs", features = ["openmp", "native", "metal"], branch = "update-llama-cpp-2025-04-06" }

[dev-dependencies]
hypr-buffer = { workspace = true }
hypr-data = { workspace = true }
hypr-gbnf = { workspace = true }
hypr-listener-interface = { workspace = true }
hypr-template = { workspace = true }
hypr-timeline = { workspace = true }

dirs = { workspace = true }
rand = "0.9.0"
serde_json = { workspace = true }
39 changes: 27 additions & 12 deletions crates/llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use hypr_gguf::GgufExt;

mod error;
mod grammar;
mod message;
mod stream;
mod types;

pub use error::*;
pub use message::*;
pub use stream::filter_tag;
pub use types::*;

const DEFAULT_MAX_INPUT_TOKENS: u32 = 1024 * 8;
const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 1024;
Expand Down Expand Up @@ -100,13 +99,20 @@ impl Llama {

let mut n_cur = batch.n_tokens();
let mut decoder = encoding_rs::UTF_8.new_decoder();
let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::grammar(&model, grammar::MARKDOWN_GRAMMAR, "root"),
LlamaSampler::temp(0.8),
LlamaSampler::penalties(0, 1.4, 0.1, 0.0),
LlamaSampler::mirostat_v2(1234, 3.0, 0.2),
]);

let mut sampler = match request.grammar {
Some(grammar) => LlamaSampler::chain_simple([
LlamaSampler::grammar(&model, grammar.as_str(), "root"),
LlamaSampler::temp(0.8),
LlamaSampler::penalties(0, 1.4, 0.1, 0.0),
LlamaSampler::mirostat_v2(1234, 3.0, 0.2),
]),
None => LlamaSampler::chain_simple([
LlamaSampler::temp(0.8),
LlamaSampler::penalties(0, 1.4, 0.1, 0.0),
LlamaSampler::mirostat_v2(1234, 3.0, 0.2),
]),
};
while n_cur <= last_index + DEFAULT_MAX_OUTPUT_TOKENS as i32 {
let token = sampler.sample(&ctx, batch.n_tokens() - 1);

Expand Down Expand Up @@ -375,7 +381,10 @@ mod tests {
#[tokio::test]
async fn test_english_1() {
let llama = get_model();
let request = LlamaRequest::new(english_1_messages());
let request = LlamaRequest {
messages: english_1_messages(),
grammar: Some(hypr_gbnf::GBNF::Enhance(None).build()),
};

run(&llama, request, true).await;
}
Expand All @@ -385,7 +394,10 @@ mod tests {
#[tokio::test]
async fn test_english_4() {
let llama = get_model();
let request = LlamaRequest::new(english_4_messages());
let request = LlamaRequest {
messages: english_4_messages(),
grammar: Some(hypr_gbnf::GBNF::Enhance(None).build()),
};

run(&llama, request, true).await;
}
Expand All @@ -395,7 +407,10 @@ mod tests {
#[tokio::test]
async fn test_english_5() {
let llama = get_model();
let request = LlamaRequest::new(english_5_messages());
let request = LlamaRequest {
messages: english_5_messages(),
grammar: Some(hypr_gbnf::GBNF::Enhance(None).build()),
};

run(&llama, request, true).await;
}
Expand Down
8 changes: 2 additions & 6 deletions crates/llama/src/message.rs → crates/llama/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,8 @@ impl FromOpenAI for LlamaChatMessage {
}
}

#[derive(Default)]
pub struct LlamaRequest {
pub grammar: Option<String>,
pub messages: Vec<LlamaChatMessage>,
}

impl LlamaRequest {
pub fn new(messages: Vec<LlamaChatMessage>) -> Self {
Self { messages }
}
}
13 changes: 8 additions & 5 deletions crates/whisper/src/local/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,14 @@ mod tests {
.model_path(concat!(env!("CARGO_MANIFEST_DIR"), "/model.bin"))
.build();

let request = hypr_llama::LlamaRequest::new(vec![hypr_llama::LlamaChatMessage::new(
"user".into(),
"Generate a json array of 1 random objects, about animals".into(),
)
.unwrap()]);
let request = hypr_llama::LlamaRequest {
messages: vec![hypr_llama::LlamaChatMessage::new(
"user".into(),
"Generate a json array of 1 random objects, about animals".into(),
)
.unwrap()],
..Default::default()
};

let response: String = llama.generate_stream(request).unwrap().collect().await;
assert!(response.len() > 4);
Expand Down
1 change: 1 addition & 0 deletions plugins/local-llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ specta-typescript = { workspace = true }

[dependencies]
hypr-file = { workspace = true }
hypr-gbnf = { workspace = true }
hypr-llama = { workspace = true }

thiserror = { workspace = true }
Expand Down
7 changes: 6 additions & 1 deletion plugins/local-llm/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ fn build_response(
.map(hypr_llama::FromOpenAI::from_openai)
.collect();

let request = hypr_llama::LlamaRequest::new(messages);
let request = hypr_llama::LlamaRequest {
messages,
// TODO: should not hard-code this
grammar: Some(hypr_gbnf::GBNF::Enhance(None).build()),
};

model.generate_stream(request).map_err(Into::into)
}
Loading