Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,4 @@ bytes = "1.10.1"
rand = "0.9.0"
indoc = "2.0.6"
owo-colors = "4.2.0"
json5 = "0.4.1"
25 changes: 24 additions & 1 deletion docs/docs/ai/llm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,27 @@ cocoindex.LlmSpec(
</TabItem>
</Tabs>

You can find the full list of models supported by Gemini [here](https://ai.google.dev/gemini-api/docs/models).
You can find the full list of models supported by Gemini [here](https://ai.google.dev/gemini-api/docs/models).

### Anthropic

To use the Anthropic LLM API, you need to set the environment variable `ANTHROPIC_API_KEY`.
You can generate the API key from [Anthropic API](https://console.anthropic.com/settings/keys).

A spec for Anthropic looks like this:

<Tabs>
<TabItem value="python" label="Python" default>

```python
cocoindex.LlmSpec(
api_type=cocoindex.LlmApiType.ANTHROPIC,
model="claude-3-5-sonnet-latest",
)
```

</TabItem>
</Tabs>

You can find the full list of models supported by Anthropic [here](https://docs.anthropic.com/en/docs/about-claude/models/all-models).

4 changes: 4 additions & 0 deletions examples/manuals_llm_extraction/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco
# Replace by this spec below, to use Gemini API model
# llm_spec=cocoindex.LlmSpec(
# api_type=cocoindex.LlmApiType.GEMINI, model="gemini-2.0-flash"),

# Replace by this spec below, to use Anthropic API model
# llm_spec=cocoindex.LlmSpec(
# api_type=cocoindex.LlmApiType.ANTHROPIC, model="claude-3-5-sonnet-latest"),
output_type=ModuleInfo,
instruction="Please extract Python module information from the manual."))
doc["module_summary"] = doc["module_info"].transform(summarize_module)
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class LlmApiType(Enum):
OPENAI = "OpenAi"
OLLAMA = "Ollama"
GEMINI = "Gemini"
ANTHROPIC = "Anthropic"

@dataclass
class LlmSpec:
Expand Down
132 changes: 132 additions & 0 deletions src/llm/anthropic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
use async_trait::async_trait;
use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat};
use anyhow::{Result, bail, Context};
use serde_json::Value;
use json5;

use crate::api_bail;
use urlencoding::encode;

pub struct Client {
model: String,
api_key: String,
client: reqwest::Client,
}

impl Client {
pub async fn new(spec: LlmSpec) -> Result<Self> {
let api_key = match std::env::var("ANTHROPIC_API_KEY") {
Ok(val) => val,
Err(_) => api_bail!("ANTHROPIC_API_KEY environment variable must be set"),
};
Ok(Self {
model: spec.model,
api_key,
client: reqwest::Client::new(),
})
}
}

#[async_trait]
impl LlmGenerationClient for Client {
async fn generate<'req>(
&self,
request: LlmGenerateRequest<'req>,
) -> Result<LlmGenerateResponse> {
let messages = vec![serde_json::json!({
"role": "user",
"content": request.user_prompt
})];

let mut payload = serde_json::json!({
"model": self.model,
"messages": messages,
"max_tokens": 4096
});

// Add system prompt as top-level field if present (required)
if let Some(system) = request.system_prompt {
payload["system"] = serde_json::json!(system);
}

let OutputFormat::JsonSchema { schema, .. } = request.output_format.as_ref().expect("Anthropic client expects OutputFormat::JsonSchema for all requests");
let schema_json = serde_json::to_value(schema)?;
payload["tools"] = serde_json::json!([
{ "type": "custom", "name": "extraction", "input_schema": schema_json }
]);

let url = "https://api.anthropic.com/v1/messages";

let encoded_api_key = encode(&self.api_key);

let resp = self.client
.post(url)
.header("x-api-key", encoded_api_key.as_ref())
.header("anthropic-version", "2023-06-01")
.json(&payload)
.send()
.await
.context("HTTP error")?;
let resp_json: Value = resp.json().await.context("Invalid JSON")?;
if let Some(error) = resp_json.get("error") {
bail!("Anthropic API error: {:?}", error);
}

// Debug print full response
// println!("Anthropic API full response: {resp_json:?}");

let resp_content = &resp_json["content"];
let tool_name = "extraction";
let mut extracted_json: Option<Value> = None;
if let Some(array) = resp_content.as_array() {
for item in array {
if item.get("type") == Some(&Value::String("tool_use".to_string()))
&& item.get("name") == Some(&Value::String(tool_name.to_string()))
{
if let Some(input) = item.get("input") {
extracted_json = Some(input.clone());
break;
}
}
}
}
let text = if let Some(json) = extracted_json {
// Try strict JSON serialization first
serde_json::to_string(&json)?
} else {
// Fallback: try text if no tool output found
match &resp_json["content"][0]["text"] {
Value::String(s) => {
// Try strict JSON parsing first
match serde_json::from_str::<serde_json::Value>(s) {
Ok(_) => s.clone(),
Err(e) => {
// Try permissive json5 parsing as fallback
match json5::from_str::<serde_json::Value>(s) {
Ok(_) => {
println!("[Anthropic] Used permissive JSON5 parser for output");
s.clone()
},
Err(e2) => return Err(anyhow::anyhow!(format!("No structured tool output or text found in response, and permissive JSON5 parsing also failed: {e}; {e2}")))
}
}
}
},
_ => return Err(anyhow::anyhow!("No structured tool output or text found in response")),
}
};

Ok(LlmGenerateResponse {
text,
})
}

fn json_schema_options(&self) -> ToJsonSchemaOptions {
ToJsonSchemaOptions {
fields_always_required: false,
supports_format: false,
extract_descriptions: false,
top_level_must_be_object: true,
}
}
}
5 changes: 5 additions & 0 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub enum LlmApiType {
Ollama,
OpenAi,
Gemini,
Anthropic,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -54,6 +55,7 @@ pub trait LlmGenerationClient: Send + Sync {
mod ollama;
mod openai;
mod gemini;
mod anthropic;

pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
let client = match spec.api_type {
Expand All @@ -66,6 +68,9 @@ pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGener
LlmApiType::Gemini => {
Box::new(gemini::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::Anthropic => {
Box::new(anthropic::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
}
};
Ok(client)
}