diff --git a/src/ops/targets/kuzu.rs b/src/ops/targets/kuzu.rs index d6b0bbc1..41d0848f 100644 --- a/src/ops/targets/kuzu.rs +++ b/src/ops/targets/kuzu.rs @@ -165,6 +165,16 @@ struct SetupState { #[serde(default, skip_serializing_if = "Option::is_none")] referenced_node_tables: Option<(ReferencedNodeTable, ReferencedNodeTable)>, + + #[serde(default, skip_serializing_if = "Vec::is_empty")] + vector_indexes: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +struct VectorIndexState { + field_name: String, + metric: spec::VectorSimilarityMetric, + method: Option, } impl<'a> From<&'a SetupState> for Cow<'a, TableColumnsSchema> { @@ -178,6 +188,8 @@ struct GraphElementDataSetupChange { actions: TableMainSetupAction, referenced_node_tables: Option<(String, String)>, drop_affected_referenced_node_tables: IndexSet, + vector_indexes_to_create: Vec, + vector_indexes_to_drop: Vec, // field names } impl setup::ResourceSetupChange for GraphElementDataSetupChange { @@ -190,6 +202,84 @@ impl setup::ResourceSetupChange for GraphElementDataSetupChange { } } +//////////////////////////////////////////////////////////// +// Vector Index Support Functions +//////////////////////////////////////////////////////////// + +fn validate_vector_index_method(method: &Option) -> Result<()> { + if let Some(method) = method { + match method { + spec::VectorIndexMethod::IvfFlat { .. } => { + api_bail!( + "IVFFlat vector index method is not supported by Kuzu. Only HNSW is supported." + ) + } + spec::VectorIndexMethod::Hnsw { .. } => Ok(()), + } + } else { + Ok(()) + } +} + +fn append_create_vector_index( + cypher: &mut CypherBuilder, + table_name: &str, + index_def: &spec::VectorIndexDef, +) -> Result<()> { + let index_name = format!("{}_{}_vector_idx", table_name, index_def.field_name); + + write!( + cypher.query_mut(), + "CALL CREATE_VECTOR_INDEX('{}', '{}', '{}'", + table_name, + index_name, + index_def.field_name + )?; + + let mut params = Vec::new(); + + // Map parameters from cocoindex to Kuzu + if let Some(spec::VectorIndexMethod::Hnsw { m, ef_construction }) = &index_def.method { + if let Some(m_val) = m { + params.push(format!("mu := {}", m_val)); + params.push(format!("ml := {}", m_val * 2)); + } + if let Some(ef_val) = ef_construction { + params.push(format!("efc := {}", ef_val)); + } + } + + // Map metric + let metric = match index_def.metric { + spec::VectorSimilarityMetric::CosineSimilarity => "cosine", + spec::VectorSimilarityMetric::L2Distance => "l2", + spec::VectorSimilarityMetric::InnerProduct => "dotproduct", + }; + params.push(format!("metric := '{}'", metric)); + + if !params.is_empty() { + write!(cypher.query_mut(), ", {}", params.join(", "))?; + } + + writeln!(cypher.query_mut(), ");")?; + Ok(()) +} + +fn append_drop_vector_index( + cypher: &mut CypherBuilder, + table_name: &str, + field_name: &str, +) -> Result<()> { + let index_name = format!("{}_{}_vector_idx", table_name, field_name); + writeln!( + cypher.query_mut(), + "CALL DROP_VECTOR_INDEX('{}', '{}');", + table_name, + index_name + )?; + Ok(()) +} + fn append_drop_table( cypher: &mut CypherBuilder, setup_change: &GraphElementDataSetupChange, @@ -772,8 +862,9 @@ impl TargetFactoryBase for Factory { let data_coll_outputs: Vec> = std::iter::zip(data_collections, analyzed_data_colls.into_iter()) .map(|(data_coll, analyzed)| { - if !data_coll.index_options.vector_indexes.is_empty() { - api_bail!("Vector indexes are not supported for Kuzu yet"); + // Validate vector index methods + for vector_index in &data_coll.index_options.vector_indexes { + validate_vector_index_method(&vector_index.method)?; } fn to_dep_table( field_mapping: &AnalyzedGraphElementFieldMapping, @@ -797,6 +888,16 @@ impl TargetFactoryBase for Factory { anyhow::Ok((to_dep_table(&rel.source)?, to_dep_table(&rel.target)?)) }) .transpose()?, + vector_indexes: data_coll + .index_options + .vector_indexes + .iter() + .map(|vi| VectorIndexState { + field_name: vi.field_name.clone(), + metric: vi.metric, + method: vi.method.clone(), + }) + .collect(), }; let export_context = ExportContext { @@ -824,6 +925,7 @@ impl TargetFactoryBase for Factory { value_columns: to_kuzu_cols(&graph_elem_schema.value_fields)?, }, referenced_node_tables: None, + vector_indexes: Vec::new(), }; let setup_key = GraphElementType { connection: decl.connection, @@ -847,8 +949,10 @@ impl TargetFactoryBase for Factory { .possible_versions() .any(|v| v.referenced_node_tables != desired.referenced_node_tables) }); + let actions = TableMainSetupAction::from_states(desired.as_ref(), &existing, existing_invalidated); + let drop_affected_referenced_node_tables = if actions.drop_existing { existing .possible_versions() @@ -858,12 +962,75 @@ impl TargetFactoryBase for Factory { } else { IndexSet::new() }; + + // Compute vector index changes + let (vector_indexes_to_create, vector_indexes_to_drop) = match &desired { + Some(desired_state) => { + let existing_indexes: Vec<&VectorIndexState> = existing + .possible_versions() + .flat_map(|v| &v.vector_indexes) + .collect(); + + let existing_index_map: std::collections::HashMap<&str, &VectorIndexState> = + existing_indexes + .iter() + .map(|vi| (vi.field_name.as_str(), *vi)) + .collect(); + + let mut to_create = Vec::new(); + let mut to_drop = Vec::new(); + + for desired_vi in &desired_state.vector_indexes { + if let Some(existing_vi) = + existing_index_map.get(desired_vi.field_name.as_str()) + { + if existing_vi.metric != desired_vi.metric + || existing_vi.method != desired_vi.method + { + to_drop.push(desired_vi.field_name.clone()); + } else { + continue; + } + } + to_create.push(spec::VectorIndexDef { + field_name: desired_vi.field_name.clone(), + metric: desired_vi.metric, + method: desired_vi.method.clone(), + }); + } + + let desired_fields: std::collections::HashSet<&str> = desired_state + .vector_indexes + .iter() + .map(|vi| vi.field_name.as_str()) + .collect(); + + for existing_vi in &existing_indexes { + if !desired_fields.contains(existing_vi.field_name.as_str()) { + to_drop.push(existing_vi.field_name.clone()); + } + } + + (to_create, to_drop) + } + None => { + let to_drop = existing + .possible_versions() + .flat_map(|v| &v.vector_indexes) + .map(|vi| vi.field_name.clone()) + .collect(); + (Vec::new(), to_drop) + } + }; + Ok(GraphElementDataSetupChange { actions, referenced_node_tables: desired .and_then(|desired| desired.referenced_node_tables) .map(|(src, tgt)| (src.table_name, tgt.table_name)), drop_affected_referenced_node_tables, + vector_indexes_to_create, + vector_indexes_to_drop, }) } @@ -1080,6 +1247,33 @@ impl TargetFactoryBase for Factory { append_delete_orphaned_nodes(&mut cypher, table)?; } + // Install vector extension if needed + let has_vector_changes = node_changes.iter().any(|c| { + !c.setup_change.vector_indexes_to_create.is_empty() + || !c.setup_change.vector_indexes_to_drop.is_empty() + }); + + if has_vector_changes { + writeln!(cypher.query_mut(), "INSTALL vector;")?; + writeln!(cypher.query_mut(), "LOAD vector;")?; + } + + // Drop vector indexes first + for change in node_changes.iter() { + let table_name = change.key.typ.label(); + for field_name in &change.setup_change.vector_indexes_to_drop { + append_drop_vector_index(&mut cypher, table_name, field_name)?; + } + } + + // Create vector indexes + for change in node_changes.iter() { + let table_name = change.key.typ.label(); + for index_def in &change.setup_change.vector_indexes_to_create { + append_create_vector_index(&mut cypher, table_name, index_def)?; + } + } + kuzu_client.run_cypher(cypher).await?; } Ok(()) @@ -1092,3 +1286,190 @@ pub fn register( ) -> Result<()> { Factory { reqwest_client }.register(registry) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::base::spec::{VectorIndexDef, VectorIndexMethod, VectorSimilarityMetric}; + + #[test] + fn test_validate_vector_index_method_accepts_hnsw() { + // Test HNSW with parameters + let hnsw_with_params = Some(VectorIndexMethod::Hnsw { + m: Some(16), + ef_construction: Some(200), + }); + assert!(validate_vector_index_method(&hnsw_with_params).is_ok()); + + // Test HNSW without parameters + let hnsw_no_params = Some(VectorIndexMethod::Hnsw { + m: None, + ef_construction: None, + }); + assert!(validate_vector_index_method(&hnsw_no_params).is_ok()); + + // Test None (default) + assert!(validate_vector_index_method(&None).is_ok()); + } + + #[test] + fn test_validate_vector_index_method_rejects_ivfflat() { + let ivfflat = Some(VectorIndexMethod::IvfFlat { lists: Some(100) }); + let result = validate_vector_index_method(&ivfflat); + + assert!(result.is_err()); + let error_msg = format!("{}", result.unwrap_err()); + assert!(error_msg.contains("IVFFlat vector index method is not supported by Kuzu")); + assert!(error_msg.contains("Only HNSW is supported")); + } + + #[test] + fn test_append_create_vector_index_basic() { + let mut cypher = CypherBuilder::new(); + let index_def = VectorIndexDef { + field_name: "embedding".to_string(), + metric: VectorSimilarityMetric::CosineSimilarity, + method: None, + }; + + let result = append_create_vector_index(&mut cypher, "documents", &index_def); + assert!(result.is_ok()); + + let query = cypher.query; + assert!(query.contains( + "CALL CREATE_VECTOR_INDEX('documents', 'documents_embedding_vector_idx', 'embedding'" + )); + assert!(query.contains("metric := 'cosine'")); + assert!(query.ends_with(");\n")); + } + + #[test] + fn test_append_create_vector_index_with_hnsw_params() { + let mut cypher = CypherBuilder::new(); + let index_def = VectorIndexDef { + field_name: "embedding".to_string(), + metric: VectorSimilarityMetric::L2Distance, + method: Some(VectorIndexMethod::Hnsw { + m: Some(16), + ef_construction: Some(200), + }), + }; + + let result = append_create_vector_index(&mut cypher, "documents", &index_def); + assert!(result.is_ok()); + + let query = cypher.query; + assert!(query.contains("mu := 16")); + assert!(query.contains("ml := 32")); // ml = 2 * m + assert!(query.contains("efc := 200")); + assert!(query.contains("metric := 'l2'")); + } + + #[test] + fn test_append_create_vector_index_metric_mapping() { + let test_cases = vec![ + (VectorSimilarityMetric::CosineSimilarity, "cosine"), + (VectorSimilarityMetric::L2Distance, "l2"), + (VectorSimilarityMetric::InnerProduct, "dotproduct"), + ]; + + for (metric, expected_kuzu_metric) in test_cases { + let mut cypher = CypherBuilder::new(); + let index_def = VectorIndexDef { + field_name: "embedding".to_string(), + metric, + method: None, + }; + + append_create_vector_index(&mut cypher, "test_table", &index_def).unwrap(); + assert!( + cypher + .query + .contains(&format!("metric := '{}'", expected_kuzu_metric)) + ); + } + } + + #[test] + fn test_append_drop_vector_index() { + let mut cypher = CypherBuilder::new(); + + let result = append_drop_vector_index(&mut cypher, "documents", "embedding"); + assert!(result.is_ok()); + + let query = cypher.query; + assert_eq!( + query, + "CALL DROP_VECTOR_INDEX('documents', 'documents_embedding_vector_idx');\n" + ); + } + + #[test] + fn test_parameter_mapping_edge_cases() { + // Test with only m parameter + let mut cypher = CypherBuilder::new(); + let index_def = VectorIndexDef { + field_name: "embedding".to_string(), + metric: VectorSimilarityMetric::CosineSimilarity, + method: Some(VectorIndexMethod::Hnsw { + m: Some(8), + ef_construction: None, + }), + }; + + append_create_vector_index(&mut cypher, "test", &index_def).unwrap(); + assert!(cypher.query.contains("mu := 8")); + assert!(cypher.query.contains("ml := 16")); + assert!(!cypher.query.contains("efc :=")); + + // Test with only ef_construction parameter + let mut cypher = CypherBuilder::new(); + let index_def = VectorIndexDef { + field_name: "embedding".to_string(), + metric: VectorSimilarityMetric::CosineSimilarity, + method: Some(VectorIndexMethod::Hnsw { + m: None, + ef_construction: Some(100), + }), + }; + + append_create_vector_index(&mut cypher, "test", &index_def).unwrap(); + assert!(!cypher.query.contains("mu :=")); + assert!(!cypher.query.contains("ml :=")); + assert!(cypher.query.contains("efc := 100")); + } + + #[test] + fn test_index_naming_consistency() { + let test_cases = vec![ + ( + "users", + "profile_embedding", + "users_profile_embedding_vector_idx", + ), + ( + "documents", + "content_vector", + "documents_content_vector_vector_idx", + ), + ("items", "embedding", "items_embedding_vector_idx"), + ]; + + for (table, field, expected_name) in test_cases { + let mut cypher = CypherBuilder::new(); + let index_def = VectorIndexDef { + field_name: field.to_string(), + metric: VectorSimilarityMetric::CosineSimilarity, + method: None, + }; + + append_create_vector_index(&mut cypher, table, &index_def).unwrap(); + assert!(cypher.query.contains(&format!("'{}'", expected_name))); + + // Test drop as well + let mut drop_cypher = CypherBuilder::new(); + append_drop_vector_index(&mut drop_cypher, table, field).unwrap(); + assert!(drop_cypher.query.contains(&format!("'{}'", expected_name))); + } + } +}