Skip to content

Commit 8bbd1b0

Browse files
Use ChatMessage pydantic model (#194)
Uses pydantic for validation, docs, and custom serialization of chat message objects on the way into and out of Redis. --------- Co-authored-by: Justin Cechmanek <[email protected]>
1 parent 3844d57 commit 8bbd1b0

File tree

8 files changed

+387
-161
lines changed

8 files changed

+387
-161
lines changed

redisvl/extensions/router/schema.py

+35
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from pydantic.v1 import BaseModel, Field, validator
55

6+
from redisvl.schema import IndexInfo, IndexSchema
7+
68

79
class Route(BaseModel):
810
"""Model representing a routing path with associated metadata and thresholds."""
@@ -80,3 +82,36 @@ def distance_threshold_must_be_valid(cls, v):
8082
if v <= 0 or v > 1:
8183
raise ValueError("distance_threshold must be between 0 and 1")
8284
return v
85+
86+
87+
class SemanticRouterIndexSchema(IndexSchema):
88+
"""Customized index schema for SemanticRouter."""
89+
90+
@classmethod
91+
def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema":
92+
"""Create an index schema based on router name and vector dimensions.
93+
94+
Args:
95+
name (str): The name of the index.
96+
vector_dims (int): The dimensions of the vectors.
97+
98+
Returns:
99+
SemanticRouterIndexSchema: The constructed index schema.
100+
"""
101+
return cls(
102+
index=IndexInfo(name=name, prefix=name),
103+
fields=[ # type: ignore
104+
{"name": "route_name", "type": "tag"},
105+
{"name": "reference", "type": "text"},
106+
{
107+
"name": "vector",
108+
"type": "vector",
109+
"attrs": {
110+
"algorithm": "flat",
111+
"dims": vector_dims,
112+
"distance_metric": "cosine",
113+
"datatype": "float32",
114+
},
115+
},
116+
],
117+
)

redisvl/extensions/router/semantic.py

+1-34
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
Route,
1414
RouteMatch,
1515
RoutingConfig,
16+
SemanticRouterIndexSchema,
1617
)
1718
from redisvl.index import SearchIndex
1819
from redisvl.query import RangeQuery
1920
from redisvl.redis.utils import convert_bytes, hashify, make_dict
20-
from redisvl.schema import IndexInfo, IndexSchema
2121
from redisvl.utils.log import get_logger
2222
from redisvl.utils.utils import model_to_dict
2323
from redisvl.utils.vectorize import (
@@ -29,39 +29,6 @@
2929
logger = get_logger(__name__)
3030

3131

32-
class SemanticRouterIndexSchema(IndexSchema):
33-
"""Customized index schema for SemanticRouter."""
34-
35-
@classmethod
36-
def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema":
37-
"""Create an index schema based on router name and vector dimensions.
38-
39-
Args:
40-
name (str): The name of the index.
41-
vector_dims (int): The dimensions of the vectors.
42-
43-
Returns:
44-
SemanticRouterIndexSchema: The constructed index schema.
45-
"""
46-
return cls(
47-
index=IndexInfo(name=name, prefix=name),
48-
fields=[ # type: ignore
49-
{"name": "route_name", "type": "tag"},
50-
{"name": "reference", "type": "text"},
51-
{
52-
"name": "vector",
53-
"type": "vector",
54-
"attrs": {
55-
"algorithm": "flat",
56-
"dims": vector_dims,
57-
"distance_metric": "cosine",
58-
"datatype": "float32",
59-
},
60-
},
61-
],
62-
)
63-
64-
6532
class SemanticRouter(BaseModel):
6633
"""Semantic Router for managing and querying route vectors."""
6734

redisvl/extensions/session_manager/base_session.py

+24-31
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from typing import Any, Dict, List, Optional, Union
2-
from uuid import uuid4
32

4-
from redis import Redis
5-
6-
from redisvl.query.filter import FilterExpression
3+
from redisvl.extensions.session_manager.schema import ChatMessage
4+
from redisvl.utils.utils import create_uuid
75

86

97
class BaseSessionManager:
@@ -32,7 +30,7 @@ def __init__(
3230
session. Defaults to instance uuid.
3331
"""
3432
self._name = name
35-
self._session_tag = session_tag or uuid4().hex
33+
self._session_tag = session_tag or create_uuid()
3634

3735
def clear(self) -> None:
3836
"""Clears the chat session history."""
@@ -85,44 +83,39 @@ def get_recent(
8583
raise NotImplementedError
8684

8785
def _format_context(
88-
self, hits: List[Dict[str, Any]], as_text: bool
86+
self, messages: List[Dict[str, Any]], as_text: bool
8987
) -> Union[List[str], List[Dict[str, str]]]:
9088
"""Extracts the prompt and response fields from the Redis hashes and
9189
formats them as either flat dictionaries or strings.
9290
9391
Args:
94-
hits (List): The hashes containing prompt & response pairs from
95-
recent conversation history.
92+
messages (List[Dict[str, Any]]): The messages from the session index.
9693
as_text (bool): Whether to return the conversation as a single string,
9794
or list of alternating prompts and responses.
9895
9996
Returns:
10097
Union[str, List[str]]: A single string transcription of the session
10198
or list of strings if as_text is false.
10299
"""
103-
if as_text:
104-
text_statements = []
105-
for hit in hits:
106-
text_statements.append(hit[self.content_field_name])
107-
return text_statements
108-
else:
109-
statements = []
110-
for hit in hits:
111-
statements.append(
112-
{
113-
self.role_field_name: hit[self.role_field_name],
114-
self.content_field_name: hit[self.content_field_name],
115-
}
116-
)
117-
if (
118-
hasattr(hit, self.tool_field_name)
119-
or isinstance(hit, dict)
120-
and self.tool_field_name in hit
121-
):
122-
statements[-1].update(
123-
{self.tool_field_name: hit[self.tool_field_name]}
124-
)
125-
return statements
100+
context = []
101+
102+
for message in messages:
103+
104+
chat_message = ChatMessage(**message)
105+
106+
if as_text:
107+
context.append(chat_message.content)
108+
else:
109+
chat_message_dict = {
110+
self.role_field_name: chat_message.role,
111+
self.content_field_name: chat_message.content,
112+
}
113+
if chat_message.tool_call_id is not None:
114+
chat_message_dict[self.tool_field_name] = chat_message.tool_call_id
115+
116+
context.append(chat_message_dict) # type: ignore
117+
118+
return context
126119

127120
def store(
128121
self, prompt: str, response: str, session_tag: Optional[str] = None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Dict, List, Optional
2+
3+
from pydantic.v1 import BaseModel, Field, root_validator
4+
5+
from redisvl.redis.utils import array_to_buffer
6+
from redisvl.schema import IndexSchema
7+
from redisvl.utils.utils import current_timestamp
8+
9+
10+
class ChatMessage(BaseModel):
11+
"""A single chat message exchanged between a user and an LLM."""
12+
13+
_id: Optional[str] = Field(default=None)
14+
"""A unique identifier for the message."""
15+
role: str # TODO -- do we enumify this?
16+
"""The role of the message sender (e.g., 'user' or 'llm')."""
17+
content: str
18+
"""The content of the message."""
19+
session_tag: str
20+
"""Tag associated with the current session."""
21+
timestamp: float = Field(default_factory=current_timestamp)
22+
"""The time the message was sent, in UTC, rounded to milliseconds."""
23+
tool_call_id: Optional[str] = Field(default=None)
24+
"""An optional identifier for a tool call associated with the message."""
25+
vector_field: Optional[List[float]] = Field(default=None)
26+
"""The vector representation of the message content."""
27+
28+
class Config:
29+
arbitrary_types_allowed = True
30+
31+
@root_validator(pre=False)
32+
@classmethod
33+
def generate_id(cls, values):
34+
if "_id" not in values:
35+
values["_id"] = f'{values["session_tag"]}:{values["timestamp"]}'
36+
return values
37+
38+
def to_dict(self) -> Dict:
39+
data = self.dict()
40+
41+
# handle optional fields
42+
if data["vector_field"] is not None:
43+
data["vector_field"] = array_to_buffer(data["vector_field"])
44+
else:
45+
del data["vector_field"]
46+
47+
if self.tool_call_id is None:
48+
del data["tool_call_id"]
49+
50+
return data
51+
52+
53+
class StandardSessionIndexSchema(IndexSchema):
54+
55+
@classmethod
56+
def from_params(cls, name: str, prefix: str):
57+
58+
return cls(
59+
index={"name": name, "prefix": prefix}, # type: ignore
60+
fields=[ # type: ignore
61+
{"name": "role", "type": "tag"},
62+
{"name": "content", "type": "text"},
63+
{"name": "tool_call_id", "type": "tag"},
64+
{"name": "timestamp", "type": "numeric"},
65+
{"name": "session_tag", "type": "tag"},
66+
],
67+
)
68+
69+
70+
class SemanticSessionIndexSchema(IndexSchema):
71+
72+
@classmethod
73+
def from_params(cls, name: str, prefix: str, vectorizer_dims: int):
74+
75+
return cls(
76+
index={"name": name, "prefix": prefix}, # type: ignore
77+
fields=[ # type: ignore
78+
{"name": "role", "type": "tag"},
79+
{"name": "content", "type": "text"},
80+
{"name": "tool_call_id", "type": "tag"},
81+
{"name": "timestamp", "type": "numeric"},
82+
{"name": "session_tag", "type": "tag"},
83+
{
84+
"name": "vector_field",
85+
"type": "vector",
86+
"attrs": {
87+
"dims": vectorizer_dims,
88+
"datatype": "float32",
89+
"distance_metric": "cosine",
90+
"algorithm": "flat",
91+
},
92+
},
93+
],
94+
)

0 commit comments

Comments
 (0)