diff --git a/docs/docs/core/flow_def.mdx b/docs/docs/core/flow_def.mdx
index 01f63742..a2d63473 100644
--- a/docs/docs/core/flow_def.mdx
+++ b/docs/docs/core/flow_def.mdx
@@ -313,6 +313,7 @@ Types of the fields must be key types. See [Key Types](data_types#key-types) for
* `field_name`: the field to create vector index.
* `metric`: the similarity metric to use.
+ * `method` (optional): the index algorithm and optional tuning parameters. Leave unset to use the target default (HNSW for Postgres). Use `cocoindex.HnswVectorIndexMethod()` or `cocoindex.IvfFlatVectorIndexMethod()` to customize the method and its parameters.
#### Similarity Metrics
diff --git a/docs/docs/examples/examples/simple_vector_index.md b/docs/docs/examples/examples/simple_vector_index.md
index dc55d2d2..98861f3d 100644
--- a/docs/docs/examples/examples/simple_vector_index.md
+++ b/docs/docs/examples/examples/simple_vector_index.md
@@ -105,6 +105,16 @@ doc_embeddings.export(
CocoIndex supports other vector databases as well, with 1-line switch.
+Need IVFFlat or custom HNSW parameters? Pass a method, for example:
+
+```python
+cocoindex.VectorIndexDef(
+ field_name="embedding",
+ metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
+ method=cocoindex.IvfFlatVectorIndexMethod(lists=200),
+)
+```
+
## Query the index
### Define a shared flow for both indexing and querying
diff --git a/python/cocoindex/__init__.py b/python/cocoindex/__init__.py
index de16a157..0b27cc27 100644
--- a/python/cocoindex/__init__.py
+++ b/python/cocoindex/__init__.py
@@ -21,7 +21,13 @@
from .flow import update_all_flows_async, setup_all_flows, drop_all_flows
from .lib import settings, init, start_server, stop
from .llm import LlmSpec, LlmApiType
-from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
+from .index import (
+ VectorSimilarityMetric,
+ VectorIndexDef,
+ IndexOptions,
+ HnswVectorIndexMethod,
+ IvfFlatVectorIndexMethod,
+)
from .setting import DatabaseConnectionSpec, Settings, ServerSettings
from .setting import get_app_namespace
from .query_handler import QueryHandlerResultFields, QueryInfo, QueryOutput
@@ -82,6 +88,8 @@
"VectorSimilarityMetric",
"VectorIndexDef",
"IndexOptions",
+ "HnswVectorIndexMethod",
+ "IvfFlatVectorIndexMethod",
# Settings
"DatabaseConnectionSpec",
"Settings",
diff --git a/python/cocoindex/index.py b/python/cocoindex/index.py
index a5ff0626..6c5e11cb 100644
--- a/python/cocoindex/index.py
+++ b/python/cocoindex/index.py
@@ -1,6 +1,6 @@
from enum import Enum
from dataclasses import dataclass
-from typing import Sequence
+from typing import Sequence, Union
class VectorSimilarityMetric(Enum):
@@ -9,6 +9,26 @@ class VectorSimilarityMetric(Enum):
INNER_PRODUCT = "InnerProduct"
+@dataclass
+class HnswVectorIndexMethod:
+ """HNSW vector index parameters."""
+
+ kind: str = "Hnsw"
+ m: int | None = None
+ ef_construction: int | None = None
+
+
+@dataclass
+class IvfFlatVectorIndexMethod:
+ """IVFFlat vector index parameters."""
+
+ kind: str = "IvfFlat"
+ lists: int | None = None
+
+
+VectorIndexMethod = Union[HnswVectorIndexMethod, IvfFlatVectorIndexMethod]
+
+
@dataclass
class VectorIndexDef:
"""
@@ -17,6 +37,7 @@ class VectorIndexDef:
field_name: str
metric: VectorSimilarityMetric
+ method: VectorIndexMethod | None = None
@dataclass
diff --git a/src/base/spec.rs b/src/base/spec.rs
index 3c01ae2e..671e96dc 100644
--- a/src/base/spec.rs
+++ b/src/base/spec.rs
@@ -384,15 +384,72 @@ impl fmt::Display for VectorSimilarityMetric {
}
}
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
+#[serde(tag = "kind")]
+pub enum VectorIndexMethod {
+ Hnsw {
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ m: Option,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ ef_construction: Option,
+ },
+ IvfFlat {
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ lists: Option,
+ },
+}
+
+impl VectorIndexMethod {
+ pub fn kind(&self) -> &'static str {
+ match self {
+ Self::Hnsw { .. } => "Hnsw",
+ Self::IvfFlat { .. } => "IvfFlat",
+ }
+ }
+}
+
+impl fmt::Display for VectorIndexMethod {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::Hnsw { m, ef_construction } => {
+ let mut parts = Vec::new();
+ if let Some(m) = m {
+ parts.push(format!("m={}", m));
+ }
+ if let Some(ef) = ef_construction {
+ parts.push(format!("ef_construction={}", ef));
+ }
+ if parts.is_empty() {
+ write!(f, "Hnsw")
+ } else {
+ write!(f, "Hnsw({})", parts.join(","))
+ }
+ }
+ Self::IvfFlat { lists } => {
+ if let Some(lists) = lists {
+ write!(f, "IvfFlat(lists={lists})")
+ } else {
+ write!(f, "IvfFlat")
+ }
+ }
+ }
+ }
+}
+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct VectorIndexDef {
pub field_name: FieldName,
pub metric: VectorSimilarityMetric,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub method: Option,
}
impl fmt::Display for VectorIndexDef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "{}:{}", self.field_name, self.metric)
+ match &self.method {
+ None => write!(f, "{}:{}", self.field_name, self.metric),
+ Some(method) => write!(f, "{}:{}:{}", self.field_name, self.metric, method),
+ }
}
}
diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs
index d246ad86..bb2aeb0f 100644
--- a/src/llm/gemini.rs
+++ b/src/llm/gemini.rs
@@ -74,6 +74,30 @@ impl AiStudioClient {
}
}
+fn build_embed_payload(
+ model: &str,
+ text: &str,
+ task_type: Option<&str>,
+ output_dimension: Option,
+) -> serde_json::Value {
+ let mut payload = serde_json::json!({
+ "model": model,
+ "content": { "parts": [{ "text": text }] },
+ });
+ if let Some(task_type) = task_type {
+ payload["taskType"] = serde_json::Value::String(task_type.to_string());
+ }
+ if let Some(output_dimension) = output_dimension {
+ payload["outputDimensionality"] = serde_json::json!(output_dimension);
+ if model.starts_with("gemini-embedding-") {
+ payload["config"] = serde_json::json!({
+ "outputDimensionality": output_dimension,
+ });
+ }
+ }
+ payload
+}
+
#[async_trait]
impl LlmGenerationClient for AiStudioClient {
async fn generate<'req>(
@@ -174,16 +198,12 @@ impl LlmEmbeddingClient for AiStudioClient {
request: super::LlmEmbeddingRequest<'req>,
) -> Result {
let url = self.get_api_url(request.model, "embedContent");
- let mut payload = serde_json::json!({
- "model": request.model,
- "content": { "parts": [{ "text": request.text }] },
- });
- if let Some(task_type) = request.task_type {
- payload["taskType"] = serde_json::Value::String(task_type.into());
- }
- if let Some(output_dimension) = request.output_dimension {
- payload["outputDimensionality"] = serde_json::Value::Number(output_dimension.into());
- }
+ let payload = build_embed_payload(
+ request.model,
+ request.text.as_ref(),
+ request.task_type.as_deref(),
+ request.output_dimension,
+ );
let resp = retryable::run(
|| async {
self.client
diff --git a/src/ops/targets/postgres.rs b/src/ops/targets/postgres.rs
index 6a583b9f..100812e3 100644
--- a/src/ops/targets/postgres.rs
+++ b/src/ops/targets/postgres.rs
@@ -461,21 +461,51 @@ fn to_vector_similarity_metric_sql(metric: VectorSimilarityMetric) -> &'static s
}
fn to_index_spec_sql(index_spec: &VectorIndexDef) -> Cow<'static, str> {
+ let (method, options) = match index_spec.method.as_ref() {
+ Some(spec::VectorIndexMethod::Hnsw { m, ef_construction }) => {
+ let mut opts = Vec::new();
+ if let Some(m) = m {
+ opts.push(format!("m = {}", m));
+ }
+ if let Some(ef) = ef_construction {
+ opts.push(format!("ef_construction = {}", ef));
+ }
+ ("hnsw", opts)
+ }
+ Some(spec::VectorIndexMethod::IvfFlat { lists }) => (
+ "ivfflat",
+ lists
+ .map(|lists| vec![format!("lists = {}", lists)])
+ .unwrap_or_default(),
+ ),
+ None => ("hnsw", Vec::new()),
+ };
+ let with_clause = if options.is_empty() {
+ String::new()
+ } else {
+ format!(" WITH ({})", options.join(", "))
+ };
format!(
- "USING hnsw ({} {})",
+ "USING {method} ({} {}){}",
index_spec.field_name,
- to_vector_similarity_metric_sql(index_spec.metric)
+ to_vector_similarity_metric_sql(index_spec.metric),
+ with_clause
)
.into()
}
fn to_vector_index_name(table_name: &str, vector_index_def: &spec::VectorIndexDef) -> String {
- format!(
+ let mut name = format!(
"{}__{}__{}",
table_name,
vector_index_def.field_name,
to_vector_similarity_metric_sql(vector_index_def.metric)
- )
+ );
+ if let Some(method) = vector_index_def.method.as_ref() {
+ name.push_str("__");
+ name.push_str(&method.kind().to_ascii_lowercase());
+ }
+ name
}
fn describe_index_spec(index_name: &str, index_spec: &VectorIndexDef) -> String {