|
15 | 15 | pymilvus_mock = MagicMock() |
16 | 16 | pymilvus_mock.DataType = MagicMock() |
17 | 17 | pymilvus_mock.MilvusClient = MagicMock |
| 18 | +pymilvus_mock.RRFRanker = MagicMock |
| 19 | +pymilvus_mock.WeightedRanker = MagicMock |
| 20 | +pymilvus_mock.AnnSearchRequest = MagicMock |
18 | 21 |
|
19 | 22 | # Apply the mock before importing MilvusIndex |
20 | 23 | with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): |
@@ -183,3 +186,141 @@ async def test_delete_collection(milvus_index, mock_milvus_client): |
183 | 186 | await milvus_index.delete() |
184 | 187 |
|
185 | 188 | mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name) |
| 189 | + |
| 190 | + |
| 191 | +async def test_query_hybrid_search_rrf( |
| 192 | + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client |
| 193 | +): |
| 194 | + """Test hybrid search with RRF reranker.""" |
| 195 | + mock_milvus_client.has_collection.return_value = True |
| 196 | + await milvus_index.add_chunks(sample_chunks, sample_embeddings) |
| 197 | + |
| 198 | + # Mock hybrid search results |
| 199 | + mock_milvus_client.hybrid_search.return_value = [ |
| 200 | + [ |
| 201 | + { |
| 202 | + "id": 0, |
| 203 | + "distance": 0.1, |
| 204 | + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, |
| 205 | + }, |
| 206 | + { |
| 207 | + "id": 1, |
| 208 | + "distance": 0.2, |
| 209 | + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, |
| 210 | + }, |
| 211 | + ] |
| 212 | + ] |
| 213 | + |
| 214 | + # Test hybrid search with RRF reranker |
| 215 | + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) |
| 216 | + query_string = "test query" |
| 217 | + response = await milvus_index.query_hybrid( |
| 218 | + embedding=query_embedding, |
| 219 | + query_string=query_string, |
| 220 | + k=2, |
| 221 | + score_threshold=0.0, |
| 222 | + reranker_type="rrf", |
| 223 | + reranker_params={"impact_factor": 60.0}, |
| 224 | + ) |
| 225 | + |
| 226 | + assert isinstance(response, QueryChunksResponse) |
| 227 | + assert len(response.chunks) == 2 |
| 228 | + assert len(response.scores) == 2 |
| 229 | + |
| 230 | + # Verify hybrid search was called with correct parameters |
| 231 | + mock_milvus_client.hybrid_search.assert_called_once() |
| 232 | + call_args = mock_milvus_client.hybrid_search.call_args |
| 233 | + |
| 234 | + # Check that the request contains both vector and BM25 search requests |
| 235 | + reqs = call_args[1]["reqs"] |
| 236 | + assert len(reqs) == 2 |
| 237 | + assert reqs[0].anns_field == "vector" |
| 238 | + assert reqs[1].anns_field == "sparse" |
| 239 | + ranker = call_args[1]["ranker"] |
| 240 | + assert ranker is not None |
| 241 | + |
| 242 | + |
| 243 | +async def test_query_hybrid_search_weighted( |
| 244 | + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client |
| 245 | +): |
| 246 | + """Test hybrid search with weighted reranker.""" |
| 247 | + mock_milvus_client.has_collection.return_value = True |
| 248 | + await milvus_index.add_chunks(sample_chunks, sample_embeddings) |
| 249 | + |
| 250 | + # Mock hybrid search results |
| 251 | + mock_milvus_client.hybrid_search.return_value = [ |
| 252 | + [ |
| 253 | + { |
| 254 | + "id": 0, |
| 255 | + "distance": 0.1, |
| 256 | + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, |
| 257 | + }, |
| 258 | + { |
| 259 | + "id": 1, |
| 260 | + "distance": 0.2, |
| 261 | + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, |
| 262 | + }, |
| 263 | + ] |
| 264 | + ] |
| 265 | + |
| 266 | + # Test hybrid search with weighted reranker |
| 267 | + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) |
| 268 | + query_string = "test query" |
| 269 | + response = await milvus_index.query_hybrid( |
| 270 | + embedding=query_embedding, |
| 271 | + query_string=query_string, |
| 272 | + k=2, |
| 273 | + score_threshold=0.0, |
| 274 | + reranker_type="weighted", |
| 275 | + reranker_params={"alpha": 0.7}, |
| 276 | + ) |
| 277 | + |
| 278 | + assert isinstance(response, QueryChunksResponse) |
| 279 | + assert len(response.chunks) == 2 |
| 280 | + assert len(response.scores) == 2 |
| 281 | + |
| 282 | + # Verify hybrid search was called with correct parameters |
| 283 | + mock_milvus_client.hybrid_search.assert_called_once() |
| 284 | + call_args = mock_milvus_client.hybrid_search.call_args |
| 285 | + ranker = call_args[1]["ranker"] |
| 286 | + assert ranker is not None |
| 287 | + |
| 288 | + |
| 289 | +async def test_query_hybrid_search_default_rrf( |
| 290 | + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client |
| 291 | +): |
| 292 | + """Test hybrid search with default RRF reranker (no reranker_type specified).""" |
| 293 | + mock_milvus_client.has_collection.return_value = True |
| 294 | + await milvus_index.add_chunks(sample_chunks, sample_embeddings) |
| 295 | + |
| 296 | + # Mock hybrid search results |
| 297 | + mock_milvus_client.hybrid_search.return_value = [ |
| 298 | + [ |
| 299 | + { |
| 300 | + "id": 0, |
| 301 | + "distance": 0.1, |
| 302 | + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, |
| 303 | + }, |
| 304 | + ] |
| 305 | + ] |
| 306 | + |
| 307 | + # Test hybrid search with default reranker (should be RRF) |
| 308 | + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) |
| 309 | + query_string = "test query" |
| 310 | + response = await milvus_index.query_hybrid( |
| 311 | + embedding=query_embedding, |
| 312 | + query_string=query_string, |
| 313 | + k=1, |
| 314 | + score_threshold=0.0, |
| 315 | + reranker_type="unknown_type", # Should default to RRF |
| 316 | + reranker_params=None, # Should use default impact_factor |
| 317 | + ) |
| 318 | + |
| 319 | + assert isinstance(response, QueryChunksResponse) |
| 320 | + assert len(response.chunks) == 1 |
| 321 | + |
| 322 | + # Verify hybrid search was called with RRF reranker |
| 323 | + mock_milvus_client.hybrid_search.assert_called_once() |
| 324 | + call_args = mock_milvus_client.hybrid_search.call_args |
| 325 | + ranker = call_args[1]["ranker"] |
| 326 | + assert ranker is not None |
0 commit comments