diff --git a/elasticsearch/helpers/vectorstore/_async/vectorstore.py b/elasticsearch/helpers/vectorstore/_async/vectorstore.py index 81356cf92..3b8c1e9e9 100644 --- a/elasticsearch/helpers/vectorstore/_async/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_async/vectorstore.py @@ -232,7 +232,7 @@ async def delete( # type: ignore[no-untyped-def] async def search( self, *, - query: Optional[str], + query: Optional[str] = None, query_vector: Optional[List[float]] = None, k: int = 4, num_candidates: int = 50, @@ -344,8 +344,9 @@ async def _create_index_if_not_exists(self) -> None: async def max_marginal_relevance_search( self, *, - embedding_service: AsyncEmbeddingService, - query: str, + query: Optional[str] = None, + query_embedding: Optional[List[float]] = None, + embedding_service: Optional[AsyncEmbeddingService] = None, vector_field: str, k: int = 4, num_candidates: int = 20, @@ -361,6 +362,8 @@ async def max_marginal_relevance_search( among selected documents. :param query (str): Text to look up documents similar to. + :param query_embedding: Input embedding vector. If given, input query string is + ignored. :param k (int): Number of Documents to return. Defaults to 4. :param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. :param lambda_mult (float): Number between 0 and 1 that determines the degree @@ -381,12 +384,22 @@ async def max_marginal_relevance_search( remove_vector_query_field_from_metadata = False # Embed the query - query_embedding = await embedding_service.embed_query(query) + if query_embedding: + query_vector = query_embedding + else: + if not query: + raise ValueError("specify either query or query_embedding to search") + elif embedding_service: + query_vector = await embedding_service.embed_query(query) + elif self.embedding_service: + query_vector = await self.embedding_service.embed_query(query) + else: + raise ValueError("specify embedding_service to search with query") # Fetch the initial documents got_hits = await self.search( query=None, - query_vector=query_embedding, + query_vector=query_vector, k=num_candidates, fields=fields, custom_query=custom_query, @@ -397,7 +410,7 @@ async def max_marginal_relevance_search( # Select documents using maximal marginal relevance selected_indices = maximal_marginal_relevance( - query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k + query_vector, got_embeddings, lambda_mult=lambda_mult, k=k ) selected_hits = [got_hits[i] for i in selected_indices] diff --git a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py index 9aaa966f3..3c4a0d51a 100644 --- a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py @@ -229,7 +229,7 @@ def delete( # type: ignore[no-untyped-def] def search( self, *, - query: Optional[str], + query: Optional[str] = None, query_vector: Optional[List[float]] = None, k: int = 4, num_candidates: int = 50, @@ -341,8 +341,9 @@ def _create_index_if_not_exists(self) -> None: def max_marginal_relevance_search( self, *, - embedding_service: EmbeddingService, - query: str, + query: Optional[str] = None, + query_embedding: Optional[List[float]] = None, + embedding_service: Optional[EmbeddingService] = None, vector_field: str, k: int = 4, num_candidates: int = 20, @@ -358,6 +359,8 @@ def max_marginal_relevance_search( among selected documents. :param query (str): Text to look up documents similar to. + :param query_embedding: Input embedding vector. If given, input query string is + ignored. :param k (int): Number of Documents to return. Defaults to 4. :param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. :param lambda_mult (float): Number between 0 and 1 that determines the degree @@ -378,12 +381,22 @@ def max_marginal_relevance_search( remove_vector_query_field_from_metadata = False # Embed the query - query_embedding = embedding_service.embed_query(query) + if query_embedding: + query_vector = query_embedding + else: + if not query: + raise ValueError("specify either query or query_embedding to search") + elif embedding_service: + query_vector = embedding_service.embed_query(query) + elif self.embedding_service: + query_vector = self.embedding_service.embed_query(query) + else: + raise ValueError("specify embedding_service to search with query") # Fetch the initial documents got_hits = self.search( query=None, - query_vector=query_embedding, + query_vector=query_vector, k=num_candidates, fields=fields, custom_query=custom_query, @@ -394,7 +407,7 @@ def max_marginal_relevance_search( # Select documents using maximal marginal relevance selected_indices = maximal_marginal_relevance( - query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k + query_vector, got_embeddings, lambda_mult=lambda_mult, k=k ) selected_hits = [got_hits[i] for i in selected_indices] diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py index a8cae670f..820746acd 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py @@ -815,6 +815,47 @@ def test_bulk_args(self, sync_client_request_saving: Any, index: str) -> None: # 1 for index exist, 1 for index create, 3 to index docs assert len(store.client.transport.requests) == 5 # type: ignore + def test_max_marginal_relevance_search_errors( + self, sync_client: Elasticsearch, index: str + ) -> None: + """Test max marginal relevance search error conditions.""" + texts = ["foo", "bar", "baz"] + vector_field = "vector_field" + embedding_service = ConsistentFakeEmbeddings() + store = VectorStore( + index=index, + retrieval_strategy=DenseVectorScriptScoreStrategy(), + embedding_service=embedding_service, + client=sync_client, + ) + store.add_texts(texts) + + # search without query embeddings vector or query + with pytest.raises( + ValueError, match="specify either query or query_embedding to search" + ): + store.max_marginal_relevance_search( + vector_field=vector_field, + k=3, + num_candidates=3, + ) + + # search without service + no_service_store = VectorStore( + index=index, + retrieval_strategy=DenseVectorScriptScoreStrategy(), + client=sync_client, + ) + with pytest.raises( + ValueError, match="specify embedding_service to search with query" + ): + no_service_store.max_marginal_relevance_search( + query=texts[0], + vector_field=vector_field, + k=3, + num_candidates=3, + ) + def test_max_marginal_relevance_search( self, sync_client: Elasticsearch, index: str ) -> None: @@ -822,6 +863,7 @@ def test_max_marginal_relevance_search( texts = ["foo", "bar", "baz"] vector_field = "vector_field" text_field = "text_field" + query_embedding = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0] embedding_service = ConsistentFakeEmbeddings() store = VectorStore( index=index, @@ -833,8 +875,8 @@ def test_max_marginal_relevance_search( ) store.add_texts(texts) + # search with query mmr_output = store.max_marginal_relevance_search( - embedding_service=embedding_service, query=texts[0], vector_field=vector_field, k=3, @@ -843,8 +885,17 @@ def test_max_marginal_relevance_search( sim_output = store.search(query=texts[0], k=3) assert mmr_output == sim_output + # search with query embeddings + mmr_output = store.max_marginal_relevance_search( + query_embedding=query_embedding, + vector_field=vector_field, + k=3, + num_candidates=3, + ) + sim_output = store.search(query_vector=query_embedding, k=3) + assert mmr_output == sim_output + mmr_output = store.max_marginal_relevance_search( - embedding_service=embedding_service, query=texts[0], vector_field=vector_field, k=2, @@ -855,7 +906,6 @@ def test_max_marginal_relevance_search( assert mmr_output[1]["_source"][text_field] == texts[1] mmr_output = store.max_marginal_relevance_search( - embedding_service=embedding_service, query=texts[0], vector_field=vector_field, k=2, @@ -868,7 +918,6 @@ def test_max_marginal_relevance_search( # if fetch_k < k, then the output will be less than k mmr_output = store.max_marginal_relevance_search( - embedding_service=embedding_service, query=texts[0], vector_field=vector_field, k=3,