Skip to content

Commit 5249360

Browse files
add support for connection kwargs
1 parent 521548a commit 5249360

File tree

2 files changed

+14
-22
lines changed

2 files changed

+14
-22
lines changed

redisvl/extensions/router/semantic.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
redis_client: Optional[Redis] = None,
8989
redis_url: str = "redis://localhost:6379",
9090
overwrite: bool = False,
91+
connection_kwargs: Dict[str, Any] = {},
9192
**kwargs,
9293
):
9394
"""Initialize the SemanticRouter.
@@ -100,7 +101,8 @@ def __init__(
100101
redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None.
101102
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
102103
overwrite (bool, optional): Whether to overwrite existing index. Defaults to False.
103-
**kwargs: Additional arguments.
104+
connection_kwargs (Dict[str, Any]): The connection arguments
105+
for the redis client. Defaults to empty {}.
104106
"""
105107
# Set vectorizer default
106108
if vectorizer is None:
@@ -115,7 +117,7 @@ def __init__(
115117
vectorizer=vectorizer,
116118
routing_config=routing_config,
117119
)
118-
self._initialize_index(redis_client, redis_url, overwrite)
120+
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)
119121

120122
def _initialize_index(
121123
self,
@@ -477,19 +479,12 @@ def clear(self) -> None:
477479
def from_dict(
478480
cls,
479481
data: Dict[str, Any],
480-
redis_client: Optional[Redis] = None,
481-
redis_url: str = "redis://localhost:6379",
482-
overwrite: bool = False,
483482
**kwargs,
484483
) -> "SemanticRouter":
485484
"""Create a SemanticRouter from a dictionary.
486485
487486
Args:
488487
data (Dict[str, Any]): The dictionary containing the semantic router data.
489-
redis_client (Optional[Redis]): Redis client for connection.
490-
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
491-
overwrite (bool): Whether to overwrite existing index.
492-
**kwargs: Additional arguments.
493488
494489
Returns:
495490
SemanticRouter: The semantic router instance.
@@ -531,9 +526,6 @@ def from_dict(
531526
routes=routes,
532527
vectorizer=vectorizer,
533528
routing_config=routing_config,
534-
redis_client=redis_client,
535-
redis_url=redis_url,
536-
overwrite=overwrite,
537529
**kwargs,
538530
)
539531

@@ -563,19 +555,12 @@ def to_dict(self) -> Dict[str, Any]:
563555
def from_yaml(
564556
cls,
565557
file_path: str,
566-
redis_client: Optional[Redis] = None,
567-
redis_url: str = "redis://localhost:6379",
568-
overwrite: bool = False,
569558
**kwargs,
570559
) -> "SemanticRouter":
571560
"""Create a SemanticRouter from a YAML file.
572561
573562
Args:
574563
file_path (str): The path to the YAML file.
575-
redis_client (Optional[Redis]): Redis client for connection.
576-
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
577-
overwrite (bool): Whether to overwrite existing index.
578-
**kwargs: Additional arguments.
579564
580565
Returns:
581566
SemanticRouter: The semantic router instance.
@@ -601,9 +586,6 @@ def from_yaml(
601586
yaml_data = yaml.safe_load(f)
602587
return cls.from_dict(
603588
yaml_data,
604-
redis_client=redis_client,
605-
redis_url=redis_url,
606-
overwrite=overwrite,
607589
**kwargs,
608590
)
609591

tests/integration/test_session_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ def test_specify_redis_client(client):
4646
assert isinstance(session._client, type(client))
4747

4848

49+
def test_specify_redis_url(client):
50+
session = StandardSessionManager(
51+
name="test_app",
52+
session_tag="abc",
53+
user_tag="123",
54+
redis_url="redis://localhost:6379",
55+
)
56+
assert isinstance(session._client, type(client))
57+
58+
4959
def test_standard_bad_connection_info():
5060
with pytest.raises(ConnectionError):
5161
StandardSessionManager(

0 commit comments

Comments
 (0)