diff --git a/docs/docs/core/flow_def.mdx b/docs/docs/core/flow_def.mdx index 01f63742..a2d63473 100644 --- a/docs/docs/core/flow_def.mdx +++ b/docs/docs/core/flow_def.mdx @@ -313,6 +313,7 @@ Types of the fields must be key types. See [Key Types](data_types#key-types) for * `field_name`: the field to create vector index. * `metric`: the similarity metric to use. + * `method` (optional): the index algorithm and optional tuning parameters. Leave unset to use the target default (HNSW for Postgres). Use `cocoindex.HnswVectorIndexMethod()` or `cocoindex.IvfFlatVectorIndexMethod()` to customize the method and its parameters. #### Similarity Metrics diff --git a/docs/docs/examples/examples/simple_vector_index.md b/docs/docs/examples/examples/simple_vector_index.md index dc55d2d2..98861f3d 100644 --- a/docs/docs/examples/examples/simple_vector_index.md +++ b/docs/docs/examples/examples/simple_vector_index.md @@ -105,6 +105,16 @@ doc_embeddings.export( CocoIndex supports other vector databases as well, with 1-line switch. +Need IVFFlat or custom HNSW parameters? Pass a method, for example: + +```python +cocoindex.VectorIndexDef( + field_name="embedding", + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, + method=cocoindex.IvfFlatVectorIndexMethod(lists=200), +) +``` + ## Query the index ### Define a shared flow for both indexing and querying diff --git a/python/cocoindex/__init__.py b/python/cocoindex/__init__.py index de16a157..0b27cc27 100644 --- a/python/cocoindex/__init__.py +++ b/python/cocoindex/__init__.py @@ -21,7 +21,13 @@ from .flow import update_all_flows_async, setup_all_flows, drop_all_flows from .lib import settings, init, start_server, stop from .llm import LlmSpec, LlmApiType -from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions +from .index import ( + VectorSimilarityMetric, + VectorIndexDef, + IndexOptions, + HnswVectorIndexMethod, + IvfFlatVectorIndexMethod, +) from .setting import DatabaseConnectionSpec, Settings, ServerSettings from .setting import get_app_namespace from .query_handler import QueryHandlerResultFields, QueryInfo, QueryOutput @@ -82,6 +88,8 @@ "VectorSimilarityMetric", "VectorIndexDef", "IndexOptions", + "HnswVectorIndexMethod", + "IvfFlatVectorIndexMethod", # Settings "DatabaseConnectionSpec", "Settings", diff --git a/python/cocoindex/index.py b/python/cocoindex/index.py index a5ff0626..6c5e11cb 100644 --- a/python/cocoindex/index.py +++ b/python/cocoindex/index.py @@ -1,6 +1,6 @@ from enum import Enum from dataclasses import dataclass -from typing import Sequence +from typing import Sequence, Union class VectorSimilarityMetric(Enum): @@ -9,6 +9,26 @@ class VectorSimilarityMetric(Enum): INNER_PRODUCT = "InnerProduct" +@dataclass +class HnswVectorIndexMethod: + """HNSW vector index parameters.""" + + kind: str = "Hnsw" + m: int | None = None + ef_construction: int | None = None + + +@dataclass +class IvfFlatVectorIndexMethod: + """IVFFlat vector index parameters.""" + + kind: str = "IvfFlat" + lists: int | None = None + + +VectorIndexMethod = Union[HnswVectorIndexMethod, IvfFlatVectorIndexMethod] + + @dataclass class VectorIndexDef: """ @@ -17,6 +37,7 @@ class VectorIndexDef: field_name: str metric: VectorSimilarityMetric + method: VectorIndexMethod | None = None @dataclass diff --git a/src/base/spec.rs b/src/base/spec.rs index 3c01ae2e..671e96dc 100644 --- a/src/base/spec.rs +++ b/src/base/spec.rs @@ -384,15 +384,72 @@ impl fmt::Display for VectorSimilarityMetric { } } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "kind")] +pub enum VectorIndexMethod { + Hnsw { + #[serde(default, skip_serializing_if = "Option::is_none")] + m: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + ef_construction: Option, + }, + IvfFlat { + #[serde(default, skip_serializing_if = "Option::is_none")] + lists: Option, + }, +} + +impl VectorIndexMethod { + pub fn kind(&self) -> &'static str { + match self { + Self::Hnsw { .. } => "Hnsw", + Self::IvfFlat { .. } => "IvfFlat", + } + } +} + +impl fmt::Display for VectorIndexMethod { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Hnsw { m, ef_construction } => { + let mut parts = Vec::new(); + if let Some(m) = m { + parts.push(format!("m={}", m)); + } + if let Some(ef) = ef_construction { + parts.push(format!("ef_construction={}", ef)); + } + if parts.is_empty() { + write!(f, "Hnsw") + } else { + write!(f, "Hnsw({})", parts.join(",")) + } + } + Self::IvfFlat { lists } => { + if let Some(lists) = lists { + write!(f, "IvfFlat(lists={lists})") + } else { + write!(f, "IvfFlat") + } + } + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct VectorIndexDef { pub field_name: FieldName, pub metric: VectorSimilarityMetric, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub method: Option, } impl fmt::Display for VectorIndexDef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}:{}", self.field_name, self.metric) + match &self.method { + None => write!(f, "{}:{}", self.field_name, self.metric), + Some(method) => write!(f, "{}:{}:{}", self.field_name, self.metric, method), + } } } diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index d246ad86..bb2aeb0f 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -74,6 +74,30 @@ impl AiStudioClient { } } +fn build_embed_payload( + model: &str, + text: &str, + task_type: Option<&str>, + output_dimension: Option, +) -> serde_json::Value { + let mut payload = serde_json::json!({ + "model": model, + "content": { "parts": [{ "text": text }] }, + }); + if let Some(task_type) = task_type { + payload["taskType"] = serde_json::Value::String(task_type.to_string()); + } + if let Some(output_dimension) = output_dimension { + payload["outputDimensionality"] = serde_json::json!(output_dimension); + if model.starts_with("gemini-embedding-") { + payload["config"] = serde_json::json!({ + "outputDimensionality": output_dimension, + }); + } + } + payload +} + #[async_trait] impl LlmGenerationClient for AiStudioClient { async fn generate<'req>( @@ -174,16 +198,12 @@ impl LlmEmbeddingClient for AiStudioClient { request: super::LlmEmbeddingRequest<'req>, ) -> Result { let url = self.get_api_url(request.model, "embedContent"); - let mut payload = serde_json::json!({ - "model": request.model, - "content": { "parts": [{ "text": request.text }] }, - }); - if let Some(task_type) = request.task_type { - payload["taskType"] = serde_json::Value::String(task_type.into()); - } - if let Some(output_dimension) = request.output_dimension { - payload["outputDimensionality"] = serde_json::Value::Number(output_dimension.into()); - } + let payload = build_embed_payload( + request.model, + request.text.as_ref(), + request.task_type.as_deref(), + request.output_dimension, + ); let resp = retryable::run( || async { self.client diff --git a/src/ops/targets/postgres.rs b/src/ops/targets/postgres.rs index 6a583b9f..100812e3 100644 --- a/src/ops/targets/postgres.rs +++ b/src/ops/targets/postgres.rs @@ -461,21 +461,51 @@ fn to_vector_similarity_metric_sql(metric: VectorSimilarityMetric) -> &'static s } fn to_index_spec_sql(index_spec: &VectorIndexDef) -> Cow<'static, str> { + let (method, options) = match index_spec.method.as_ref() { + Some(spec::VectorIndexMethod::Hnsw { m, ef_construction }) => { + let mut opts = Vec::new(); + if let Some(m) = m { + opts.push(format!("m = {}", m)); + } + if let Some(ef) = ef_construction { + opts.push(format!("ef_construction = {}", ef)); + } + ("hnsw", opts) + } + Some(spec::VectorIndexMethod::IvfFlat { lists }) => ( + "ivfflat", + lists + .map(|lists| vec![format!("lists = {}", lists)]) + .unwrap_or_default(), + ), + None => ("hnsw", Vec::new()), + }; + let with_clause = if options.is_empty() { + String::new() + } else { + format!(" WITH ({})", options.join(", ")) + }; format!( - "USING hnsw ({} {})", + "USING {method} ({} {}){}", index_spec.field_name, - to_vector_similarity_metric_sql(index_spec.metric) + to_vector_similarity_metric_sql(index_spec.metric), + with_clause ) .into() } fn to_vector_index_name(table_name: &str, vector_index_def: &spec::VectorIndexDef) -> String { - format!( + let mut name = format!( "{}__{}__{}", table_name, vector_index_def.field_name, to_vector_similarity_metric_sql(vector_index_def.metric) - ) + ); + if let Some(method) = vector_index_def.method.as_ref() { + name.push_str("__"); + name.push_str(&method.kind().to_ascii_lowercase()); + } + name } fn describe_index_spec(index_name: &str, index_spec: &VectorIndexDef) -> String {