diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 3fe79631..1dfe2fe9 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -24,7 +24,11 @@ from redisvl.index.storage import HashStorage, JsonStorage from redisvl.query.query import BaseQuery, CountQuery, FilterQuery -from redisvl.redis.connection import RedisConnectionFactory +from redisvl.redis.connection import ( + RedisConnectionFactory, + convert_index_info_to_schema, + validate_modules, +) from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType from redisvl.utils.log import get_logger @@ -102,7 +106,7 @@ def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): if not self.exists(): - raise ValueError( + raise RuntimeError( f"Index has not been created. Must be created before calling {func.__name__}" ) return func(self, *args, **kwargs) @@ -162,7 +166,6 @@ def __init__( self.schema = schema - # set custom lib name self._lib_name: Optional[str] = kwargs.pop("lib_name", None) # set up redis connection @@ -317,6 +320,34 @@ class SearchIndex(BaseSearchIndex): """ + @classmethod + def from_existing( + cls, + name: str, + redis_client: Optional[redis.Redis] = None, + redis_url: Optional[str] = None, + **kwargs, + ): + # Handle redis instance + if redis_url: + redis_client = RedisConnectionFactory.connect( + redis_url=redis_url, use_async=False, **kwargs + ) + if not redis_client: + raise ValueError( + "Must provide either a redis_url or redis_client to fetch Redis index info." + ) + + # Validate modules + installed_modules = RedisConnectionFactory._get_modules(redis_client) + validate_modules(installed_modules, [{"name": "search", "ver": 20810}]) + + # Fetch index info and convert to schema + index_info = cls._info(name, redis_client) + schema_dict = convert_index_info_to_schema(index_info) + schema = IndexSchema.from_dict(schema_dict) + return cls(schema, redis_client, **kwargs) + def connect(self, redis_url: Optional[str] = None, **kwargs): """Connect to a Redis instance using the provided `redis_url`, falling back to the `REDIS_URL` environment variable (if available). @@ -653,22 +684,28 @@ def exists(self) -> bool: """ return self.schema.index.name in self.listall() + @staticmethod + def _info(name: str, redis_client: redis.Redis) -> Dict[str, Any]: + """Run FT.INFO to fetch information about the index.""" + try: + return convert_bytes(redis_client.ft(name).info()) # type: ignore + except: + logger.exception(f"Error while fetching {name} index info") + raise + @check_index_exists() - def info(self) -> Dict[str, Any]: + def info(self, name: Optional[str] = None) -> Dict[str, Any]: """Get information about the index. + Args: + name (str, optional): Index name to fetch info about. + Defaults to None. + Returns: dict: A dictionary containing the information about the index. """ - try: - return convert_bytes( - self._redis_client.ft(self.schema.index.name).info() # type: ignore - ) - except: - logger.exception( - f"Error while fetching {self.schema.index.name} index info" - ) - raise + index_name = name or self.schema.index.name + return self._info(index_name, self._redis_client) # type: ignore class AsyncSearchIndex(BaseSearchIndex): @@ -698,6 +735,36 @@ class AsyncSearchIndex(BaseSearchIndex): """ + @classmethod + async def from_existing( + cls, + name: str, + redis_client: Optional[aredis.Redis] = None, + redis_url: Optional[str] = None, + **kwargs, + ): + if redis_url: + redis_client = RedisConnectionFactory.connect( + redis_url=redis_url, use_async=True, **kwargs + ) + + if not redis_client: + raise ValueError( + "Must provide either a redis_url or redis_client to fetch Redis index info." + ) + + # Validate modules + installed_modules = await RedisConnectionFactory._get_modules_async( + redis_client + ) + validate_modules(installed_modules, [{"name": "search", "ver": 20810}]) + + # Fetch index info and convert to schema + index_info = await cls._info(name, redis_client) + schema_dict = convert_index_info_to_schema(index_info) + schema = IndexSchema.from_dict(schema_dict) + return cls(schema, redis_client, **kwargs) + def connect(self, redis_url: Optional[str] = None, **kwargs): """Connect to a Redis instance using the provided `redis_url`, falling back to the `REDIS_URL` environment variable (if available). @@ -1035,19 +1102,24 @@ async def exists(self) -> bool: """ return self.schema.index.name in await self.listall() + @staticmethod + async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]: + try: + return convert_bytes(await redis_client.ft(name).info()) # type: ignore + except: + logger.exception(f"Error while fetching {name} index info") + raise + @check_async_index_exists() - async def info(self) -> Dict[str, Any]: + async def info(self, name: Optional[str] = None) -> Dict[str, Any]: """Get information about the index. + Args: + name (str, optional): Index name to fetch info about. + Defaults to None. + Returns: dict: A dictionary containing the information about the index. """ - try: - return convert_bytes( - await self._redis_client.ft(self.schema.index.name).info() # type: ignore - ) - except: - logger.exception( - f"Error while fetching {self.schema.index.name} index info" - ) - raise + index_name = name or self.schema.index.name + return await self._info(index_name, self._redis_client) # type: ignore diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index d747b7ca..3e949dd3 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -13,11 +13,16 @@ ) from redis.exceptions import ResponseError -from redisvl.redis.constants import REDIS_REQUIRED_MODULES +from redisvl.redis.constants import DEFAULT_REQUIRED_MODULES from redisvl.redis.utils import convert_bytes from redisvl.version import __version__ +def unpack_redis_modules(module_list: List[Dict[str, Any]]) -> Dict[str, Any]: + """Unpack a list of Redis modules pulled from the MODULES LIST command.""" + return {module["name"]: module["ver"] for module in module_list} + + def get_address_from_env() -> str: """Get a redis connection from environment variables. @@ -43,6 +48,82 @@ def make_lib_name(*args) -> str: return f"redis-py({custom_libs})" +def convert_index_info_to_schema(index_info: Dict[str, Any]) -> Dict[str, Any]: + """Convert the output of FT.INFO into a schema-ready dictionary. + + Args: + index_info (Dict[str, Any]): Output of the Redis FT.INFO command. + + Returns: + Dict[str, Any]: Schema dictionary. + """ + index_name = index_info["index_name"] + prefixes = index_info["index_definition"][3][0] + storage_type = index_info["index_definition"][1].lower() + + index_fields = index_info["attributes"] + + def parse_vector_attrs(attrs): + vector_attrs = {attrs[i].lower(): attrs[i + 1] for i in range(6, len(attrs), 2)} + vector_attrs["dims"] = int(vector_attrs.pop("dim")) + vector_attrs["distance_metric"] = vector_attrs.pop("distance_metric").lower() + vector_attrs["algorithm"] = vector_attrs.pop("algorithm").lower() + vector_attrs["datatype"] = vector_attrs.pop("data_type").lower() + return vector_attrs + + def parse_attrs(attrs): + return {attrs[i].lower(): attrs[i + 1] for i in range(6, len(attrs), 2)} + + schema_fields = [] + + for field_attrs in index_fields: + # parse field info + name = field_attrs[1] if storage_type == "hash" else field_attrs[3] + field = {"name": name, "type": field_attrs[5].lower()} + if storage_type == "json": + field["path"] = field_attrs[1] + # parse field attrs + if field_attrs[5] == "VECTOR": + field["attrs"] = parse_vector_attrs(field_attrs) + else: + field["attrs"] = parse_attrs(field_attrs) + # append field + schema_fields.append(field) + + return { + "index": {"name": index_name, "prefix": prefixes, "storage_type": storage_type}, + "fields": schema_fields, + } + + +def validate_modules( + installed_modules: Dict[str, Any], + required_modules: Optional[List[Dict[str, Any]]] = None, +) -> None: + """ + Validates if required Redis modules are installed. + + Args: + installed_modules: List of installed modules. + required_modules: List of required modules. + + Raises: + ValueError: If required Redis modules are not installed. + """ + required_modules = required_modules or DEFAULT_REQUIRED_MODULES + + for required_module in required_modules: + if required_module["name"] in installed_modules: + installed_version = installed_modules[required_module["name"]] # type: ignore + if int(installed_version) >= int(required_module["ver"]): # type: ignore + return + + raise ValueError( + f"Required Redis database module {required_module['name']} with version >= {required_module['ver']} not installed. " + "See Redis Stack documentation: https://redis.io/docs/stack/" + ) + + class RedisConnectionFactory: """Builds connections to a Redis database, supporting both synchronous and asynchronous clients. @@ -128,14 +209,14 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi def validate_redis( client: Union[Redis, AsyncRedis], lib_name: Optional[str] = None, - redis_required_modules: Optional[List[Dict[str, Any]]] = None, + required_modules: Optional[List[Dict[str, Any]]] = None, ) -> None: """Validates the Redis connection. Args: client (Redis or AsyncRedis): Redis client. lib_name (str): Library name to set on the Redis client. - redis_required_modules (List[Dict[str, Any]]): List of required modules and their versions. + required_modules (List[Dict[str, Any]]): List of required modules and their versions. Raises: ValueError: If required Redis modules are not installed. @@ -145,18 +226,26 @@ def validate_redis( RedisConnectionFactory._validate_async_redis, client, lib_name, - redis_required_modules, + required_modules, ) else: RedisConnectionFactory._validate_sync_redis( - client, lib_name, redis_required_modules + client, lib_name, required_modules ) + @staticmethod + def _get_modules(client: Redis) -> Dict[str, Any]: + return unpack_redis_modules(convert_bytes(client.module_list())) + + @staticmethod + async def _get_modules_async(client: AsyncRedis) -> Dict[str, Any]: + return unpack_redis_modules(convert_bytes(await client.module_list())) + @staticmethod def _validate_sync_redis( client: Redis, lib_name: Optional[str], - redis_required_modules: Optional[List[Dict[str, Any]]], + required_modules: Optional[List[Dict[str, Any]]], ) -> None: """Validates the sync client.""" # Set client library name @@ -168,16 +257,16 @@ def _validate_sync_redis( client.echo(_lib_name) # Get list of modules - modules_list = convert_bytes(client.module_list()) + installed_modules = RedisConnectionFactory._get_modules(client) # Validate available modules - RedisConnectionFactory._validate_modules(modules_list, redis_required_modules) + validate_modules(installed_modules, required_modules) @staticmethod async def _validate_async_redis( client: AsyncRedis, lib_name: Optional[str], - redis_required_modules: Optional[List[Dict[str, Any]]], + required_modules: Optional[List[Dict[str, Any]]], ) -> None: """Validates the async client.""" # Set client library name @@ -189,10 +278,10 @@ async def _validate_async_redis( await client.echo(_lib_name) # Get list of modules - modules_list = convert_bytes(await client.module_list()) + installed_modules = await RedisConnectionFactory._get_modules_async(client) # Validate available modules - RedisConnectionFactory._validate_modules(modules_list, redis_required_modules) + validate_modules(installed_modules, required_modules) @staticmethod def _run_async(coro, *args, **kwargs): @@ -232,31 +321,3 @@ def _run_async(coro, *args, **kwargs): finally: # Close the event loop to release resources loop.close() - - @staticmethod - def _validate_modules( - installed_modules, redis_required_modules: Optional[List[Dict[str, Any]]] = None - ) -> None: - """ - Validates if required Redis modules are installed. - - Args: - installed_modules: List of installed modules. - redis_required_modules: List of required modules. - - Raises: - ValueError: If required Redis modules are not installed. - """ - installed_modules = {module["name"]: module for module in installed_modules} - redis_required_modules = redis_required_modules or REDIS_REQUIRED_MODULES - - for required_module in redis_required_modules: - if required_module["name"] in installed_modules: - installed_version = installed_modules[required_module["name"]]["ver"] - if int(installed_version) >= int(required_module["ver"]): # type: ignore - return - - raise ValueError( - f"Required Redis database module {required_module['name']} with version >= {required_module['ver']} not installed. " - "Refer to Redis Stack documentation: https://redis.io/docs/stack/" - ) diff --git a/redisvl/redis/constants.py b/redisvl/redis/constants.py index 43ec7394..aeae2541 100644 --- a/redisvl/redis/constants.py +++ b/redisvl/redis/constants.py @@ -1,5 +1,5 @@ # required modules -REDIS_REQUIRED_MODULES = [ +DEFAULT_REQUIRED_MODULES = [ {"name": "search", "ver": 20600}, {"name": "searchlight", "ver": 20600}, ] diff --git a/tests/unit/test_async_search_index.py b/tests/integration/test_async_search_index.py similarity index 66% rename from tests/unit/test_async_search_index.py rename to tests/integration/test_async_search_index.py index 3e7f6793..b4f57cf8 100644 --- a/tests/unit/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -18,6 +18,16 @@ def async_index(index_schema): return AsyncSearchIndex(schema=index_schema) +@pytest.fixture +def async_index_from_dict(): + return AsyncSearchIndex.from_dict({"index": {"name": "my_index"}, "fields": fields}) + + +@pytest.fixture +def async_index_from_yaml(): + return AsyncSearchIndex.from_yaml("schemas/test_json_schema.yaml") + + def test_search_index_properties(index_schema, async_index): assert async_index.schema == index_schema # custom settings @@ -32,6 +42,83 @@ def test_search_index_properties(index_schema, async_index): assert async_index.key("foo").startswith(async_index.prefix) +def test_search_index_from_yaml(async_index_from_yaml): + assert async_index_from_yaml.name == "json-test" + assert async_index_from_yaml.client == None + assert async_index_from_yaml.prefix == "json" + assert async_index_from_yaml.key_separator == ":" + assert async_index_from_yaml.storage_type == StorageType.JSON + assert async_index_from_yaml.key("foo").startswith(async_index_from_yaml.prefix) + + +def test_search_index_from_dict(async_index_from_dict): + assert async_index_from_dict.name == "my_index" + assert async_index_from_dict.client == None + assert async_index_from_dict.prefix == "rvl" + assert async_index_from_dict.key_separator == ":" + assert async_index_from_dict.storage_type == StorageType.HASH + assert len(async_index_from_dict.schema.fields) == len(fields) + assert async_index_from_dict.key("foo").startswith(async_index_from_dict.prefix) + + +@pytest.mark.asyncio +async def test_search_index_from_existing(async_client, async_index): + async_index.set_client(async_client) + await async_index.create(overwrite=True) + + try: + async_index2 = await AsyncSearchIndex.from_existing( + async_index.name, redis_client=async_client + ) + except Exception as e: + pytest.skip(str(e)) + + assert async_index2.schema == async_index.schema + + +@pytest.mark.asyncio +async def test_search_index_from_existing_complex(async_client): + schema = { + "index": { + "name": "test", + "prefix": "test", + "storage_type": "json", + }, + "fields": [ + {"name": "user", "type": "tag", "path": "$.user"}, + {"name": "credit_score", "type": "tag", "path": "$.metadata.credit_score"}, + {"name": "job", "type": "text", "path": "$.metadata.job"}, + { + "name": "age", + "type": "numeric", + "path": "$.metadata.age", + "attrs": {"sortable": False}, + }, + { + "name": "user_embedding", + "type": "vector", + "attrs": { + "dims": 3, + "distance_metric": "cosine", + "algorithm": "flat", + "datatype": "float32", + }, + }, + ], + } + async_index = AsyncSearchIndex.from_dict(schema, redis_client=async_client) + await async_index.create(overwrite=True) + + try: + async_index2 = await AsyncSearchIndex.from_existing( + async_index.name, redis_client=async_client + ) + except Exception as e: + pytest.skip(str(e)) + + assert async_index2.schema == async_index.schema + + def test_search_index_no_prefix(index_schema): # specify an explicitly empty prefix... index_schema.index.prefix = "" diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py index cfe6e45b..30f1c3f3 100644 --- a/tests/integration/test_connection.py +++ b/tests/integration/test_connection.py @@ -5,7 +5,14 @@ from redis.asyncio import Redis as AsyncRedis from redis.exceptions import ConnectionError -from redisvl.redis.connection import RedisConnectionFactory, get_address_from_env +from redisvl.redis.connection import ( + RedisConnectionFactory, + convert_index_info_to_schema, + get_address_from_env, + unpack_redis_modules, + validate_modules, +) +from redisvl.schema import IndexSchema from redisvl.version import __version__ EXPECTED_LIB_NAME = f"redis-py(redisvl_v{__version__})" @@ -44,6 +51,105 @@ def test_get_address_from_env(redis_url): assert get_address_from_env() == redis_url +def test_unpack_redis_modules(): + module_list = [ + { + "name": "search", + "ver": 20811, + "path": "/opt/redis-stack/lib/redisearch.so", + "args": [], + }, + { + "name": "ReJSON", + "ver": 20609, + "path": "/opt/redis-stack/lib/rejson.so", + "args": [], + }, + ] + installed_modules = unpack_redis_modules(module_list) + assert installed_modules == {"search": 20811, "ReJSON": 20609} + + +def test_convert_index_info_to_schema(): + index_info = { + "index_name": "image_summaries", + "index_options": [], + "index_definition": [ + "key_type", + "HASH", + "prefixes", + ["summary"], + "default_score", + "1", + ], + "attributes": [ + [ + "identifier", + "content", + "attribute", + "content", + "type", + "TEXT", + "WEIGHT", + "1", + ], + [ + "identifier", + "doc_id", + "attribute", + "doc_id", + "type", + "TAG", + "SEPARATOR", + ",", + ], + [ + "identifier", + "content_vector", + "attribute", + "content_vector", + "type", + "VECTOR", + "algorithm", + "FLAT", + "data_type", + "FLOAT32", + "dim", + 1536, + "distance_metric", + "COSINE", + ], + ], + } + schema_dict = convert_index_info_to_schema(index_info) + assert "index" in schema_dict + assert "fields" in schema_dict + assert len(schema_dict["fields"]) == len(index_info["attributes"]) + + schema = IndexSchema.from_dict(schema_dict) + assert schema.index.name == index_info["index_name"] + + +def test_validate_modules_exist(): + validate_modules( + installed_modules={"search": 20811}, + required_modules=[ + {"name": "search", "ver": 20600}, + {"name": "searchlight", "ver": 20600}, + ], + ) + + +def test_validate_modules_not_exist(): + with pytest.raises(ValueError): + validate_modules( + installed_modules={"search": 20811}, + required_modules=[ + {"name": "ReJSON", "ver": 20600}, + ], + ) + + def test_sync_redis_connect(redis_url): client = RedisConnectionFactory.connect(redis_url) assert client is not None diff --git a/tests/unit/test_search_index.py b/tests/integration/test_search_index.py similarity index 69% rename from tests/unit/test_search_index.py rename to tests/integration/test_search_index.py index 503153a2..5243e6ba 100644 --- a/tests/unit/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -2,6 +2,7 @@ from redisvl.index import SearchIndex from redisvl.query import VectorQuery +from redisvl.redis.connection import RedisConnectionFactory, validate_modules from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType @@ -18,6 +19,11 @@ def index(index_schema): return SearchIndex(schema=index_schema) +@pytest.fixture +def index_from_dict(): + return SearchIndex.from_dict({"index": {"name": "my_index"}, "fields": fields}) + + @pytest.fixture def index_from_yaml(): return SearchIndex.from_yaml("schemas/test_json_schema.yaml") @@ -44,6 +50,68 @@ def test_search_index_from_yaml(index_from_yaml): assert index_from_yaml.key("foo").startswith(index_from_yaml.prefix) +def test_search_index_from_dict(index_from_dict): + assert index_from_dict.name == "my_index" + assert index_from_dict.client == None + assert index_from_dict.prefix == "rvl" + assert index_from_dict.key_separator == ":" + assert index_from_dict.storage_type == StorageType.HASH + assert len(index_from_dict.schema.fields) == len(fields) + assert index_from_dict.key("foo").startswith(index_from_dict.prefix) + + +def test_search_index_from_existing(client, index): + index.set_client(client) + index.create(overwrite=True) + + try: + index2 = SearchIndex.from_existing(index.name, redis_client=client) + except Exception as e: + pytest.skip(str(e)) + + assert index2.schema == index.schema + + +def test_search_index_from_existing_complex(client): + schema = { + "index": { + "name": "test", + "prefix": "test", + "storage_type": "json", + }, + "fields": [ + {"name": "user", "type": "tag", "path": "$.user"}, + {"name": "credit_score", "type": "tag", "path": "$.metadata.credit_score"}, + {"name": "job", "type": "text", "path": "$.metadata.job"}, + { + "name": "age", + "type": "numeric", + "path": "$.metadata.age", + "attrs": {"sortable": False}, + }, + { + "name": "user_embedding", + "type": "vector", + "attrs": { + "dims": 3, + "distance_metric": "cosine", + "algorithm": "flat", + "datatype": "float32", + }, + }, + ], + } + index = SearchIndex.from_dict(schema, redis_client=client) + index.create(overwrite=True) + + try: + index2 = SearchIndex.from_existing(index.name, redis_client=client) + except Exception as e: + pytest.skip(str(e)) + + assert index.schema == index2.schema + + def test_search_index_no_prefix(index_schema): # specify an explicitly empty prefix... index_schema.index.prefix = "" @@ -125,12 +193,6 @@ def bad_preprocess(record): index.load(data, id_field="id", preprocess=bad_preprocess) -def test_search_index_load_empty(client, index): - index.set_client(client) - index.create(overwrite=True, drop=True) - index.load([]) - - def test_no_id_field(client, index): index.set_client(client) index.create(overwrite=True, drop=True) @@ -145,7 +207,7 @@ def test_check_index_exists_before_delete(client, index): index.set_client(client) index.create(overwrite=True, drop=True) index.delete(drop=True) - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): index.delete() @@ -160,7 +222,7 @@ def test_check_index_exists_before_search(client, index): return_fields=["user", "credit_score", "age", "job", "location"], num_results=7, ) - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): index.search(query.query, query_params=query.params) @@ -169,7 +231,7 @@ def test_check_index_exists_before_info(client, index): index.create(overwrite=True, drop=True) index.delete(drop=True) - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): index.info()