diff --git a/conftest.py b/conftest.py index c6f16d5a..bfd87cd6 100644 --- a/conftest.py +++ b/conftest.py @@ -143,10 +143,3 @@ def clear_db(redis): def app_name(): return "test_app" -@pytest.fixture -def session_tag(): - return "123" - -@pytest.fixture -def user_tag(): - return "abc" diff --git a/docs/user_guide/session_manager_07.ipynb b/docs/user_guide/session_manager_07.ipynb index d47624e1..95a62dd8 100644 --- a/docs/user_guide/session_manager_07.ipynb +++ b/docs/user_guide/session_manager_07.ipynb @@ -74,7 +74,7 @@ ], "source": [ "from redisvl.extensions.session_manager import SemanticSessionManager\n", - "user_session = SemanticSessionManager(name='llm_chef', session_tag='123', user_tag='abc')\n", + "user_session = SemanticSessionManager(name='llm_chef')\n", "user_session.add_message({\"role\":\"system\", \"content\":\"You are a helpful chef, assisting people in making delicious meals\"})\n", "\n", "client = CohereClient()" diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index e1c63875..8ec167b7 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -13,7 +13,7 @@ class SemanticCache(BaseLLMCache): """Semantic Cache for Large Language Models.""" - entry_id_field_name: str = "id" + entry_id_field_name: str = "_id" prompt_field_name: str = "prompt" vector_field_name: str = "prompt_vector" response_field_name: str = "response" @@ -222,7 +222,8 @@ def _search_cache( cache_hits: List[Dict[str, Any]] = self._index.query(query) # Process cache hits for hit in cache_hits: - self._refresh_ttl(hit[self.entry_id_field_name]) + key = hit["id"] + self._refresh_ttl(key) # Check for metadata and deserialize if self.metadata_field_name in hit: hit[self.metadata_field_name] = self.deserialize( diff --git a/redisvl/extensions/session_manager/base_session.py b/redisvl/extensions/session_manager/base_session.py index 4d09c2cd..97c4a9f1 100644 --- a/redisvl/extensions/session_manager/base_session.py +++ b/redisvl/extensions/session_manager/base_session.py @@ -1,20 +1,23 @@ from typing import Any, Dict, List, Optional, Union +from uuid import uuid4 from redis import Redis +from redisvl.query.filter import FilterExpression + class BaseSessionManager: - id_field_name: str = "id_field" + id_field_name: str = "_id" role_field_name: str = "role" content_field_name: str = "content" tool_field_name: str = "tool_call_id" timestamp_field_name: str = "timestamp" + session_field_name: str = "session_tag" def __init__( self, name: str, - session_tag: str, - user_tag: str, + session_tag: Optional[str] = None, ): """Initialize session memory with index @@ -26,29 +29,10 @@ def __init__( Args: name (str): The name of the session manager index. session_tag (str): Tag to be added to entries to link to a specific - session. - user_tag (str): Tag to be added to entries to link to a specific user. + session. Defaults to instance uuid. """ self._name = name - self._user_tag = user_tag - self._session_tag = session_tag - - def set_scope( - self, - session_tag: Optional[str] = None, - user_tag: Optional[str] = None, - ) -> None: - """Set the filter to apply to querries based on the desired scope. - - This new scope persists until another call to set_scope is made, or if - scope specified in calls to get_recent. - - Args: - session_tag (str): Id of the specific session to filter to. Default is - None. - user_tag (str): Id of the specific user to filter to. Default is None. - """ - raise NotImplementedError + self._session_tag = session_tag or uuid4().hex def clear(self) -> None: """Clears the chat session history.""" @@ -75,23 +59,21 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: def get_recent( self, top_k: int = 5, - session_tag: Optional[str] = None, - user_tag: Optional[str] = None, as_text: bool = False, raw: bool = False, + session_tag: Optional[str] = None, ) -> Union[List[str], List[Dict[str, str]]]: """Retreive the recent conversation history in sequential order. Args: top_k (int): The number of previous exchanges to return. Default is 5. Note that one exchange contains both a prompt and response. - session_tag (str): Tag to be added to entries to link to a specific - session. - user_tag (str): Tag to be added to entries to link to a specific user. as_text (bool): Whether to return the conversation as a single string, or list of alternating prompts and responses. raw (bool): Whether to return the full Redis hash entry or just the prompt and response + session_tag (str): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. Returns: Union[str, List[str]]: A single string transcription of the session @@ -113,6 +95,7 @@ def _format_context( recent conversation history. as_text (bool): Whether to return the conversation as a single string, or list of alternating prompts and responses. + Returns: Union[str, List[str]]: A single string transcription of the session or list of strings if as_text is false. @@ -141,7 +124,9 @@ def _format_context( ) return statements - def store(self, prompt: str, response: str) -> None: + def store( + self, prompt: str, response: str, session_tag: Optional[str] = None + ) -> None: """Insert a prompt:response pair into the session memory. A timestamp is associated with each exchange so that they can be later sorted in sequential ordering after retrieval. @@ -149,25 +134,32 @@ def store(self, prompt: str, response: str) -> None: Args: prompt (str): The user prompt to the LLM. response (str): The corresponding LLM response. + session_tag (Optional[str]): The tag to mark the message with. Defaults to None. """ raise NotImplementedError - def add_messages(self, messages: List[Dict[str, str]]) -> None: + def add_messages( + self, messages: List[Dict[str, str]], session_tag: Optional[str] = None + ) -> None: """Insert a list of prompts and responses into the session memory. A timestamp is associated with each so that they can be later sorted in sequential ordering after retrieval. Args: messages (List[Dict[str, str]]): The list of user prompts and LLM responses. + session_tag (Optional[str]): The tag to mark the messages with. Defaults to None. """ raise NotImplementedError - def add_message(self, message: Dict[str, str]) -> None: + def add_message( + self, message: Dict[str, str], session_tag: Optional[str] = None + ) -> None: """Insert a single prompt or response into the session memory. A timestamp is associated with it so that it can be later sorted in sequential ordering after retrieval. Args: message (Dict[str,str]): The user prompt or LLM response. + session_tag (Optional[str]): The tag to mark the message with. Defaults to None. """ raise NotImplementedError diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index fdea13cd..ccfcf2e7 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -6,7 +6,7 @@ from redisvl.extensions.session_manager import BaseSessionManager from redisvl.index import SearchIndex from redisvl.query import FilterQuery, RangeQuery -from redisvl.query.filter import Tag +from redisvl.query.filter import FilterExpression, Tag from redisvl.redis.utils import array_to_buffer from redisvl.schema.schema import IndexSchema from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer @@ -25,7 +25,6 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int): {"name": "tool_call_id", "type": "text"}, {"name": "timestamp", "type": "numeric"}, {"name": "session_tag", "type": "tag"}, - {"name": "user_tag", "type": "tag"}, { "name": "vector_field", "type": "vector", @@ -41,15 +40,12 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int): class SemanticSessionManager(BaseSessionManager): - session_field_name: str = "session_tag" - user_field_name: str = "user_tag" vector_field_name: str = "vector_field" def __init__( self, name: str, - session_tag: str, - user_tag: str, + session_tag: Optional[str] = None, prefix: Optional[str] = None, vectorizer: Optional[BaseVectorizer] = None, distance_threshold: float = 0.3, @@ -68,9 +64,8 @@ def __init__( Args: name (str): The name of the session manager index. - session_tag (str): Tag to be added to entries to link to a specific - session. - user_tag (str): Tag to be added to entries to link to a specific user. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. prefix (Optional[str]): Prefix for the keys for this session data. Defaults to None and will be replaced with the index name. vectorizer (Optional[BaseVectorizer]): The vectorizer used to create embeddings. @@ -86,7 +81,7 @@ def __init__( from either the prompt or response in a single string. """ - super().__init__(name, session_tag, user_tag) + super().__init__(name, session_tag) prefix = prefix or name @@ -110,37 +105,7 @@ def __init__( self._index.create(overwrite=False) - self.set_scope(session_tag, user_tag) - - def set_scope( - self, - session_tag: Optional[str] = None, - user_tag: Optional[str] = None, - ) -> None: - """Set the tag filter to apply to querries based on the desired scope. - - This new scope persists until another call to set_scope is made, or if - scope specified in calls to get_recent or get_relevant. - - Args: - session_tag (str): Id of the specific session to filter to. Default is - None, which means all sessions will be in scope. - user_tag (str): Id of the specific user to filter to. Default is None, - which means all users will be in scope. - """ - if not (session_tag or user_tag): - return - self._session_tag = session_tag or self._session_tag - self._user_tag = user_tag or self._user_tag - tag_filter = Tag(self.user_field_name) == [] - if user_tag: - tag_filter = tag_filter & (Tag(self.user_field_name) == self._user_tag) - if session_tag: - tag_filter = tag_filter & ( - Tag(self.session_field_name) == self._session_tag - ) - - self._tag_filter = tag_filter + self._default_session_filter = Tag(self.session_field_name) == self._session_tag def clear(self) -> None: """Clears the chat session history.""" @@ -150,28 +115,26 @@ def delete(self) -> None: """Clear all conversation keys and remove the search index.""" self._index.delete(drop=True) - def drop(self, id_field: Optional[str] = None) -> None: + def drop(self, id: Optional[str] = None) -> None: """Remove a specific exchange from the conversation history. Args: - id_field (Optional[str]): The id_field of the entry to delete. + id (Optional[str]): The id of the session entry to delete. If None then the last entry is deleted. """ - if id_field: - sep = self._index.key_separator - key = sep.join([self._index.schema.index.name, id_field]) - else: - key = self.get_recent(top_k=1, raw=True)[0]["id"] # type: ignore - self._index.client.delete(key) # type: ignore + if id is None: + id = self.get_recent(top_k=1, raw=True)[0][self.id_field_name] # type: ignore + + self._index.client.delete(self._index.key(id)) # type: ignore @property def messages(self) -> Union[List[str], List[Dict[str, str]]]: """Returns the full chat history.""" # TODO raw or as_text? + # TODO refactor method to use get_recent and support other session tags return_fields = [ self.id_field_name, self.session_field_name, - self.user_field_name, self.role_field_name, self.content_field_name, self.tool_field_name, @@ -179,7 +142,7 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: ] query = FilterQuery( - filter_expression=self._tag_filter, + filter_expression=self._default_session_filter, return_fields=return_fields, ) @@ -196,7 +159,6 @@ def get_relevant( top_k: int = 5, fall_back: bool = False, session_tag: Optional[str] = None, - user_tag: Optional[str] = None, raw: bool = False, ) -> Union[List[str], List[Dict[str, str]]]: """Searches the chat history for information semantically related to @@ -214,8 +176,8 @@ def get_relevant( top_k (int): The number of previous messages to return. Default is 5. fallback (bool): Whether to drop back to recent conversation history if no relevant context is found. - session_tag (str): Tag of entries linked to a specific session. - user_tag (str): Tag of entries linked to a specific user. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. raw (bool): Whether to return the full Redis hash entry or just the message. @@ -229,10 +191,9 @@ def get_relevant( raise ValueError("top_k must be an integer greater than or equal to -1") if top_k == 0: return [] - self.set_scope(session_tag, user_tag) + return_fields = [ self.session_field_name, - self.user_field_name, self.role_field_name, self.content_field_name, self.timestamp_field_name, @@ -240,6 +201,12 @@ def get_relevant( self.vector_field_name, ] + session_filter = ( + Tag(self.session_field_name) == session_tag + if session_tag + else self._default_session_filter + ) + query = RangeQuery( vector=self._vectorizer.embed(prompt), vector_field_name=self.vector_field_name, @@ -247,7 +214,7 @@ def get_relevant( distance_threshold=self._distance_threshold, num_results=top_k, return_score=True, - filter_expression=self._tag_filter, + filter_expression=session_filter, ) hits = self._index.query(query) @@ -261,22 +228,20 @@ def get_relevant( def get_recent( self, top_k: int = 5, - session_tag: Optional[str] = None, - user_tag: Optional[str] = None, as_text: bool = False, raw: bool = False, + session_tag: Optional[str] = None, ) -> Union[List[str], List[Dict[str, str]]]: """Retreive the recent conversation history in sequential order. Args: top_k (int): The number of previous exchanges to return. Default is 5. - session_tag (str): Tag to be added to entries to link to a specific - session. - user_tag (str): Tag to be added to entries to link to a specific user. as_text (bool): Whether to return the conversation as a single string, or list of alternating prompts and responses. raw (bool): Whether to return the full Redis hash entry or just the prompt and response + session_tag (Optional[str]): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. Returns: Union[str, List[str]]: A single string transcription of the session @@ -288,19 +253,23 @@ def get_recent( if type(top_k) != int or top_k < 0: raise ValueError("top_k must be an integer greater than or equal to 0") - self.set_scope(session_tag, user_tag) return_fields = [ self.id_field_name, self.session_field_name, - self.user_field_name, self.role_field_name, self.content_field_name, self.tool_field_name, self.timestamp_field_name, ] + session_filter = ( + Tag(self.session_field_name) == session_tag + if session_tag + else self._default_session_filter + ) + query = FilterQuery( - filter_expression=self._tag_filter, + filter_expression=session_filter, return_fields=return_fields, num_results=top_k, ) @@ -320,7 +289,9 @@ def distance_threshold(self): def set_distance_threshold(self, threshold): self._distance_threshold = threshold - def store(self, prompt: str, response: str) -> None: + def store( + self, prompt: str, response: str, session_tag: Optional[str] = None + ) -> None: """Insert a prompt:response pair into the session memory. A timestamp is associated with each message so that they can be later sorted in sequential ordering after retrieval. @@ -328,48 +299,60 @@ def store(self, prompt: str, response: str) -> None: Args: prompt (str): The user prompt to the LLM. response (str): The corresponding LLM response. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. """ self.add_messages( [ {self.role_field_name: "user", self.content_field_name: prompt}, {self.role_field_name: "llm", self.content_field_name: response}, - ] + ], + session_tag, ) - def add_messages(self, messages: List[Dict[str, str]]) -> None: + def add_messages( + self, messages: List[Dict[str, str]], session_tag: Optional[str] = None + ) -> None: """Insert a list of prompts and responses into the session memory. A timestamp is associated with each so that they can be later sorted in sequential ordering after retrieval. Args: messages (List[Dict[str, str]]): The list of user prompts and LLM responses. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. """ sep = self._index.key_separator + session_tag = session_tag or self._session_tag payloads = [] for message in messages: vector = self._vectorizer.embed(message[self.content_field_name]) timestamp = time() - id_field = sep.join([self._user_tag, self._session_tag, str(timestamp)]) + id_field = sep.join([self._session_tag, str(timestamp)]) payload = { self.id_field_name: id_field, self.role_field_name: message[self.role_field_name], self.content_field_name: message[self.content_field_name], self.timestamp_field_name: timestamp, - self.session_field_name: self._session_tag, - self.user_field_name: self._user_tag, self.vector_field_name: array_to_buffer(vector), + self.session_field_name: session_tag, } + if self.tool_field_name in message: payload.update({self.tool_field_name: message[self.tool_field_name]}) payloads.append(payload) self._index.load(data=payloads, id_field=self.id_field_name) - def add_message(self, message: Dict[str, str]) -> None: + def add_message( + self, message: Dict[str, str], session_tag: Optional[str] = None + ) -> None: """Insert a single prompt or response into the session memory. A timestamp is associated with it so that it can be later sorted in sequential ordering after retrieval. Args: message (Dict[str,str]): The user prompt or LLM response. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. """ - self.add_messages([message]) + self.add_messages([message], session_tag) diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 640f13f9..0a1a7b25 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -1,4 +1,3 @@ -import json from time import time from typing import Any, Dict, List, Optional, Union @@ -24,20 +23,17 @@ def from_params(cls, name: str, prefix: str): {"name": "tool_call_id", "type": "text"}, {"name": "timestamp", "type": "numeric"}, {"name": "session_tag", "type": "tag"}, - {"name": "user_tag", "type": "tag"}, ], ) class StandardSessionManager(BaseSessionManager): session_field_name: str = "session_tag" - user_field_name: str = "user_tag" def __init__( self, name: str, - session_tag: str, - user_tag: str, + session_tag: Optional[str] = None, prefix: Optional[str] = None, redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", @@ -54,8 +50,7 @@ def __init__( Args: name (str): The name of the session manager index. session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. - user_tag (Optional[str]): Tag to be added to entries to link to a specific user. + session. Defaults to instance uuid. prefix (Optional[str]): Prefix for the keys for this session data. Defaults to None and will be replaced with the index name. redis_client (Optional[Redis]): A Redis client instance. Defaults to @@ -68,52 +63,22 @@ def __init__( constructed from the prompt & response in a single string. """ - super().__init__(name, session_tag, user_tag) + super().__init__(name, session_tag) prefix = prefix or name schema = StandardSessionIndexSchema.from_params(name, prefix) + self._index = SearchIndex(schema=schema) - # handle redis connection if redis_client: self._index.set_client(redis_client) - elif redis_url: - self._index.connect(redis_url=redis_url, **connection_kwargs) + else: + self._index.connect(redis_url=redis_url) self._index.create(overwrite=False) - self.set_scope(session_tag, user_tag) - - def set_scope( - self, - session_tag: Optional[str] = None, - user_tag: Optional[str] = None, - ) -> None: - """Set the filter to apply to queries based on the desired scope. - - This new scope persists until another call to set_scope is made, or if - scope specified in calls to get_recent or get_relevant. - - Args: - session_tag (str): Id of the specific session to filter to. Default is - None, which means all sessions will be in scope. - user_tag (str): Id of the specific user to filter to. Default is None, - which means all users will be in scope. - """ - if not (session_tag or user_tag): - return - self._session_tag = session_tag or self._session_tag - self._user_tag = user_tag or self._user_tag - tag_filter = Tag(self.user_field_name) == [] - if user_tag: - tag_filter = tag_filter & (Tag(self.user_field_name) == self._user_tag) - if session_tag: - tag_filter = tag_filter & ( - Tag(self.session_field_name) == self._session_tag - ) - - self._tag_filter = tag_filter + self._default_session_filter = Tag(self.session_field_name) == self._session_tag def clear(self) -> None: """Clears the chat session history.""" @@ -123,28 +88,26 @@ def delete(self) -> None: """Clear all conversation keys and remove the search index.""" self._index.delete(drop=True) - def drop(self, id_field: Optional[str] = None) -> None: + def drop(self, id: Optional[str] = None) -> None: """Remove a specific exchange from the conversation history. Args: - id_field (Optional[str]): The id_field of the entry to delete. + id (Optional[str]): The id of the session entry to delete. If None then the last entry is deleted. """ - if id_field: - sep = self._index.key_separator - key = sep.join([self._index.schema.index.name, id_field]) - else: - key = self.get_recent(top_k=1, raw=True)[0]["id"] # type: ignore - self._index.client.delete(key) # type: ignore + if id is None: + id = self.get_recent(top_k=1, raw=True)[0][self.id_field_name] # type: ignore + + self._index.client.delete(self._index.key(id)) # type: ignore @property def messages(self) -> Union[List[str], List[Dict[str, str]]]: """Returns the full chat history.""" # TODO raw or as_text? + # TODO refactor this method to use get_recent and support other session tags? return_fields = [ self.id_field_name, self.session_field_name, - self.user_field_name, self.role_field_name, self.content_field_name, self.tool_field_name, @@ -152,7 +115,7 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: ] query = FilterQuery( - filter_expression=self._tag_filter, + filter_expression=self._default_session_filter, return_fields=return_fields, ) @@ -165,22 +128,20 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: def get_recent( self, top_k: int = 5, - session_tag: Optional[str] = None, - user_tag: Optional[str] = None, as_text: bool = False, raw: bool = False, + session_tag: Optional[str] = None, ) -> Union[List[str], List[Dict[str, str]]]: """Retreive the recent conversation history in sequential order. Args: top_k (int): The number of previous messages to return. Default is 5. - session_tag (str): Tag to be added to entries to link to a specific - session. - user_tag (str): Tag to be added to entries to link to a specific user. as_text (bool): Whether to return the conversation as a single string, or list of alternating prompts and responses. raw (bool): Whether to return the full Redis hash entry or just the prompt and response + session_tag (Optional[str]): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. Returns: Union[str, List[str]]: A single string transcription of the session @@ -192,19 +153,23 @@ def get_recent( if type(top_k) != int or top_k < 0: raise ValueError("top_k must be an integer greater than or equal to 0") - self.set_scope(session_tag, user_tag) return_fields = [ self.id_field_name, self.session_field_name, - self.user_field_name, self.role_field_name, self.content_field_name, self.tool_field_name, self.timestamp_field_name, ] + session_filter = ( + Tag(self.session_field_name) == session_tag + if session_tag + else self._default_session_filter + ) + query = FilterQuery( - filter_expression=self._tag_filter, + filter_expression=session_filter, return_fields=return_fields, num_results=top_k, ) @@ -217,7 +182,9 @@ def get_recent( return hits[::-1] return self._format_context(hits[::-1], as_text) - def store(self, prompt: str, response: str) -> None: + def store( + self, prompt: str, response: str, session_tag: Optional[str] = None + ) -> None: """Insert a prompt:response pair into the session memory. A timestamp is associated with each exchange so that they can be later sorted in sequential ordering after retrieval. @@ -225,46 +192,59 @@ def store(self, prompt: str, response: str) -> None: Args: prompt (str): The user prompt to the LLM. response (str): The corresponding LLM response. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. """ self.add_messages( [ {self.role_field_name: "user", self.content_field_name: prompt}, {self.role_field_name: "llm", self.content_field_name: response}, - ] + ], + session_tag, ) - def add_messages(self, messages: List[Dict[str, str]]) -> None: + def add_messages( + self, messages: List[Dict[str, str]], session_tag: Optional[str] = None + ) -> None: """Insert a list of prompts and responses into the session memory. A timestamp is associated with each so that they can be later sorted in sequential ordering after retrieval. Args: messages (List[Dict[str, str]]): The list of user prompts and LLM responses. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. """ sep = self._index.key_separator + session_tag = session_tag or self._session_tag payloads = [] for message in messages: timestamp = time() - id_field = sep.join([self._user_tag, self._session_tag, str(timestamp)]) + id_field = sep.join([self._session_tag, str(timestamp)]) payload = { self.id_field_name: id_field, self.role_field_name: message[self.role_field_name], self.content_field_name: message[self.content_field_name], self.timestamp_field_name: timestamp, - self.session_field_name: self._session_tag, - self.user_field_name: self._user_tag, + self.session_field_name: session_tag, } + if self.tool_field_name in message: payload.update({self.tool_field_name: message[self.tool_field_name]}) + payloads.append(payload) self._index.load(data=payloads, id_field=self.id_field_name) - def add_message(self, message: Dict[str, str]) -> None: + def add_message( + self, message: Dict[str, str], session_tag: Optional[str] = None + ) -> None: """Insert a single prompt or response into the session memory. A timestamp is associated with it so that it can be later sorted in sequential ordering after retrieval. Args: message (Dict[str,str]): The user prompt or LLM response. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. """ - self.add_messages([message]) + self.add_messages([message], session_tag) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index cc152291..ef4ad7fe 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -114,6 +114,20 @@ def test_ttl_expiration(cache_with_ttl, vectorizer): assert len(check_result) == 0 +def test_ttl_refresh(cache_with_ttl, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + cache_with_ttl.store(prompt, response, vector=vector) + + for _ in range(3): + sleep(1) + check_result = cache_with_ttl.check(vector=vector) + + assert len(check_result) == 1 + + def test_ttl_expiration_after_update(cache_with_ttl, vectorizer): prompt = "This is a test prompt." response = "This is a test response." diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index b3df7d17..ce6b9943 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -1,6 +1,3 @@ -import json -import time - import pytest from redis.exceptions import ConnectionError @@ -11,38 +8,32 @@ @pytest.fixture -def standard_session(app_name, user_tag, session_tag, client): - session = StandardSessionManager( - app_name, session_tag=session_tag, user_tag=user_tag, redis_client=client - ) +def standard_session(app_name, client): + session = StandardSessionManager(app_name, redis_client=client) yield session session.clear() - session.delete() @pytest.fixture -def semantic_session(app_name, user_tag, session_tag, client): - session = SemanticSessionManager( - app_name, session_tag=session_tag, user_tag=user_tag, redis_client=client - ) +def semantic_session(app_name, client): + session = SemanticSessionManager(app_name, redis_client=client) yield session session.clear() session.delete() +# test standard session manager def test_specify_redis_client(client): - session = StandardSessionManager( - name="test_app", session_tag="abc", user_tag="123", redis_client=client - ) + session = StandardSessionManager(name="test_app", redis_client=client) assert isinstance(session._index.client, type(client)) -def test_specify_redis_url(client): +def test_specify_redis_url(client, redis_url): session = StandardSessionManager( name="test_app", session_tag="abc", user_tag="123", - redis_url="redis://localhost:6379", + redis_url=redis_url, ) assert isinstance(session._index.client, type(client)) @@ -52,12 +43,11 @@ def test_standard_bad_connection_info(): StandardSessionManager( name="test_app", session_tag="abc", - user_tag="123", redis_url="redis://localhost:6389", # bad url ) -def test_standard_store_and_get(standard_session): +def test_standard_store(standard_session): context = standard_session.get_recent() assert len(context) == 0 @@ -67,26 +57,6 @@ def test_standard_store_and_get(standard_session): standard_session.store(prompt="fourth prompt", response="fourth response") standard_session.store(prompt="fifth prompt", response="fifth response") - # test default context history size - default_context = standard_session.get_recent() - assert len(default_context) == 5 # default is 5 - - # test specified context history size - partial_context = standard_session.get_recent(top_k=2) - assert len(partial_context) == 2 - assert partial_context == [ - {"role": "user", "content": "fifth prompt"}, - {"role": "llm", "content": "fifth response"}, - ] - - # test larger context history returns full history - too_large_context = standard_session.get_recent(top_k=100) - assert len(too_large_context) == 10 - - # test no context is returned when top_k is 0 - no_context = standard_session.get_recent(top_k=0) - assert len(no_context) == 0 - # test that order is maintained full_context = standard_session.get_recent(top_k=10) assert full_context == [ @@ -102,19 +72,6 @@ def test_standard_store_and_get(standard_session): {"role": "llm", "content": "fifth response"}, ] - # test that a ValueError is raised when top_k is invalid - with pytest.raises(ValueError): - bad_context = standard_session.get_recent(top_k=-2) - - with pytest.raises(ValueError): - bad_context = standard_session.get_recent(top_k=-2.0) - - with pytest.raises(ValueError): - bad_context = standard_session.get_recent(top_k=1.3) - - with pytest.raises(ValueError): - bad_context = standard_session.get_recent(top_k="3") - def test_standard_add_and_get(standard_session): context = standard_session.get_recent() @@ -167,6 +124,19 @@ def test_standard_add_and_get(standard_session): {"role": "llm", "content": "third response"}, ] + # test that a ValueError is raised when top_k is invalid + with pytest.raises(ValueError): + bad_context = standard_session.get_recent(top_k=-2) + + with pytest.raises(ValueError): + bad_context = standard_session.get_recent(top_k=-2.0) + + with pytest.raises(ValueError): + bad_context = standard_session.get_recent(top_k=1.3) + + with pytest.raises(ValueError): + bad_context = standard_session.get_recent(top_k="3") + def test_standard_add_messages(standard_session): context = standard_session.get_recent() @@ -185,7 +155,7 @@ def test_standard_add_messages(standard_session): }, { "role": "tool", - "content": "tool resuilt 2", + "content": "tool result 2", "tool_call_id": "tool call two", }, {"role": "user", "content": "fourth prompt"}, @@ -193,27 +163,15 @@ def test_standard_add_messages(standard_session): ] ) - # test default context history size - default_context = standard_session.get_recent() - assert len(default_context) == 5 # default is 5 - - # test specified context history size - partial_context = standard_session.get_recent(top_k=2) - assert len(partial_context) == 2 - assert partial_context == [ - {"role": "user", "content": "fourth prompt"}, - {"role": "llm", "content": "fourth response"}, - ] - - # test that order is maintained full_context = standard_session.get_recent(top_k=10) + assert len(full_context) == 8 assert full_context == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, {"role": "user", "content": "second prompt"}, {"role": "llm", "content": "second response"}, {"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"}, - {"role": "tool", "content": "tool resuilt 2", "tool_call_id": "tool call two"}, + {"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"}, {"role": "user", "content": "fourth prompt"}, {"role": "llm", "content": "fourth response"}, ] @@ -239,66 +197,33 @@ def test_standard_messages_property(standard_session): ] -def test_standard_set_scope(standard_session, app_name, user_tag, session_tag): - # test calling set_scope with no params does not change scope +def test_standard_scope(standard_session): + # store entries under default session tag standard_session.store("some prompt", "some response") - standard_session.set_scope() - context = standard_session.get_recent() - assert context == [ - {"role": "user", "content": "some prompt"}, - {"role": "llm", "content": "some response"}, - ] - # test that changing user and session id does indeed change access scope - new_user = "def" - standard_session.set_scope(user_tag=new_user) - standard_session.store("new user prompt", "new user response") - context = standard_session.get_recent() + # test that changing session tag does indeed change access scope + new_session = "def" + standard_session.store( + "new user prompt", "new user response", session_tag=new_session + ) + context = standard_session.get_recent(session_tag=new_session) assert context == [ {"role": "user", "content": "new user prompt"}, {"role": "llm", "content": "new user response"}, ] - # test that previous user and session data is still accessible - previous_user = "abc" - standard_session.set_scope(user_tag=previous_user) + # test that default session data is still accessible context = standard_session.get_recent() assert context == [ {"role": "user", "content": "some prompt"}, {"role": "llm", "content": "some response"}, ] - standard_session.set_scope(session_tag="789", user_tag="ghi") - no_context = standard_session.get_recent() + bad_session = "xyz" + no_context = standard_session.get_recent(session_tag=bad_session) assert no_context == [] -def test_standard_get_recent_with_scope(standard_session, session_tag): - # test that passing user or session id to get_recent(...) changes scope - standard_session.store("first prompt", "first response") - - context = standard_session.get_recent() - assert context == [ - {"role": "user", "content": "first prompt"}, - {"role": "llm", "content": "first response"}, - ] - - context = standard_session.get_recent(session_tag="456") - assert context == [] - - # test that scope change persists after being updated via get_recent(...) - standard_session.store("new session prompt", "new session response") - context = standard_session.get_recent() - assert context == [ - {"role": "user", "content": "new session prompt"}, - {"role": "llm", "content": "new session response"}, - ] - - # clean up lingering sessions - standard_session.clear() - standard_session.set_scope(session_tag=session_tag) - - def test_standard_get_text(standard_session): standard_session.store("first prompt", "first response") text = standard_session.get_recent(as_text=True) @@ -337,7 +262,7 @@ def test_standard_drop(standard_session): # test drop(id) removes the specified element context = standard_session.get_recent(top_k=10, raw=True) - middle_id = context[3]["id_field"] + middle_id = context[3][standard_session.id_field_name] standard_session.drop(middle_id) context = standard_session.get_recent(top_k=6) assert context == [ @@ -360,7 +285,7 @@ def test_standard_clear(standard_session): # test semantic session manager def test_semantic_specify_client(client): session = SemanticSessionManager( - name="test_app", session_tag="abc", user_tag="123", redis_client=client + name="test_app", session_tag="abc", redis_client=client ) assert isinstance(session._index.client, type(client)) @@ -370,42 +295,34 @@ def test_semantic_bad_connection_info(): SemanticSessionManager( name="test_app", session_tag="abc", - user_tag="123", redis_url="redis://localhost:6389", ) -def test_semantic_set_scope(semantic_session, app_name, user_tag, session_tag): - # test calling set_scope with no params does not change scope +def test_semantic_scope(semantic_session): + # store entries under default session tag semantic_session.store("some prompt", "some response") - semantic_session.set_scope() - context = semantic_session.get_recent() - assert context == [ - {"role": "user", "content": "some prompt"}, - {"role": "llm", "content": "some response"}, - ] - # test that changing user and session id does indeed change access scope - new_user = "def" - semantic_session.set_scope(user_tag=new_user) - semantic_session.store("new user prompt", "new user response") - context = semantic_session.get_recent() + # test that changing session tag does indeed change access scope + new_session = "def" + semantic_session.store( + "new user prompt", "new user response", session_tag=new_session + ) + context = semantic_session.get_recent(session_tag=new_session) assert context == [ {"role": "user", "content": "new user prompt"}, {"role": "llm", "content": "new user response"}, ] - # test that previous user and session data is still accessible - previous_user = "abc" - semantic_session.set_scope(user_tag=previous_user) + # test that previous session data is still accessible context = semantic_session.get_recent() assert context == [ {"role": "user", "content": "some prompt"}, {"role": "llm", "content": "some response"}, ] - semantic_session.set_scope(session_tag="789", user_tag="ghi") - no_context = semantic_session.get_recent() + bad_session = "xyz" + no_context = semantic_session.get_recent(session_tag=bad_session) assert no_context == [] @@ -608,7 +525,7 @@ def test_semantic_drop(semantic_session): # test drop(id) removes the specified element context = semantic_session.get_recent(top_k=5, raw=True) - middle_id = context[2]["id_field"] + middle_id = context[2][semantic_session.id_field_name] semantic_session.drop(middle_id) context = semantic_session.get_recent(top_k=4) assert context == [