1414 List ,
1515 Optional ,
1616 Union ,
17+ cast ,
1718)
1819import warnings
19-
2020from redisvl .utils .utils import deprecated_function
2121
2222if TYPE_CHECKING :
@@ -797,11 +797,6 @@ class AsyncSearchIndex(BaseSearchIndex):
797797
798798 """
799799
800- # TODO: The `aredis.Redis` type is not working for type checks.
801- _redis_client : Optional [redis .asyncio .Redis ] = None
802- _redis_url : Optional [str ] = None
803- _redis_kwargs : Dict [str , Any ] = {}
804-
805800 def __init__ (
806801 self ,
807802 schema : IndexSchema ,
@@ -842,7 +837,6 @@ def __init__(
842837 async def disconnect (self ):
843838 """Asynchronously disconnect and cleanup the underlying async redis connection."""
844839 if self ._redis_client is not None :
845- print (self ._redis_client )
846840 await self ._redis_client .aclose () # type: ignore
847841 self ._redis_client = None
848842
@@ -867,7 +861,9 @@ async def from_existing(
867861 if redis_client and redis_url :
868862 raise ValueError ("Cannot provide both redis_client and redis_url" )
869863 elif redis_url :
870- redis_client = RedisConnectionFactory .get_async_redis_connection (url = redis_url , ** kwargs )
864+ redis_client = RedisConnectionFactory .get_async_redis_connection (
865+ url = redis_url , ** kwargs
866+ )
871867 elif redis_client :
872868 pass
873869 else :
@@ -895,46 +891,36 @@ async def from_existing(
895891 schema = IndexSchema .from_dict (schema_dict )
896892 return cls (schema , redis_client = redis_client , ** kwargs )
897893
898- @deprecated_function ("client" , "Use await self.get_client()" )
899894 @property
900- def client (self ) -> aredis .Redis :
895+ def client (self ) -> Optional [ aredis .Redis ] :
901896 """The underlying redis-py client object."""
902- if self ._redis_client is None :
903- if asyncio .current_task () is not None :
904- warnings .warn ("Risk of deadlock! Use await self.get_client() if you are "
905- "in an async context." )
906- redis_client = asyncio .run_coroutine_threadsafe (self .get_client (), asyncio .get_event_loop ())
907- client = redis_client .result ()
908- else :
909- client = self ._redis_client
910- return client
897+ return self ._redis_client
911898
899+ @deprecated_function ("connect" , "Pass connection parameters in __init__." )
912900 async def connect (self , redis_url : Optional [str ] = None , ** kwargs ):
913901 """[DEPRECATED] Connect to a Redis instance. Use connection parameters in __init__."""
914- import warnings
915-
916902 warnings .warn (
917903 "connect() is deprecated; pass connection parameters in __init__" ,
918904 DeprecationWarning ,
919905 )
920- client = RedisConnectionFactory .connect (
906+ client : redis . asyncio . Redis = RedisConnectionFactory .connect (
921907 redis_url = redis_url , use_async = True , ** kwargs
922- )
908+ ) # type: ignore
923909 return await self .set_client (client )
924910
925- @deprecated_function ("set_client" , "Pass connection parameters in __init__" )
926- async def set_client (self , redis_client : Optional [ aredis .Redis ] ):
911+ @deprecated_function ("set_client" , "Pass connection parameters in __init__. " )
912+ async def set_client (self , redis_client : aredis .Redis ):
927913 """
928914 [DEPRECATED] Manually set the Redis client to use with the search index.
929915 This method is deprecated; please provide connection parameters in __init__.
930916 """
931- if self . _redis_client is not None :
932- await self ._redis_client . aclose () # type: ignore
917+ redis_client = await self . _validate_client ( redis_client )
918+ await self .disconnect ()
933919 async with self ._lock :
934920 self ._redis_client = redis_client
935921 return self
936922
937- async def get_client (self ) -> aredis . Redis :
923+ async def _get_client (self ):
938924 """Lazily instantiate and return the async Redis client."""
939925 if self ._redis_client is None :
940926 async with self ._lock :
@@ -951,6 +937,23 @@ async def get_client(self) -> aredis.Redis:
951937 )
952938 return self ._redis_client
953939
940+ async def get_client (self ) -> aredis .Redis :
941+ """Return this index's async Redis client."""
942+ return await self ._get_client ()
943+
944+ async def _validate_client (self , redis_client : aredis .Redis ) -> aredis .Redis :
945+ if isinstance (redis_client , redis .Redis ):
946+ warnings .warn (
947+ "Converting sync Redis client to async client is deprecated "
948+ "and will be removed in the next major version. Please use an "
949+ "async Redis client instead." ,
950+ DeprecationWarning ,
951+ )
952+ redis_client = RedisConnectionFactory .sync_to_async_redis (redis_client )
953+ elif not isinstance (redis_client , aredis .Redis ):
954+ raise ValueError ("Invalid client type: must be redis.asyncio.Redis" )
955+ return redis_client
956+
954957 async def create (self , overwrite : bool = False , drop : bool = False ) -> None :
955958 """Asynchronously create an index in Redis with the current schema
956959 and properties.
@@ -1171,7 +1174,7 @@ async def search(self, *args, **kwargs) -> "Result":
11711174 """
11721175 client = await self .get_client ()
11731176 try :
1174- return client .ft (self .schema .index .name ).search (* args , ** kwargs )
1177+ return await client .ft (self .schema .index .name ).search (* args , ** kwargs )
11751178 except Exception as e :
11761179 raise RedisSearchError (f"Error while searching: { str (e )} " ) from e
11771180
@@ -1262,7 +1265,7 @@ async def listall(self) -> List[str]:
12621265 Returns:
12631266 List[str]: The list of indices in the database.
12641267 """
1265- client : aredis . Redis = await self .get_client ()
1268+ client = await self .get_client ()
12661269 return convert_bytes (await client .execute_command ("FT._LIST" ))
12671270
12681271 async def exists (self ) -> bool :
@@ -1283,7 +1286,7 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
12831286 Returns:
12841287 dict: A dictionary containing the information about the index.
12851288 """
1286- client : aredis . Redis = await self .get_client ()
1289+ client = await self .get_client ()
12871290 index_name = name or self .schema .index .name
12881291 return await type (self )._info (index_name , client )
12891292
0 commit comments