Skip to content

Commit d9ccb26

Browse files
mypy formatting
1 parent c40ee6f commit d9ccb26

File tree

1 file changed

+70
-82
lines changed

1 file changed

+70
-82
lines changed

redisvl/extensions/session_manager/session.py

+70-82
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,32 @@
11
import hashlib
22
from datetime import datetime
33
from typing import Any, Dict, List, Optional, Tuple, Union
4+
45
from redis import Redis
6+
57
from redisvl.index import SearchIndex
68
from redisvl.query import FilterQuery, RangeQuery
7-
from redisvl.query.filter import Tag, Num
9+
from redisvl.query.filter import Num, Tag
810
from redisvl.redis.utils import array_to_buffer
911
from redisvl.schema.schema import IndexSchema
1012
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
1113

14+
1215
class SessionManager:
1316
def __init__(
1417
self,
1518
name: str,
1619
session_id: str,
1720
user_id: str,
1821
application_id: str,
19-
scope: str = 'session',
22+
scope: str = "session",
2023
prefix: Optional[str] = None,
2124
vectorizer: Optional[BaseVectorizer] = None,
2225
distance_threshold: float = 0.3,
2326
redis_client: Optional[Redis] = None,
24-
preamble: str = ''
25-
):
26-
""" Initialize session memory with index
27+
preamble: str = "",
28+
):
29+
"""Initialize session memory with index
2730
2831
Session Manager stores the current and previous user text prompts and
2932
LLM responses to allow for enriching future prompts with session
@@ -88,7 +91,7 @@ def __init__(
8891
"distance_metric": "cosine",
8992
"algorithm": "flat",
9093
},
91-
},
94+
},
9295
]
9396
)
9497

@@ -104,22 +107,18 @@ def __init__(
104107
self._index.create(overwrite=False)
105108

106109
self._tag_filter = Tag("application_id") == self._application_id
107-
if self._scope == 'user':
110+
if self._scope == "user":
108111
user_filter = Tag("user_id") == self._user_id
109112
self._tag_filter = self._tag_filter & user_filter
110-
if self._scope == 'session':
113+
if self._scope == "session":
111114
session_filter = Tag("session_id") == self._session_id
112115
user_filter = Tag("user_id") == self._user_id
113116
self._tag_filter = self._tag_filter & user_filter & session_filter
114117

115-
116118
def set_scope(
117-
self,
118-
session_id: str = None,
119-
user_id: str = None,
120-
application_id: str = None
121-
) -> None:
122-
""" Set the tag filter to apply to querries based on the desired scope.
119+
self, session_id: Optional[str] = None, user_id: Optional[str] = None, application_id: Optional[str] = None
120+
) -> None:
121+
"""Set the tag filter to apply to querries based on the desired scope.
123122
124123
This new scope persists until another call to set_scope is made, or if
125124
scope specified in calls to fetch_recent or fetch_relevant.
@@ -135,7 +134,7 @@ def set_scope(
135134
if not (session_id or user_id or application_id):
136135
return
137136

138-
tag_filter = Tag('application_id') == []
137+
tag_filter = Tag("application_id") == []
139138
if application_id:
140139
tag_filter = tag_filter & (Tag("application_id") == application_id)
141140
if user_id:
@@ -145,32 +144,29 @@ def set_scope(
145144

146145
self._tag_filter = tag_filter
147146

148-
149147
def clear(self) -> None:
150-
""" Clears the chat session history. """
151-
with self._index.client.pipeline(transaction=False) as pipe:
152-
for key in self._index.client.scan_iter(match=f"{self._index.prefix}:*"):
148+
"""Clears the chat session history."""
149+
with self._index.client.pipeline(transaction=False) as pipe: # type: ignore
150+
for key in self._index.client.scan_iter(match=f"{self._index.prefix}:*"): # type: ignore
153151
pipe.delete(key)
154152
pipe.execute()
155153

156-
157154
def delete(self) -> None:
158-
""" Clear all conversation keys and remove the search index. """
155+
"""Clear all conversation keys and remove the search index."""
159156
self._index.delete(drop=True)
160157

161-
162158
def fetch_relevant(
163159
self,
164160
prompt: str,
165161
as_text: bool = False,
166162
top_k: int = 3,
167163
fall_back: bool = False,
168-
session_id: str = None,
169-
user_id: str = None,
170-
application_id: str = None,
171-
raw: bool = False
172-
) -> Union[List[str], List[Dict[str,str]]]:
173-
""" Searches the chat history for information semantically related to
164+
session_id: Optional[str] = None,
165+
user_id: Optional[str] = None,
166+
application_id: Optional[str] = None,
167+
raw: bool = False,
168+
) -> Union[List[str], List[Dict[str, str]]]:
169+
"""Searches the chat history for information semantically related to
174170
the specified prompt.
175171
176172
This method uses vector similarity search with a text prompt as input.
@@ -216,7 +212,7 @@ def fetch_relevant(
216212
distance_threshold=self._distance_threshold,
217213
num_results=top_k,
218214
return_score=True,
219-
filter_expression=self._tag_filter
215+
filter_expression=self._tag_filter,
220216
)
221217
hits = self._index.query(query)
222218

@@ -227,17 +223,16 @@ def fetch_relevant(
227223
return hits
228224
return self._format_context(hits, as_text)
229225

230-
231226
def fetch_recent(
232227
self,
233228
as_text: bool = False,
234229
top_k: int = 3,
235-
session_id: str = None,
236-
user_id: str = None,
237-
application_id: str = None,
238-
raw = False
239-
) -> Union[List[str], List[Dict[str,str]]]:
240-
""" Retreive the recent conversation history in sequential order.
230+
session_id: Optional[str] = None,
231+
user_id: Optional[str] = None,
232+
application_id: Optional[str] = None,
233+
raw: bool = False,
234+
) -> Union[List[str], List[Dict[str, str]]]:
235+
"""Retreive the recent conversation history in sequential order.
241236
242237
Args:
243238
as_text bool: Whether to return the conversation as a single string,
@@ -265,27 +260,23 @@ def fetch_recent(
265260
"timestamp",
266261
]
267262

268-
count_key = ":".join([self._application_id, self._user_id, self._session_id, "count"])
263+
count_key = ":".join(
264+
[self._application_id, self._user_id, self._session_id, "count"]
265+
)
269266
count = self._redis_client.get(count_key) or 0
270267
last_k_filter = Num("count") > int(count) - top_k
271268
combined = self._tag_filter & last_k_filter
272269

273-
query = FilterQuery(
274-
return_fields=return_fields,
275-
filter_expression=combined
276-
)
270+
query = FilterQuery(return_fields=return_fields, filter_expression=combined)
277271
hits = self._index.query(query)
278272
if raw:
279273
return hits
280274
return self._format_context(hits, as_text)
281275

282-
283276
def _format_context(
284-
self,
285-
hits: List[Dict[str, Any]],
286-
as_text: bool
287-
) -> Union[List[str], List[Dict[str, str]]]:
288-
""" Extracts the prompt and response fields from the Redis hashes and
277+
self, hits: List[Dict[str, Any]], as_text: bool
278+
) -> Union[List[str], List[Dict[str, str]]]:
279+
"""Extracts the prompt and response fields from the Redis hashes and
289280
formats them as either flat dictionaries oor strings.
290281
291282
Args:
@@ -298,71 +289,68 @@ def _format_context(
298289
or list of strings if as_text is false.
299290
"""
300291
if hits:
301-
hits.sort(key=lambda x: x['timestamp']) # TODO move sorting to query.py
292+
hits.sort(key=lambda x: x["timestamp"]) # TODO move sorting to query.py
302293

303294
if as_text:
304-
statements = [self._preamble["_content"]]
295+
text_statements = [self._preamble["_content"]]
305296
for hit in hits:
306-
statements.append(hit["prompt"])
307-
statements.append(hit["response"])
297+
text_statements.append(hit["prompt"])
298+
text_statements.append(hit["response"])
299+
return text_statements
308300
else:
309301
statements = [self._preamble]
310302
for hit in hits:
311303
statements.append({"role": "_user", "_content": hit["prompt"]})
312304
statements.append({"role": "_llm", "_content": hit["response"]})
313-
return statements
314-
305+
return statements
315306

316307
@property
317308
def distance_threshold(self):
318309
return self._distance_threshold
319310

320-
321311
def set_distance_threshold(self, threshold):
322312
self._distance_threshold = threshold
323313

324-
325314
def store(self, exchange: Tuple[str, str]) -> str:
326-
""" Insert a prompt:response pair into the session memory. A timestamp
327-
is associated with each exchange so that they can be later sorted
328-
in sequential ordering after retrieval.
315+
"""Insert a prompt:response pair into the session memory. A timestamp
316+
is associated with each exchange so that they can be later sorted
317+
in sequential ordering after retrieval.
329318
330-
Args:
331-
exchange Tuple[str, str]: The user prompt and corresponding LLM
332-
response.
319+
Args:
320+
exchange Tuple[str, str]: The user prompt and corresponding LLM
321+
response.
333322
334-
Returns:
335-
str: The Redis key for the entry added to the database.
323+
Returns:
324+
str: The Redis key for the entry added to the database.
336325
"""
337-
count_key = ":".join([self._application_id, self._user_id, self._session_id, "count"])
326+
count_key = ":".join(
327+
[self._application_id, self._user_id, self._session_id, "count"]
328+
)
338329
count = self._redis_client.incr(count_key)
339330
vector = self._vectorizer.embed(exchange[0] + exchange[1])
340331
timestamp = int(datetime.now().timestamp())
341332
payload = {
342-
"id": self.hash_input(exchange[0]+str(timestamp)),
343-
"prompt": exchange[0],
344-
"response": exchange[1],
345-
"timestamp": timestamp,
346-
"session_id": self._session_id,
347-
"user_id": self._user_id,
348-
"application_id": self._application_id,
349-
"count": count,
350-
"token_count": 1, #TODO get actual token count
351-
"combined_vector_field": array_to_buffer(vector)
333+
"id": self.hash_input(exchange[0] + str(timestamp)),
334+
"prompt": exchange[0],
335+
"response": exchange[1],
336+
"timestamp": timestamp,
337+
"session_id": self._session_id,
338+
"user_id": self._user_id,
339+
"application_id": self._application_id,
340+
"count": count,
341+
"token_count": 1, # TODO get actual token count
342+
"combined_vector_field": array_to_buffer(vector),
352343
}
353-
key = self._index.load(data=[payload])
354-
return key
355-
344+
keys = self._index.load(data=[payload])
345+
return keys[0]
356346

357347
def set_preamble(self, prompt: str) -> None:
358-
""" Add a preamble statement to the the begining of each session to be
359-
included in each subsequent LLM call.
348+
"""Add a preamble statement to the the begining of each session to be
349+
included in each subsequent LLM call.
360350
"""
361351
self._preamble = {"role": "_preamble", "_content": prompt}
362352
# TODO store this in Redis with asigned scope?
363353

364-
365354
def hash_input(self, prompt: str):
366355
"""Hashes the input using SHA256."""
367356
return hashlib.sha256(prompt.encode("utf-8")).hexdigest()
368-

0 commit comments

Comments
 (0)