Skip to content

Commit cb61457

Browse files
Standardize redis init in extensions (#188)
1 parent 877f4f2 commit cb61457

File tree

7 files changed

+101
-67
lines changed

7 files changed

+101
-67
lines changed

redisvl/extensions/llmcache/semantic.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
vectorizer: Optional[BaseVectorizer] = None,
2929
redis_client: Optional[Redis] = None,
3030
redis_url: str = "redis://localhost:6379",
31-
connection_args: Dict[str, Any] = {},
31+
connection_kwargs: Dict[str, Any] = {},
3232
**kwargs,
3333
):
3434
"""Semantic Cache for Large Language Models.
@@ -43,14 +43,13 @@ def __init__(
4343
cache. Defaults to 0.1.
4444
ttl (Optional[int], optional): The time-to-live for records cached
4545
in Redis. Defaults to None.
46-
vectorizer (BaseVectorizer, optional): The vectorizer for the cache.
46+
vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache.
4747
Defaults to HFTextVectorizer.
48-
redis_client(Redis, optional): A redis client connection instance.
48+
redis_client(Optional[Redis], optional): A redis client connection instance.
4949
Defaults to None.
50-
redis_url (str, optional): The redis url. Defaults to
51-
"redis://localhost:6379".
52-
connection_args (Dict[str, Any], optional): The connection arguments
53-
for the redis client. Defaults to None.
50+
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
51+
connection_kwargs (Dict[str, Any]): The connection arguments
52+
for the redis client. Defaults to empty {}.
5453
5554
Raises:
5655
TypeError: If an invalid vectorizer is provided.
@@ -96,8 +95,8 @@ def __init__(
9695
# handle redis connection
9796
if redis_client:
9897
self._index.set_client(redis_client)
99-
else:
100-
self._index.connect(redis_url=redis_url, **connection_args)
98+
elif redis_url:
99+
self._index.connect(redis_url=redis_url, **connection_kwargs)
101100

102101
# initialize other components
103102
self.default_return_fields = [

redisvl/extensions/router/semantic.py

+7-27
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ def __init__(
8686
vectorizer: Optional[BaseVectorizer] = None,
8787
routing_config: Optional[RoutingConfig] = None,
8888
redis_client: Optional[Redis] = None,
89-
redis_url: Optional[str] = None,
89+
redis_url: str = "redis://localhost:6379",
9090
overwrite: bool = False,
91+
connection_kwargs: Dict[str, Any] = {},
9192
**kwargs,
9293
):
9394
"""Initialize the SemanticRouter.
@@ -98,9 +99,10 @@ def __init__(
9899
vectorizer (BaseVectorizer, optional): The vectorizer used to embed route references. Defaults to default HFTextVectorizer.
99100
routing_config (RoutingConfig, optional): Configuration for routing behavior. Defaults to the default RoutingConfig.
100101
redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None.
101-
redis_url (Optional[str], optional): Redis URL for connection. Defaults to None.
102+
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
102103
overwrite (bool, optional): Whether to overwrite existing index. Defaults to False.
103-
**kwargs: Additional arguments.
104+
connection_kwargs (Dict[str, Any]): The connection arguments
105+
for the redis client. Defaults to empty {}.
104106
"""
105107
# Set vectorizer default
106108
if vectorizer is None:
@@ -115,12 +117,12 @@ def __init__(
115117
vectorizer=vectorizer,
116118
routing_config=routing_config,
117119
)
118-
self._initialize_index(redis_client, redis_url, overwrite)
120+
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)
119121

120122
def _initialize_index(
121123
self,
122124
redis_client: Optional[Redis] = None,
123-
redis_url: Optional[str] = None,
125+
redis_url: str = "redis://localhost:6379",
124126
overwrite: bool = False,
125127
**connection_kwargs,
126128
):
@@ -132,8 +134,6 @@ def _initialize_index(
132134
self._index.set_client(redis_client)
133135
elif redis_url:
134136
self._index.connect(redis_url=redis_url, **connection_kwargs)
135-
else:
136-
raise ValueError("Must provide either a redis client or redis url string.")
137137

138138
existed = self._index.exists()
139139
self._index.create(overwrite=overwrite)
@@ -479,19 +479,12 @@ def clear(self) -> None:
479479
def from_dict(
480480
cls,
481481
data: Dict[str, Any],
482-
redis_client: Optional[Redis] = None,
483-
redis_url: Optional[str] = None,
484-
overwrite: bool = False,
485482
**kwargs,
486483
) -> "SemanticRouter":
487484
"""Create a SemanticRouter from a dictionary.
488485
489486
Args:
490487
data (Dict[str, Any]): The dictionary containing the semantic router data.
491-
redis_client (Optional[Redis]): Redis client for connection.
492-
redis_url (Optional[str]): Redis URL for connection.
493-
overwrite (bool): Whether to overwrite existing index.
494-
**kwargs: Additional arguments.
495488
496489
Returns:
497490
SemanticRouter: The semantic router instance.
@@ -533,9 +526,6 @@ def from_dict(
533526
routes=routes,
534527
vectorizer=vectorizer,
535528
routing_config=routing_config,
536-
redis_client=redis_client,
537-
redis_url=redis_url,
538-
overwrite=overwrite,
539529
**kwargs,
540530
)
541531

@@ -565,19 +555,12 @@ def to_dict(self) -> Dict[str, Any]:
565555
def from_yaml(
566556
cls,
567557
file_path: str,
568-
redis_client: Optional[Redis] = None,
569-
redis_url: Optional[str] = None,
570-
overwrite: bool = False,
571558
**kwargs,
572559
) -> "SemanticRouter":
573560
"""Create a SemanticRouter from a YAML file.
574561
575562
Args:
576563
file_path (str): The path to the YAML file.
577-
redis_client (Optional[Redis]): Redis client for connection.
578-
redis_url (Optional[str]): Redis URL for connection.
579-
overwrite (bool): Whether to overwrite existing index.
580-
**kwargs: Additional arguments.
581564
582565
Returns:
583566
SemanticRouter: The semantic router instance.
@@ -603,9 +586,6 @@ def from_yaml(
603586
yaml_data = yaml.safe_load(f)
604587
return cls.from_dict(
605588
yaml_data,
606-
redis_client=redis_client,
607-
redis_url=redis_url,
608-
overwrite=overwrite,
609589
**kwargs,
610590
)
611591

redisvl/extensions/session_manager/semantic_session.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from time import time
2-
from typing import Dict, List, Optional, Union
2+
from typing import Any, Dict, List, Optional, Union
33

44
from redis import Redis
55

@@ -27,6 +27,8 @@ def __init__(
2727
distance_threshold: float = 0.3,
2828
redis_client: Optional[Redis] = None,
2929
redis_url: str = "redis://localhost:6379",
30+
connection_kwargs: Dict[str, Any] = {},
31+
**kwargs,
3032
):
3133
"""Initialize session memory with index
3234
@@ -43,12 +45,14 @@ def __init__(
4345
user_tag (str): Tag to be added to entries to link to a specific user.
4446
prefix (Optional[str]): Prefix for the keys for this session data.
4547
Defaults to None and will be replaced with the index name.
46-
vectorizer (Vectorizer): The vectorizer to create embeddings with.
48+
vectorizer (Optional[BaseVectorizer]): The vectorizer used to create embeddings.
4749
distance_threshold (float): The maximum semantic distance to be
4850
included in the context. Defaults to 0.3.
4951
redis_client (Optional[Redis]): A Redis client instance. Defaults to
5052
None.
51-
redis_url (str): The URL of the Redis instance. Defaults to 'redis://localhost:6379'.
53+
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
54+
connection_kwargs (Dict[str, Any]): The connection arguments
55+
for the redis client. Defaults to empty {}.
5256
5357
The proposed schema will support a single vector embedding constructed
5458
from either the prompt or response in a single string.
@@ -89,10 +93,11 @@ def __init__(
8993

9094
self._index = SearchIndex(schema=schema)
9195

96+
# handle redis connection
9297
if redis_client:
9398
self._index.set_client(redis_client)
94-
else:
95-
self._index.connect(redis_url=redis_url)
99+
elif redis_url:
100+
self._index.connect(redis_url=redis_url, **connection_kwargs)
96101

97102
self._index.create(overwrite=False)
98103

redisvl/extensions/session_manager/standard_session.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import json
22
from time import time
3-
from typing import Dict, List, Optional, Union
3+
from typing import Any, Dict, List, Optional, Union
44

55
from redis import Redis
66

77
from redisvl.extensions.session_manager import BaseSessionManager
8+
from redisvl.redis.connection import RedisConnectionFactory
89

910

1011
class StandardSessionManager(BaseSessionManager):
@@ -16,6 +17,8 @@ def __init__(
1617
user_tag: str,
1718
redis_client: Optional[Redis] = None,
1819
redis_url: str = "redis://localhost:6379",
20+
connection_kwargs: Dict[str, Any] = {},
21+
**kwargs,
1922
):
2023
"""Initialize session memory
2124
@@ -31,18 +34,24 @@ def __init__(
3134
user_tag (str): Tag to be added to entries to link to a specific user.
3235
redis_client (Optional[Redis]): A Redis client instance. Defaults to
3336
None.
34-
redis_url (str): The URL of the Redis instance. Defaults to 'redis://localhost:6379'.
37+
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
38+
connection_kwargs (Dict[str, Any]): The connection arguments
39+
for the redis client. Defaults to empty {}.
3540
3641
The proposed schema will support a single combined vector embedding
3742
constructed from the prompt & response in a single string.
3843
3944
"""
4045
super().__init__(name, session_tag, user_tag)
4146

47+
# handle redis connection
4248
if redis_client:
4349
self._client = redis_client
44-
else:
45-
self._client = Redis.from_url(redis_url)
50+
elif redis_url:
51+
self._client = RedisConnectionFactory.get_redis_connection(
52+
redis_url, **connection_kwargs
53+
)
54+
RedisConnectionFactory.validate_redis(self._client)
4655

4756
self.set_scope(session_tag, user_tag)
4857

@@ -51,7 +60,7 @@ def set_scope(
5160
session_tag: Optional[str] = None,
5261
user_tag: Optional[str] = None,
5362
) -> None:
54-
"""Set the filter to apply to querries based on the desired scope.
63+
"""Set the filter to apply to queries based on the desired scope.
5564
5665
This new scope persists until another call to set_scope is made, or if
5766
scope is specified in calls to get_recent.

tests/integration/test_llmcache.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from time import sleep
33

44
import pytest
5+
from redis.exceptions import ConnectionError
56

67
from redisvl.extensions.llmcache import SemanticCache
78
from redisvl.index.index import SearchIndex
@@ -40,19 +41,17 @@ def cache_with_ttl(vectorizer, redis_url):
4041

4142

4243
@pytest.fixture
43-
def cache_with_redis_client(vectorizer, client, redis_url):
44+
def cache_with_redis_client(vectorizer, client):
4445
cache_instance = SemanticCache(
4546
vectorizer=vectorizer,
4647
redis_client=client,
4748
distance_threshold=0.2,
48-
redis_url=redis_url,
4949
)
5050
yield cache_instance
5151
cache_instance.clear() # Clear cache after each test
5252
cache_instance._index.delete(True) # Clean up index
5353

5454

55-
# # Test handling invalid input for check method
5655
def test_bad_ttl(cache):
5756
with pytest.raises(ValueError):
5857
cache.set_ttl(2.5)
@@ -76,7 +75,6 @@ def test_reset_ttl(cache):
7675
assert cache.ttl is None
7776

7877

79-
# Test basic store and check functionality
8078
def test_store_and_check(cache, vectorizer):
8179
prompt = "This is a test prompt."
8280
response = "This is a test response."
@@ -91,7 +89,6 @@ def test_store_and_check(cache, vectorizer):
9189
assert "metadata" not in check_result[0]
9290

9391

94-
# Test clearing the cache
9592
def test_clear(cache, vectorizer):
9693
prompt = "This is a test prompt."
9794
response = "This is a test response."
@@ -139,7 +136,6 @@ def test_check_no_match(cache, vectorizer):
139136
assert len(check_result) == 0
140137

141138

142-
# Test handling invalid input for check method
143139
def test_check_invalid_input(cache):
144140
with pytest.raises(ValueError):
145141
cache.check()
@@ -148,7 +144,15 @@ def test_check_invalid_input(cache):
148144
cache.check(prompt="test", return_fields="bad value")
149145

150146

151-
# Test storing with metadata
147+
def test_bad_connection_info(vectorizer):
148+
with pytest.raises(ConnectionError):
149+
SemanticCache(
150+
vectorizer=vectorizer,
151+
distance_threshold=0.2,
152+
redis_url="redis://localhost:6389",
153+
)
154+
155+
152156
def test_store_with_metadata(cache, vectorizer):
153157
prompt = "This is another test prompt."
154158
response = "This is another test response."
@@ -165,7 +169,6 @@ def test_store_with_metadata(cache, vectorizer):
165169
assert check_result[0]["prompt"] == prompt
166170

167171

168-
# Test storing with invalid metadata
169172
def test_store_with_invalid_metadata(cache, vectorizer):
170173
prompt = "This is another test prompt."
171174
response = "This is another test response."
@@ -179,7 +182,6 @@ def test_store_with_invalid_metadata(cache, vectorizer):
179182
cache.store(prompt, response, vector=vector, metadata=metadata)
180183

181184

182-
# Test setting and getting the distance threshold
183185
def test_distance_threshold(cache):
184186
initial_threshold = cache.distance_threshold
185187
new_threshold = 0.1
@@ -189,14 +191,12 @@ def test_distance_threshold(cache):
189191
assert cache.distance_threshold != initial_threshold
190192

191193

192-
# Test out of range distance threshold
193194
def test_distance_threshold_out_of_range(cache):
194195
out_of_range_threshold = -1
195196
with pytest.raises(ValueError):
196197
cache.set_threshold(out_of_range_threshold)
197198

198199

199-
# Test storing and retrieving multiple items
200200
def test_multiple_items(cache, vectorizer):
201201
prompts_responses = {
202202
"prompt1": "response1",
@@ -217,12 +217,10 @@ def test_multiple_items(cache, vectorizer):
217217
assert "metadata" not in check_result[0]
218218

219219

220-
# Test retrieving underlying SearchIndex for the cache.
221220
def test_get_index(cache):
222221
assert isinstance(cache.index, SearchIndex)
223222

224223

225-
# Test basic functionality with cache created with user-provided Redis client
226224
def test_store_and_check_with_provided_client(cache_with_redis_client, vectorizer):
227225
prompt = "This is a test prompt."
228226
response = "This is a test response."
@@ -237,13 +235,11 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize
237235
assert "metadata" not in check_result[0]
238236

239237

240-
# Test deleting the cache
241238
def test_delete(cache_no_cleanup):
242239
cache_no_cleanup.delete()
243240
assert not cache_no_cleanup.index.exists()
244241

245242

246-
# Test we can only store and check vectors of correct dimensions
247243
def test_vector_size(cache, vectorizer):
248244
prompt = "This is test prompt."
249245
response = "This is a test response."

0 commit comments

Comments
 (0)