Skip to content

Commit 3844d57

Browse files
Improve session manager scope handling (#193)
This PR removes the requirement that user and session tags be specified on session manager initialization. Session tags can be added when storing messages and filter expressions can be used to retrieve specific chat histories. A default uuid is used when no session is provided. --------- Co-authored-by: Tyler Hutcherson <[email protected]>
1 parent a5854be commit 3844d57

File tree

8 files changed

+196
-316
lines changed

8 files changed

+196
-316
lines changed

conftest.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,3 @@ def clear_db(redis):
143143
def app_name():
144144
return "test_app"
145145

146-
@pytest.fixture
147-
def session_tag():
148-
return "123"
149-
150-
@pytest.fixture
151-
def user_tag():
152-
return "abc"

docs/user_guide/session_manager_07.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
],
7575
"source": [
7676
"from redisvl.extensions.session_manager import SemanticSessionManager\n",
77-
"user_session = SemanticSessionManager(name='llm_chef', session_tag='123', user_tag='abc')\n",
77+
"user_session = SemanticSessionManager(name='llm_chef')\n",
7878
"user_session.add_message({\"role\":\"system\", \"content\":\"You are a helpful chef, assisting people in making delicious meals\"})\n",
7979
"\n",
8080
"client = CohereClient()"

redisvl/extensions/llmcache/semantic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class SemanticCache(BaseLLMCache):
1414
"""Semantic Cache for Large Language Models."""
1515

16-
entry_id_field_name: str = "id"
16+
entry_id_field_name: str = "_id"
1717
prompt_field_name: str = "prompt"
1818
vector_field_name: str = "prompt_vector"
1919
response_field_name: str = "response"
@@ -222,7 +222,8 @@ def _search_cache(
222222
cache_hits: List[Dict[str, Any]] = self._index.query(query)
223223
# Process cache hits
224224
for hit in cache_hits:
225-
self._refresh_ttl(hit[self.entry_id_field_name])
225+
key = hit["id"]
226+
self._refresh_ttl(key)
226227
# Check for metadata and deserialize
227228
if self.metadata_field_name in hit:
228229
hit[self.metadata_field_name] = self.deserialize(

redisvl/extensions/session_manager/base_session.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
from typing import Any, Dict, List, Optional, Union
2+
from uuid import uuid4
23

34
from redis import Redis
45

6+
from redisvl.query.filter import FilterExpression
7+
58

69
class BaseSessionManager:
7-
id_field_name: str = "id_field"
10+
id_field_name: str = "_id"
811
role_field_name: str = "role"
912
content_field_name: str = "content"
1013
tool_field_name: str = "tool_call_id"
1114
timestamp_field_name: str = "timestamp"
15+
session_field_name: str = "session_tag"
1216

1317
def __init__(
1418
self,
1519
name: str,
16-
session_tag: str,
17-
user_tag: str,
20+
session_tag: Optional[str] = None,
1821
):
1922
"""Initialize session memory with index
2023
@@ -26,29 +29,10 @@ def __init__(
2629
Args:
2730
name (str): The name of the session manager index.
2831
session_tag (str): Tag to be added to entries to link to a specific
29-
session.
30-
user_tag (str): Tag to be added to entries to link to a specific user.
32+
session. Defaults to instance uuid.
3133
"""
3234
self._name = name
33-
self._user_tag = user_tag
34-
self._session_tag = session_tag
35-
36-
def set_scope(
37-
self,
38-
session_tag: Optional[str] = None,
39-
user_tag: Optional[str] = None,
40-
) -> None:
41-
"""Set the filter to apply to querries based on the desired scope.
42-
43-
This new scope persists until another call to set_scope is made, or if
44-
scope specified in calls to get_recent.
45-
46-
Args:
47-
session_tag (str): Id of the specific session to filter to. Default is
48-
None.
49-
user_tag (str): Id of the specific user to filter to. Default is None.
50-
"""
51-
raise NotImplementedError
35+
self._session_tag = session_tag or uuid4().hex
5236

5337
def clear(self) -> None:
5438
"""Clears the chat session history."""
@@ -75,23 +59,21 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]:
7559
def get_recent(
7660
self,
7761
top_k: int = 5,
78-
session_tag: Optional[str] = None,
79-
user_tag: Optional[str] = None,
8062
as_text: bool = False,
8163
raw: bool = False,
64+
session_tag: Optional[str] = None,
8265
) -> Union[List[str], List[Dict[str, str]]]:
8366
"""Retreive the recent conversation history in sequential order.
8467
8568
Args:
8669
top_k (int): The number of previous exchanges to return. Default is 5.
8770
Note that one exchange contains both a prompt and response.
88-
session_tag (str): Tag to be added to entries to link to a specific
89-
session.
90-
user_tag (str): Tag to be added to entries to link to a specific user.
9171
as_text (bool): Whether to return the conversation as a single string,
9272
or list of alternating prompts and responses.
9373
raw (bool): Whether to return the full Redis hash entry or just the
9474
prompt and response
75+
session_tag (str): Tag to be added to entries to link to a specific
76+
session. Defaults to instance uuid.
9577
9678
Returns:
9779
Union[str, List[str]]: A single string transcription of the session
@@ -113,6 +95,7 @@ def _format_context(
11395
recent conversation history.
11496
as_text (bool): Whether to return the conversation as a single string,
11597
or list of alternating prompts and responses.
98+
11699
Returns:
117100
Union[str, List[str]]: A single string transcription of the session
118101
or list of strings if as_text is false.
@@ -141,33 +124,42 @@ def _format_context(
141124
)
142125
return statements
143126

144-
def store(self, prompt: str, response: str) -> None:
127+
def store(
128+
self, prompt: str, response: str, session_tag: Optional[str] = None
129+
) -> None:
145130
"""Insert a prompt:response pair into the session memory. A timestamp
146131
is associated with each exchange so that they can be later sorted
147132
in sequential ordering after retrieval.
148133
149134
Args:
150135
prompt (str): The user prompt to the LLM.
151136
response (str): The corresponding LLM response.
137+
session_tag (Optional[str]): The tag to mark the message with. Defaults to None.
152138
"""
153139
raise NotImplementedError
154140

155-
def add_messages(self, messages: List[Dict[str, str]]) -> None:
141+
def add_messages(
142+
self, messages: List[Dict[str, str]], session_tag: Optional[str] = None
143+
) -> None:
156144
"""Insert a list of prompts and responses into the session memory.
157145
A timestamp is associated with each so that they can be later sorted
158146
in sequential ordering after retrieval.
159147
160148
Args:
161149
messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
150+
session_tag (Optional[str]): The tag to mark the messages with. Defaults to None.
162151
"""
163152
raise NotImplementedError
164153

165-
def add_message(self, message: Dict[str, str]) -> None:
154+
def add_message(
155+
self, message: Dict[str, str], session_tag: Optional[str] = None
156+
) -> None:
166157
"""Insert a single prompt or response into the session memory.
167158
A timestamp is associated with it so that it can be later sorted
168159
in sequential ordering after retrieval.
169160
170161
Args:
171162
message (Dict[str,str]): The user prompt or LLM response.
163+
session_tag (Optional[str]): The tag to mark the message with. Defaults to None.
172164
"""
173165
raise NotImplementedError

0 commit comments

Comments
 (0)