Skip to content

Commit 61e7338

Browse files
Introduce the LLM session manager classes (#141)
Here's the proposed class for LLM session management. It support recent or full conversation storage & retrieval, as well as relevance based conversation section retrieval.
1 parent aa05797 commit 61e7338

File tree

8 files changed

+1975
-0
lines changed

8 files changed

+1975
-0
lines changed

conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,15 @@ def clear_db(redis):
138138
redis.flushall()
139139
yield
140140
redis.flushall()
141+
142+
@pytest.fixture
143+
def app_name():
144+
return "test_app"
145+
146+
@pytest.fixture
147+
def session_tag():
148+
return "123"
149+
150+
@pytest.fixture
151+
def user_tag():
152+
return "abc"

docs/user_guide/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ llmcache_03
1717
vectorizers_04
1818
hash_vs_json_05
1919
rerankers_06
20+
session_manager_07
2021
```
2122

docs/user_guide/session_manager_07.ipynb

Lines changed: 615 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from redisvl.extensions.session_manager.base_session import BaseSessionManager
2+
from redisvl.extensions.session_manager.semantic_session import SemanticSessionManager
3+
from redisvl.extensions.session_manager.standard_session import StandardSessionManager
4+
5+
__all__ = ["BaseSessionManager", "StandardSessionManager", "SemanticSessionManager"]
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from typing import Any, Dict, List, Optional, Union
2+
3+
from redis import Redis
4+
5+
6+
class BaseSessionManager:
7+
id_field_name: str = "id_field"
8+
role_field_name: str = "role"
9+
content_field_name: str = "content"
10+
tool_field_name: str = "tool_call_id"
11+
timestamp_field_name: str = "timestamp"
12+
13+
def __init__(
14+
self,
15+
name: str,
16+
session_tag: str,
17+
user_tag: str,
18+
):
19+
"""Initialize session memory with index
20+
21+
Session Manager stores the current and previous user text prompts and
22+
LLM responses to allow for enriching future prompts with session
23+
context. Session history is stored in individual user or LLM prompts and
24+
responses.
25+
26+
Args:
27+
name (str): The name of the session manager index.
28+
session_tag (str): Tag to be added to entries to link to a specific
29+
session.
30+
user_tag (str): Tag to be added to entries to link to a specific user.
31+
"""
32+
self._name = name
33+
self._user_tag = user_tag
34+
self._session_tag = session_tag
35+
36+
def set_scope(
37+
self,
38+
session_tag: Optional[str] = None,
39+
user_tag: Optional[str] = None,
40+
) -> None:
41+
"""Set the filter to apply to querries based on the desired scope.
42+
43+
This new scope persists until another call to set_scope is made, or if
44+
scope specified in calls to get_recent.
45+
46+
Args:
47+
session_tag (str): Id of the specific session to filter to. Default is
48+
None.
49+
user_tag (str): Id of the specific user to filter to. Default is None.
50+
"""
51+
raise NotImplementedError
52+
53+
def clear(self) -> None:
54+
"""Clears the chat session history."""
55+
raise NotImplementedError
56+
57+
def delete(self) -> None:
58+
"""Clear all conversation history and remove any search indices."""
59+
raise NotImplementedError
60+
61+
def drop(self, id_field: Optional[str] = None) -> None:
62+
"""Remove a specific exchange from the conversation history.
63+
64+
Args:
65+
id_field (Optional[str]): The id_field of the entry to delete.
66+
If None then the last entry is deleted.
67+
"""
68+
raise NotImplementedError
69+
70+
@property
71+
def messages(self) -> Union[List[str], List[Dict[str, str]]]:
72+
"""Returns the full chat history."""
73+
raise NotImplementedError
74+
75+
def get_recent(
76+
self,
77+
top_k: int = 5,
78+
session_tag: Optional[str] = None,
79+
user_tag: Optional[str] = None,
80+
as_text: bool = False,
81+
raw: bool = False,
82+
) -> Union[List[str], List[Dict[str, str]]]:
83+
"""Retreive the recent conversation history in sequential order.
84+
85+
Args:
86+
top_k (int): The number of previous exchanges to return. Default is 5.
87+
Note that one exchange contains both a prompt and response.
88+
session_tag (str): Tag to be added to entries to link to a specific
89+
session.
90+
user_tag (str): Tag to be added to entries to link to a specific user.
91+
as_text (bool): Whether to return the conversation as a single string,
92+
or list of alternating prompts and responses.
93+
raw (bool): Whether to return the full Redis hash entry or just the
94+
prompt and response
95+
96+
Returns:
97+
Union[str, List[str]]: A single string transcription of the session
98+
or list of strings if as_text is false.
99+
100+
Raises:
101+
ValueError: If top_k is not an integer greater than or equal to 0.
102+
"""
103+
raise NotImplementedError
104+
105+
def _format_context(
106+
self, hits: List[Dict[str, Any]], as_text: bool
107+
) -> Union[List[str], List[Dict[str, str]]]:
108+
"""Extracts the prompt and response fields from the Redis hashes and
109+
formats them as either flat dictionaries or strings.
110+
111+
Args:
112+
hits (List): The hashes containing prompt & response pairs from
113+
recent conversation history.
114+
as_text (bool): Whether to return the conversation as a single string,
115+
or list of alternating prompts and responses.
116+
Returns:
117+
Union[str, List[str]]: A single string transcription of the session
118+
or list of strings if as_text is false.
119+
"""
120+
if as_text:
121+
text_statements = []
122+
for hit in hits:
123+
text_statements.append(hit[self.content_field_name])
124+
return text_statements
125+
else:
126+
statements = []
127+
for hit in hits:
128+
statements.append(
129+
{
130+
self.role_field_name: hit[self.role_field_name],
131+
self.content_field_name: hit[self.content_field_name],
132+
}
133+
)
134+
if (
135+
hasattr(hit, self.tool_field_name)
136+
or isinstance(hit, dict)
137+
and self.tool_field_name in hit
138+
):
139+
statements[-1].update(
140+
{self.tool_field_name: hit[self.tool_field_name]}
141+
)
142+
return statements
143+
144+
def store(self, prompt: str, response: str) -> None:
145+
"""Insert a prompt:response pair into the session memory. A timestamp
146+
is associated with each exchange so that they can be later sorted
147+
in sequential ordering after retrieval.
148+
149+
Args:
150+
prompt (str): The user prompt to the LLM.
151+
response (str): The corresponding LLM response.
152+
"""
153+
raise NotImplementedError
154+
155+
def add_messages(self, messages: List[Dict[str, str]]) -> None:
156+
"""Insert a list of prompts and responses into the session memory.
157+
A timestamp is associated with each so that they can be later sorted
158+
in sequential ordering after retrieval.
159+
160+
Args:
161+
messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
162+
"""
163+
raise NotImplementedError
164+
165+
def add_message(self, message: Dict[str, str]) -> None:
166+
"""Insert a single prompt or response into the session memory.
167+
A timestamp is associated with it so that it can be later sorted
168+
in sequential ordering after retrieval.
169+
170+
Args:
171+
message (Dict[str,str]): The user prompt or LLM response.
172+
"""
173+
raise NotImplementedError

0 commit comments

Comments
 (0)