diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java index 0b8a938f430..8ef20793167 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java @@ -38,6 +38,7 @@ import io.milvus.param.R; import io.milvus.param.R.Status; import io.milvus.param.RpcStatus; +import io.milvus.param.collection.CollectionSchemaParam; import io.milvus.param.collection.CreateCollectionParam; import io.milvus.param.collection.DropCollectionParam; import io.milvus.param.collection.FieldType; @@ -443,6 +444,8 @@ void createCollection() { if (!isDatabaseCollectionExists()) { createCollection(this.databaseName, this.collectionName, this.idFieldName, this.isAutoId, this.contentFieldName, this.metadataFieldName, this.embeddingFieldName); + createIndex(this.databaseName, this.collectionName, this.embeddingFieldName, this.indexType, + this.metricType, this.indexParameters); } R indexDescriptionResponse = this.milvusClient @@ -452,19 +455,8 @@ void createCollection() { .build()); if (indexDescriptionResponse.getData() == null) { - R indexStatus = this.milvusClient.createIndex(CreateIndexParam.newBuilder() - .withDatabaseName(this.databaseName) - .withCollectionName(this.collectionName) - .withFieldName(this.embeddingFieldName) - .withIndexType(this.indexType) - .withMetricType(this.metricType) - .withExtraParam(this.indexParameters) - .withSyncMode(Boolean.FALSE) - .build()); - - if (indexStatus.getException() != null) { - throw new RuntimeException("Failed to create Index", indexStatus.getException()); - } + createIndex(this.databaseName, this.collectionName, this.embeddingFieldName, this.indexType, + this.metricType, this.indexParameters); } R loadCollectionStatus = this.milvusClient.loadCollection(LoadCollectionParam.newBuilder() @@ -507,10 +499,12 @@ void createCollection(String databaseName, String collectionName, String idField .withDescription("Spring AI Vector Store") .withConsistencyLevel(ConsistencyLevelEnum.STRONG) .withShardsNum(2) - .addFieldType(docIdFieldType) - .addFieldType(contentFieldType) - .addFieldType(metadataFieldType) - .addFieldType(embeddingFieldType) + .withSchema(CollectionSchemaParam.newBuilder() + .addFieldType(docIdFieldType) + .addFieldType(contentFieldType) + .addFieldType(metadataFieldType) + .addFieldType(embeddingFieldType) + .build()) .build(); R collectionStatus = this.milvusClient.createCollection(createCollectionReq); @@ -520,6 +514,23 @@ void createCollection(String databaseName, String collectionName, String idField } + void createIndex(String databaseName, String collectionName, String embeddingFieldName, IndexType indexType, + MetricType metricType, String indexParameters) { + R indexStatus = this.milvusClient.createIndex(CreateIndexParam.newBuilder() + .withDatabaseName(databaseName) + .withCollectionName(collectionName) + .withFieldName(embeddingFieldName) + .withIndexType(indexType) + .withMetricType(metricType) + .withExtraParam(indexParameters) + .withSyncMode(Boolean.FALSE) + .build()); + + if (indexStatus.getException() != null) { + throw new RuntimeException("Failed to create Index", indexStatus.getException()); + } + } + int embeddingDimensions() { if (this.embeddingDimension != INVALID_EMBEDDING_DIMENSION) { return this.embeddingDimension; diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreIT.java index 50c5a64c4dc..360b35c7855 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreIT.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -26,6 +27,10 @@ import java.util.function.Consumer; import java.util.stream.Collectors; +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.AppenderBase; +import io.milvus.client.AbstractMilvusGrpcClient; import io.milvus.client.MilvusServiceClient; import io.milvus.param.ConnectParam; import io.milvus.param.IndexType; @@ -34,6 +39,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.LoggerFactory; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.milvus.MilvusContainer; @@ -323,6 +329,37 @@ public void deleteWithComplexFilterExpression() { }); } + @Test + void initializeSchema() { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=COSINE").run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + + Logger logger = (Logger) LoggerFactory.getLogger(AbstractMilvusGrpcClient.class); + LogAppender logAppender = new LogAppender(); + logger.addAppender(logAppender); + logAppender.start(); + + resetCollection(vectorStore); + + assertThat(logAppender.capturedLogs).isEmpty(); + }); + } + + static class LogAppender extends AppenderBase { + + private final List capturedLogs = new ArrayList<>(); + + @Override + protected void append(ILoggingEvent eventObject) { + capturedLogs.add(eventObject.getFormattedMessage()); + } + + public List getCapturedLogs() { + return capturedLogs; + } + + } + @Test void getNativeClientTest() { this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=COSINE").run(context -> {