Skip to content

Commit 6905674

Browse files
allow embeddings vector to be used for mmr searching (#2620) (#2625) (#2639)
* allow embeddings vector to be used for mmr searching (#2620) Signed-off-by: rishabh208gupta <[email protected]> * Use embedding service if provided --------- Signed-off-by: rishabh208gupta <[email protected]> Co-authored-by: Quentin Pradet <[email protected]> (cherry picked from commit 3b1bce7) Co-authored-by: Rishabh Gupta <[email protected]>
1 parent 5235aaa commit 6905674

File tree

3 files changed

+91
-16
lines changed

3 files changed

+91
-16
lines changed

elasticsearch/helpers/vectorstore/_async/vectorstore.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ async def delete( # type: ignore[no-untyped-def]
232232
async def search(
233233
self,
234234
*,
235-
query: Optional[str],
235+
query: Optional[str] = None,
236236
query_vector: Optional[List[float]] = None,
237237
k: int = 4,
238238
num_candidates: int = 50,
@@ -344,8 +344,9 @@ async def _create_index_if_not_exists(self) -> None:
344344
async def max_marginal_relevance_search(
345345
self,
346346
*,
347-
embedding_service: AsyncEmbeddingService,
348-
query: str,
347+
query: Optional[str] = None,
348+
query_embedding: Optional[List[float]] = None,
349+
embedding_service: Optional[AsyncEmbeddingService] = None,
349350
vector_field: str,
350351
k: int = 4,
351352
num_candidates: int = 20,
@@ -361,6 +362,8 @@ async def max_marginal_relevance_search(
361362
among selected documents.
362363
363364
:param query (str): Text to look up documents similar to.
365+
:param query_embedding: Input embedding vector. If given, input query string is
366+
ignored.
364367
:param k (int): Number of Documents to return. Defaults to 4.
365368
:param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
366369
:param lambda_mult (float): Number between 0 and 1 that determines the degree
@@ -381,12 +384,22 @@ async def max_marginal_relevance_search(
381384
remove_vector_query_field_from_metadata = False
382385

383386
# Embed the query
384-
query_embedding = await embedding_service.embed_query(query)
387+
if query_embedding:
388+
query_vector = query_embedding
389+
else:
390+
if not query:
391+
raise ValueError("specify either query or query_embedding to search")
392+
elif embedding_service:
393+
query_vector = await embedding_service.embed_query(query)
394+
elif self.embedding_service:
395+
query_vector = await self.embedding_service.embed_query(query)
396+
else:
397+
raise ValueError("specify embedding_service to search with query")
385398

386399
# Fetch the initial documents
387400
got_hits = await self.search(
388401
query=None,
389-
query_vector=query_embedding,
402+
query_vector=query_vector,
390403
k=num_candidates,
391404
fields=fields,
392405
custom_query=custom_query,
@@ -397,7 +410,7 @@ async def max_marginal_relevance_search(
397410

398411
# Select documents using maximal marginal relevance
399412
selected_indices = maximal_marginal_relevance(
400-
query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k
413+
query_vector, got_embeddings, lambda_mult=lambda_mult, k=k
401414
)
402415
selected_hits = [got_hits[i] for i in selected_indices]
403416

elasticsearch/helpers/vectorstore/_sync/vectorstore.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def delete( # type: ignore[no-untyped-def]
229229
def search(
230230
self,
231231
*,
232-
query: Optional[str],
232+
query: Optional[str] = None,
233233
query_vector: Optional[List[float]] = None,
234234
k: int = 4,
235235
num_candidates: int = 50,
@@ -341,8 +341,9 @@ def _create_index_if_not_exists(self) -> None:
341341
def max_marginal_relevance_search(
342342
self,
343343
*,
344-
embedding_service: EmbeddingService,
345-
query: str,
344+
query: Optional[str] = None,
345+
query_embedding: Optional[List[float]] = None,
346+
embedding_service: Optional[EmbeddingService] = None,
346347
vector_field: str,
347348
k: int = 4,
348349
num_candidates: int = 20,
@@ -358,6 +359,8 @@ def max_marginal_relevance_search(
358359
among selected documents.
359360
360361
:param query (str): Text to look up documents similar to.
362+
:param query_embedding: Input embedding vector. If given, input query string is
363+
ignored.
361364
:param k (int): Number of Documents to return. Defaults to 4.
362365
:param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
363366
:param lambda_mult (float): Number between 0 and 1 that determines the degree
@@ -378,12 +381,22 @@ def max_marginal_relevance_search(
378381
remove_vector_query_field_from_metadata = False
379382

380383
# Embed the query
381-
query_embedding = embedding_service.embed_query(query)
384+
if query_embedding:
385+
query_vector = query_embedding
386+
else:
387+
if not query:
388+
raise ValueError("specify either query or query_embedding to search")
389+
elif embedding_service:
390+
query_vector = embedding_service.embed_query(query)
391+
elif self.embedding_service:
392+
query_vector = self.embedding_service.embed_query(query)
393+
else:
394+
raise ValueError("specify embedding_service to search with query")
382395

383396
# Fetch the initial documents
384397
got_hits = self.search(
385398
query=None,
386-
query_vector=query_embedding,
399+
query_vector=query_vector,
387400
k=num_candidates,
388401
fields=fields,
389402
custom_query=custom_query,
@@ -394,7 +407,7 @@ def max_marginal_relevance_search(
394407

395408
# Select documents using maximal marginal relevance
396409
selected_indices = maximal_marginal_relevance(
397-
query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k
410+
query_vector, got_embeddings, lambda_mult=lambda_mult, k=k
398411
)
399412
selected_hits = [got_hits[i] for i in selected_indices]
400413

test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -815,13 +815,55 @@ def test_bulk_args(self, sync_client_request_saving: Any, index: str) -> None:
815815
# 1 for index exist, 1 for index create, 3 to index docs
816816
assert len(store.client.transport.requests) == 5 # type: ignore
817817

818+
def test_max_marginal_relevance_search_errors(
819+
self, sync_client: Elasticsearch, index: str
820+
) -> None:
821+
"""Test max marginal relevance search error conditions."""
822+
texts = ["foo", "bar", "baz"]
823+
vector_field = "vector_field"
824+
embedding_service = ConsistentFakeEmbeddings()
825+
store = VectorStore(
826+
index=index,
827+
retrieval_strategy=DenseVectorScriptScoreStrategy(),
828+
embedding_service=embedding_service,
829+
client=sync_client,
830+
)
831+
store.add_texts(texts)
832+
833+
# search without query embeddings vector or query
834+
with pytest.raises(
835+
ValueError, match="specify either query or query_embedding to search"
836+
):
837+
store.max_marginal_relevance_search(
838+
vector_field=vector_field,
839+
k=3,
840+
num_candidates=3,
841+
)
842+
843+
# search without service
844+
no_service_store = VectorStore(
845+
index=index,
846+
retrieval_strategy=DenseVectorScriptScoreStrategy(),
847+
client=sync_client,
848+
)
849+
with pytest.raises(
850+
ValueError, match="specify embedding_service to search with query"
851+
):
852+
no_service_store.max_marginal_relevance_search(
853+
query=texts[0],
854+
vector_field=vector_field,
855+
k=3,
856+
num_candidates=3,
857+
)
858+
818859
def test_max_marginal_relevance_search(
819860
self, sync_client: Elasticsearch, index: str
820861
) -> None:
821862
"""Test max marginal relevance search."""
822863
texts = ["foo", "bar", "baz"]
823864
vector_field = "vector_field"
824865
text_field = "text_field"
866+
query_embedding = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]
825867
embedding_service = ConsistentFakeEmbeddings()
826868
store = VectorStore(
827869
index=index,
@@ -833,8 +875,8 @@ def test_max_marginal_relevance_search(
833875
)
834876
store.add_texts(texts)
835877

878+
# search with query
836879
mmr_output = store.max_marginal_relevance_search(
837-
embedding_service=embedding_service,
838880
query=texts[0],
839881
vector_field=vector_field,
840882
k=3,
@@ -843,8 +885,17 @@ def test_max_marginal_relevance_search(
843885
sim_output = store.search(query=texts[0], k=3)
844886
assert mmr_output == sim_output
845887

888+
# search with query embeddings
889+
mmr_output = store.max_marginal_relevance_search(
890+
query_embedding=query_embedding,
891+
vector_field=vector_field,
892+
k=3,
893+
num_candidates=3,
894+
)
895+
sim_output = store.search(query_vector=query_embedding, k=3)
896+
assert mmr_output == sim_output
897+
846898
mmr_output = store.max_marginal_relevance_search(
847-
embedding_service=embedding_service,
848899
query=texts[0],
849900
vector_field=vector_field,
850901
k=2,
@@ -855,7 +906,6 @@ def test_max_marginal_relevance_search(
855906
assert mmr_output[1]["_source"][text_field] == texts[1]
856907

857908
mmr_output = store.max_marginal_relevance_search(
858-
embedding_service=embedding_service,
859909
query=texts[0],
860910
vector_field=vector_field,
861911
k=2,
@@ -868,7 +918,6 @@ def test_max_marginal_relevance_search(
868918

869919
# if fetch_k < k, then the output will be less than k
870920
mmr_output = store.max_marginal_relevance_search(
871-
embedding_service=embedding_service,
872921
query=texts[0],
873922
vector_field=vector_field,
874923
k=3,

0 commit comments

Comments
 (0)