1515 Optional ,
1616 Union ,
1717)
18+ import warnings
19+
20+ from redisvl .utils .utils import deprecated_function
1821
1922if TYPE_CHECKING :
2023 from redis .commands .search .aggregation import AggregateResult
@@ -839,6 +842,7 @@ def __init__(
839842 async def disconnect (self ):
840843 """Asynchronously disconnect and cleanup the underlying async redis connection."""
841844 if self ._redis_client is not None :
845+ print (self ._redis_client )
842846 await self ._redis_client .aclose () # type: ignore
843847 self ._redis_client = None
844848
@@ -860,12 +864,13 @@ async def from_existing(
860864 redis_url (Optional[str]): The URL of the Redis server to
861865 connect to.
862866 """
863- if redis_url :
864- redis_client = RedisConnectionFactory .connect (
865- redis_url = redis_url , use_async = True , ** kwargs
866- )
867-
868- if not redis_client :
867+ if redis_client and redis_url :
868+ raise ValueError ("Cannot provide both redis_client and redis_url" )
869+ elif redis_url :
870+ redis_client = RedisConnectionFactory .get_async_redis_connection (url = redis_url , ** kwargs )
871+ elif redis_client :
872+ pass
873+ else :
869874 raise ValueError (
870875 "Must provide either a redis_url or redis_client to fetch Redis index info."
871876 )
@@ -888,14 +893,21 @@ async def from_existing(
888893 index_info = await cls ._info (name , redis_client )
889894 schema_dict = convert_index_info_to_schema (index_info )
890895 schema = IndexSchema .from_dict (schema_dict )
891- index = cls (schema , ** kwargs )
892- await index .set_client (redis_client )
893- return index
896+ return cls (schema , redis_client = redis_client , ** kwargs )
894897
898+ @deprecated_function ("client" , "Use await self.get_client()" )
895899 @property
896- def client (self ) -> Optional [ aredis .Redis ] :
900+ def client (self ) -> aredis .Redis :
897901 """The underlying redis-py client object."""
898- return self ._redis_client
902+ if self ._redis_client is None :
903+ if asyncio .current_task () is not None :
904+ warnings .warn ("Risk of deadlock! Use await self.get_client() if you are "
905+ "in an async context." )
906+ redis_client = asyncio .run_coroutine_threadsafe (self .get_client (), asyncio .get_event_loop ())
907+ client = redis_client .result ()
908+ else :
909+ client = self ._redis_client
910+ return client
899911
900912 async def connect (self , redis_url : Optional [str ] = None , ** kwargs ):
901913 """[DEPRECATED] Connect to a Redis instance. Use connection parameters in __init__."""
@@ -910,31 +922,19 @@ async def connect(self, redis_url: Optional[str] = None, **kwargs):
910922 )
911923 return await self .set_client (client )
912924
925+ @deprecated_function ("set_client" , "Pass connection parameters in __init__" )
913926 async def set_client (self , redis_client : Optional [aredis .Redis ]):
914- """[DEPRECATED] Manually set the Redis client to use with the search index.
915- This method is deprecated; please provide connection parameters in __init__.
916927 """
917- import warnings
918-
919- warnings .warn (
920- "set_client() is deprecated; pass connection parameters in __init__" ,
921- DeprecationWarning ,
922- )
923- return await self ._set_client (redis_client )
924-
925- async def _set_client (self , redis_client : Optional [redis .asyncio .Redis ]):
926- """
927- Set the Redis client to use with the search index.
928-
929- NOTE: Remove this method once the deprecation period is over.
928+ [DEPRECATED] Manually set the Redis client to use with the search index.
929+ This method is deprecated; please provide connection parameters in __init__.
930930 """
931931 if self ._redis_client is not None :
932932 await self ._redis_client .aclose () # type: ignore
933933 async with self ._lock :
934934 self ._redis_client = redis_client
935935 return self
936936
937- async def _get_client (self ) -> aredis .Redis :
937+ async def get_client (self ) -> aredis .Redis :
938938 """Lazily instantiate and return the async Redis client."""
939939 if self ._redis_client is None :
940940 async with self ._lock :
@@ -976,7 +976,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
976976 # overwrite an index in Redis; drop associated data (clean slate)
977977 await index.create(overwrite=True, drop=True)
978978 """
979- client = await self ._get_client ()
979+ client = await self .get_client ()
980980 redis_fields = self .schema .redis_fields
981981
982982 if not redis_fields :
@@ -1012,7 +1012,7 @@ async def delete(self, drop: bool = True):
10121012 Raises:
10131013 redis.exceptions.ResponseError: If the index does not exist.
10141014 """
1015- client = await self ._get_client ()
1015+ client = await self .get_client ()
10161016 try :
10171017 await client .ft (self .schema .index .name ).dropindex (delete_documents = drop )
10181018 except Exception as e :
@@ -1025,7 +1025,7 @@ async def clear(self) -> int:
10251025 Returns:
10261026 int: Count of records deleted from Redis.
10271027 """
1028- client = await self ._get_client ()
1028+ client = await self .get_client ()
10291029 total_records_deleted : int = 0
10301030
10311031 async for batch in self .paginate (
@@ -1046,7 +1046,7 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int:
10461046 Returns:
10471047 int: Count of records deleted from Redis.
10481048 """
1049- client = await self ._get_client ()
1049+ client = await self .get_client ()
10501050 if isinstance (keys , list ):
10511051 return await client .delete (* keys )
10521052 else :
@@ -1110,7 +1110,7 @@ async def add_field(d):
11101110 keys = await index.load(data, preprocess=add_field)
11111111
11121112 """
1113- client = await self ._get_client ()
1113+ client = await self .get_client ()
11141114 try :
11151115 return await self ._storage .awrite (
11161116 client ,
@@ -1137,7 +1137,7 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
11371137 Returns:
11381138 Dict[str, Any]: The fetched object.
11391139 """
1140- client = await self ._get_client ()
1140+ client = await self .get_client ()
11411141 obj = await self ._storage .aget (client , [self .key (id )])
11421142 if obj :
11431143 return convert_bytes (obj [0 ])
@@ -1153,10 +1153,9 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult":
11531153 Returns:
11541154 Result: Raw Redis aggregation results.
11551155 """
1156- client = await self ._get_client ()
1156+ client = await self .get_client ()
11571157 try :
1158- # TODO: Typing
1159- return await client .ft (self .schema .index .name ).aggregate (* args , ** kwargs )
1158+ return client .ft (self .schema .index .name ).aggregate (* args , ** kwargs )
11601159 except Exception as e :
11611160 raise RedisSearchError (f"Error while aggregating: { str (e )} " ) from e
11621161
@@ -1170,10 +1169,9 @@ async def search(self, *args, **kwargs) -> "Result":
11701169 Returns:
11711170 Result: Raw Redis search results.
11721171 """
1173- client = await self ._get_client ()
1172+ client = await self .get_client ()
11741173 try :
1175- # TODO: Typing
1176- return await client .ft (self .schema .index .name ).search (* args , ** kwargs )
1174+ return client .ft (self .schema .index .name ).search (* args , ** kwargs )
11771175 except Exception as e :
11781176 raise RedisSearchError (f"Error while searching: { str (e )} " ) from e
11791177
@@ -1264,7 +1262,7 @@ async def listall(self) -> List[str]:
12641262 Returns:
12651263 List[str]: The list of indices in the database.
12661264 """
1267- client : aredis .Redis = await self ._get_client ()
1265+ client : aredis .Redis = await self .get_client ()
12681266 return convert_bytes (await client .execute_command ("FT._LIST" ))
12691267
12701268 async def exists (self ) -> bool :
@@ -1285,7 +1283,7 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
12851283 Returns:
12861284 dict: A dictionary containing the information about the index.
12871285 """
1288- client : aredis .Redis = await self ._get_client ()
1286+ client : aredis .Redis = await self .get_client ()
12891287 index_name = name or self .schema .index .name
12901288 return await type (self )._info (index_name , client )
12911289
0 commit comments