1
1
import hashlib
2
2
from datetime import datetime
3
3
from typing import Any , Dict , List , Optional , Tuple , Union
4
+
4
5
from redis import Redis
6
+
5
7
from redisvl .index import SearchIndex
6
8
from redisvl .query import FilterQuery , RangeQuery
7
- from redisvl .query .filter import Tag , Num
9
+ from redisvl .query .filter import Num , Tag
8
10
from redisvl .redis .utils import array_to_buffer
9
11
from redisvl .schema .schema import IndexSchema
10
12
from redisvl .utils .vectorize import BaseVectorizer , HFTextVectorizer
11
13
14
+
12
15
class SessionManager :
13
16
def __init__ (
14
17
self ,
15
18
name : str ,
16
19
session_id : str ,
17
20
user_id : str ,
18
21
application_id : str ,
19
- scope : str = ' session' ,
22
+ scope : str = " session" ,
20
23
prefix : Optional [str ] = None ,
21
24
vectorizer : Optional [BaseVectorizer ] = None ,
22
25
distance_threshold : float = 0.3 ,
23
26
redis_client : Optional [Redis ] = None ,
24
- preamble : str = ''
25
- ):
26
- """ Initialize session memory with index
27
+ preamble : str = "" ,
28
+ ):
29
+ """Initialize session memory with index
27
30
28
31
Session Manager stores the current and previous user text prompts and
29
32
LLM responses to allow for enriching future prompts with session
@@ -88,7 +91,7 @@ def __init__(
88
91
"distance_metric" : "cosine" ,
89
92
"algorithm" : "flat" ,
90
93
},
91
- },
94
+ },
92
95
]
93
96
)
94
97
@@ -104,22 +107,18 @@ def __init__(
104
107
self ._index .create (overwrite = False )
105
108
106
109
self ._tag_filter = Tag ("application_id" ) == self ._application_id
107
- if self ._scope == ' user' :
110
+ if self ._scope == " user" :
108
111
user_filter = Tag ("user_id" ) == self ._user_id
109
112
self ._tag_filter = self ._tag_filter & user_filter
110
- if self ._scope == ' session' :
113
+ if self ._scope == " session" :
111
114
session_filter = Tag ("session_id" ) == self ._session_id
112
115
user_filter = Tag ("user_id" ) == self ._user_id
113
116
self ._tag_filter = self ._tag_filter & user_filter & session_filter
114
117
115
-
116
118
def set_scope (
117
- self ,
118
- session_id : str = None ,
119
- user_id : str = None ,
120
- application_id : str = None
121
- ) -> None :
122
- """ Set the tag filter to apply to querries based on the desired scope.
119
+ self , session_id : Optional [str ] = None , user_id : Optional [str ] = None , application_id : Optional [str ] = None
120
+ ) -> None :
121
+ """Set the tag filter to apply to querries based on the desired scope.
123
122
124
123
This new scope persists until another call to set_scope is made, or if
125
124
scope specified in calls to fetch_recent or fetch_relevant.
@@ -135,7 +134,7 @@ def set_scope(
135
134
if not (session_id or user_id or application_id ):
136
135
return
137
136
138
- tag_filter = Tag (' application_id' ) == []
137
+ tag_filter = Tag (" application_id" ) == []
139
138
if application_id :
140
139
tag_filter = tag_filter & (Tag ("application_id" ) == application_id )
141
140
if user_id :
@@ -145,32 +144,29 @@ def set_scope(
145
144
146
145
self ._tag_filter = tag_filter
147
146
148
-
149
147
def clear (self ) -> None :
150
- """ Clears the chat session history. """
151
- with self ._index .client .pipeline (transaction = False ) as pipe :
152
- for key in self ._index .client .scan_iter (match = f"{ self ._index .prefix } :*" ):
148
+ """Clears the chat session history."""
149
+ with self ._index .client .pipeline (transaction = False ) as pipe : # type: ignore
150
+ for key in self ._index .client .scan_iter (match = f"{ self ._index .prefix } :*" ): # type: ignore
153
151
pipe .delete (key )
154
152
pipe .execute ()
155
153
156
-
157
154
def delete (self ) -> None :
158
- """ Clear all conversation keys and remove the search index. """
155
+ """Clear all conversation keys and remove the search index."""
159
156
self ._index .delete (drop = True )
160
157
161
-
162
158
def fetch_relevant (
163
159
self ,
164
160
prompt : str ,
165
161
as_text : bool = False ,
166
162
top_k : int = 3 ,
167
163
fall_back : bool = False ,
168
- session_id : str = None ,
169
- user_id : str = None ,
170
- application_id : str = None ,
171
- raw : bool = False
172
- ) -> Union [List [str ], List [Dict [str ,str ]]]:
173
- """ Searches the chat history for information semantically related to
164
+ session_id : Optional [ str ] = None ,
165
+ user_id : Optional [ str ] = None ,
166
+ application_id : Optional [ str ] = None ,
167
+ raw : bool = False ,
168
+ ) -> Union [List [str ], List [Dict [str , str ]]]:
169
+ """Searches the chat history for information semantically related to
174
170
the specified prompt.
175
171
176
172
This method uses vector similarity search with a text prompt as input.
@@ -216,7 +212,7 @@ def fetch_relevant(
216
212
distance_threshold = self ._distance_threshold ,
217
213
num_results = top_k ,
218
214
return_score = True ,
219
- filter_expression = self ._tag_filter
215
+ filter_expression = self ._tag_filter ,
220
216
)
221
217
hits = self ._index .query (query )
222
218
@@ -227,17 +223,16 @@ def fetch_relevant(
227
223
return hits
228
224
return self ._format_context (hits , as_text )
229
225
230
-
231
226
def fetch_recent (
232
227
self ,
233
228
as_text : bool = False ,
234
229
top_k : int = 3 ,
235
- session_id : str = None ,
236
- user_id : str = None ,
237
- application_id : str = None ,
238
- raw = False
239
- ) -> Union [List [str ], List [Dict [str ,str ]]]:
240
- """ Retreive the recent conversation history in sequential order.
230
+ session_id : Optional [ str ] = None ,
231
+ user_id : Optional [ str ] = None ,
232
+ application_id : Optional [ str ] = None ,
233
+ raw : bool = False ,
234
+ ) -> Union [List [str ], List [Dict [str , str ]]]:
235
+ """Retreive the recent conversation history in sequential order.
241
236
242
237
Args:
243
238
as_text bool: Whether to return the conversation as a single string,
@@ -265,27 +260,23 @@ def fetch_recent(
265
260
"timestamp" ,
266
261
]
267
262
268
- count_key = ":" .join ([self ._application_id , self ._user_id , self ._session_id , "count" ])
263
+ count_key = ":" .join (
264
+ [self ._application_id , self ._user_id , self ._session_id , "count" ]
265
+ )
269
266
count = self ._redis_client .get (count_key ) or 0
270
267
last_k_filter = Num ("count" ) > int (count ) - top_k
271
268
combined = self ._tag_filter & last_k_filter
272
269
273
- query = FilterQuery (
274
- return_fields = return_fields ,
275
- filter_expression = combined
276
- )
270
+ query = FilterQuery (return_fields = return_fields , filter_expression = combined )
277
271
hits = self ._index .query (query )
278
272
if raw :
279
273
return hits
280
274
return self ._format_context (hits , as_text )
281
275
282
-
283
276
def _format_context (
284
- self ,
285
- hits : List [Dict [str , Any ]],
286
- as_text : bool
287
- ) -> Union [List [str ], List [Dict [str , str ]]]:
288
- """ Extracts the prompt and response fields from the Redis hashes and
277
+ self , hits : List [Dict [str , Any ]], as_text : bool
278
+ ) -> Union [List [str ], List [Dict [str , str ]]]:
279
+ """Extracts the prompt and response fields from the Redis hashes and
289
280
formats them as either flat dictionaries oor strings.
290
281
291
282
Args:
@@ -298,71 +289,68 @@ def _format_context(
298
289
or list of strings if as_text is false.
299
290
"""
300
291
if hits :
301
- hits .sort (key = lambda x : x [' timestamp' ]) # TODO move sorting to query.py
292
+ hits .sort (key = lambda x : x [" timestamp" ]) # TODO move sorting to query.py
302
293
303
294
if as_text :
304
- statements = [self ._preamble ["_content" ]]
295
+ text_statements = [self ._preamble ["_content" ]]
305
296
for hit in hits :
306
- statements .append (hit ["prompt" ])
307
- statements .append (hit ["response" ])
297
+ text_statements .append (hit ["prompt" ])
298
+ text_statements .append (hit ["response" ])
299
+ return text_statements
308
300
else :
309
301
statements = [self ._preamble ]
310
302
for hit in hits :
311
303
statements .append ({"role" : "_user" , "_content" : hit ["prompt" ]})
312
304
statements .append ({"role" : "_llm" , "_content" : hit ["response" ]})
313
- return statements
314
-
305
+ return statements
315
306
316
307
@property
317
308
def distance_threshold (self ):
318
309
return self ._distance_threshold
319
310
320
-
321
311
def set_distance_threshold (self , threshold ):
322
312
self ._distance_threshold = threshold
323
313
324
-
325
314
def store (self , exchange : Tuple [str , str ]) -> str :
326
- """ Insert a prompt:response pair into the session memory. A timestamp
327
- is associated with each exchange so that they can be later sorted
328
- in sequential ordering after retrieval.
315
+ """Insert a prompt:response pair into the session memory. A timestamp
316
+ is associated with each exchange so that they can be later sorted
317
+ in sequential ordering after retrieval.
329
318
330
- Args:
331
- exchange Tuple[str, str]: The user prompt and corresponding LLM
332
- response.
319
+ Args:
320
+ exchange Tuple[str, str]: The user prompt and corresponding LLM
321
+ response.
333
322
334
- Returns:
335
- str: The Redis key for the entry added to the database.
323
+ Returns:
324
+ str: The Redis key for the entry added to the database.
336
325
"""
337
- count_key = ":" .join ([self ._application_id , self ._user_id , self ._session_id , "count" ])
326
+ count_key = ":" .join (
327
+ [self ._application_id , self ._user_id , self ._session_id , "count" ]
328
+ )
338
329
count = self ._redis_client .incr (count_key )
339
330
vector = self ._vectorizer .embed (exchange [0 ] + exchange [1 ])
340
331
timestamp = int (datetime .now ().timestamp ())
341
332
payload = {
342
- "id" : self .hash_input (exchange [0 ]+ str (timestamp )),
343
- "prompt" : exchange [0 ],
344
- "response" : exchange [1 ],
345
- "timestamp" : timestamp ,
346
- "session_id" : self ._session_id ,
347
- "user_id" : self ._user_id ,
348
- "application_id" : self ._application_id ,
349
- "count" : count ,
350
- "token_count" : 1 , # TODO get actual token count
351
- "combined_vector_field" : array_to_buffer (vector )
333
+ "id" : self .hash_input (exchange [0 ] + str (timestamp )),
334
+ "prompt" : exchange [0 ],
335
+ "response" : exchange [1 ],
336
+ "timestamp" : timestamp ,
337
+ "session_id" : self ._session_id ,
338
+ "user_id" : self ._user_id ,
339
+ "application_id" : self ._application_id ,
340
+ "count" : count ,
341
+ "token_count" : 1 , # TODO get actual token count
342
+ "combined_vector_field" : array_to_buffer (vector ),
352
343
}
353
- key = self ._index .load (data = [payload ])
354
- return key
355
-
344
+ keys = self ._index .load (data = [payload ])
345
+ return keys [0 ]
356
346
357
347
def set_preamble (self , prompt : str ) -> None :
358
- """ Add a preamble statement to the the begining of each session to be
359
- included in each subsequent LLM call.
348
+ """Add a preamble statement to the the begining of each session to be
349
+ included in each subsequent LLM call.
360
350
"""
361
351
self ._preamble = {"role" : "_preamble" , "_content" : prompt }
362
352
# TODO store this in Redis with asigned scope?
363
353
364
-
365
354
def hash_input (self , prompt : str ):
366
355
"""Hashes the input using SHA256."""
367
356
return hashlib .sha256 (prompt .encode ("utf-8" )).hexdigest ()
368
-
0 commit comments