Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/docs/core/flow_def.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ Types of the fields must be key types. See [Key Types](data_types#key-types) for

* `field_name`: the field to create vector index.
* `metric`: the similarity metric to use.
* `method` (optional): the index algorithm and optional tuning parameters. Leave unset to use the target default (HNSW for Postgres). Use `cocoindex.HnswVectorIndexMethod()` or `cocoindex.IvfFlatVectorIndexMethod()` to customize the method and its parameters.

#### Similarity Metrics

Expand Down
10 changes: 10 additions & 0 deletions docs/docs/examples/examples/simple_vector_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ doc_embeddings.export(
CocoIndex supports other vector databases as well, with 1-line switch.
<DocumentationButton url="https://cocoindex.io/docs/ops/targets" text="Targets" />

Need IVFFlat or custom HNSW parameters? Pass a method, for example:

```python
cocoindex.VectorIndexDef(
field_name="embedding",
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
method=cocoindex.IvfFlatVectorIndexMethod(lists=200),
)
```

## Query the index

### Define a shared flow for both indexing and querying
Expand Down
10 changes: 9 additions & 1 deletion python/cocoindex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
from .flow import update_all_flows_async, setup_all_flows, drop_all_flows
from .lib import settings, init, start_server, stop
from .llm import LlmSpec, LlmApiType
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
from .index import (
VectorSimilarityMetric,
VectorIndexDef,
IndexOptions,
HnswVectorIndexMethod,
IvfFlatVectorIndexMethod,
)
from .setting import DatabaseConnectionSpec, Settings, ServerSettings
from .setting import get_app_namespace
from .query_handler import QueryHandlerResultFields, QueryInfo, QueryOutput
Expand Down Expand Up @@ -82,6 +88,8 @@
"VectorSimilarityMetric",
"VectorIndexDef",
"IndexOptions",
"HnswVectorIndexMethod",
"IvfFlatVectorIndexMethod",
# Settings
"DatabaseConnectionSpec",
"Settings",
Expand Down
23 changes: 22 additions & 1 deletion python/cocoindex/index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from dataclasses import dataclass
from typing import Sequence
from typing import Sequence, Union


class VectorSimilarityMetric(Enum):
Expand All @@ -9,6 +9,26 @@ class VectorSimilarityMetric(Enum):
INNER_PRODUCT = "InnerProduct"


@dataclass
class HnswVectorIndexMethod:
"""HNSW vector index parameters."""

kind: str = "Hnsw"
m: int | None = None
ef_construction: int | None = None


@dataclass
class IvfFlatVectorIndexMethod:
"""IVFFlat vector index parameters."""

kind: str = "IvfFlat"
lists: int | None = None


VectorIndexMethod = Union[HnswVectorIndexMethod, IvfFlatVectorIndexMethod]


@dataclass
class VectorIndexDef:
"""
Expand All @@ -17,6 +37,7 @@ class VectorIndexDef:

field_name: str
metric: VectorSimilarityMetric
method: VectorIndexMethod | None = None


@dataclass
Expand Down
59 changes: 58 additions & 1 deletion src/base/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,15 +384,72 @@ impl fmt::Display for VectorSimilarityMetric {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "kind")]
pub enum VectorIndexMethod {
Hnsw {
#[serde(default, skip_serializing_if = "Option::is_none")]
m: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
ef_construction: Option<u32>,
},
IvfFlat {
#[serde(default, skip_serializing_if = "Option::is_none")]
lists: Option<u32>,
},
}

impl VectorIndexMethod {
pub fn kind(&self) -> &'static str {
match self {
Self::Hnsw { .. } => "Hnsw",
Self::IvfFlat { .. } => "IvfFlat",
}
}
}

impl fmt::Display for VectorIndexMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Hnsw { m, ef_construction } => {
let mut parts = Vec::new();
if let Some(m) = m {
parts.push(format!("m={}", m));
}
if let Some(ef) = ef_construction {
parts.push(format!("ef_construction={}", ef));
}
if parts.is_empty() {
write!(f, "Hnsw")
} else {
write!(f, "Hnsw({})", parts.join(","))
}
}
Self::IvfFlat { lists } => {
if let Some(lists) = lists {
write!(f, "IvfFlat(lists={lists})")
} else {
write!(f, "IvfFlat")
}
}
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct VectorIndexDef {
pub field_name: FieldName,
pub metric: VectorSimilarityMetric,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub method: Option<VectorIndexMethod>,
}

impl fmt::Display for VectorIndexDef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.field_name, self.metric)
match &self.method {
None => write!(f, "{}:{}", self.field_name, self.metric),
Some(method) => write!(f, "{}:{}:{}", self.field_name, self.metric, method),
}
}
}

Expand Down
40 changes: 30 additions & 10 deletions src/llm/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,30 @@ impl AiStudioClient {
}
}

fn build_embed_payload(
model: &str,
text: &str,
task_type: Option<&str>,
output_dimension: Option<u32>,
) -> serde_json::Value {
let mut payload = serde_json::json!({
"model": model,
"content": { "parts": [{ "text": text }] },
});
if let Some(task_type) = task_type {
payload["taskType"] = serde_json::Value::String(task_type.to_string());
}
if let Some(output_dimension) = output_dimension {
payload["outputDimensionality"] = serde_json::json!(output_dimension);
if model.starts_with("gemini-embedding-") {
payload["config"] = serde_json::json!({
"outputDimensionality": output_dimension,
});
}
}
payload
}

#[async_trait]
impl LlmGenerationClient for AiStudioClient {
async fn generate<'req>(
Expand Down Expand Up @@ -174,16 +198,12 @@ impl LlmEmbeddingClient for AiStudioClient {
request: super::LlmEmbeddingRequest<'req>,
) -> Result<super::LlmEmbeddingResponse> {
let url = self.get_api_url(request.model, "embedContent");
let mut payload = serde_json::json!({
"model": request.model,
"content": { "parts": [{ "text": request.text }] },
});
if let Some(task_type) = request.task_type {
payload["taskType"] = serde_json::Value::String(task_type.into());
}
if let Some(output_dimension) = request.output_dimension {
payload["outputDimensionality"] = serde_json::Value::Number(output_dimension.into());
}
let payload = build_embed_payload(
request.model,
request.text.as_ref(),
request.task_type.as_deref(),
request.output_dimension,
);
let resp = retryable::run(
|| async {
self.client
Expand Down
38 changes: 34 additions & 4 deletions src/ops/targets/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,21 +461,51 @@ fn to_vector_similarity_metric_sql(metric: VectorSimilarityMetric) -> &'static s
}

fn to_index_spec_sql(index_spec: &VectorIndexDef) -> Cow<'static, str> {
let (method, options) = match index_spec.method.as_ref() {
Some(spec::VectorIndexMethod::Hnsw { m, ef_construction }) => {
let mut opts = Vec::new();
if let Some(m) = m {
opts.push(format!("m = {}", m));
}
if let Some(ef) = ef_construction {
opts.push(format!("ef_construction = {}", ef));
}
("hnsw", opts)
}
Some(spec::VectorIndexMethod::IvfFlat { lists }) => (
"ivfflat",
lists
.map(|lists| vec![format!("lists = {}", lists)])
.unwrap_or_default(),
),
None => ("hnsw", Vec::new()),
};
let with_clause = if options.is_empty() {
String::new()
} else {
format!(" WITH ({})", options.join(", "))
};
format!(
"USING hnsw ({} {})",
"USING {method} ({} {}){}",
index_spec.field_name,
to_vector_similarity_metric_sql(index_spec.metric)
to_vector_similarity_metric_sql(index_spec.metric),
with_clause
)
.into()
}

fn to_vector_index_name(table_name: &str, vector_index_def: &spec::VectorIndexDef) -> String {
format!(
let mut name = format!(
"{}__{}__{}",
table_name,
vector_index_def.field_name,
to_vector_similarity_metric_sql(vector_index_def.metric)
)
);
if let Some(method) = vector_index_def.method.as_ref() {
name.push_str("__");
name.push_str(&method.kind().to_ascii_lowercase());
}
name
}

fn describe_index_spec(index_name: &str, index_spec: &VectorIndexDef) -> String {
Expand Down
Loading