Skip to content

Commit c65d1e5

Browse files
authored
feat: fix Gemini embedding config and expose Postgres index tuning (#1050)
* fixing issue of gemini-embedding-001 wrt outputDimensionality * feat: allow configuring Postgres vector index method and options * fixing some rust formatting issues * addressed the review comments * removed unecessary defaulting functionality
1 parent 416962e commit c65d1e5

File tree

7 files changed

+164
-17
lines changed

7 files changed

+164
-17
lines changed

docs/docs/core/flow_def.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ Types of the fields must be key types. See [Key Types](data_types#key-types) for
313313

314314
* `field_name`: the field to create vector index.
315315
* `metric`: the similarity metric to use.
316+
* `method` (optional): the index algorithm and optional tuning parameters. Leave unset to use the target default (HNSW for Postgres). Use `cocoindex.HnswVectorIndexMethod()` or `cocoindex.IvfFlatVectorIndexMethod()` to customize the method and its parameters.
316317

317318
#### Similarity Metrics
318319

docs/docs/examples/examples/simple_vector_index.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@ doc_embeddings.export(
105105
CocoIndex supports other vector databases as well, with 1-line switch.
106106
<DocumentationButton url="https://cocoindex.io/docs/ops/targets" text="Targets" />
107107

108+
Need IVFFlat or custom HNSW parameters? Pass a method, for example:
109+
110+
```python
111+
cocoindex.VectorIndexDef(
112+
field_name="embedding",
113+
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
114+
method=cocoindex.IvfFlatVectorIndexMethod(lists=200),
115+
)
116+
```
117+
108118
## Query the index
109119

110120
### Define a shared flow for both indexing and querying

python/cocoindex/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
from .flow import update_all_flows_async, setup_all_flows, drop_all_flows
2222
from .lib import settings, init, start_server, stop
2323
from .llm import LlmSpec, LlmApiType
24-
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
24+
from .index import (
25+
VectorSimilarityMetric,
26+
VectorIndexDef,
27+
IndexOptions,
28+
HnswVectorIndexMethod,
29+
IvfFlatVectorIndexMethod,
30+
)
2531
from .setting import DatabaseConnectionSpec, Settings, ServerSettings
2632
from .setting import get_app_namespace
2733
from .query_handler import QueryHandlerResultFields, QueryInfo, QueryOutput
@@ -82,6 +88,8 @@
8288
"VectorSimilarityMetric",
8389
"VectorIndexDef",
8490
"IndexOptions",
91+
"HnswVectorIndexMethod",
92+
"IvfFlatVectorIndexMethod",
8593
# Settings
8694
"DatabaseConnectionSpec",
8795
"Settings",

python/cocoindex/index.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import Enum
22
from dataclasses import dataclass
3-
from typing import Sequence
3+
from typing import Sequence, Union
44

55

66
class VectorSimilarityMetric(Enum):
@@ -9,6 +9,26 @@ class VectorSimilarityMetric(Enum):
99
INNER_PRODUCT = "InnerProduct"
1010

1111

12+
@dataclass
13+
class HnswVectorIndexMethod:
14+
"""HNSW vector index parameters."""
15+
16+
kind: str = "Hnsw"
17+
m: int | None = None
18+
ef_construction: int | None = None
19+
20+
21+
@dataclass
22+
class IvfFlatVectorIndexMethod:
23+
"""IVFFlat vector index parameters."""
24+
25+
kind: str = "IvfFlat"
26+
lists: int | None = None
27+
28+
29+
VectorIndexMethod = Union[HnswVectorIndexMethod, IvfFlatVectorIndexMethod]
30+
31+
1232
@dataclass
1333
class VectorIndexDef:
1434
"""
@@ -17,6 +37,7 @@ class VectorIndexDef:
1737

1838
field_name: str
1939
metric: VectorSimilarityMetric
40+
method: VectorIndexMethod | None = None
2041

2142

2243
@dataclass

src/base/spec.rs

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,15 +384,72 @@ impl fmt::Display for VectorSimilarityMetric {
384384
}
385385
}
386386

387+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
388+
#[serde(tag = "kind")]
389+
pub enum VectorIndexMethod {
390+
Hnsw {
391+
#[serde(default, skip_serializing_if = "Option::is_none")]
392+
m: Option<u32>,
393+
#[serde(default, skip_serializing_if = "Option::is_none")]
394+
ef_construction: Option<u32>,
395+
},
396+
IvfFlat {
397+
#[serde(default, skip_serializing_if = "Option::is_none")]
398+
lists: Option<u32>,
399+
},
400+
}
401+
402+
impl VectorIndexMethod {
403+
pub fn kind(&self) -> &'static str {
404+
match self {
405+
Self::Hnsw { .. } => "Hnsw",
406+
Self::IvfFlat { .. } => "IvfFlat",
407+
}
408+
}
409+
}
410+
411+
impl fmt::Display for VectorIndexMethod {
412+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413+
match self {
414+
Self::Hnsw { m, ef_construction } => {
415+
let mut parts = Vec::new();
416+
if let Some(m) = m {
417+
parts.push(format!("m={}", m));
418+
}
419+
if let Some(ef) = ef_construction {
420+
parts.push(format!("ef_construction={}", ef));
421+
}
422+
if parts.is_empty() {
423+
write!(f, "Hnsw")
424+
} else {
425+
write!(f, "Hnsw({})", parts.join(","))
426+
}
427+
}
428+
Self::IvfFlat { lists } => {
429+
if let Some(lists) = lists {
430+
write!(f, "IvfFlat(lists={lists})")
431+
} else {
432+
write!(f, "IvfFlat")
433+
}
434+
}
435+
}
436+
}
437+
}
438+
387439
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
388440
pub struct VectorIndexDef {
389441
pub field_name: FieldName,
390442
pub metric: VectorSimilarityMetric,
443+
#[serde(default, skip_serializing_if = "Option::is_none")]
444+
pub method: Option<VectorIndexMethod>,
391445
}
392446

393447
impl fmt::Display for VectorIndexDef {
394448
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395-
write!(f, "{}:{}", self.field_name, self.metric)
449+
match &self.method {
450+
None => write!(f, "{}:{}", self.field_name, self.metric),
451+
Some(method) => write!(f, "{}:{}:{}", self.field_name, self.metric, method),
452+
}
396453
}
397454
}
398455

src/llm/gemini.rs

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,30 @@ impl AiStudioClient {
7474
}
7575
}
7676

77+
fn build_embed_payload(
78+
model: &str,
79+
text: &str,
80+
task_type: Option<&str>,
81+
output_dimension: Option<u32>,
82+
) -> serde_json::Value {
83+
let mut payload = serde_json::json!({
84+
"model": model,
85+
"content": { "parts": [{ "text": text }] },
86+
});
87+
if let Some(task_type) = task_type {
88+
payload["taskType"] = serde_json::Value::String(task_type.to_string());
89+
}
90+
if let Some(output_dimension) = output_dimension {
91+
payload["outputDimensionality"] = serde_json::json!(output_dimension);
92+
if model.starts_with("gemini-embedding-") {
93+
payload["config"] = serde_json::json!({
94+
"outputDimensionality": output_dimension,
95+
});
96+
}
97+
}
98+
payload
99+
}
100+
77101
#[async_trait]
78102
impl LlmGenerationClient for AiStudioClient {
79103
async fn generate<'req>(
@@ -174,16 +198,12 @@ impl LlmEmbeddingClient for AiStudioClient {
174198
request: super::LlmEmbeddingRequest<'req>,
175199
) -> Result<super::LlmEmbeddingResponse> {
176200
let url = self.get_api_url(request.model, "embedContent");
177-
let mut payload = serde_json::json!({
178-
"model": request.model,
179-
"content": { "parts": [{ "text": request.text }] },
180-
});
181-
if let Some(task_type) = request.task_type {
182-
payload["taskType"] = serde_json::Value::String(task_type.into());
183-
}
184-
if let Some(output_dimension) = request.output_dimension {
185-
payload["outputDimensionality"] = serde_json::Value::Number(output_dimension.into());
186-
}
201+
let payload = build_embed_payload(
202+
request.model,
203+
request.text.as_ref(),
204+
request.task_type.as_deref(),
205+
request.output_dimension,
206+
);
187207
let resp = retryable::run(
188208
|| async {
189209
self.client

src/ops/targets/postgres.rs

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,21 +461,51 @@ fn to_vector_similarity_metric_sql(metric: VectorSimilarityMetric) -> &'static s
461461
}
462462

463463
fn to_index_spec_sql(index_spec: &VectorIndexDef) -> Cow<'static, str> {
464+
let (method, options) = match index_spec.method.as_ref() {
465+
Some(spec::VectorIndexMethod::Hnsw { m, ef_construction }) => {
466+
let mut opts = Vec::new();
467+
if let Some(m) = m {
468+
opts.push(format!("m = {}", m));
469+
}
470+
if let Some(ef) = ef_construction {
471+
opts.push(format!("ef_construction = {}", ef));
472+
}
473+
("hnsw", opts)
474+
}
475+
Some(spec::VectorIndexMethod::IvfFlat { lists }) => (
476+
"ivfflat",
477+
lists
478+
.map(|lists| vec![format!("lists = {}", lists)])
479+
.unwrap_or_default(),
480+
),
481+
None => ("hnsw", Vec::new()),
482+
};
483+
let with_clause = if options.is_empty() {
484+
String::new()
485+
} else {
486+
format!(" WITH ({})", options.join(", "))
487+
};
464488
format!(
465-
"USING hnsw ({} {})",
489+
"USING {method} ({} {}){}",
466490
index_spec.field_name,
467-
to_vector_similarity_metric_sql(index_spec.metric)
491+
to_vector_similarity_metric_sql(index_spec.metric),
492+
with_clause
468493
)
469494
.into()
470495
}
471496

472497
fn to_vector_index_name(table_name: &str, vector_index_def: &spec::VectorIndexDef) -> String {
473-
format!(
498+
let mut name = format!(
474499
"{}__{}__{}",
475500
table_name,
476501
vector_index_def.field_name,
477502
to_vector_similarity_metric_sql(vector_index_def.metric)
478-
)
503+
);
504+
if let Some(method) = vector_index_def.method.as_ref() {
505+
name.push_str("__");
506+
name.push_str(&method.kind().to_ascii_lowercase());
507+
}
508+
name
479509
}
480510

481511
fn describe_index_spec(index_name: &str, index_spec: &VectorIndexDef) -> String {

0 commit comments

Comments
 (0)