Skip to content

Commit 04736f1

Browse files
Adds metadata field to chat message history (#357)
1 parent f941ffd commit 04736f1

File tree

7 files changed

+140
-24
lines changed

7 files changed

+140
-24
lines changed

redisvl/extensions/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
CACHE_VECTOR_FIELD_NAME: str = "prompt_vector"
2424
INSERTED_AT_FIELD_NAME: str = "inserted_at"
2525
UPDATED_AT_FIELD_NAME: str = "updated_at"
26-
METADATA_FIELD_NAME: str = "metadata"
26+
METADATA_FIELD_NAME: str = "metadata" # also used in MessageHistory
2727

2828
# EmbeddingsCache
2929
TEXT_FIELD_NAME: str = "text"

redisvl/extensions/message_history/base_history.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
from redisvl.extensions.constants import (
44
CONTENT_FIELD_NAME,
5+
METADATA_FIELD_NAME,
56
ROLE_FIELD_NAME,
67
TOOL_FIELD_NAME,
78
)
89
from redisvl.extensions.message_history.schema import ChatMessage
9-
from redisvl.utils.utils import create_ulid
10+
from redisvl.utils.utils import create_ulid, deserialize
1011

1112

1213
class BaseMessageHistory:
@@ -111,6 +112,10 @@ def _format_context(
111112
}
112113
if chat_message.tool_call_id is not None:
113114
chat_message_dict[TOOL_FIELD_NAME] = chat_message.tool_call_id
115+
if chat_message.metadata is not None:
116+
chat_message_dict[METADATA_FIELD_NAME] = deserialize(
117+
chat_message.metadata
118+
)
114119

115120
context.append(chat_message_dict) # type: ignore
116121

redisvl/extensions/message_history/message_history.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from redisvl.extensions.constants import (
66
CONTENT_FIELD_NAME,
77
ID_FIELD_NAME,
8+
METADATA_FIELD_NAME,
89
ROLE_FIELD_NAME,
910
SESSION_FIELD_NAME,
1011
TIMESTAMP_FIELD_NAME,
@@ -15,6 +16,7 @@
1516
from redisvl.index import SearchIndex
1617
from redisvl.query import FilterQuery
1718
from redisvl.query.filter import Tag
19+
from redisvl.utils.utils import serialize
1820

1921

2022
class MessageHistory(BaseMessageHistory):
@@ -98,11 +100,13 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]:
98100
CONTENT_FIELD_NAME,
99101
TOOL_FIELD_NAME,
100102
TIMESTAMP_FIELD_NAME,
103+
METADATA_FIELD_NAME,
101104
]
102105

103106
query = FilterQuery(
104107
filter_expression=self._default_session_filter,
105108
return_fields=return_fields,
109+
num_results=1000,
106110
)
107111
query.sort_by(TIMESTAMP_FIELD_NAME, asc=True)
108112
messages = self._index.query(query)
@@ -144,6 +148,7 @@ def get_recent(
144148
CONTENT_FIELD_NAME,
145149
TOOL_FIELD_NAME,
146150
TIMESTAMP_FIELD_NAME,
151+
METADATA_FIELD_NAME,
147152
]
148153

149154
session_filter = (
@@ -210,7 +215,8 @@ def add_messages(
210215

211216
if TOOL_FIELD_NAME in message:
212217
chat_message.tool_call_id = message[TOOL_FIELD_NAME]
213-
218+
if METADATA_FIELD_NAME in message:
219+
chat_message.metadata = serialize(message[METADATA_FIELD_NAME])
214220
chat_messages.append(chat_message.to_dict())
215221

216222
self._index.load(data=chat_messages, id_field=ID_FIELD_NAME)

redisvl/extensions/message_history/schema.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
from typing import Dict, List, Optional
22

3-
from pydantic import BaseModel, ConfigDict, Field, model_validator
3+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
44

55
from redisvl.extensions.constants import (
66
CONTENT_FIELD_NAME,
77
ID_FIELD_NAME,
88
MESSAGE_VECTOR_FIELD_NAME,
9+
METADATA_FIELD_NAME,
910
ROLE_FIELD_NAME,
1011
SESSION_FIELD_NAME,
1112
TIMESTAMP_FIELD_NAME,
1213
TOOL_FIELD_NAME,
1314
)
1415
from redisvl.redis.utils import array_to_buffer
1516
from redisvl.schema import IndexSchema
16-
from redisvl.utils.utils import current_timestamp
17+
from redisvl.utils.utils import current_timestamp, deserialize
1718

1819

1920
class ChatMessage(BaseModel):
@@ -33,6 +34,8 @@ class ChatMessage(BaseModel):
3334
"""An optional identifier for a tool call associated with the message."""
3435
vector_field: Optional[List[float]] = Field(default=None)
3536
"""The vector representation of the message content."""
37+
metadata: Optional[str] = Field(default=None)
38+
"""Optional additional data to store alongside the message"""
3639
model_config = ConfigDict(arbitrary_types_allowed=True)
3740

3841
@model_validator(mode="before")
@@ -54,6 +57,7 @@ def to_dict(self, dtype: Optional[str] = None) -> Dict:
5457
data[MESSAGE_VECTOR_FIELD_NAME] = array_to_buffer(
5558
data[MESSAGE_VECTOR_FIELD_NAME], dtype # type: ignore[arg-type]
5659
)
60+
5761
return data
5862

5963

@@ -70,6 +74,7 @@ def from_params(cls, name: str, prefix: str):
7074
{"name": TOOL_FIELD_NAME, "type": "tag"},
7175
{"name": TIMESTAMP_FIELD_NAME, "type": "numeric"},
7276
{"name": SESSION_FIELD_NAME, "type": "tag"},
77+
{"name": METADATA_FIELD_NAME, "type": "text"},
7378
],
7479
)
7580

@@ -87,6 +92,7 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str):
8792
{"name": TOOL_FIELD_NAME, "type": "tag"},
8893
{"name": TIMESTAMP_FIELD_NAME, "type": "numeric"},
8994
{"name": SESSION_FIELD_NAME, "type": "tag"},
95+
{"name": METADATA_FIELD_NAME, "type": "text"},
9096
{
9197
"name": MESSAGE_VECTOR_FIELD_NAME,
9298
"type": "vector",

redisvl/extensions/message_history/semantic_history.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
CONTENT_FIELD_NAME,
77
ID_FIELD_NAME,
88
MESSAGE_VECTOR_FIELD_NAME,
9+
METADATA_FIELD_NAME,
910
ROLE_FIELD_NAME,
1011
SESSION_FIELD_NAME,
1112
TIMESTAMP_FIELD_NAME,
@@ -19,7 +20,7 @@
1920
from redisvl.index import SearchIndex
2021
from redisvl.query import FilterQuery, RangeQuery
2122
from redisvl.query.filter import Tag
22-
from redisvl.utils.utils import deprecated_argument, validate_vector_dims
23+
from redisvl.utils.utils import deprecated_argument, serialize, validate_vector_dims
2324
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
2425

2526

@@ -149,8 +150,9 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]:
149150
SESSION_FIELD_NAME,
150151
ROLE_FIELD_NAME,
151152
CONTENT_FIELD_NAME,
152-
TOOL_FIELD_NAME,
153153
TIMESTAMP_FIELD_NAME,
154+
TOOL_FIELD_NAME,
155+
METADATA_FIELD_NAME,
154156
]
155157

156158
query = FilterQuery(
@@ -214,6 +216,7 @@ def get_relevant(
214216
CONTENT_FIELD_NAME,
215217
TIMESTAMP_FIELD_NAME,
216218
TOOL_FIELD_NAME,
219+
METADATA_FIELD_NAME,
217220
]
218221

219222
session_filter = (
@@ -274,8 +277,9 @@ def get_recent(
274277
SESSION_FIELD_NAME,
275278
ROLE_FIELD_NAME,
276279
CONTENT_FIELD_NAME,
277-
TOOL_FIELD_NAME,
278280
TIMESTAMP_FIELD_NAME,
281+
TOOL_FIELD_NAME,
282+
METADATA_FIELD_NAME,
279283
]
280284

281285
session_filter = (
@@ -355,6 +359,8 @@ def add_messages(
355359

356360
if TOOL_FIELD_NAME in message:
357361
chat_message.tool_call_id = message[TOOL_FIELD_NAME]
362+
if METADATA_FIELD_NAME in message:
363+
chat_message.metadata = serialize(message[METADATA_FIELD_NAME])
358364

359365
chat_messages.append(chat_message.to_dict(dtype=self._vectorizer.dtype))
360366

tests/integration/test_message_history.py

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,15 @@ def test_standard_add_and_get(standard_history):
101101
"role": "tool",
102102
"content": "tool result 1",
103103
"tool_call_id": "tool call one",
104+
"metadata": {"tool call params": "abc 123"},
104105
}
105106
)
106107
standard_history.add_message(
107108
{
108109
"role": "tool",
109110
"content": "tool result 2",
110111
"tool_call_id": "tool call two",
112+
"metadata": {"tool call params": "abc 456"},
111113
}
112114
)
113115
standard_history.add_message({"role": "user", "content": "third prompt"})
@@ -121,7 +123,12 @@ def test_standard_add_and_get(standard_history):
121123
partial_context = standard_history.get_recent(top_k=3)
122124
assert len(partial_context) == 3
123125
assert partial_context == [
124-
{"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"},
126+
{
127+
"role": "tool",
128+
"content": "tool result 2",
129+
"tool_call_id": "tool call two",
130+
"metadata": {"tool call params": "abc 456"},
131+
},
125132
{"role": "user", "content": "third prompt"},
126133
{"role": "llm", "content": "third response"},
127134
]
@@ -133,8 +140,18 @@ def test_standard_add_and_get(standard_history):
133140
{"role": "llm", "content": "first response"},
134141
{"role": "user", "content": "second prompt"},
135142
{"role": "llm", "content": "second response"},
136-
{"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"},
137-
{"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"},
143+
{
144+
"role": "tool",
145+
"content": "tool result 1",
146+
"tool_call_id": "tool call one",
147+
"metadata": {"tool call params": "abc 123"},
148+
},
149+
{
150+
"role": "tool",
151+
"content": "tool result 2",
152+
"tool_call_id": "tool call two",
153+
"metadata": {"tool call params": "abc 456"},
154+
},
138155
{"role": "user", "content": "third prompt"},
139156
{"role": "llm", "content": "third response"},
140157
]
@@ -160,7 +177,11 @@ def test_standard_add_messages(standard_history):
160177
standard_history.add_messages(
161178
[
162179
{"role": "user", "content": "first prompt"},
163-
{"role": "llm", "content": "first response"},
180+
{
181+
"role": "llm",
182+
"content": "first response",
183+
"metadata": {"llm provider": "openai"},
184+
},
164185
{"role": "user", "content": "second prompt"},
165186
{"role": "llm", "content": "second response"},
166187
{
@@ -182,7 +203,11 @@ def test_standard_add_messages(standard_history):
182203
assert len(full_context) == 8
183204
assert full_context == [
184205
{"role": "user", "content": "first prompt"},
185-
{"role": "llm", "content": "first response"},
206+
{
207+
"role": "llm",
208+
"content": "first response",
209+
"metadata": {"llm provider": "openai"},
210+
},
186211
{"role": "user", "content": "second prompt"},
187212
{"role": "llm", "content": "second response"},
188213
{"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"},
@@ -198,17 +223,21 @@ def test_standard_messages_property(standard_history):
198223
{"role": "user", "content": "first prompt"},
199224
{"role": "llm", "content": "first response"},
200225
{"role": "user", "content": "second prompt"},
201-
{"role": "llm", "content": "second response"},
202-
{"role": "user", "content": "third prompt"},
226+
{
227+
"role": "llm",
228+
"content": "second response",
229+
"metadata": {"params": "abc"},
230+
},
231+
{"role": "user", "content": "third prompt", "metadata": 42},
203232
]
204233
)
205234

206235
assert standard_history.messages == [
207236
{"role": "user", "content": "first prompt"},
208237
{"role": "llm", "content": "first response"},
209238
{"role": "user", "content": "second prompt"},
210-
{"role": "llm", "content": "second response"},
211-
{"role": "user", "content": "third prompt"},
239+
{"role": "llm", "content": "second response", "metadata": {"params": "abc"}},
240+
{"role": "user", "content": "third prompt", "metadata": 42},
212241
]
213242

214243

@@ -357,7 +386,14 @@ def test_semantic_store_and_get_recent(semantic_history):
357386
semantic_history.add_message(
358387
{"role": "tool", "content": "tool result", "tool_call_id": "tool id"}
359388
)
360-
# test default context history size
389+
semantic_history.add_message(
390+
{
391+
"role": "tool",
392+
"content": "tool result",
393+
"tool_call_id": "tool id",
394+
"metadata": "return value from tool",
395+
}
396+
) # test default context history size
361397
default_context = semantic_history.get_recent()
362398
assert len(default_context) == 5 # 5 is default
363399

@@ -367,10 +403,10 @@ def test_semantic_store_and_get_recent(semantic_history):
367403

368404
# test larger context history returns full history
369405
too_large_context = semantic_history.get_recent(top_k=100)
370-
assert len(too_large_context) == 9
406+
assert len(too_large_context) == 10
371407

372408
# test that order is maintained
373-
full_context = semantic_history.get_recent(top_k=9)
409+
full_context = semantic_history.get_recent(top_k=10)
374410
assert full_context == [
375411
{"role": "user", "content": "first prompt"},
376412
{"role": "llm", "content": "first response"},
@@ -381,15 +417,26 @@ def test_semantic_store_and_get_recent(semantic_history):
381417
{"role": "user", "content": "fourth prompt"},
382418
{"role": "llm", "content": "fourth response"},
383419
{"role": "tool", "content": "tool result", "tool_call_id": "tool id"},
420+
{
421+
"role": "tool",
422+
"content": "tool result",
423+
"tool_call_id": "tool id",
424+
"metadata": "return value from tool",
425+
},
384426
]
385427

386428
# test that more recent entries are returned
387429
context = semantic_history.get_recent(top_k=4)
388430
assert context == [
389-
{"role": "llm", "content": "third response"},
390431
{"role": "user", "content": "fourth prompt"},
391432
{"role": "llm", "content": "fourth response"},
392433
{"role": "tool", "content": "tool result", "tool_call_id": "tool id"},
434+
{
435+
"role": "tool",
436+
"content": "tool result",
437+
"tool_call_id": "tool id",
438+
"metadata": "return value from tool",
439+
},
393440
]
394441

395442
# test no entries are returned and no error is raised if top_k == 0
@@ -422,11 +469,13 @@ def test_semantic_messages_property(semantic_history):
422469
"role": "tool",
423470
"content": "tool result 1",
424471
"tool_call_id": "tool call one",
472+
"metadata": 42,
425473
},
426474
{
427475
"role": "tool",
428476
"content": "tool result 2",
429477
"tool_call_id": "tool call two",
478+
"metadata": [1, 2, 3],
430479
},
431480
{"role": "user", "content": "second prompt"},
432481
{"role": "llm", "content": "second response"},
@@ -437,8 +486,18 @@ def test_semantic_messages_property(semantic_history):
437486
assert semantic_history.messages == [
438487
{"role": "user", "content": "first prompt"},
439488
{"role": "llm", "content": "first response"},
440-
{"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"},
441-
{"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"},
489+
{
490+
"role": "tool",
491+
"content": "tool result 1",
492+
"tool_call_id": "tool call one",
493+
"metadata": 42,
494+
},
495+
{
496+
"role": "tool",
497+
"content": "tool result 2",
498+
"tool_call_id": "tool call two",
499+
"metadata": [1, 2, 3],
500+
},
442501
{"role": "user", "content": "second prompt"},
443502
{"role": "llm", "content": "second response"},
444503
{"role": "user", "content": "third prompt"},

0 commit comments

Comments
 (0)