Skip to content

Commit af1c2d0

Browse files
committed
WIP
1 parent 7907674 commit af1c2d0

File tree

5 files changed

+94
-53
lines changed

5 files changed

+94
-53
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,17 +174,16 @@ def _modify_schema(
174174

175175
async def _get_async_index(self) -> AsyncSearchIndex:
176176
"""Lazily construct the async search index class."""
177+
# Construct async index if necessary
177178
if not self._aindex:
178-
# Construct async index if necessary
179-
self._aindex = AsyncSearchIndex(schema=self._index.schema)
180-
# Connect Redis async client
181179
redis_client = self.redis_kwargs["redis_client"]
182180
redis_url = self.redis_kwargs["redis_url"]
183181
connection_kwargs = self.redis_kwargs["connection_kwargs"]
184-
if redis_client is not None:
185-
await self._aindex.set_client(redis_client)
186-
elif redis_url:
187-
await self._aindex.connect(redis_url, **connection_kwargs) # type: ignore
182+
183+
self._aindex = AsyncSearchIndex(schema=self._index.schema,
184+
redis_client=redis_client,
185+
redis_url=redis_url,
186+
**connection_kwargs)
188187
return self._aindex
189188

190189
@property
@@ -290,7 +289,8 @@ async def _async_refresh_ttl(self, key: str) -> None:
290289
"""Async refresh the time-to-live for the specified key."""
291290
aindex = await self._get_async_index()
292291
if self._ttl:
293-
await aindex.client.expire(key, self._ttl) # type: ignore
292+
client = await aindex.get_client()
293+
await client.expire(key, self._ttl) # type: ignore
294294

295295
def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
296296
"""Converts a text prompt to its vector representation using the

redisvl/index/index.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
Optional,
1616
Union,
1717
)
18+
import warnings
19+
20+
from redisvl.utils.utils import deprecated_function
1821

1922
if TYPE_CHECKING:
2023
from redis.commands.search.aggregation import AggregateResult
@@ -839,6 +842,7 @@ def __init__(
839842
async def disconnect(self):
840843
"""Asynchronously disconnect and cleanup the underlying async redis connection."""
841844
if self._redis_client is not None:
845+
print(self._redis_client)
842846
await self._redis_client.aclose() # type: ignore
843847
self._redis_client = None
844848

@@ -860,12 +864,13 @@ async def from_existing(
860864
redis_url (Optional[str]): The URL of the Redis server to
861865
connect to.
862866
"""
863-
if redis_url:
864-
redis_client = RedisConnectionFactory.connect(
865-
redis_url=redis_url, use_async=True, **kwargs
866-
)
867-
868-
if not redis_client:
867+
if redis_client and redis_url:
868+
raise ValueError("Cannot provide both redis_client and redis_url")
869+
elif redis_url:
870+
redis_client = RedisConnectionFactory.get_async_redis_connection(url=redis_url, **kwargs)
871+
elif redis_client:
872+
pass
873+
else:
869874
raise ValueError(
870875
"Must provide either a redis_url or redis_client to fetch Redis index info."
871876
)
@@ -888,14 +893,21 @@ async def from_existing(
888893
index_info = await cls._info(name, redis_client)
889894
schema_dict = convert_index_info_to_schema(index_info)
890895
schema = IndexSchema.from_dict(schema_dict)
891-
index = cls(schema, **kwargs)
892-
await index.set_client(redis_client)
893-
return index
896+
return cls(schema, redis_client=redis_client, **kwargs)
894897

898+
@deprecated_function("client", "Use await self.get_client()")
895899
@property
896-
def client(self) -> Optional[aredis.Redis]:
900+
def client(self) -> aredis.Redis:
897901
"""The underlying redis-py client object."""
898-
return self._redis_client
902+
if self._redis_client is None:
903+
if asyncio.current_task() is not None:
904+
warnings.warn("Risk of deadlock! Use await self.get_client() if you are "
905+
"in an async context.")
906+
redis_client = asyncio.run_coroutine_threadsafe(self.get_client(), asyncio.get_event_loop())
907+
client = redis_client.result()
908+
else:
909+
client = self._redis_client
910+
return client
899911

900912
async def connect(self, redis_url: Optional[str] = None, **kwargs):
901913
"""[DEPRECATED] Connect to a Redis instance. Use connection parameters in __init__."""
@@ -910,31 +922,19 @@ async def connect(self, redis_url: Optional[str] = None, **kwargs):
910922
)
911923
return await self.set_client(client)
912924

925+
@deprecated_function("set_client", "Pass connection parameters in __init__")
913926
async def set_client(self, redis_client: Optional[aredis.Redis]):
914-
"""[DEPRECATED] Manually set the Redis client to use with the search index.
915-
This method is deprecated; please provide connection parameters in __init__.
916927
"""
917-
import warnings
918-
919-
warnings.warn(
920-
"set_client() is deprecated; pass connection parameters in __init__",
921-
DeprecationWarning,
922-
)
923-
return await self._set_client(redis_client)
924-
925-
async def _set_client(self, redis_client: Optional[redis.asyncio.Redis]):
926-
"""
927-
Set the Redis client to use with the search index.
928-
929-
NOTE: Remove this method once the deprecation period is over.
928+
[DEPRECATED] Manually set the Redis client to use with the search index.
929+
This method is deprecated; please provide connection parameters in __init__.
930930
"""
931931
if self._redis_client is not None:
932932
await self._redis_client.aclose() # type: ignore
933933
async with self._lock:
934934
self._redis_client = redis_client
935935
return self
936936

937-
async def _get_client(self) -> aredis.Redis:
937+
async def get_client(self) -> aredis.Redis:
938938
"""Lazily instantiate and return the async Redis client."""
939939
if self._redis_client is None:
940940
async with self._lock:
@@ -976,7 +976,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
976976
# overwrite an index in Redis; drop associated data (clean slate)
977977
await index.create(overwrite=True, drop=True)
978978
"""
979-
client = await self._get_client()
979+
client = await self.get_client()
980980
redis_fields = self.schema.redis_fields
981981

982982
if not redis_fields:
@@ -1012,7 +1012,7 @@ async def delete(self, drop: bool = True):
10121012
Raises:
10131013
redis.exceptions.ResponseError: If the index does not exist.
10141014
"""
1015-
client = await self._get_client()
1015+
client = await self.get_client()
10161016
try:
10171017
await client.ft(self.schema.index.name).dropindex(delete_documents=drop)
10181018
except Exception as e:
@@ -1025,7 +1025,7 @@ async def clear(self) -> int:
10251025
Returns:
10261026
int: Count of records deleted from Redis.
10271027
"""
1028-
client = await self._get_client()
1028+
client = await self.get_client()
10291029
total_records_deleted: int = 0
10301030

10311031
async for batch in self.paginate(
@@ -1046,7 +1046,7 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int:
10461046
Returns:
10471047
int: Count of records deleted from Redis.
10481048
"""
1049-
client = await self._get_client()
1049+
client = await self.get_client()
10501050
if isinstance(keys, list):
10511051
return await client.delete(*keys)
10521052
else:
@@ -1110,7 +1110,7 @@ async def add_field(d):
11101110
keys = await index.load(data, preprocess=add_field)
11111111
11121112
"""
1113-
client = await self._get_client()
1113+
client = await self.get_client()
11141114
try:
11151115
return await self._storage.awrite(
11161116
client,
@@ -1137,7 +1137,7 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
11371137
Returns:
11381138
Dict[str, Any]: The fetched object.
11391139
"""
1140-
client = await self._get_client()
1140+
client = await self.get_client()
11411141
obj = await self._storage.aget(client, [self.key(id)])
11421142
if obj:
11431143
return convert_bytes(obj[0])
@@ -1153,10 +1153,9 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult":
11531153
Returns:
11541154
Result: Raw Redis aggregation results.
11551155
"""
1156-
client = await self._get_client()
1156+
client = await self.get_client()
11571157
try:
1158-
# TODO: Typing
1159-
return await client.ft(self.schema.index.name).aggregate(*args, **kwargs)
1158+
return client.ft(self.schema.index.name).aggregate(*args, **kwargs)
11601159
except Exception as e:
11611160
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e
11621161

@@ -1170,10 +1169,9 @@ async def search(self, *args, **kwargs) -> "Result":
11701169
Returns:
11711170
Result: Raw Redis search results.
11721171
"""
1173-
client = await self._get_client()
1172+
client = await self.get_client()
11741173
try:
1175-
# TODO: Typing
1176-
return await client.ft(self.schema.index.name).search(*args, **kwargs)
1174+
return client.ft(self.schema.index.name).search(*args, **kwargs)
11771175
except Exception as e:
11781176
raise RedisSearchError(f"Error while searching: {str(e)}") from e
11791177

@@ -1264,7 +1262,7 @@ async def listall(self) -> List[str]:
12641262
Returns:
12651263
List[str]: The list of indices in the database.
12661264
"""
1267-
client: aredis.Redis = await self._get_client()
1265+
client: aredis.Redis = await self.get_client()
12681266
return convert_bytes(await client.execute_command("FT._LIST"))
12691267

12701268
async def exists(self) -> bool:
@@ -1285,7 +1283,7 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
12851283
Returns:
12861284
dict: A dictionary containing the information about the index.
12871285
"""
1288-
client: aredis.Redis = await self._get_client()
1286+
client: aredis.Redis = await self.get_client()
12891287
index_name = name or self.schema.index.name
12901288
return await type(self)._info(index_name, client)
12911289

redisvl/utils/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def validate_vector_dims(v1: int, v2: int) -> None:
5252

5353

5454
def serialize(data: Dict[str, Any]) -> str:
55-
"""Serlize the input into a string."""
55+
"""Serialize the input into a string."""
5656
return json.dumps(data)
5757

5858

@@ -88,3 +88,27 @@ def inner(*args, **kwargs):
8888
return inner
8989

9090
return wrapper
91+
92+
93+
def deprecated_function(name: Optional[str] = None, replacement: Optional[str] = None):
94+
"""
95+
Decorator to mark a function as deprecated.
96+
97+
When the wrapped function is called, the decorator will log a deprecation
98+
warning.
99+
"""
100+
def decorator(func):
101+
fn_name = name or func.__name__
102+
warning_message = f"Function {fn_name} is deprecated and will be " \
103+
"removed in the next major release."
104+
if replacement:
105+
warning_message += replacement
106+
107+
@wraps(func)
108+
def wrapper(*args, **kwargs):
109+
warn(warning_message, category=DeprecationWarning, stacklevel=3)
110+
return func(*args, **kwargs)
111+
112+
return wrapper
113+
114+
return decorator

tests/integration/test_async_search_index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def test_search_index_redis_url(redis_url, index_schema):
137137
)
138138
assert async_index.client
139139

140-
async_index.disconnect()
140+
await async_index.disconnect()
141141
assert async_index.client == None
142142

143143

@@ -155,7 +155,7 @@ async def test_search_index_set_client(async_client, client, async_index):
155155
assert async_index.client == async_client
156156
await async_index.set_client(client)
157157

158-
async_index.disconnect()
158+
await async_index.disconnect()
159159
assert async_index.client == None
160160

161161

tests/unit/test_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
convert_bytes,
88
make_dict,
99
)
10-
from redisvl.utils.utils import deprecated_argument
10+
from redisvl.utils.utils import deprecated_argument, deprecated_function
1111

1212

1313
def test_even_number_of_elements():
@@ -237,3 +237,22 @@ async def test_func(dtype=None, vectorizer=None):
237237
with pytest.warns(DeprecationWarning):
238238
result = await test_func(dtype="float32")
239239
assert result == 1
240+
241+
242+
243+
class TestDeprecatedFunction:
244+
def test_deprecated_function_warning(self):
245+
@deprecated_function("new_func", "Use new_func2")
246+
def old_func():
247+
pass
248+
249+
with pytest.warns(DeprecationWarning):
250+
old_func()
251+
252+
def test_deprecated_function_warning_with_name(self):
253+
@deprecated_function("new_func", "Use new_func2")
254+
def old_func():
255+
pass
256+
257+
with pytest.warns(DeprecationWarning):
258+
old_func()

0 commit comments

Comments
 (0)