Skip to content

Commit 7de1961

Browse files
feat: implement keyword and hybrid search for Weaviate provider
1 parent cec00c5 commit 7de1961

File tree

8 files changed

+482
-24
lines changed

8 files changed

+482
-24
lines changed

llama_stack/providers/remote/vector_io/weaviate/weaviate.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import weaviate.classes as wvc
1111
from numpy.typing import NDArray
1212
from weaviate.classes.init import Auth
13-
from weaviate.classes.query import Filter
13+
from weaviate.classes.query import Filter, HybridFusion
1414

1515
from llama_stack.apis.common.content_types import InterleavedContent
1616
from llama_stack.apis.common.errors import VectorStoreNotFoundError
@@ -26,6 +26,7 @@
2626
OpenAIVectorStoreMixin,
2727
)
2828
from llama_stack.providers.utils.memory.vector_store import (
29+
RERANKER_TYPE_RRF,
2930
ChunkForDeletion,
3031
EmbeddingIndex,
3132
VectorDBWithIndex,
@@ -88,6 +89,9 @@ async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> No
8889
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
8990

9091
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
92+
log.info(
93+
f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
94+
)
9195
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
9296
collection = self.client.collections.get(sanitized_collection_name)
9397

@@ -109,12 +113,16 @@ async def query_vector(self, embedding: NDArray, k: int, score_threshold: float)
109113
continue
110114

111115
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
116+
log.info(f"📈 Document distance: {doc.metadata.distance}, calculated score: {score}")
117+
112118
if score < score_threshold:
113119
continue
114120

121+
log.info(f"Document {chunk.metadata.get('document_id')} has score {score}")
115122
chunks.append(chunk)
116123
scores.append(score)
117124

125+
log.info(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
118126
return QueryChunksResponse(chunks=chunks, scores=scores)
119127

120128
async def delete(self, chunk_ids: list[str] | None = None) -> None:
@@ -136,7 +144,46 @@ async def query_keyword(
136144
k: int,
137145
score_threshold: float,
138146
) -> QueryChunksResponse:
139-
raise NotImplementedError("Keyword search is not supported in Weaviate")
147+
"""
148+
Performs BM25-based keyword search using Weaviate's built-in full-text search.
149+
Args:
150+
query_string: The text query for keyword search
151+
k: Limit of number of results to return
152+
score_threshold: Minimum similarity score threshold
153+
Returns:
154+
QueryChunksResponse with combined results
155+
"""
156+
log.info(f"WEAVIATE KEYWORD SEARCH CALLED: query='{query_string}', k={k}, threshold={score_threshold}")
157+
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
158+
collection = self.client.collections.get(sanitized_collection_name)
159+
160+
# Perform BM25 keyword search on chunk_content field
161+
results = collection.query.bm25(
162+
query=query_string,
163+
limit=k,
164+
return_metadata=wvc.query.MetadataQuery(score=True),
165+
)
166+
167+
chunks = []
168+
scores = []
169+
for doc in results.objects:
170+
chunk_json = doc.properties["chunk_content"]
171+
try:
172+
chunk_dict = json.loads(chunk_json)
173+
chunk = Chunk(**chunk_dict)
174+
except Exception:
175+
log.exception(f"Failed to parse document: {chunk_json}")
176+
continue
177+
178+
score = doc.metadata.score if doc.metadata.score is not None else 0.0
179+
if score < score_threshold:
180+
continue
181+
182+
chunks.append(chunk)
183+
scores.append(score)
184+
185+
log.info(f"WEAVIATE KEYWORD SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}.")
186+
return QueryChunksResponse(chunks=chunks, scores=scores)
140187

141188
async def query_hybrid(
142189
self,
@@ -147,7 +194,62 @@ async def query_hybrid(
147194
reranker_type: str,
148195
reranker_params: dict[str, Any] | None = None,
149196
) -> QueryChunksResponse:
150-
raise NotImplementedError("Hybrid search is not supported in Weaviate")
197+
"""
198+
Hybrid search combining vector similarity and keyword search using Weaviate's native hybrid search.
199+
Args:
200+
embedding: The query embedding vector
201+
query_string: The text query for keyword search
202+
k: Limit of number of results to return
203+
score_threshold: Minimum similarity score threshold
204+
reranker_type: Type of reranker to use ("rrf" or "normalized")
205+
reranker_params: Parameters for the reranker
206+
Returns:
207+
QueryChunksResponse with combined results
208+
"""
209+
log.info(
210+
f"WEAVIATE HYBRID SEARCH CALLED: query='{query_string}', embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}, reranker={reranker_type}"
211+
)
212+
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
213+
collection = self.client.collections.get(sanitized_collection_name)
214+
215+
# Ranked (RRF) reranker fusion type
216+
if reranker_type == RERANKER_TYPE_RRF:
217+
rerank = HybridFusion.RANKED
218+
# Relative score (Normalized) reranker fusion type
219+
else:
220+
rerank = HybridFusion.RELATIVE_SCORE
221+
222+
# Perform hybrid search using Weaviate's native hybrid search
223+
results = collection.query.hybrid(
224+
query=query_string,
225+
alpha=0.5, # Range <0, 1>, where 0.5 will equally favor vector and keyword search
226+
vector=embedding.tolist(),
227+
limit=k,
228+
fusion_type=rerank,
229+
return_metadata=wvc.query.MetadataQuery(score=True),
230+
)
231+
232+
chunks = []
233+
scores = []
234+
for doc in results.objects:
235+
chunk_json = doc.properties["chunk_content"]
236+
try:
237+
chunk_dict = json.loads(chunk_json)
238+
chunk = Chunk(**chunk_dict)
239+
except Exception:
240+
log.exception(f"Failed to parse document: {chunk_json}")
241+
continue
242+
243+
score = doc.metadata.score if doc.metadata.score is not None else 0.0
244+
if score < score_threshold:
245+
continue
246+
247+
log.info(f"Document {chunk.metadata.get('document_id')} has score {score}")
248+
chunks.append(chunk)
249+
scores.append(score)
250+
251+
log.info(f"WEAVIATE HYBRID SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
252+
return QueryChunksResponse(chunks=chunks, scores=scores)
151253

152254

153255
class WeaviateVectorIOAdapter(

llama_stack/providers/utils/memory/vector_store.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class ChunkForDeletion(BaseModel):
5050
# Constants for reranker types
5151
RERANKER_TYPE_RRF = "rrf"
5252
RERANKER_TYPE_WEIGHTED = "weighted"
53+
RERANKER_TYPE_NORMALIZED = "normalized"
5354

5455

5556
def parse_pdf(data: bytes) -> str:
@@ -325,6 +326,8 @@ async def query_chunks(
325326
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
326327
reranker_type = RERANKER_TYPE_WEIGHTED
327328
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
329+
elif strategy == "normalized":
330+
reranker_type = RERANKER_TYPE_NORMALIZED
328331
else:
329332
reranker_type = RERANKER_TYPE_RRF
330333
k_value = ranker.get("params", {}).get("k", 60.0)

pyproject.toml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ classifiers = [
2525
]
2626
dependencies = [
2727
"aiohttp",
28-
"fastapi>=0.115.0,<1.0", # server
29-
"fire", # for MCP in LLS client
28+
"fastapi>=0.115.0,<1.0", # server
29+
"fire", # for MCP in LLS client
3030
"httpx",
3131
"huggingface-hub>=0.34.0,<1.0",
3232
"jinja2>=3.1.6",
@@ -44,12 +44,13 @@ dependencies = [
4444
"tiktoken",
4545
"pillow",
4646
"h11>=0.16.0",
47-
"python-multipart>=0.0.20", # For fastapi Form
48-
"uvicorn>=0.34.0", # server
49-
"opentelemetry-sdk>=1.30.0", # server
47+
"python-multipart>=0.0.20", # For fastapi Form
48+
"uvicorn>=0.34.0", # server
49+
"opentelemetry-sdk>=1.30.0", # server
5050
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
51-
"aiosqlite>=0.21.0", # server - for metadata store
52-
"asyncpg", # for metadata store
51+
"aiosqlite>=0.21.0", # server - for metadata store
52+
"asyncpg", # for metadata store
53+
"weaviate-client>=4.16.5",
5354
]
5455

5556
[project.optional-dependencies]

tests/integration/vector_io/test_openai_vector_stores.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
2222
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
2323
for p in vector_io_providers:
2424
if p.provider_type in [
25+
"inline::chromadb",
2526
"inline::faiss",
26-
"inline::sqlite-vec",
2727
"inline::milvus",
28-
"inline::chromadb",
29-
"remote::pgvector",
28+
"inline::qdrant",
29+
"inline::sqlite-vec",
3030
"remote::chromadb",
31+
"remote::milvus",
32+
"remote::pgvector",
3133
"remote::qdrant",
32-
"inline::qdrant",
3334
"remote::weaviate",
34-
"remote::milvus",
3535
]:
3636
return
3737

@@ -47,21 +47,23 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
4747
"inline::milvus",
4848
"inline::chromadb",
4949
"inline::qdrant",
50-
"remote::pgvector",
5150
"remote::chromadb",
52-
"remote::weaviate",
53-
"remote::qdrant",
5451
"remote::milvus",
52+
"remote::pgvector",
53+
"remote::qdrant",
54+
"remote::weaviate",
5555
],
5656
"keyword": [
57+
"inline::milvus",
5758
"inline::sqlite-vec",
5859
"remote::milvus",
59-
"inline::milvus",
60+
"remote::weaviate",
6061
],
6162
"hybrid": [
62-
"inline::sqlite-vec",
6363
"inline::milvus",
64+
"inline::sqlite-vec",
6465
"remote::milvus",
66+
"remote::weaviate",
6567
],
6668
}
6769
supported_providers = search_mode_support.get(search_mode, [])

tests/unit/providers/vector_io/conftest.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
2424
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
2525
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
26+
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
27+
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
2628

2729
EMBEDDING_DIMENSION = 384
2830
COLLECTION_PREFIX = "test_collection"
2931
MILVUS_ALIAS = "test_milvus"
3032

3133

32-
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
34+
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "weaviate"])
3335
def vector_provider(request):
3436
return request.param
3537

@@ -333,6 +335,78 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
333335
await index.delete()
334336

335337

338+
@pytest.fixture
339+
def weaviate_vec_db_path():
340+
return "localhost:8080"
341+
342+
343+
@pytest.fixture
344+
async def weaviate_vec_index(weaviate_vec_db_path, embedding_dimension):
345+
import uuid
346+
347+
import weaviate
348+
349+
# Connect to local Weaviate instance
350+
client = weaviate.connect_to_local(
351+
host="localhost",
352+
port=8080,
353+
)
354+
355+
collection_name = f"{COLLECTION_PREFIX}_{uuid.uuid4()}"
356+
index = WeaviateIndex(client=client, collection_name=collection_name)
357+
358+
# Create the collection for this test
359+
import weaviate.classes as wvc
360+
from weaviate.collections.classes.config import _CollectionConfig
361+
362+
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
363+
364+
sanitized_name = sanitize_collection_name(collection_name, weaviate_format=True)
365+
collection_config = _CollectionConfig(
366+
name=sanitized_name,
367+
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
368+
properties=[
369+
wvc.config.Property(
370+
name="chunk_content",
371+
data_type=wvc.config.DataType.TEXT,
372+
),
373+
],
374+
)
375+
if not client.collections.exists(sanitized_name):
376+
client.collections.create_from_config(collection_config)
377+
378+
yield index
379+
await index.delete()
380+
client.close()
381+
382+
383+
@pytest.fixture
384+
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
385+
config = WeaviateVectorIOConfig(
386+
weaviate_cluster_url=weaviate_vec_db_path,
387+
weaviate_api_key=None,
388+
kvstore=SqliteKVStoreConfig(),
389+
)
390+
adapter = WeaviateVectorIOAdapter(
391+
config=config,
392+
inference_api=mock_inference_api,
393+
files_api=None,
394+
)
395+
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
396+
await adapter.initialize()
397+
await adapter.register_vector_db(
398+
VectorDB(
399+
identifier=collection_id,
400+
provider_id="test_provider",
401+
embedding_model="test_model",
402+
embedding_dimension=embedding_dimension,
403+
)
404+
)
405+
adapter.test_collection_id = collection_id
406+
yield adapter
407+
await adapter.shutdown()
408+
409+
336410
@pytest.fixture
337411
def vector_io_adapter(vector_provider, request):
338412
"""Returns the appropriate vector IO adapter based on the provider parameter."""
@@ -342,6 +416,7 @@ def vector_io_adapter(vector_provider, request):
342416
"sqlite_vec": "sqlite_vec_adapter",
343417
"chroma": "chroma_vec_adapter",
344418
"qdrant": "qdrant_vec_adapter",
419+
"weaviate": "weaviate_vec_adapter",
345420
}
346421
return request.getfixturevalue(vector_provider_dict[vector_provider])
347422

tests/unit/providers/vector_io/remote/test_milvus.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
2424
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
2525

26-
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
26+
# This test is a unit test for the MilvusIndex class. This should only contain
2727
# tests which are specific to this class. More general (API-level) tests should be placed in
2828
# tests/integration/vector_io/
2929
#
3030
# How to run this test:
3131
#
32-
# pytest tests/unit/providers/vector_io/test_milvus.py \
32+
# pytest tests/unit/providers/vector_io/remote/test_milvus.py \
3333
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
3434

3535
MILVUS_PROVIDER = "milvus"
@@ -106,6 +106,7 @@ async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_m
106106

107107
# Verify the insert call had the right number of chunks
108108
insert_call = mock_milvus_client.insert.call_args
109+
print(insert_call[1])
109110
assert len(insert_call[1]["data"]) == len(sample_chunks)
110111

111112

@@ -324,3 +325,6 @@ async def test_query_hybrid_search_default_rrf(
324325
call_args = mock_milvus_client.hybrid_search.call_args
325326
ranker = call_args[1]["ranker"]
326327
assert ranker is not None
328+
329+
330+
# TODO: Write tests for the MilvusVectorIOAdapter class.

0 commit comments

Comments
 (0)