Skip to content

Commit c2bf280

Browse files
committed
More cleanup
1 parent af1c2d0 commit c2bf280

File tree

5 files changed

+50
-47
lines changed

5 files changed

+50
-47
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,12 @@ async def _get_async_index(self) -> AsyncSearchIndex:
176176
"""Lazily construct the async search index class."""
177177
# Construct async index if necessary
178178
if not self._aindex:
179-
redis_client = self.redis_kwargs["redis_client"]
180-
redis_url = self.redis_kwargs["redis_url"]
181-
connection_kwargs = self.redis_kwargs["connection_kwargs"]
182-
183-
self._aindex = AsyncSearchIndex(schema=self._index.schema,
184-
redis_client=redis_client,
185-
redis_url=redis_url,
186-
**connection_kwargs)
179+
self._aindex = AsyncSearchIndex(
180+
schema=self._index.schema,
181+
redis_client=self.redis_kwargs["redis_client"],
182+
redis_url=self.redis_kwargs["redis_url"],
183+
**self.redis_kwargs["connection_kwargs"]
184+
)
187185
return self._aindex
188186

189187
@property

redisvl/index/index.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
List,
1515
Optional,
1616
Union,
17+
cast,
1718
)
1819
import warnings
19-
2020
from redisvl.utils.utils import deprecated_function
2121

2222
if TYPE_CHECKING:
@@ -797,11 +797,6 @@ class AsyncSearchIndex(BaseSearchIndex):
797797
798798
"""
799799

800-
# TODO: The `aredis.Redis` type is not working for type checks.
801-
_redis_client: Optional[redis.asyncio.Redis] = None
802-
_redis_url: Optional[str] = None
803-
_redis_kwargs: Dict[str, Any] = {}
804-
805800
def __init__(
806801
self,
807802
schema: IndexSchema,
@@ -842,7 +837,6 @@ def __init__(
842837
async def disconnect(self):
843838
"""Asynchronously disconnect and cleanup the underlying async redis connection."""
844839
if self._redis_client is not None:
845-
print(self._redis_client)
846840
await self._redis_client.aclose() # type: ignore
847841
self._redis_client = None
848842

@@ -867,7 +861,9 @@ async def from_existing(
867861
if redis_client and redis_url:
868862
raise ValueError("Cannot provide both redis_client and redis_url")
869863
elif redis_url:
870-
redis_client = RedisConnectionFactory.get_async_redis_connection(url=redis_url, **kwargs)
864+
redis_client = RedisConnectionFactory.get_async_redis_connection(
865+
url=redis_url, **kwargs
866+
)
871867
elif redis_client:
872868
pass
873869
else:
@@ -895,46 +891,36 @@ async def from_existing(
895891
schema = IndexSchema.from_dict(schema_dict)
896892
return cls(schema, redis_client=redis_client, **kwargs)
897893

898-
@deprecated_function("client", "Use await self.get_client()")
899894
@property
900-
def client(self) -> aredis.Redis:
895+
def client(self) -> Optional[aredis.Redis]:
901896
"""The underlying redis-py client object."""
902-
if self._redis_client is None:
903-
if asyncio.current_task() is not None:
904-
warnings.warn("Risk of deadlock! Use await self.get_client() if you are "
905-
"in an async context.")
906-
redis_client = asyncio.run_coroutine_threadsafe(self.get_client(), asyncio.get_event_loop())
907-
client = redis_client.result()
908-
else:
909-
client = self._redis_client
910-
return client
897+
return self._redis_client
911898

899+
@deprecated_function("connect", "Pass connection parameters in __init__.")
912900
async def connect(self, redis_url: Optional[str] = None, **kwargs):
913901
"""[DEPRECATED] Connect to a Redis instance. Use connection parameters in __init__."""
914-
import warnings
915-
916902
warnings.warn(
917903
"connect() is deprecated; pass connection parameters in __init__",
918904
DeprecationWarning,
919905
)
920-
client = RedisConnectionFactory.connect(
906+
client: redis.asyncio.Redis = RedisConnectionFactory.connect(
921907
redis_url=redis_url, use_async=True, **kwargs
922-
)
908+
) # type: ignore
923909
return await self.set_client(client)
924910

925-
@deprecated_function("set_client", "Pass connection parameters in __init__")
926-
async def set_client(self, redis_client: Optional[aredis.Redis]):
911+
@deprecated_function("set_client", "Pass connection parameters in __init__.")
912+
async def set_client(self, redis_client: aredis.Redis):
927913
"""
928914
[DEPRECATED] Manually set the Redis client to use with the search index.
929915
This method is deprecated; please provide connection parameters in __init__.
930916
"""
931-
if self._redis_client is not None:
932-
await self._redis_client.aclose() # type: ignore
917+
redis_client = await self._validate_client(redis_client)
918+
await self.disconnect()
933919
async with self._lock:
934920
self._redis_client = redis_client
935921
return self
936922

937-
async def get_client(self) -> aredis.Redis:
923+
async def _get_client(self):
938924
"""Lazily instantiate and return the async Redis client."""
939925
if self._redis_client is None:
940926
async with self._lock:
@@ -951,6 +937,23 @@ async def get_client(self) -> aredis.Redis:
951937
)
952938
return self._redis_client
953939

940+
async def get_client(self) -> aredis.Redis:
941+
"""Return this index's async Redis client."""
942+
return await self._get_client()
943+
944+
async def _validate_client(self, redis_client: aredis.Redis) -> aredis.Redis:
945+
if isinstance(redis_client, redis.Redis):
946+
warnings.warn(
947+
"Converting sync Redis client to async client is deprecated "
948+
"and will be removed in the next major version. Please use an "
949+
"async Redis client instead.",
950+
DeprecationWarning,
951+
)
952+
redis_client = RedisConnectionFactory.sync_to_async_redis(redis_client)
953+
elif not isinstance(redis_client, aredis.Redis):
954+
raise ValueError("Invalid client type: must be redis.asyncio.Redis")
955+
return redis_client
956+
954957
async def create(self, overwrite: bool = False, drop: bool = False) -> None:
955958
"""Asynchronously create an index in Redis with the current schema
956959
and properties.
@@ -1171,7 +1174,7 @@ async def search(self, *args, **kwargs) -> "Result":
11711174
"""
11721175
client = await self.get_client()
11731176
try:
1174-
return client.ft(self.schema.index.name).search(*args, **kwargs)
1177+
return await client.ft(self.schema.index.name).search(*args, **kwargs)
11751178
except Exception as e:
11761179
raise RedisSearchError(f"Error while searching: {str(e)}") from e
11771180

@@ -1262,7 +1265,7 @@ async def listall(self) -> List[str]:
12621265
Returns:
12631266
List[str]: The list of indices in the database.
12641267
"""
1265-
client: aredis.Redis = await self.get_client()
1268+
client = await self.get_client()
12661269
return convert_bytes(await client.execute_command("FT._LIST"))
12671270

12681271
async def exists(self) -> bool:
@@ -1283,7 +1286,7 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
12831286
Returns:
12841287
dict: A dictionary containing the information about the index.
12851288
"""
1286-
client: aredis.Redis = await self.get_client()
1289+
client = await self.get_client()
12871290
index_name = name or self.schema.index.name
12881291
return await type(self)._info(index_name, client)
12891292

redisvl/redis/connection.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any, Dict, List, Optional, Type
2+
from typing import Any, Dict, List, Optional, Type, Union
33

44
from redis import Redis
55
from redis.asyncio import Connection as AsyncConnection
@@ -12,6 +12,7 @@
1212
from redisvl.exceptions import RedisModuleVersionError
1313
from redisvl.redis.constants import DEFAULT_REQUIRED_MODULES
1414
from redisvl.redis.utils import convert_bytes
15+
from redisvl.utils.utils import deprecated_function
1516
from redisvl.version import __version__
1617

1718

@@ -191,7 +192,7 @@ class RedisConnectionFactory:
191192
@classmethod
192193
def connect(
193194
cls, redis_url: Optional[str] = None, use_async: bool = False, **kwargs
194-
) -> None:
195+
) -> Union[Redis, AsyncRedis]:
195196
"""Create a connection to the Redis database based on a URL and some
196197
connection kwargs.
197198
@@ -260,14 +261,15 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi
260261
# fallback to env var REDIS_URL
261262
return AsyncRedis.from_url(get_address_from_env(), **kwargs)
262263

264+
@deprecated_function("sync_to_async_redis", "Please use an async Redis client instead.")
263265
@staticmethod
264266
def sync_to_async_redis(redis_client: Redis) -> AsyncRedis:
265267
# pick the right connection class
266268
connection_class: Type[AbstractConnection] = (
267269
AsyncSSLConnection
268270
if redis_client.connection_pool.connection_class == SSLConnection
269271
else AsyncConnection
270-
)
272+
) # type: ignore
271273
# make async client
272274
return AsyncRedis.from_pool( # type: ignore
273275
AsyncConnectionPool(

redisvl/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def deprecated_function(name: Optional[str] = None, replacement: Optional[str] =
100100
def decorator(func):
101101
fn_name = name or func.__name__
102102
warning_message = f"Function {fn_name} is deprecated and will be " \
103-
"removed in the next major release."
103+
"removed in the next major release. "
104104
if replacement:
105105
warning_message += replacement
106106

tests/integration/test_async_search_index.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_search_index_properties(index_schema, async_index):
3333
assert async_index.schema == index_schema
3434
# custom settings
3535
assert async_index.name == index_schema.index.name == "my_index"
36-
assert async_index.client == None
36+
assert async_index.client is None
3737
# default settings
3838
assert async_index.prefix == index_schema.index.prefix == "rvl"
3939
assert async_index.key_separator == index_schema.index.key_separator == ":"
@@ -45,7 +45,7 @@ def test_search_index_properties(index_schema, async_index):
4545

4646
def test_search_index_from_yaml(async_index_from_yaml):
4747
assert async_index_from_yaml.name == "json-test"
48-
assert async_index_from_yaml.client == None
48+
assert async_index_from_yaml.client is None
4949
assert async_index_from_yaml.prefix == "json"
5050
assert async_index_from_yaml.key_separator == ":"
5151
assert async_index_from_yaml.storage_type == StorageType.JSON
@@ -54,7 +54,7 @@ def test_search_index_from_yaml(async_index_from_yaml):
5454

5555
def test_search_index_from_dict(async_index_from_dict):
5656
assert async_index_from_dict.name == "my_index"
57-
assert async_index_from_dict.client == None
57+
assert async_index_from_dict.client is None
5858
assert async_index_from_dict.prefix == "rvl"
5959
assert async_index_from_dict.key_separator == ":"
6060
assert async_index_from_dict.storage_type == StorageType.HASH
@@ -156,7 +156,7 @@ async def test_search_index_set_client(async_client, client, async_index):
156156
await async_index.set_client(client)
157157

158158
await async_index.disconnect()
159-
assert async_index.client == None
159+
assert async_index.client is None
160160

161161

162162
@pytest.mark.asyncio

0 commit comments

Comments
 (0)