diff --git a/redisvl/extensions/constants.py b/redisvl/extensions/constants.py index a7c78a18..cb58c98c 100644 --- a/redisvl/extensions/constants.py +++ b/redisvl/extensions/constants.py @@ -23,7 +23,7 @@ CACHE_VECTOR_FIELD_NAME: str = "prompt_vector" INSERTED_AT_FIELD_NAME: str = "inserted_at" UPDATED_AT_FIELD_NAME: str = "updated_at" -METADATA_FIELD_NAME: str = "metadata" +METADATA_FIELD_NAME: str = "metadata" # also used in MessageHistory # EmbeddingsCache TEXT_FIELD_NAME: str = "text" diff --git a/redisvl/extensions/message_history/base_history.py b/redisvl/extensions/message_history/base_history.py index 0b40be44..4c842c2a 100644 --- a/redisvl/extensions/message_history/base_history.py +++ b/redisvl/extensions/message_history/base_history.py @@ -2,11 +2,12 @@ from redisvl.extensions.constants import ( CONTENT_FIELD_NAME, + METADATA_FIELD_NAME, ROLE_FIELD_NAME, TOOL_FIELD_NAME, ) from redisvl.extensions.message_history.schema import ChatMessage -from redisvl.utils.utils import create_ulid +from redisvl.utils.utils import create_ulid, deserialize class BaseMessageHistory: @@ -111,6 +112,10 @@ def _format_context( } if chat_message.tool_call_id is not None: chat_message_dict[TOOL_FIELD_NAME] = chat_message.tool_call_id + if chat_message.metadata is not None: + chat_message_dict[METADATA_FIELD_NAME] = deserialize( + chat_message.metadata + ) context.append(chat_message_dict) # type: ignore diff --git a/redisvl/extensions/message_history/message_history.py b/redisvl/extensions/message_history/message_history.py index 4520d7a4..fef7ab8f 100644 --- a/redisvl/extensions/message_history/message_history.py +++ b/redisvl/extensions/message_history/message_history.py @@ -5,6 +5,7 @@ from redisvl.extensions.constants import ( CONTENT_FIELD_NAME, ID_FIELD_NAME, + METADATA_FIELD_NAME, ROLE_FIELD_NAME, SESSION_FIELD_NAME, TIMESTAMP_FIELD_NAME, @@ -15,6 +16,7 @@ from redisvl.index import SearchIndex from redisvl.query import FilterQuery from redisvl.query.filter import Tag +from redisvl.utils.utils import serialize class MessageHistory(BaseMessageHistory): @@ -98,11 +100,13 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: CONTENT_FIELD_NAME, TOOL_FIELD_NAME, TIMESTAMP_FIELD_NAME, + METADATA_FIELD_NAME, ] query = FilterQuery( filter_expression=self._default_session_filter, return_fields=return_fields, + num_results=1000, ) query.sort_by(TIMESTAMP_FIELD_NAME, asc=True) messages = self._index.query(query) @@ -144,6 +148,7 @@ def get_recent( CONTENT_FIELD_NAME, TOOL_FIELD_NAME, TIMESTAMP_FIELD_NAME, + METADATA_FIELD_NAME, ] session_filter = ( @@ -210,7 +215,8 @@ def add_messages( if TOOL_FIELD_NAME in message: chat_message.tool_call_id = message[TOOL_FIELD_NAME] - + if METADATA_FIELD_NAME in message: + chat_message.metadata = serialize(message[METADATA_FIELD_NAME]) chat_messages.append(chat_message.to_dict()) self._index.load(data=chat_messages, id_field=ID_FIELD_NAME) diff --git a/redisvl/extensions/message_history/schema.py b/redisvl/extensions/message_history/schema.py index 839b84ff..20b1fec2 100644 --- a/redisvl/extensions/message_history/schema.py +++ b/redisvl/extensions/message_history/schema.py @@ -1,11 +1,12 @@ from typing import Dict, List, Optional -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from redisvl.extensions.constants import ( CONTENT_FIELD_NAME, ID_FIELD_NAME, MESSAGE_VECTOR_FIELD_NAME, + METADATA_FIELD_NAME, ROLE_FIELD_NAME, SESSION_FIELD_NAME, TIMESTAMP_FIELD_NAME, @@ -13,7 +14,7 @@ ) from redisvl.redis.utils import array_to_buffer from redisvl.schema import IndexSchema -from redisvl.utils.utils import current_timestamp +from redisvl.utils.utils import current_timestamp, deserialize class ChatMessage(BaseModel): @@ -33,6 +34,8 @@ class ChatMessage(BaseModel): """An optional identifier for a tool call associated with the message.""" vector_field: Optional[List[float]] = Field(default=None) """The vector representation of the message content.""" + metadata: Optional[str] = Field(default=None) + """Optional additional data to store alongside the message""" model_config = ConfigDict(arbitrary_types_allowed=True) @model_validator(mode="before") @@ -54,6 +57,7 @@ def to_dict(self, dtype: Optional[str] = None) -> Dict: data[MESSAGE_VECTOR_FIELD_NAME] = array_to_buffer( data[MESSAGE_VECTOR_FIELD_NAME], dtype # type: ignore[arg-type] ) + return data @@ -70,6 +74,7 @@ def from_params(cls, name: str, prefix: str): {"name": TOOL_FIELD_NAME, "type": "tag"}, {"name": TIMESTAMP_FIELD_NAME, "type": "numeric"}, {"name": SESSION_FIELD_NAME, "type": "tag"}, + {"name": METADATA_FIELD_NAME, "type": "text"}, ], ) @@ -87,6 +92,7 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str): {"name": TOOL_FIELD_NAME, "type": "tag"}, {"name": TIMESTAMP_FIELD_NAME, "type": "numeric"}, {"name": SESSION_FIELD_NAME, "type": "tag"}, + {"name": METADATA_FIELD_NAME, "type": "text"}, { "name": MESSAGE_VECTOR_FIELD_NAME, "type": "vector", diff --git a/redisvl/extensions/message_history/semantic_history.py b/redisvl/extensions/message_history/semantic_history.py index 0c6906d5..7861c27b 100644 --- a/redisvl/extensions/message_history/semantic_history.py +++ b/redisvl/extensions/message_history/semantic_history.py @@ -6,6 +6,7 @@ CONTENT_FIELD_NAME, ID_FIELD_NAME, MESSAGE_VECTOR_FIELD_NAME, + METADATA_FIELD_NAME, ROLE_FIELD_NAME, SESSION_FIELD_NAME, TIMESTAMP_FIELD_NAME, @@ -19,7 +20,7 @@ from redisvl.index import SearchIndex from redisvl.query import FilterQuery, RangeQuery from redisvl.query.filter import Tag -from redisvl.utils.utils import deprecated_argument, validate_vector_dims +from redisvl.utils.utils import deprecated_argument, serialize, validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer @@ -149,8 +150,9 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: SESSION_FIELD_NAME, ROLE_FIELD_NAME, CONTENT_FIELD_NAME, - TOOL_FIELD_NAME, TIMESTAMP_FIELD_NAME, + TOOL_FIELD_NAME, + METADATA_FIELD_NAME, ] query = FilterQuery( @@ -214,6 +216,7 @@ def get_relevant( CONTENT_FIELD_NAME, TIMESTAMP_FIELD_NAME, TOOL_FIELD_NAME, + METADATA_FIELD_NAME, ] session_filter = ( @@ -274,8 +277,9 @@ def get_recent( SESSION_FIELD_NAME, ROLE_FIELD_NAME, CONTENT_FIELD_NAME, - TOOL_FIELD_NAME, TIMESTAMP_FIELD_NAME, + TOOL_FIELD_NAME, + METADATA_FIELD_NAME, ] session_filter = ( @@ -355,6 +359,8 @@ def add_messages( if TOOL_FIELD_NAME in message: chat_message.tool_call_id = message[TOOL_FIELD_NAME] + if METADATA_FIELD_NAME in message: + chat_message.metadata = serialize(message[METADATA_FIELD_NAME]) chat_messages.append(chat_message.to_dict(dtype=self._vectorizer.dtype)) diff --git a/tests/integration/test_message_history.py b/tests/integration/test_message_history.py index 716c8edd..c3190356 100644 --- a/tests/integration/test_message_history.py +++ b/tests/integration/test_message_history.py @@ -101,6 +101,7 @@ def test_standard_add_and_get(standard_history): "role": "tool", "content": "tool result 1", "tool_call_id": "tool call one", + "metadata": {"tool call params": "abc 123"}, } ) standard_history.add_message( @@ -108,6 +109,7 @@ def test_standard_add_and_get(standard_history): "role": "tool", "content": "tool result 2", "tool_call_id": "tool call two", + "metadata": {"tool call params": "abc 456"}, } ) standard_history.add_message({"role": "user", "content": "third prompt"}) @@ -121,7 +123,12 @@ def test_standard_add_and_get(standard_history): partial_context = standard_history.get_recent(top_k=3) assert len(partial_context) == 3 assert partial_context == [ - {"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"}, + { + "role": "tool", + "content": "tool result 2", + "tool_call_id": "tool call two", + "metadata": {"tool call params": "abc 456"}, + }, {"role": "user", "content": "third prompt"}, {"role": "llm", "content": "third response"}, ] @@ -133,8 +140,18 @@ def test_standard_add_and_get(standard_history): {"role": "llm", "content": "first response"}, {"role": "user", "content": "second prompt"}, {"role": "llm", "content": "second response"}, - {"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"}, - {"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"}, + { + "role": "tool", + "content": "tool result 1", + "tool_call_id": "tool call one", + "metadata": {"tool call params": "abc 123"}, + }, + { + "role": "tool", + "content": "tool result 2", + "tool_call_id": "tool call two", + "metadata": {"tool call params": "abc 456"}, + }, {"role": "user", "content": "third prompt"}, {"role": "llm", "content": "third response"}, ] @@ -160,7 +177,11 @@ def test_standard_add_messages(standard_history): standard_history.add_messages( [ {"role": "user", "content": "first prompt"}, - {"role": "llm", "content": "first response"}, + { + "role": "llm", + "content": "first response", + "metadata": {"llm provider": "openai"}, + }, {"role": "user", "content": "second prompt"}, {"role": "llm", "content": "second response"}, { @@ -182,7 +203,11 @@ def test_standard_add_messages(standard_history): assert len(full_context) == 8 assert full_context == [ {"role": "user", "content": "first prompt"}, - {"role": "llm", "content": "first response"}, + { + "role": "llm", + "content": "first response", + "metadata": {"llm provider": "openai"}, + }, {"role": "user", "content": "second prompt"}, {"role": "llm", "content": "second response"}, {"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"}, @@ -198,8 +223,12 @@ def test_standard_messages_property(standard_history): {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, {"role": "user", "content": "second prompt"}, - {"role": "llm", "content": "second response"}, - {"role": "user", "content": "third prompt"}, + { + "role": "llm", + "content": "second response", + "metadata": {"params": "abc"}, + }, + {"role": "user", "content": "third prompt", "metadata": 42}, ] ) @@ -207,8 +236,8 @@ def test_standard_messages_property(standard_history): {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, {"role": "user", "content": "second prompt"}, - {"role": "llm", "content": "second response"}, - {"role": "user", "content": "third prompt"}, + {"role": "llm", "content": "second response", "metadata": {"params": "abc"}}, + {"role": "user", "content": "third prompt", "metadata": 42}, ] @@ -357,7 +386,14 @@ def test_semantic_store_and_get_recent(semantic_history): semantic_history.add_message( {"role": "tool", "content": "tool result", "tool_call_id": "tool id"} ) - # test default context history size + semantic_history.add_message( + { + "role": "tool", + "content": "tool result", + "tool_call_id": "tool id", + "metadata": "return value from tool", + } + ) # test default context history size default_context = semantic_history.get_recent() assert len(default_context) == 5 # 5 is default @@ -367,10 +403,10 @@ def test_semantic_store_and_get_recent(semantic_history): # test larger context history returns full history too_large_context = semantic_history.get_recent(top_k=100) - assert len(too_large_context) == 9 + assert len(too_large_context) == 10 # test that order is maintained - full_context = semantic_history.get_recent(top_k=9) + full_context = semantic_history.get_recent(top_k=10) assert full_context == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, @@ -381,15 +417,26 @@ def test_semantic_store_and_get_recent(semantic_history): {"role": "user", "content": "fourth prompt"}, {"role": "llm", "content": "fourth response"}, {"role": "tool", "content": "tool result", "tool_call_id": "tool id"}, + { + "role": "tool", + "content": "tool result", + "tool_call_id": "tool id", + "metadata": "return value from tool", + }, ] # test that more recent entries are returned context = semantic_history.get_recent(top_k=4) assert context == [ - {"role": "llm", "content": "third response"}, {"role": "user", "content": "fourth prompt"}, {"role": "llm", "content": "fourth response"}, {"role": "tool", "content": "tool result", "tool_call_id": "tool id"}, + { + "role": "tool", + "content": "tool result", + "tool_call_id": "tool id", + "metadata": "return value from tool", + }, ] # test no entries are returned and no error is raised if top_k == 0 @@ -422,11 +469,13 @@ def test_semantic_messages_property(semantic_history): "role": "tool", "content": "tool result 1", "tool_call_id": "tool call one", + "metadata": 42, }, { "role": "tool", "content": "tool result 2", "tool_call_id": "tool call two", + "metadata": [1, 2, 3], }, {"role": "user", "content": "second prompt"}, {"role": "llm", "content": "second response"}, @@ -437,8 +486,18 @@ def test_semantic_messages_property(semantic_history): assert semantic_history.messages == [ {"role": "user", "content": "first prompt"}, {"role": "llm", "content": "first response"}, - {"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"}, - {"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"}, + { + "role": "tool", + "content": "tool result 1", + "tool_call_id": "tool call one", + "metadata": 42, + }, + { + "role": "tool", + "content": "tool result 2", + "tool_call_id": "tool call two", + "metadata": [1, 2, 3], + }, {"role": "user", "content": "second prompt"}, {"role": "llm", "content": "second response"}, {"role": "user", "content": "third prompt"}, diff --git a/tests/unit/test_message_history_schema.py b/tests/unit/test_message_history_schema.py index 6143d2da..9fd0a625 100644 --- a/tests/unit/test_message_history_schema.py +++ b/tests/unit/test_message_history_schema.py @@ -3,7 +3,7 @@ from redisvl.extensions.message_history.schema import ChatMessage from redisvl.redis.utils import array_to_buffer -from redisvl.utils.utils import create_ulid, current_timestamp +from redisvl.utils.utils import create_ulid, current_timestamp, deserialize, serialize def test_chat_message_creation(): @@ -26,6 +26,7 @@ def test_chat_message_creation(): assert chat_message.timestamp == timestamp assert chat_message.tool_call_id is None assert chat_message.vector_field is None + assert chat_message.metadata is None def test_chat_message_default_id_generation(): @@ -61,6 +62,36 @@ def test_chat_message_with_tool_call_id(): assert chat_message.tool_call_id == tool_call_id +def test_chat_message_with_metadata(): + session_tag = create_ulid() + timestamp = current_timestamp() + content = "Hello, world!" + metadata = {"language": "Python", "version": "3.13"} + + chat_message = ChatMessage( + entry_id=f"{session_tag}:{timestamp}", + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + metadata=serialize(metadata), + ) + + assert chat_message.metadata == serialize(metadata) + + # test that metadta need not be a dictionary + for other_metadata in ["raw string", 42, [1, 2, 3], ["a", "b", "c"]]: + chat_message = ChatMessage( + entry_id=f"{session_tag}:{timestamp}", + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + metadata=serialize(other_metadata), + ) + assert chat_message.metadata == serialize(other_metadata) + + def test_chat_message_with_vector_field(): session_tag = create_ulid() timestamp = current_timestamp() @@ -84,6 +115,7 @@ def test_chat_message_to_dict(): timestamp = current_timestamp() content = "Hello, world!" vector_field = [0.1, 0.2, 0.3] + metadata = {"language": "Python", "version": "3.13"} chat_message = ChatMessage( entry_id=f"{session_tag}:{timestamp}", @@ -92,6 +124,7 @@ def test_chat_message_to_dict(): session_tag=session_tag, timestamp=timestamp, vector_field=vector_field, + metadata=serialize(metadata), ) data = chat_message.to_dict(dtype="float32") @@ -102,6 +135,7 @@ def test_chat_message_to_dict(): assert data["session_tag"] == session_tag assert data["timestamp"] == timestamp assert data["vector_field"] == array_to_buffer(vector_field, "float32") + assert data["metadata"] == serialize(metadata) def test_chat_message_missing_fields():