Skip to content

Commit 2ca002d

Browse files
Rishabh GuptaRishabh Gupta
Rishabh Gupta
authored and
Rishabh Gupta
committed
allow embeddings vector to be used for mmr searching (#2620)
1 parent 6521b55 commit 2ca002d

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

elasticsearch/helpers/vectorstore/_async/vectorstore.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ async def _create_index_if_not_exists(self) -> None:
344344
async def max_marginal_relevance_search(
345345
self,
346346
*,
347-
embedding_service: AsyncEmbeddingService,
348-
query: str,
347+
query: Optional[str],
348+
query_embedding: Optional[List[float]] = None,
349349
vector_field: str,
350350
k: int = 4,
351351
num_candidates: int = 20,
@@ -361,6 +361,8 @@ async def max_marginal_relevance_search(
361361
among selected documents.
362362
363363
:param query (str): Text to look up documents similar to.
364+
:param query_embedding: Input embedding vector. If given, input query string is
365+
ignored.
364366
:param k (int): Number of Documents to return. Defaults to 4.
365367
:param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
366368
:param lambda_mult (float): Number between 0 and 1 that determines the degree
@@ -381,7 +383,10 @@ async def max_marginal_relevance_search(
381383
remove_vector_query_field_from_metadata = False
382384

383385
# Embed the query
384-
query_embedding = await embedding_service.embed_query(query)
386+
if self.embedding_service and not query_embedding:
387+
if not query:
388+
raise ValueError("specify a query or a query_embedding to search")
389+
query_embedding = await self.embedding_service.embed_query(query)
385390

386391
# Fetch the initial documents
387392
got_hits = await self.search(

elasticsearch/helpers/vectorstore/_sync/vectorstore.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ def _create_index_if_not_exists(self) -> None:
341341
def max_marginal_relevance_search(
342342
self,
343343
*,
344-
embedding_service: EmbeddingService,
345-
query: str,
344+
query: Optional[str],
345+
query_embedding: Optional[List[float]] = None,
346346
vector_field: str,
347347
k: int = 4,
348348
num_candidates: int = 20,
@@ -358,6 +358,8 @@ def max_marginal_relevance_search(
358358
among selected documents.
359359
360360
:param query (str): Text to look up documents similar to.
361+
:param query_embedding: Input embedding vector. If given, input query string is
362+
ignored.
361363
:param k (int): Number of Documents to return. Defaults to 4.
362364
:param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
363365
:param lambda_mult (float): Number between 0 and 1 that determines the degree
@@ -378,7 +380,10 @@ def max_marginal_relevance_search(
378380
remove_vector_query_field_from_metadata = False
379381

380382
# Embed the query
381-
query_embedding = embedding_service.embed_query(query)
383+
if self.embedding_service and not query_embedding:
384+
if not query:
385+
raise ValueError("specify a query or a query_embedding to search")
386+
query_embedding = self.embedding_service.embed_query(query)
382387

383388
# Fetch the initial documents
384389
got_hits = self.search(

test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,6 @@ def test_max_marginal_relevance_search(
834834
store.add_texts(texts)
835835

836836
mmr_output = store.max_marginal_relevance_search(
837-
embedding_service=embedding_service,
838837
query=texts[0],
839838
vector_field=vector_field,
840839
k=3,
@@ -844,7 +843,6 @@ def test_max_marginal_relevance_search(
844843
assert mmr_output == sim_output
845844

846845
mmr_output = store.max_marginal_relevance_search(
847-
embedding_service=embedding_service,
848846
query=texts[0],
849847
vector_field=vector_field,
850848
k=2,
@@ -855,7 +853,6 @@ def test_max_marginal_relevance_search(
855853
assert mmr_output[1]["_source"][text_field] == texts[1]
856854

857855
mmr_output = store.max_marginal_relevance_search(
858-
embedding_service=embedding_service,
859856
query=texts[0],
860857
vector_field=vector_field,
861858
k=2,
@@ -868,7 +865,6 @@ def test_max_marginal_relevance_search(
868865

869866
# if fetch_k < k, then the output will be less than k
870867
mmr_output = store.max_marginal_relevance_search(
871-
embedding_service=embedding_service,
872868
query=texts[0],
873869
vector_field=vector_field,
874870
k=3,

0 commit comments

Comments
 (0)