Skip to content

Commit e3928e6

Browse files
feat: Implement hybrid search in Milvus (#2644)
# What does this PR do? This PR implements hybrid search for Milvus DB based on the inbuilt milvus support. To test: ``` pytest tests/unit/providers/vector_io/remote/test_milvus.py -v -s --tb=long --disable-warnings --asyncio-mode=auto ``` Signed-off-by: Varsha Prasad Narsing <[email protected]>
1 parent 5a2d323 commit e3928e6

File tree

4 files changed

+204
-9
lines changed

4 files changed

+204
-9
lines changed

llama_stack/providers/remote/vector_io/milvus/milvus.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Any
1111

1212
from numpy.typing import NDArray
13-
from pymilvus import DataType, Function, FunctionType, MilvusClient
13+
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
1414

1515
from llama_stack.apis.common.errors import VectorStoreNotFoundError
1616
from llama_stack.apis.files.files import Files
@@ -27,6 +27,7 @@
2727
from llama_stack.providers.utils.kvstore.api import KVStore
2828
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
2929
from llama_stack.providers.utils.memory.vector_store import (
30+
RERANKER_TYPE_WEIGHTED,
3031
EmbeddingIndex,
3132
VectorDBWithIndex,
3233
)
@@ -238,7 +239,53 @@ async def query_hybrid(
238239
reranker_type: str,
239240
reranker_params: dict[str, Any] | None = None,
240241
) -> QueryChunksResponse:
241-
raise NotImplementedError("Hybrid search is not supported in Milvus")
242+
"""
243+
Hybrid search using Milvus's native hybrid search capabilities.
244+
245+
This implementation uses Milvus's hybrid_search method which combines
246+
vector search and BM25 search with configurable reranking strategies.
247+
"""
248+
search_requests = []
249+
250+
# nprobe: Controls search accuracy vs performance trade-off
251+
# 10 balances these trade-offs for RAG applications
252+
search_requests.append(
253+
AnnSearchRequest(data=[embedding.tolist()], anns_field="vector", param={"nprobe": 10}, limit=k)
254+
)
255+
256+
# drop_ratio_search: Filters low-importance terms to improve search performance
257+
# 0.2 balances noise reduction with recall
258+
search_requests.append(
259+
AnnSearchRequest(data=[query_string], anns_field="sparse", param={"drop_ratio_search": 0.2}, limit=k)
260+
)
261+
262+
if reranker_type == RERANKER_TYPE_WEIGHTED:
263+
alpha = (reranker_params or {}).get("alpha", 0.5)
264+
rerank = WeightedRanker(alpha, 1 - alpha)
265+
else:
266+
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
267+
rerank = RRFRanker(impact_factor)
268+
269+
search_res = await asyncio.to_thread(
270+
self.client.hybrid_search,
271+
collection_name=self.collection_name,
272+
reqs=search_requests,
273+
ranker=rerank,
274+
limit=k,
275+
output_fields=["chunk_content"],
276+
)
277+
278+
chunks = []
279+
scores = []
280+
for res in search_res[0]:
281+
chunk = Chunk(**res["entity"]["chunk_content"])
282+
chunks.append(chunk)
283+
scores.append(res["distance"])
284+
285+
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
286+
filtered_scores = [score for score in scores if score >= score_threshold]
287+
288+
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
242289

243290
async def delete_chunk(self, chunk_id: str) -> None:
244291
"""Remove a chunk from the Milvus collection."""

llama_stack/providers/utils/memory/vector_store.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -302,23 +302,25 @@ async def query_chunks(
302302
mode = params.get("mode")
303303
score_threshold = params.get("score_threshold", 0.0)
304304

305-
# Get ranker configuration
306305
ranker = params.get("ranker")
307306
if ranker is None:
308-
# Default to RRF with impact_factor=60.0
309307
reranker_type = RERANKER_TYPE_RRF
310308
reranker_params = {"impact_factor": 60.0}
311309
else:
312-
reranker_type = ranker.type
313-
reranker_params = (
314-
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
315-
)
310+
strategy = ranker.get("strategy", "rrf")
311+
if strategy == "weighted":
312+
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
313+
reranker_type = RERANKER_TYPE_WEIGHTED
314+
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
315+
else:
316+
reranker_type = RERANKER_TYPE_RRF
317+
k_value = ranker.get("params", {}).get("k", 60.0)
318+
reranker_params = {"impact_factor": k_value}
316319

317320
query_string = interleaved_content_as_str(query)
318321
if mode == "keyword":
319322
return await self.index.query_keyword(query_string, k, score_threshold)
320323

321-
# Calculate embeddings for both vector and hybrid modes
322324
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
323325
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
324326
if mode == "hybrid":

tests/integration/vector_io/test_openai_vector_stores.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
3030
"remote::qdrant",
3131
"inline::qdrant",
3232
"remote::weaviate",
33+
"remote::milvus",
3334
]:
3435
return
3536

@@ -49,12 +50,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
4950
"remote::chromadb",
5051
"remote::weaviate",
5152
"remote::qdrant",
53+
"remote::milvus",
5254
],
5355
"keyword": [
5456
"inline::sqlite-vec",
57+
"remote::milvus",
5558
],
5659
"hybrid": [
5760
"inline::sqlite-vec",
61+
"inline::milvus",
62+
"remote::milvus",
5863
],
5964
}
6065
supported_providers = search_mode_support.get(search_mode, [])

tests/unit/providers/vector_io/remote/test_milvus.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
pymilvus_mock = MagicMock()
1616
pymilvus_mock.DataType = MagicMock()
1717
pymilvus_mock.MilvusClient = MagicMock
18+
pymilvus_mock.RRFRanker = MagicMock
19+
pymilvus_mock.WeightedRanker = MagicMock
20+
pymilvus_mock.AnnSearchRequest = MagicMock
1821

1922
# Apply the mock before importing MilvusIndex
2023
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
@@ -183,3 +186,141 @@ async def test_delete_collection(milvus_index, mock_milvus_client):
183186
await milvus_index.delete()
184187

185188
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
189+
190+
191+
async def test_query_hybrid_search_rrf(
192+
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
193+
):
194+
"""Test hybrid search with RRF reranker."""
195+
mock_milvus_client.has_collection.return_value = True
196+
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
197+
198+
# Mock hybrid search results
199+
mock_milvus_client.hybrid_search.return_value = [
200+
[
201+
{
202+
"id": 0,
203+
"distance": 0.1,
204+
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
205+
},
206+
{
207+
"id": 1,
208+
"distance": 0.2,
209+
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
210+
},
211+
]
212+
]
213+
214+
# Test hybrid search with RRF reranker
215+
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
216+
query_string = "test query"
217+
response = await milvus_index.query_hybrid(
218+
embedding=query_embedding,
219+
query_string=query_string,
220+
k=2,
221+
score_threshold=0.0,
222+
reranker_type="rrf",
223+
reranker_params={"impact_factor": 60.0},
224+
)
225+
226+
assert isinstance(response, QueryChunksResponse)
227+
assert len(response.chunks) == 2
228+
assert len(response.scores) == 2
229+
230+
# Verify hybrid search was called with correct parameters
231+
mock_milvus_client.hybrid_search.assert_called_once()
232+
call_args = mock_milvus_client.hybrid_search.call_args
233+
234+
# Check that the request contains both vector and BM25 search requests
235+
reqs = call_args[1]["reqs"]
236+
assert len(reqs) == 2
237+
assert reqs[0].anns_field == "vector"
238+
assert reqs[1].anns_field == "sparse"
239+
ranker = call_args[1]["ranker"]
240+
assert ranker is not None
241+
242+
243+
async def test_query_hybrid_search_weighted(
244+
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
245+
):
246+
"""Test hybrid search with weighted reranker."""
247+
mock_milvus_client.has_collection.return_value = True
248+
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
249+
250+
# Mock hybrid search results
251+
mock_milvus_client.hybrid_search.return_value = [
252+
[
253+
{
254+
"id": 0,
255+
"distance": 0.1,
256+
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
257+
},
258+
{
259+
"id": 1,
260+
"distance": 0.2,
261+
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
262+
},
263+
]
264+
]
265+
266+
# Test hybrid search with weighted reranker
267+
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
268+
query_string = "test query"
269+
response = await milvus_index.query_hybrid(
270+
embedding=query_embedding,
271+
query_string=query_string,
272+
k=2,
273+
score_threshold=0.0,
274+
reranker_type="weighted",
275+
reranker_params={"alpha": 0.7},
276+
)
277+
278+
assert isinstance(response, QueryChunksResponse)
279+
assert len(response.chunks) == 2
280+
assert len(response.scores) == 2
281+
282+
# Verify hybrid search was called with correct parameters
283+
mock_milvus_client.hybrid_search.assert_called_once()
284+
call_args = mock_milvus_client.hybrid_search.call_args
285+
ranker = call_args[1]["ranker"]
286+
assert ranker is not None
287+
288+
289+
async def test_query_hybrid_search_default_rrf(
290+
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
291+
):
292+
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
293+
mock_milvus_client.has_collection.return_value = True
294+
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
295+
296+
# Mock hybrid search results
297+
mock_milvus_client.hybrid_search.return_value = [
298+
[
299+
{
300+
"id": 0,
301+
"distance": 0.1,
302+
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
303+
},
304+
]
305+
]
306+
307+
# Test hybrid search with default reranker (should be RRF)
308+
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
309+
query_string = "test query"
310+
response = await milvus_index.query_hybrid(
311+
embedding=query_embedding,
312+
query_string=query_string,
313+
k=1,
314+
score_threshold=0.0,
315+
reranker_type="unknown_type", # Should default to RRF
316+
reranker_params=None, # Should use default impact_factor
317+
)
318+
319+
assert isinstance(response, QueryChunksResponse)
320+
assert len(response.chunks) == 1
321+
322+
# Verify hybrid search was called with RRF reranker
323+
mock_milvus_client.hybrid_search.assert_called_once()
324+
call_args = mock_milvus_client.hybrid_search.call_args
325+
ranker = call_args[1]["ranker"]
326+
assert ranker is not None

0 commit comments

Comments
 (0)