diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index e91f4754..fdea13cd 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -12,6 +12,34 @@ from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +class SemanticSessionIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, prefix: str, vectorizer_dims: int): + + return cls( + index={"name": name, "prefix": prefix}, # type: ignore + fields=[ # type: ignore + {"name": "role", "type": "text"}, + {"name": "content", "type": "text"}, + {"name": "tool_call_id", "type": "text"}, + {"name": "timestamp", "type": "numeric"}, + {"name": "session_tag", "type": "tag"}, + {"name": "user_tag", "type": "tag"}, + { + "name": "vector_field", + "type": "vector", + "attrs": { + "dims": vectorizer_dims, + "datatype": "float32", + "distance_metric": "cosine", + "algorithm": "flat", + }, + }, + ], + ) + + class SemanticSessionManager(BaseSessionManager): session_field_name: str = "session_tag" user_field_name: str = "user_tag" @@ -68,27 +96,8 @@ def __init__( self.set_distance_threshold(distance_threshold) - schema = IndexSchema.from_dict({"index": {"name": name, "prefix": prefix}}) - - schema.add_fields( - [ - {"name": "role", "type": "text"}, - {"name": "content", "type": "text"}, - {"name": "tool_call_id", "type": "text"}, - {"name": "timestamp", "type": "numeric"}, - {"name": "session_tag", "type": "tag"}, - {"name": "user_tag", "type": "tag"}, - { - "name": "vector_field", - "type": "vector", - "attrs": { - "dims": self._vectorizer.dims, - "datatype": "float32", - "distance_metric": "cosine", - "algorithm": "flat", - }, - }, - ] + schema = SemanticSessionIndexSchema.from_params( + name, prefix, self._vectorizer.dims ) self._index = SearchIndex(schema=schema) @@ -260,19 +269,18 @@ def get_recent( """Retreive the recent conversation history in sequential order. Args: - as_text (bool): Whether to return the conversation as a single string, - or list of alternating prompts and responses. top_k (int): The number of previous exchanges to return. Default is 5. - Note that one exchange contains both a prompt and a respoonse. session_tag (str): Tag to be added to entries to link to a specific session. user_tag (str): Tag to be added to entries to link to a specific user. + as_text (bool): Whether to return the conversation as a single string, + or list of alternating prompts and responses. raw (bool): Whether to return the full Redis hash entry or just the prompt and response Returns: Union[str, List[str]]: A single string transcription of the session - or list of strings if as_text is false. + or list of strings if as_text is false. Raises: ValueError: if top_k is not an integer greater than or equal to 0. diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index a9e67fc6..640f13f9 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -5,16 +5,40 @@ from redis import Redis from redisvl.extensions.session_manager import BaseSessionManager -from redisvl.redis.connection import RedisConnectionFactory +from redisvl.index import SearchIndex +from redisvl.query import FilterQuery +from redisvl.query.filter import Tag +from redisvl.schema.schema import IndexSchema + + +class StandardSessionIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, prefix: str): + + return cls( + index={"name": name, "prefix": prefix}, # type: ignore + fields=[ # type: ignore + {"name": "role", "type": "text"}, + {"name": "content", "type": "text"}, + {"name": "tool_call_id", "type": "text"}, + {"name": "timestamp", "type": "numeric"}, + {"name": "session_tag", "type": "tag"}, + {"name": "user_tag", "type": "tag"}, + ], + ) class StandardSessionManager(BaseSessionManager): + session_field_name: str = "session_tag" + user_field_name: str = "user_tag" def __init__( self, name: str, session_tag: str, user_tag: str, + prefix: Optional[str] = None, redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", connection_kwargs: Dict[str, Any] = {}, @@ -29,9 +53,11 @@ def __init__( Args: name (str): The name of the session manager index. - session_tag (str): Tag to be added to entries to link to a specific + session_tag (Optional[str]): Tag to be added to entries to link to a specific session. - user_tag (str): Tag to be added to entries to link to a specific user. + user_tag (Optional[str]): Tag to be added to entries to link to a specific user. + prefix (Optional[str]): Prefix for the keys for this session data. + Defaults to None and will be replaced with the index name. redis_client (Optional[Redis]): A Redis client instance. Defaults to None. redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. @@ -44,14 +70,18 @@ def __init__( """ super().__init__(name, session_tag, user_tag) + prefix = prefix or name + + schema = StandardSessionIndexSchema.from_params(name, prefix) + self._index = SearchIndex(schema=schema) + # handle redis connection if redis_client: - self._client = redis_client + self._index.set_client(redis_client) elif redis_url: - self._client = RedisConnectionFactory.get_redis_connection( - redis_url, **connection_kwargs - ) - RedisConnectionFactory.validate_sync_redis(self._client) + self._index.connect(redis_url=redis_url, **connection_kwargs) + + self._index.create(overwrite=False) self.set_scope(session_tag, user_tag) @@ -63,27 +93,35 @@ def set_scope( """Set the filter to apply to queries based on the desired scope. This new scope persists until another call to set_scope is made, or if - scope is specified in calls to get_recent. + scope specified in calls to get_recent or get_relevant. Args: session_tag (str): Id of the specific session to filter to. Default is - None, which means session_tag will be unchanged. + None, which means all sessions will be in scope. user_tag (str): Id of the specific user to filter to. Default is None, - which means user_tag will be unchanged. + which means all users will be in scope. """ if not (session_tag or user_tag): return - self._session_tag = session_tag or self._session_tag self._user_tag = user_tag or self._user_tag + tag_filter = Tag(self.user_field_name) == [] + if user_tag: + tag_filter = tag_filter & (Tag(self.user_field_name) == self._user_tag) + if session_tag: + tag_filter = tag_filter & ( + Tag(self.session_field_name) == self._session_tag + ) + + self._tag_filter = tag_filter def clear(self) -> None: """Clears the chat session history.""" - self._client.delete(self.key) + self._index.clear() def delete(self) -> None: - """Clears the chat session history.""" - self._client.delete(self.key) + """Clear all conversation keys and remove the search index.""" + self._index.delete(drop=True) def drop(self, id_field: Optional[str] = None) -> None: """Remove a specific exchange from the conversation history. @@ -93,19 +131,36 @@ def drop(self, id_field: Optional[str] = None) -> None: If None then the last entry is deleted. """ if id_field: - messages = self._client.lrange(self.key, 0, -1) - messages = [json.loads(msg) for msg in messages] - messages = [msg for msg in messages if msg["id_field"] != id_field] - messages = [json.dumps(msg) for msg in messages] - self.clear() - self._client.rpush(self.key, *messages) + sep = self._index.key_separator + key = sep.join([self._index.schema.index.name, id_field]) else: - self._client.rpop(self.key) + key = self.get_recent(top_k=1, raw=True)[0]["id"] # type: ignore + self._index.client.delete(key) # type: ignore @property def messages(self) -> Union[List[str], List[Dict[str, str]]]: """Returns the full chat history.""" - return self.get_recent(top_k=-1) + # TODO raw or as_text? + return_fields = [ + self.id_field_name, + self.session_field_name, + self.user_field_name, + self.role_field_name, + self.content_field_name, + self.tool_field_name, + self.timestamp_field_name, + ] + + query = FilterQuery( + filter_expression=self._tag_filter, + return_fields=return_fields, + ) + + sorted_query = query.query + sorted_query.sort_by(self.timestamp_field_name, asc=True) + hits = self._index.search(sorted_query, query.params).docs + + return self._format_context(hits, as_text=False) def get_recent( self, @@ -119,7 +174,6 @@ def get_recent( Args: top_k (int): The number of previous messages to return. Default is 5. - To get all messages set top_k = -1. session_tag (str): Tag to be added to entries to link to a specific session. user_tag (str): Tag to be added to entries to link to a specific user. @@ -133,24 +187,35 @@ def get_recent( or list of strings if as_text is false. Raises: - ValueError: if top_k is not an integer greater than or equal to -1. + ValueError: if top_k is not an integer greater than or equal to 0. """ - if type(top_k) != int or top_k < -1: - raise ValueError("top_k must be an integer greater than or equal to -1") - if top_k == 0: - return [] - elif top_k == -1: - top_k = 0 + if type(top_k) != int or top_k < 0: + raise ValueError("top_k must be an integer greater than or equal to 0") + self.set_scope(session_tag, user_tag) - messages = self._client.lrange(self.key, -top_k, -1) - messages = [json.loads(msg) for msg in messages] - if raw: - return messages - return self._format_context(messages, as_text) + return_fields = [ + self.id_field_name, + self.session_field_name, + self.user_field_name, + self.role_field_name, + self.content_field_name, + self.tool_field_name, + self.timestamp_field_name, + ] + + query = FilterQuery( + filter_expression=self._tag_filter, + return_fields=return_fields, + num_results=top_k, + ) - @property - def key(self): - return ":".join([self._name, self._user_tag, self._session_tag]) + sorted_query = query.query + sorted_query.sort_by(self.timestamp_field_name, asc=False) + hits = self._index.search(sorted_query, query.params).docs + + if raw: + return hits[::-1] + return self._format_context(hits[::-1], as_text) def store(self, prompt: str, response: str) -> None: """Insert a prompt:response pair into the session memory. A timestamp @@ -162,7 +227,10 @@ def store(self, prompt: str, response: str) -> None: response (str): The corresponding LLM response. """ self.add_messages( - [{"role": "user", "content": prompt}, {"role": "llm", "content": response}] + [ + {self.role_field_name: "user", self.content_field_name: prompt}, + {self.role_field_name: "llm", self.content_field_name: response}, + ] ) def add_messages(self, messages: List[Dict[str, str]]) -> None: @@ -173,23 +241,23 @@ def add_messages(self, messages: List[Dict[str, str]]) -> None: Args: messages (List[Dict[str, str]]): The list of user prompts and LLM responses. """ + sep = self._index.key_separator payloads = [] for message in messages: timestamp = time() + id_field = sep.join([self._user_tag, self._session_tag, str(timestamp)]) payload = { - self.id_field_name: ":".join( - [self._user_tag, self._session_tag, str(timestamp)] - ), + self.id_field_name: id_field, self.role_field_name: message[self.role_field_name], self.content_field_name: message[self.content_field_name], self.timestamp_field_name: timestamp, + self.session_field_name: self._session_tag, + self.user_field_name: self._user_tag, } if self.tool_field_name in message: payload.update({self.tool_field_name: message[self.tool_field_name]}) - - payloads.append(json.dumps(payload)) - - self._client.rpush(self.key, *payloads) + payloads.append(payload) + self._index.load(data=payloads, id_field=self.id_field_name) def add_message(self, message: Dict[str, str]) -> None: """Insert a single prompt or response into the session memory. diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index c1f277c8..b3df7d17 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -30,20 +30,11 @@ def semantic_session(app_name, user_tag, session_tag, client): session.delete() -# test standard session manager -def test_key_creation(client): - # test default key creation - session = StandardSessionManager( - name="test_app", session_tag="123", user_tag="abc", redis_client=client - ) - assert session.key == "test_app:abc:123" - - def test_specify_redis_client(client): session = StandardSessionManager( name="test_app", session_tag="abc", user_tag="123", redis_client=client ) - assert isinstance(session._client, type(client)) + assert isinstance(session._index.client, type(client)) def test_specify_redis_url(client): @@ -53,7 +44,7 @@ def test_specify_redis_url(client): user_tag="123", redis_url="redis://localhost:6379", ) - assert isinstance(session._client, type(client)) + assert isinstance(session._index.client, type(client)) def test_standard_bad_connection_info(): @@ -96,11 +87,8 @@ def test_standard_store_and_get(standard_session): no_context = standard_session.get_recent(top_k=0) assert len(no_context) == 0 - # test that the full context is returned when top_k is -1 - full_context = standard_session.get_recent(top_k=-1) - assert len(full_context) == 10 - # test that order is maintained + full_context = standard_session.get_recent(top_k=10) assert full_context == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, @@ -150,8 +138,8 @@ def test_standard_add_and_get(standard_session): "tool_call_id": "tool call two", } ) - standard_session.add_message({"role": "user", "content": "fourth prompt"}) - standard_session.add_message({"role": "llm", "content": "fourth response"}) + standard_session.add_message({"role": "user", "content": "third prompt"}) + standard_session.add_message({"role": "llm", "content": "third response"}) # test default context history size default_context = standard_session.get_recent() @@ -162,12 +150,12 @@ def test_standard_add_and_get(standard_session): assert len(partial_context) == 3 assert partial_context == [ {"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"}, - {"role": "user", "content": "fourth prompt"}, - {"role": "llm", "content": "fourth response"}, + {"role": "user", "content": "third prompt"}, + {"role": "llm", "content": "third response"}, ] # test that order is maintained - full_context = standard_session.get_recent(top_k=-1) + full_context = standard_session.get_recent(top_k=10) assert full_context == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, @@ -175,8 +163,8 @@ def test_standard_add_and_get(standard_session): {"role": "llm", "content": "second response"}, {"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"}, {"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"}, - {"role": "user", "content": "fourth prompt"}, - {"role": "llm", "content": "fourth response"}, + {"role": "user", "content": "third prompt"}, + {"role": "llm", "content": "third response"}, ] @@ -218,7 +206,7 @@ def test_standard_add_messages(standard_session): ] # test that order is maintained - full_context = standard_session.get_recent(top_k=-1) + full_context = standard_session.get_recent(top_k=10) assert full_context == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, @@ -253,34 +241,37 @@ def test_standard_messages_property(standard_session): def test_standard_set_scope(standard_session, app_name, user_tag, session_tag): # test calling set_scope with no params does not change scope - current_key = standard_session.key + standard_session.store("some prompt", "some response") standard_session.set_scope() - assert standard_session.key == current_key + context = standard_session.get_recent() + assert context == [ + {"role": "user", "content": "some prompt"}, + {"role": "llm", "content": "some response"}, + ] - # test passing either user_tag or session_tag only changes corresponding value + # test that changing user and session id does indeed change access scope new_user = "def" standard_session.set_scope(user_tag=new_user) - assert standard_session.key == f"{app_name}:{new_user}:{session_tag}" - - new_session = "456" - standard_session.set_scope(session_tag=new_session) - assert standard_session.key == f"{app_name}:{new_user}:{new_session}" - - # test that changing user and session id does indeed change access scope standard_session.store("new user prompt", "new user response") + context = standard_session.get_recent() + assert context == [ + {"role": "user", "content": "new user prompt"}, + {"role": "llm", "content": "new user response"}, + ] + + # test that previous user and session data is still accessible + previous_user = "abc" + standard_session.set_scope(user_tag=previous_user) + context = standard_session.get_recent() + assert context == [ + {"role": "user", "content": "some prompt"}, + {"role": "llm", "content": "some response"}, + ] standard_session.set_scope(session_tag="789", user_tag="ghi") no_context = standard_session.get_recent() assert no_context == [] - # change scope back to read previously stored entries - standard_session.set_scope(session_tag="456", user_tag="def") - previous_context = standard_session.get_recent() - assert previous_context == [ - {"role": "user", "content": "new user prompt"}, - {"role": "llm", "content": "new user response"}, - ] - def test_standard_get_recent_with_scope(standard_session, session_tag): # test that passing user or session id to get_recent(...) changes scope @@ -319,23 +310,14 @@ def test_standard_get_text(standard_session): def test_standard_get_raw(standard_session): - current_time = int(time.time()) standard_session.store("first prompt", "first response") standard_session.store("second prompt", "second response") raw = standard_session.get_recent(raw=True) assert len(raw) == 4 - assert raw[0].keys() == { - "id_field", - "role", - "content", - "timestamp", - } assert raw[0]["role"] == "user" assert raw[0]["content"] == "first prompt" - assert current_time <= raw[0]["timestamp"] <= time.time() assert raw[1]["role"] == "llm" assert raw[1]["content"] == "first response" - assert raw[1]["timestamp"] > raw[0]["timestamp"] def test_standard_drop(standard_session): @@ -354,7 +336,7 @@ def test_standard_drop(standard_session): ] # test drop(id) removes the specified element - context = standard_session.get_recent(top_k=-1, raw=True) + context = standard_session.get_recent(top_k=10, raw=True) middle_id = context[3]["id_field"] standard_session.drop(middle_id) context = standard_session.get_recent(top_k=6) @@ -371,14 +353,7 @@ def test_standard_drop(standard_session): def test_standard_clear(standard_session): standard_session.store("some prompt", "some response") standard_session.clear() - empty_context = standard_session.get_recent(top_k=-1) - assert empty_context == [] - - -def test_standard_delete(standard_session): - standard_session.store("some prompt", "some response") - standard_session.delete() - empty_context = standard_session.get_recent(top_k=-1) + empty_context = standard_session.get_recent(top_k=10) assert empty_context == [] @@ -606,14 +581,14 @@ def test_semantic_add_and_get_relevant(semantic_session): def test_semantic_get_raw(semantic_session): - current_time = int(time.time()) semantic_session.store("first prompt", "first response") semantic_session.store("second prompt", "second response") raw = semantic_session.get_recent(raw=True) assert len(raw) == 4 + assert raw[0]["role"] == "user" assert raw[0]["content"] == "first prompt" + assert raw[1]["role"] == "llm" assert raw[1]["content"] == "first response" - assert current_time <= float(raw[0]["timestamp"]) <= time.time() def test_semantic_drop(semantic_session):