Skip to content

Commit a5854be

Browse files
Feat/raae 194/standard session search index (#191)
Replaces the list implementation of Standard Session Manager to use a hash implementation.
1 parent 67eee3d commit a5854be

File tree

3 files changed

+184
-133
lines changed

3 files changed

+184
-133
lines changed

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,34 @@
1212
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
1313

1414

15+
class SemanticSessionIndexSchema(IndexSchema):
16+
17+
@classmethod
18+
def from_params(cls, name: str, prefix: str, vectorizer_dims: int):
19+
20+
return cls(
21+
index={"name": name, "prefix": prefix}, # type: ignore
22+
fields=[ # type: ignore
23+
{"name": "role", "type": "text"},
24+
{"name": "content", "type": "text"},
25+
{"name": "tool_call_id", "type": "text"},
26+
{"name": "timestamp", "type": "numeric"},
27+
{"name": "session_tag", "type": "tag"},
28+
{"name": "user_tag", "type": "tag"},
29+
{
30+
"name": "vector_field",
31+
"type": "vector",
32+
"attrs": {
33+
"dims": vectorizer_dims,
34+
"datatype": "float32",
35+
"distance_metric": "cosine",
36+
"algorithm": "flat",
37+
},
38+
},
39+
],
40+
)
41+
42+
1543
class SemanticSessionManager(BaseSessionManager):
1644
session_field_name: str = "session_tag"
1745
user_field_name: str = "user_tag"
@@ -68,27 +96,8 @@ def __init__(
6896

6997
self.set_distance_threshold(distance_threshold)
7098

71-
schema = IndexSchema.from_dict({"index": {"name": name, "prefix": prefix}})
72-
73-
schema.add_fields(
74-
[
75-
{"name": "role", "type": "text"},
76-
{"name": "content", "type": "text"},
77-
{"name": "tool_call_id", "type": "text"},
78-
{"name": "timestamp", "type": "numeric"},
79-
{"name": "session_tag", "type": "tag"},
80-
{"name": "user_tag", "type": "tag"},
81-
{
82-
"name": "vector_field",
83-
"type": "vector",
84-
"attrs": {
85-
"dims": self._vectorizer.dims,
86-
"datatype": "float32",
87-
"distance_metric": "cosine",
88-
"algorithm": "flat",
89-
},
90-
},
91-
]
99+
schema = SemanticSessionIndexSchema.from_params(
100+
name, prefix, self._vectorizer.dims
92101
)
93102

94103
self._index = SearchIndex(schema=schema)
@@ -260,19 +269,18 @@ def get_recent(
260269
"""Retreive the recent conversation history in sequential order.
261270
262271
Args:
263-
as_text (bool): Whether to return the conversation as a single string,
264-
or list of alternating prompts and responses.
265272
top_k (int): The number of previous exchanges to return. Default is 5.
266-
Note that one exchange contains both a prompt and a respoonse.
267273
session_tag (str): Tag to be added to entries to link to a specific
268274
session.
269275
user_tag (str): Tag to be added to entries to link to a specific user.
276+
as_text (bool): Whether to return the conversation as a single string,
277+
or list of alternating prompts and responses.
270278
raw (bool): Whether to return the full Redis hash entry or just the
271279
prompt and response
272280
273281
Returns:
274282
Union[str, List[str]]: A single string transcription of the session
275-
or list of strings if as_text is false.
283+
or list of strings if as_text is false.
276284
277285
Raises:
278286
ValueError: if top_k is not an integer greater than or equal to 0.

redisvl/extensions/session_manager/standard_session.py

Lines changed: 115 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,40 @@
55
from redis import Redis
66

77
from redisvl.extensions.session_manager import BaseSessionManager
8-
from redisvl.redis.connection import RedisConnectionFactory
8+
from redisvl.index import SearchIndex
9+
from redisvl.query import FilterQuery
10+
from redisvl.query.filter import Tag
11+
from redisvl.schema.schema import IndexSchema
12+
13+
14+
class StandardSessionIndexSchema(IndexSchema):
15+
16+
@classmethod
17+
def from_params(cls, name: str, prefix: str):
18+
19+
return cls(
20+
index={"name": name, "prefix": prefix}, # type: ignore
21+
fields=[ # type: ignore
22+
{"name": "role", "type": "text"},
23+
{"name": "content", "type": "text"},
24+
{"name": "tool_call_id", "type": "text"},
25+
{"name": "timestamp", "type": "numeric"},
26+
{"name": "session_tag", "type": "tag"},
27+
{"name": "user_tag", "type": "tag"},
28+
],
29+
)
930

1031

1132
class StandardSessionManager(BaseSessionManager):
33+
session_field_name: str = "session_tag"
34+
user_field_name: str = "user_tag"
1235

1336
def __init__(
1437
self,
1538
name: str,
1639
session_tag: str,
1740
user_tag: str,
41+
prefix: Optional[str] = None,
1842
redis_client: Optional[Redis] = None,
1943
redis_url: str = "redis://localhost:6379",
2044
connection_kwargs: Dict[str, Any] = {},
@@ -29,9 +53,11 @@ def __init__(
2953
3054
Args:
3155
name (str): The name of the session manager index.
32-
session_tag (str): Tag to be added to entries to link to a specific
56+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
3357
session.
34-
user_tag (str): Tag to be added to entries to link to a specific user.
58+
user_tag (Optional[str]): Tag to be added to entries to link to a specific user.
59+
prefix (Optional[str]): Prefix for the keys for this session data.
60+
Defaults to None and will be replaced with the index name.
3561
redis_client (Optional[Redis]): A Redis client instance. Defaults to
3662
None.
3763
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
@@ -44,14 +70,18 @@ def __init__(
4470
"""
4571
super().__init__(name, session_tag, user_tag)
4672

73+
prefix = prefix or name
74+
75+
schema = StandardSessionIndexSchema.from_params(name, prefix)
76+
self._index = SearchIndex(schema=schema)
77+
4778
# handle redis connection
4879
if redis_client:
49-
self._client = redis_client
80+
self._index.set_client(redis_client)
5081
elif redis_url:
51-
self._client = RedisConnectionFactory.get_redis_connection(
52-
redis_url, **connection_kwargs
53-
)
54-
RedisConnectionFactory.validate_sync_redis(self._client)
82+
self._index.connect(redis_url=redis_url, **connection_kwargs)
83+
84+
self._index.create(overwrite=False)
5585

5686
self.set_scope(session_tag, user_tag)
5787

@@ -63,27 +93,35 @@ def set_scope(
6393
"""Set the filter to apply to queries based on the desired scope.
6494
6595
This new scope persists until another call to set_scope is made, or if
66-
scope is specified in calls to get_recent.
96+
scope specified in calls to get_recent or get_relevant.
6797
6898
Args:
6999
session_tag (str): Id of the specific session to filter to. Default is
70-
None, which means session_tag will be unchanged.
100+
None, which means all sessions will be in scope.
71101
user_tag (str): Id of the specific user to filter to. Default is None,
72-
which means user_tag will be unchanged.
102+
which means all users will be in scope.
73103
"""
74104
if not (session_tag or user_tag):
75105
return
76-
77106
self._session_tag = session_tag or self._session_tag
78107
self._user_tag = user_tag or self._user_tag
108+
tag_filter = Tag(self.user_field_name) == []
109+
if user_tag:
110+
tag_filter = tag_filter & (Tag(self.user_field_name) == self._user_tag)
111+
if session_tag:
112+
tag_filter = tag_filter & (
113+
Tag(self.session_field_name) == self._session_tag
114+
)
115+
116+
self._tag_filter = tag_filter
79117

80118
def clear(self) -> None:
81119
"""Clears the chat session history."""
82-
self._client.delete(self.key)
120+
self._index.clear()
83121

84122
def delete(self) -> None:
85-
"""Clears the chat session history."""
86-
self._client.delete(self.key)
123+
"""Clear all conversation keys and remove the search index."""
124+
self._index.delete(drop=True)
87125

88126
def drop(self, id_field: Optional[str] = None) -> None:
89127
"""Remove a specific exchange from the conversation history.
@@ -93,19 +131,36 @@ def drop(self, id_field: Optional[str] = None) -> None:
93131
If None then the last entry is deleted.
94132
"""
95133
if id_field:
96-
messages = self._client.lrange(self.key, 0, -1)
97-
messages = [json.loads(msg) for msg in messages]
98-
messages = [msg for msg in messages if msg["id_field"] != id_field]
99-
messages = [json.dumps(msg) for msg in messages]
100-
self.clear()
101-
self._client.rpush(self.key, *messages)
134+
sep = self._index.key_separator
135+
key = sep.join([self._index.schema.index.name, id_field])
102136
else:
103-
self._client.rpop(self.key)
137+
key = self.get_recent(top_k=1, raw=True)[0]["id"] # type: ignore
138+
self._index.client.delete(key) # type: ignore
104139

105140
@property
106141
def messages(self) -> Union[List[str], List[Dict[str, str]]]:
107142
"""Returns the full chat history."""
108-
return self.get_recent(top_k=-1)
143+
# TODO raw or as_text?
144+
return_fields = [
145+
self.id_field_name,
146+
self.session_field_name,
147+
self.user_field_name,
148+
self.role_field_name,
149+
self.content_field_name,
150+
self.tool_field_name,
151+
self.timestamp_field_name,
152+
]
153+
154+
query = FilterQuery(
155+
filter_expression=self._tag_filter,
156+
return_fields=return_fields,
157+
)
158+
159+
sorted_query = query.query
160+
sorted_query.sort_by(self.timestamp_field_name, asc=True)
161+
hits = self._index.search(sorted_query, query.params).docs
162+
163+
return self._format_context(hits, as_text=False)
109164

110165
def get_recent(
111166
self,
@@ -119,7 +174,6 @@ def get_recent(
119174
120175
Args:
121176
top_k (int): The number of previous messages to return. Default is 5.
122-
To get all messages set top_k = -1.
123177
session_tag (str): Tag to be added to entries to link to a specific
124178
session.
125179
user_tag (str): Tag to be added to entries to link to a specific user.
@@ -133,24 +187,35 @@ def get_recent(
133187
or list of strings if as_text is false.
134188
135189
Raises:
136-
ValueError: if top_k is not an integer greater than or equal to -1.
190+
ValueError: if top_k is not an integer greater than or equal to 0.
137191
"""
138-
if type(top_k) != int or top_k < -1:
139-
raise ValueError("top_k must be an integer greater than or equal to -1")
140-
if top_k == 0:
141-
return []
142-
elif top_k == -1:
143-
top_k = 0
192+
if type(top_k) != int or top_k < 0:
193+
raise ValueError("top_k must be an integer greater than or equal to 0")
194+
144195
self.set_scope(session_tag, user_tag)
145-
messages = self._client.lrange(self.key, -top_k, -1)
146-
messages = [json.loads(msg) for msg in messages]
147-
if raw:
148-
return messages
149-
return self._format_context(messages, as_text)
196+
return_fields = [
197+
self.id_field_name,
198+
self.session_field_name,
199+
self.user_field_name,
200+
self.role_field_name,
201+
self.content_field_name,
202+
self.tool_field_name,
203+
self.timestamp_field_name,
204+
]
205+
206+
query = FilterQuery(
207+
filter_expression=self._tag_filter,
208+
return_fields=return_fields,
209+
num_results=top_k,
210+
)
150211

151-
@property
152-
def key(self):
153-
return ":".join([self._name, self._user_tag, self._session_tag])
212+
sorted_query = query.query
213+
sorted_query.sort_by(self.timestamp_field_name, asc=False)
214+
hits = self._index.search(sorted_query, query.params).docs
215+
216+
if raw:
217+
return hits[::-1]
218+
return self._format_context(hits[::-1], as_text)
154219

155220
def store(self, prompt: str, response: str) -> None:
156221
"""Insert a prompt:response pair into the session memory. A timestamp
@@ -162,7 +227,10 @@ def store(self, prompt: str, response: str) -> None:
162227
response (str): The corresponding LLM response.
163228
"""
164229
self.add_messages(
165-
[{"role": "user", "content": prompt}, {"role": "llm", "content": response}]
230+
[
231+
{self.role_field_name: "user", self.content_field_name: prompt},
232+
{self.role_field_name: "llm", self.content_field_name: response},
233+
]
166234
)
167235

168236
def add_messages(self, messages: List[Dict[str, str]]) -> None:
@@ -173,23 +241,23 @@ def add_messages(self, messages: List[Dict[str, str]]) -> None:
173241
Args:
174242
messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
175243
"""
244+
sep = self._index.key_separator
176245
payloads = []
177246
for message in messages:
178247
timestamp = time()
248+
id_field = sep.join([self._user_tag, self._session_tag, str(timestamp)])
179249
payload = {
180-
self.id_field_name: ":".join(
181-
[self._user_tag, self._session_tag, str(timestamp)]
182-
),
250+
self.id_field_name: id_field,
183251
self.role_field_name: message[self.role_field_name],
184252
self.content_field_name: message[self.content_field_name],
185253
self.timestamp_field_name: timestamp,
254+
self.session_field_name: self._session_tag,
255+
self.user_field_name: self._user_tag,
186256
}
187257
if self.tool_field_name in message:
188258
payload.update({self.tool_field_name: message[self.tool_field_name]})
189-
190-
payloads.append(json.dumps(payload))
191-
192-
self._client.rpush(self.key, *payloads)
259+
payloads.append(payload)
260+
self._index.load(data=payloads, id_field=self.id_field_name)
193261

194262
def add_message(self, message: Dict[str, str]) -> None:
195263
"""Insert a single prompt or response into the session memory.

0 commit comments

Comments
 (0)