Skip to content

Commit 19dedcb

Browse files
Expose aggregation API from SearchIndex (#230)
In order to support more advanced queries, we expose the `aggregate` method to pass through to the core Redis FT.AGGREGATE API. This PR also simplifies and standardizes error handling for Redis searches/aggregations on the index.
1 parent 3c74dee commit 19dedcb

File tree

5 files changed

+73
-69
lines changed

5 files changed

+73
-69
lines changed

redisvl/exceptions.py

+4
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@ class RedisVLException(Exception):
44

55
class RedisModuleVersionError(RedisVLException):
66
"""Invalid module versions installed"""
7+
8+
9+
class RedisSearchError(RedisVLException):
10+
"""Error while performing a search or aggregate request"""

redisvl/extensions/router/semantic.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,9 @@ def _classify_route(
256256
)
257257

258258
try:
259-
aggregation_result: AggregateResult = self._index.client.ft( # type: ignore
260-
self._index.name
261-
).aggregate(aggregate_request, vector_range_query.params)
259+
aggregation_result: AggregateResult = self._index.aggregate(
260+
aggregate_request, vector_range_query.params
261+
)
262262
except ResponseError as e:
263263
if "VSS is not yet supported on FT.AGGREGATE" in str(e):
264264
raise RuntimeError(
@@ -308,9 +308,9 @@ def _classify_multi_route(
308308
)
309309

310310
try:
311-
aggregation_result: AggregateResult = self._index.client.ft( # type: ignore
312-
self._index.name
313-
).aggregate(aggregate_request, vector_range_query.params)
311+
aggregation_result: AggregateResult = self._index.aggregate(
312+
aggregate_request, vector_range_query.params
313+
)
314314
except ResponseError as e:
315315
if "VSS is not yet supported on FT.AGGREGATE" in str(e):
316316
raise RuntimeError(

redisvl/index/index.py

+55-57
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818

1919
if TYPE_CHECKING:
20+
from redis.commands.search.aggregation import AggregateResult
2021
from redis.commands.search.document import Document
2122
from redis.commands.search.result import Result
2223
from redisvl.query.query import BaseQuery
@@ -25,7 +26,7 @@
2526
import redis.asyncio as aredis
2627
from redis.commands.search.indexDefinition import IndexDefinition
2728

28-
from redisvl.exceptions import RedisModuleVersionError
29+
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
2930
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
3031
from redisvl.query import BaseQuery, CountQuery, FilterQuery
3132
from redisvl.query.filter import FilterExpression
@@ -123,36 +124,6 @@ async def wrapper(self, *args, **kwargs):
123124
return decorator
124125

125126

126-
def check_index_exists():
127-
def decorator(func):
128-
@wraps(func)
129-
def wrapper(self, *args, **kwargs):
130-
if not self.exists():
131-
raise RuntimeError(
132-
f"Index has not been created. Must be created before calling {func.__name__}"
133-
)
134-
return func(self, *args, **kwargs)
135-
136-
return wrapper
137-
138-
return decorator
139-
140-
141-
def check_async_index_exists():
142-
def decorator(func):
143-
@wraps(func)
144-
async def wrapper(self, *args, **kwargs):
145-
if not await self.exists():
146-
raise ValueError(
147-
f"Index has not been created. Must be created before calling {func.__name__}"
148-
)
149-
return await func(self, *args, **kwargs)
150-
151-
return wrapper
152-
153-
return decorator
154-
155-
156127
class BaseSearchIndex:
157128
"""Base search engine class"""
158129

@@ -486,7 +457,6 @@ def create(self, overwrite: bool = False, drop: bool = False) -> None:
486457
logger.exception("Error while trying to create the index")
487458
raise
488459

489-
@check_index_exists()
490460
def delete(self, drop: bool = True):
491461
"""Delete the search index while optionally dropping all keys associated
492462
with the index.
@@ -502,8 +472,8 @@ def delete(self, drop: bool = True):
502472
self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore
503473
delete_documents=drop
504474
)
505-
except:
506-
logger.exception("Error while deleting index")
475+
except Exception as e:
476+
raise RedisSearchError(f"Error while deleting index: {str(e)}") from e
507477

508478
def clear(self) -> int:
509479
"""Clear all keys in Redis associated with the index, leaving the index
@@ -629,13 +599,29 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]:
629599
return convert_bytes(obj[0])
630600
return None
631601

632-
@check_index_exists()
602+
def aggregate(self, *args, **kwargs) -> "AggregateResult":
603+
"""Perform an aggregation operation against the index.
604+
605+
Wrapper around the aggregation API that adds the index name
606+
to the query and passes along the rest of the arguments
607+
to the redis-py ft().aggregate() method.
608+
609+
Returns:
610+
Result: Raw Redis aggregation results.
611+
"""
612+
try:
613+
return self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore
614+
*args, **kwargs
615+
)
616+
except Exception as e:
617+
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e
618+
633619
def search(self, *args, **kwargs) -> "Result":
634620
"""Perform a search against the index.
635621
636-
Wrapper around redis.search.Search that adds the index name
637-
to the search query and passes along the rest of the arguments
638-
to the redis-py ft.search() method.
622+
Wrapper around the search API that adds the index name
623+
to the query and passes along the rest of the arguments
624+
to the redis-py ft().search() method.
639625
640626
Returns:
641627
Result: Raw Redis search results.
@@ -644,9 +630,8 @@ def search(self, *args, **kwargs) -> "Result":
644630
return self._redis_client.ft(self.schema.index.name).search( # type: ignore
645631
*args, **kwargs
646632
)
647-
except:
648-
logger.exception("Error while searching")
649-
raise
633+
except Exception as e:
634+
raise RedisSearchError(f"Error while searching: {str(e)}") from e
650635

651636
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
652637
"""Execute a query and process results."""
@@ -752,11 +737,11 @@ def _info(name: str, redis_client: redis.Redis) -> Dict[str, Any]:
752737
"""Run FT.INFO to fetch information about the index."""
753738
try:
754739
return convert_bytes(redis_client.ft(name).info()) # type: ignore
755-
except:
756-
logger.exception(f"Error while fetching {name} index info")
757-
raise
740+
except Exception as e:
741+
raise RedisSearchError(
742+
f"Error while fetching {name} index info: {str(e)}"
743+
) from e
758744

759-
@check_index_exists()
760745
def info(self, name: Optional[str] = None) -> Dict[str, Any]:
761746
"""Get information about the index.
762747
@@ -1010,7 +995,6 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
1010995
logger.exception("Error while trying to create the index")
1011996
raise
1012997

1013-
@check_async_index_exists()
1014998
async def delete(self, drop: bool = True):
1015999
"""Delete the search index.
10161000
@@ -1025,9 +1009,8 @@ async def delete(self, drop: bool = True):
10251009
await self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore
10261010
delete_documents=drop
10271011
)
1028-
except:
1029-
logger.exception("Error while deleting index")
1030-
raise
1012+
except Exception as e:
1013+
raise RedisSearchError(f"Error while deleting index: {str(e)}") from e
10311014

10321015
async def clear(self) -> int:
10331016
"""Clear all keys in Redis associated with the index, leaving the index
@@ -1152,7 +1135,23 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
11521135
return convert_bytes(obj[0])
11531136
return None
11541137

1155-
@check_async_index_exists()
1138+
async def aggregate(self, *args, **kwargs) -> "AggregateResult":
1139+
"""Perform an aggregation operation against the index.
1140+
1141+
Wrapper around the aggregation API that adds the index name
1142+
to the query and passes along the rest of the arguments
1143+
to the redis-py ft().aggregate() method.
1144+
1145+
Returns:
1146+
Result: Raw Redis aggregation results.
1147+
"""
1148+
try:
1149+
return await self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore
1150+
*args, **kwargs
1151+
)
1152+
except Exception as e:
1153+
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e
1154+
11561155
async def search(self, *args, **kwargs) -> "Result":
11571156
"""Perform a search on this index.
11581157
@@ -1167,9 +1166,8 @@ async def search(self, *args, **kwargs) -> "Result":
11671166
return await self._redis_client.ft(self.schema.index.name).search( # type: ignore
11681167
*args, **kwargs
11691168
)
1170-
except:
1171-
logger.exception("Error while searching")
1172-
raise
1169+
except Exception as e:
1170+
raise RedisSearchError(f"Error while searching: {str(e)}") from e
11731171

11741172
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
11751173
"""Asynchronously execute a query and process results."""
@@ -1275,11 +1273,11 @@ async def exists(self) -> bool:
12751273
async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]:
12761274
try:
12771275
return convert_bytes(await redis_client.ft(name).info()) # type: ignore
1278-
except:
1279-
logger.exception(f"Error while fetching {name} index info")
1280-
raise
1276+
except Exception as e:
1277+
raise RedisSearchError(
1278+
f"Error while fetching {name} index info: {str(e)}"
1279+
) from e
12811280

1282-
@check_async_index_exists()
12831281
async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
12841282
"""Get information about the index.
12851283

tests/integration/test_async_search_index.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
from redisvl.exceptions import RedisSearchError
34
from redisvl.index import AsyncSearchIndex
45
from redisvl.query import VectorQuery
56
from redisvl.redis.utils import convert_bytes
@@ -291,7 +292,7 @@ async def test_check_index_exists_before_delete(async_client, async_index):
291292
await async_index.set_client(async_client)
292293
await async_index.create(overwrite=True, drop=True)
293294
await async_index.delete(drop=True)
294-
with pytest.raises(ValueError):
295+
with pytest.raises(RedisSearchError):
295296
await async_index.delete()
296297

297298

@@ -307,7 +308,7 @@ async def test_check_index_exists_before_search(async_client, async_index):
307308
return_fields=["user", "credit_score", "age", "job", "location"],
308309
num_results=7,
309310
)
310-
with pytest.raises(ValueError):
311+
with pytest.raises(RedisSearchError):
311312
await async_index.search(query.query, query_params=query.params)
312313

313314

@@ -317,5 +318,5 @@ async def test_check_index_exists_before_info(async_client, async_index):
317318
await async_index.create(overwrite=True, drop=True)
318319
await async_index.delete(drop=True)
319320

320-
with pytest.raises(ValueError):
321+
with pytest.raises(RedisSearchError):
321322
await async_index.info()

tests/integration/test_search_index.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
from redisvl.exceptions import RedisSearchError
34
from redisvl.index import SearchIndex
45
from redisvl.query import VectorQuery
56
from redisvl.redis.connection import RedisConnectionFactory, validate_modules
@@ -251,7 +252,7 @@ def test_check_index_exists_before_delete(client, index):
251252
index.set_client(client)
252253
index.create(overwrite=True, drop=True)
253254
index.delete(drop=True)
254-
with pytest.raises(RuntimeError):
255+
with pytest.raises(RedisSearchError):
255256
index.delete()
256257

257258

@@ -266,7 +267,7 @@ def test_check_index_exists_before_search(client, index):
266267
return_fields=["user", "credit_score", "age", "job", "location"],
267268
num_results=7,
268269
)
269-
with pytest.raises(RuntimeError):
270+
with pytest.raises(RedisSearchError):
270271
index.search(query.query, query_params=query.params)
271272

272273

@@ -275,7 +276,7 @@ def test_check_index_exists_before_info(client, index):
275276
index.create(overwrite=True, drop=True)
276277
index.delete(drop=True)
277278

278-
with pytest.raises(RuntimeError):
279+
with pytest.raises(RedisSearchError):
279280
index.info()
280281

281282

0 commit comments

Comments
 (0)