5
5
from redis import Redis
6
6
7
7
from redisvl .extensions .session_manager import BaseSessionManager
8
- from redisvl .redis .connection import RedisConnectionFactory
8
+ from redisvl .index import SearchIndex
9
+ from redisvl .query import FilterQuery
10
+ from redisvl .query .filter import Tag
11
+ from redisvl .schema .schema import IndexSchema
12
+
13
+
14
+ class StandardSessionIndexSchema (IndexSchema ):
15
+
16
+ @classmethod
17
+ def from_params (cls , name : str , prefix : str ):
18
+
19
+ return cls (
20
+ index = {"name" : name , "prefix" : prefix }, # type: ignore
21
+ fields = [ # type: ignore
22
+ {"name" : "role" , "type" : "text" },
23
+ {"name" : "content" , "type" : "text" },
24
+ {"name" : "tool_call_id" , "type" : "text" },
25
+ {"name" : "timestamp" , "type" : "numeric" },
26
+ {"name" : "session_tag" , "type" : "tag" },
27
+ {"name" : "user_tag" , "type" : "tag" },
28
+ ],
29
+ )
9
30
10
31
11
32
class StandardSessionManager (BaseSessionManager ):
33
+ session_field_name : str = "session_tag"
34
+ user_field_name : str = "user_tag"
12
35
13
36
def __init__ (
14
37
self ,
15
38
name : str ,
16
39
session_tag : str ,
17
40
user_tag : str ,
41
+ prefix : Optional [str ] = None ,
18
42
redis_client : Optional [Redis ] = None ,
19
43
redis_url : str = "redis://localhost:6379" ,
20
44
connection_kwargs : Dict [str , Any ] = {},
@@ -29,9 +53,11 @@ def __init__(
29
53
30
54
Args:
31
55
name (str): The name of the session manager index.
32
- session_tag (str): Tag to be added to entries to link to a specific
56
+ session_tag (Optional[ str] ): Tag to be added to entries to link to a specific
33
57
session.
34
- user_tag (str): Tag to be added to entries to link to a specific user.
58
+ user_tag (Optional[str]): Tag to be added to entries to link to a specific user.
59
+ prefix (Optional[str]): Prefix for the keys for this session data.
60
+ Defaults to None and will be replaced with the index name.
35
61
redis_client (Optional[Redis]): A Redis client instance. Defaults to
36
62
None.
37
63
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
@@ -44,14 +70,18 @@ def __init__(
44
70
"""
45
71
super ().__init__ (name , session_tag , user_tag )
46
72
73
+ prefix = prefix or name
74
+
75
+ schema = StandardSessionIndexSchema .from_params (name , prefix )
76
+ self ._index = SearchIndex (schema = schema )
77
+
47
78
# handle redis connection
48
79
if redis_client :
49
- self ._client = redis_client
80
+ self ._index . set_client ( redis_client )
50
81
elif redis_url :
51
- self ._client = RedisConnectionFactory .get_redis_connection (
52
- redis_url , ** connection_kwargs
53
- )
54
- RedisConnectionFactory .validate_sync_redis (self ._client )
82
+ self ._index .connect (redis_url = redis_url , ** connection_kwargs )
83
+
84
+ self ._index .create (overwrite = False )
55
85
56
86
self .set_scope (session_tag , user_tag )
57
87
@@ -63,27 +93,35 @@ def set_scope(
63
93
"""Set the filter to apply to queries based on the desired scope.
64
94
65
95
This new scope persists until another call to set_scope is made, or if
66
- scope is specified in calls to get_recent.
96
+ scope specified in calls to get_recent or get_relevant .
67
97
68
98
Args:
69
99
session_tag (str): Id of the specific session to filter to. Default is
70
- None, which means session_tag will be unchanged .
100
+ None, which means all sessions will be in scope .
71
101
user_tag (str): Id of the specific user to filter to. Default is None,
72
- which means user_tag will be unchanged .
102
+ which means all users will be in scope .
73
103
"""
74
104
if not (session_tag or user_tag ):
75
105
return
76
-
77
106
self ._session_tag = session_tag or self ._session_tag
78
107
self ._user_tag = user_tag or self ._user_tag
108
+ tag_filter = Tag (self .user_field_name ) == []
109
+ if user_tag :
110
+ tag_filter = tag_filter & (Tag (self .user_field_name ) == self ._user_tag )
111
+ if session_tag :
112
+ tag_filter = tag_filter & (
113
+ Tag (self .session_field_name ) == self ._session_tag
114
+ )
115
+
116
+ self ._tag_filter = tag_filter
79
117
80
118
def clear (self ) -> None :
81
119
"""Clears the chat session history."""
82
- self ._client . delete ( self . key )
120
+ self ._index . clear ( )
83
121
84
122
def delete (self ) -> None :
85
- """Clears the chat session history ."""
86
- self ._client .delete (self . key )
123
+ """Clear all conversation keys and remove the search index ."""
124
+ self ._index .delete (drop = True )
87
125
88
126
def drop (self , id_field : Optional [str ] = None ) -> None :
89
127
"""Remove a specific exchange from the conversation history.
@@ -93,19 +131,36 @@ def drop(self, id_field: Optional[str] = None) -> None:
93
131
If None then the last entry is deleted.
94
132
"""
95
133
if id_field :
96
- messages = self ._client .lrange (self .key , 0 , - 1 )
97
- messages = [json .loads (msg ) for msg in messages ]
98
- messages = [msg for msg in messages if msg ["id_field" ] != id_field ]
99
- messages = [json .dumps (msg ) for msg in messages ]
100
- self .clear ()
101
- self ._client .rpush (self .key , * messages )
134
+ sep = self ._index .key_separator
135
+ key = sep .join ([self ._index .schema .index .name , id_field ])
102
136
else :
103
- self ._client .rpop (self .key )
137
+ key = self .get_recent (top_k = 1 , raw = True )[0 ]["id" ] # type: ignore
138
+ self ._index .client .delete (key ) # type: ignore
104
139
105
140
@property
106
141
def messages (self ) -> Union [List [str ], List [Dict [str , str ]]]:
107
142
"""Returns the full chat history."""
108
- return self .get_recent (top_k = - 1 )
143
+ # TODO raw or as_text?
144
+ return_fields = [
145
+ self .id_field_name ,
146
+ self .session_field_name ,
147
+ self .user_field_name ,
148
+ self .role_field_name ,
149
+ self .content_field_name ,
150
+ self .tool_field_name ,
151
+ self .timestamp_field_name ,
152
+ ]
153
+
154
+ query = FilterQuery (
155
+ filter_expression = self ._tag_filter ,
156
+ return_fields = return_fields ,
157
+ )
158
+
159
+ sorted_query = query .query
160
+ sorted_query .sort_by (self .timestamp_field_name , asc = True )
161
+ hits = self ._index .search (sorted_query , query .params ).docs
162
+
163
+ return self ._format_context (hits , as_text = False )
109
164
110
165
def get_recent (
111
166
self ,
@@ -119,7 +174,6 @@ def get_recent(
119
174
120
175
Args:
121
176
top_k (int): The number of previous messages to return. Default is 5.
122
- To get all messages set top_k = -1.
123
177
session_tag (str): Tag to be added to entries to link to a specific
124
178
session.
125
179
user_tag (str): Tag to be added to entries to link to a specific user.
@@ -133,24 +187,35 @@ def get_recent(
133
187
or list of strings if as_text is false.
134
188
135
189
Raises:
136
- ValueError: if top_k is not an integer greater than or equal to -1 .
190
+ ValueError: if top_k is not an integer greater than or equal to 0 .
137
191
"""
138
- if type (top_k ) != int or top_k < - 1 :
139
- raise ValueError ("top_k must be an integer greater than or equal to -1" )
140
- if top_k == 0 :
141
- return []
142
- elif top_k == - 1 :
143
- top_k = 0
192
+ if type (top_k ) != int or top_k < 0 :
193
+ raise ValueError ("top_k must be an integer greater than or equal to 0" )
194
+
144
195
self .set_scope (session_tag , user_tag )
145
- messages = self ._client .lrange (self .key , - top_k , - 1 )
146
- messages = [json .loads (msg ) for msg in messages ]
147
- if raw :
148
- return messages
149
- return self ._format_context (messages , as_text )
196
+ return_fields = [
197
+ self .id_field_name ,
198
+ self .session_field_name ,
199
+ self .user_field_name ,
200
+ self .role_field_name ,
201
+ self .content_field_name ,
202
+ self .tool_field_name ,
203
+ self .timestamp_field_name ,
204
+ ]
205
+
206
+ query = FilterQuery (
207
+ filter_expression = self ._tag_filter ,
208
+ return_fields = return_fields ,
209
+ num_results = top_k ,
210
+ )
150
211
151
- @property
152
- def key (self ):
153
- return ":" .join ([self ._name , self ._user_tag , self ._session_tag ])
212
+ sorted_query = query .query
213
+ sorted_query .sort_by (self .timestamp_field_name , asc = False )
214
+ hits = self ._index .search (sorted_query , query .params ).docs
215
+
216
+ if raw :
217
+ return hits [::- 1 ]
218
+ return self ._format_context (hits [::- 1 ], as_text )
154
219
155
220
def store (self , prompt : str , response : str ) -> None :
156
221
"""Insert a prompt:response pair into the session memory. A timestamp
@@ -162,7 +227,10 @@ def store(self, prompt: str, response: str) -> None:
162
227
response (str): The corresponding LLM response.
163
228
"""
164
229
self .add_messages (
165
- [{"role" : "user" , "content" : prompt }, {"role" : "llm" , "content" : response }]
230
+ [
231
+ {self .role_field_name : "user" , self .content_field_name : prompt },
232
+ {self .role_field_name : "llm" , self .content_field_name : response },
233
+ ]
166
234
)
167
235
168
236
def add_messages (self , messages : List [Dict [str , str ]]) -> None :
@@ -173,23 +241,23 @@ def add_messages(self, messages: List[Dict[str, str]]) -> None:
173
241
Args:
174
242
messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
175
243
"""
244
+ sep = self ._index .key_separator
176
245
payloads = []
177
246
for message in messages :
178
247
timestamp = time ()
248
+ id_field = sep .join ([self ._user_tag , self ._session_tag , str (timestamp )])
179
249
payload = {
180
- self .id_field_name : ":" .join (
181
- [self ._user_tag , self ._session_tag , str (timestamp )]
182
- ),
250
+ self .id_field_name : id_field ,
183
251
self .role_field_name : message [self .role_field_name ],
184
252
self .content_field_name : message [self .content_field_name ],
185
253
self .timestamp_field_name : timestamp ,
254
+ self .session_field_name : self ._session_tag ,
255
+ self .user_field_name : self ._user_tag ,
186
256
}
187
257
if self .tool_field_name in message :
188
258
payload .update ({self .tool_field_name : message [self .tool_field_name ]})
189
-
190
- payloads .append (json .dumps (payload ))
191
-
192
- self ._client .rpush (self .key , * payloads )
259
+ payloads .append (payload )
260
+ self ._index .load (data = payloads , id_field = self .id_field_name )
193
261
194
262
def add_message (self , message : Dict [str , str ]) -> None :
195
263
"""Insert a single prompt or response into the session memory.
0 commit comments