Skip to content

Expose aggregation API from SearchIndex #230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions redisvl/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ class RedisVLException(Exception):

class RedisModuleVersionError(RedisVLException):
"""Invalid module versions installed"""


class RedisSearchError(RedisVLException):
"""Error while performing a search or aggregate request"""
12 changes: 6 additions & 6 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ def _classify_route(
)

try:
aggregation_result: AggregateResult = self._index.client.ft( # type: ignore
self._index.name
).aggregate(aggregate_request, vector_range_query.params)
aggregation_result: AggregateResult = self._index.aggregate(
aggregate_request, vector_range_query.params
)
except ResponseError as e:
if "VSS is not yet supported on FT.AGGREGATE" in str(e):
raise RuntimeError(
Expand Down Expand Up @@ -308,9 +308,9 @@ def _classify_multi_route(
)

try:
aggregation_result: AggregateResult = self._index.client.ft( # type: ignore
self._index.name
).aggregate(aggregate_request, vector_range_query.params)
aggregation_result: AggregateResult = self._index.aggregate(
aggregate_request, vector_range_query.params
)
except ResponseError as e:
if "VSS is not yet supported on FT.AGGREGATE" in str(e):
raise RuntimeError(
Expand Down
112 changes: 55 additions & 57 deletions redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)

if TYPE_CHECKING:
from redis.commands.search.aggregation import AggregateResult
from redis.commands.search.document import Document
from redis.commands.search.result import Result
from redisvl.query.query import BaseQuery
Expand All @@ -25,7 +26,7 @@
import redis.asyncio as aredis
from redis.commands.search.indexDefinition import IndexDefinition

from redisvl.exceptions import RedisModuleVersionError
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
from redisvl.query import BaseQuery, CountQuery, FilterQuery
from redisvl.query.filter import FilterExpression
Expand Down Expand Up @@ -123,36 +124,6 @@ async def wrapper(self, *args, **kwargs):
return decorator


def check_index_exists():
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.exists():
raise RuntimeError(
f"Index has not been created. Must be created before calling {func.__name__}"
)
return func(self, *args, **kwargs)

return wrapper

return decorator


def check_async_index_exists():
def decorator(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
if not await self.exists():
raise ValueError(
f"Index has not been created. Must be created before calling {func.__name__}"
)
return await func(self, *args, **kwargs)

return wrapper

return decorator


class BaseSearchIndex:
"""Base search engine class"""

Expand Down Expand Up @@ -486,7 +457,6 @@ def create(self, overwrite: bool = False, drop: bool = False) -> None:
logger.exception("Error while trying to create the index")
raise

@check_index_exists()
def delete(self, drop: bool = True):
"""Delete the search index while optionally dropping all keys associated
with the index.
Expand All @@ -502,8 +472,8 @@ def delete(self, drop: bool = True):
self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore
delete_documents=drop
)
except:
logger.exception("Error while deleting index")
except Exception as e:
raise RedisSearchError(f"Error while deleting index: {str(e)}") from e

def clear(self) -> int:
"""Clear all keys in Redis associated with the index, leaving the index
Expand Down Expand Up @@ -629,13 +599,29 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]:
return convert_bytes(obj[0])
return None

@check_index_exists()
def aggregate(self, *args, **kwargs) -> "AggregateResult":
"""Perform an aggregation operation against the index.

Wrapper around the aggregation API that adds the index name
to the query and passes along the rest of the arguments
to the redis-py ft().aggregate() method.

Returns:
Result: Raw Redis aggregation results.
"""
try:
return self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore
*args, **kwargs
)
except Exception as e:
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e

def search(self, *args, **kwargs) -> "Result":
"""Perform a search against the index.

Wrapper around redis.search.Search that adds the index name
to the search query and passes along the rest of the arguments
to the redis-py ft.search() method.
Wrapper around the search API that adds the index name
to the query and passes along the rest of the arguments
to the redis-py ft().search() method.

Returns:
Result: Raw Redis search results.
Expand All @@ -644,9 +630,8 @@ def search(self, *args, **kwargs) -> "Result":
return self._redis_client.ft(self.schema.index.name).search( # type: ignore
*args, **kwargs
)
except:
logger.exception("Error while searching")
raise
except Exception as e:
raise RedisSearchError(f"Error while searching: {str(e)}") from e

def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
"""Execute a query and process results."""
Expand Down Expand Up @@ -752,11 +737,11 @@ def _info(name: str, redis_client: redis.Redis) -> Dict[str, Any]:
"""Run FT.INFO to fetch information about the index."""
try:
return convert_bytes(redis_client.ft(name).info()) # type: ignore
except:
logger.exception(f"Error while fetching {name} index info")
raise
except Exception as e:
raise RedisSearchError(
f"Error while fetching {name} index info: {str(e)}"
) from e

@check_index_exists()
def info(self, name: Optional[str] = None) -> Dict[str, Any]:
"""Get information about the index.

Expand Down Expand Up @@ -1010,7 +995,6 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
logger.exception("Error while trying to create the index")
raise

@check_async_index_exists()
async def delete(self, drop: bool = True):
"""Delete the search index.

Expand All @@ -1025,9 +1009,8 @@ async def delete(self, drop: bool = True):
await self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore
delete_documents=drop
)
except:
logger.exception("Error while deleting index")
raise
except Exception as e:
raise RedisSearchError(f"Error while deleting index: {str(e)}") from e

async def clear(self) -> int:
"""Clear all keys in Redis associated with the index, leaving the index
Expand Down Expand Up @@ -1152,7 +1135,23 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
return convert_bytes(obj[0])
return None

@check_async_index_exists()
async def aggregate(self, *args, **kwargs) -> "AggregateResult":
"""Perform an aggregation operation against the index.

Wrapper around the aggregation API that adds the index name
to the query and passes along the rest of the arguments
to the redis-py ft().aggregate() method.

Returns:
Result: Raw Redis aggregation results.
"""
try:
return await self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore
*args, **kwargs
)
except Exception as e:
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e

async def search(self, *args, **kwargs) -> "Result":
"""Perform a search on this index.

Expand All @@ -1167,9 +1166,8 @@ async def search(self, *args, **kwargs) -> "Result":
return await self._redis_client.ft(self.schema.index.name).search( # type: ignore
*args, **kwargs
)
except:
logger.exception("Error while searching")
raise
except Exception as e:
raise RedisSearchError(f"Error while searching: {str(e)}") from e

async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
"""Asynchronously execute a query and process results."""
Expand Down Expand Up @@ -1275,11 +1273,11 @@ async def exists(self) -> bool:
async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]:
try:
return convert_bytes(await redis_client.ft(name).info()) # type: ignore
except:
logger.exception(f"Error while fetching {name} index info")
raise
except Exception as e:
raise RedisSearchError(
f"Error while fetching {name} index info: {str(e)}"
) from e

@check_async_index_exists()
async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
"""Get information about the index.

Expand Down
7 changes: 4 additions & 3 deletions tests/integration/test_async_search_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from redisvl.exceptions import RedisSearchError
from redisvl.index import AsyncSearchIndex
from redisvl.query import VectorQuery
from redisvl.redis.utils import convert_bytes
Expand Down Expand Up @@ -291,7 +292,7 @@ async def test_check_index_exists_before_delete(async_client, async_index):
await async_index.set_client(async_client)
await async_index.create(overwrite=True, drop=True)
await async_index.delete(drop=True)
with pytest.raises(ValueError):
with pytest.raises(RedisSearchError):
await async_index.delete()


Expand All @@ -307,7 +308,7 @@ async def test_check_index_exists_before_search(async_client, async_index):
return_fields=["user", "credit_score", "age", "job", "location"],
num_results=7,
)
with pytest.raises(ValueError):
with pytest.raises(RedisSearchError):
await async_index.search(query.query, query_params=query.params)


Expand All @@ -317,5 +318,5 @@ async def test_check_index_exists_before_info(async_client, async_index):
await async_index.create(overwrite=True, drop=True)
await async_index.delete(drop=True)

with pytest.raises(ValueError):
with pytest.raises(RedisSearchError):
await async_index.info()
7 changes: 4 additions & 3 deletions tests/integration/test_search_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from redisvl.exceptions import RedisSearchError
from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
from redisvl.redis.connection import RedisConnectionFactory, validate_modules
Expand Down Expand Up @@ -251,7 +252,7 @@ def test_check_index_exists_before_delete(client, index):
index.set_client(client)
index.create(overwrite=True, drop=True)
index.delete(drop=True)
with pytest.raises(RuntimeError):
with pytest.raises(RedisSearchError):
index.delete()


Expand All @@ -266,7 +267,7 @@ def test_check_index_exists_before_search(client, index):
return_fields=["user", "credit_score", "age", "job", "location"],
num_results=7,
)
with pytest.raises(RuntimeError):
with pytest.raises(RedisSearchError):
index.search(query.query, query_params=query.params)


Expand All @@ -275,7 +276,7 @@ def test_check_index_exists_before_info(client, index):
index.create(overwrite=True, drop=True)
index.delete(drop=True)

with pytest.raises(RuntimeError):
with pytest.raises(RedisSearchError):
index.info()


Expand Down
Loading