Skip to content

Commit 51e58aa

Browse files
moves extension field names into constants file (#225)
Cache, session, and router classes and their corresponding schema fields are hard coded. Since they must match between class and schema and not be modified they're moved to a constants file.
1 parent 951d630 commit 51e58aa

File tree

10 files changed

+191
-131
lines changed

10 files changed

+191
-131
lines changed

redisvl/extensions/constants.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
Constants used within the extension classes SemanticCache, BaseSessionManager,
3+
StandardSessionManager,SemanticSessionManager and SemanticRouter.
4+
These constants are also used within theses classes corresponding schema.
5+
"""
6+
7+
# BaseSessionManager
8+
ID_FIELD_NAME: str = "entry_id"
9+
ROLE_FIELD_NAME: str = "role"
10+
CONTENT_FIELD_NAME: str = "content"
11+
TOOL_FIELD_NAME: str = "tool_call_id"
12+
TIMESTAMP_FIELD_NAME: str = "timestamp"
13+
SESSION_FIELD_NAME: str = "session_tag"
14+
15+
# SemanticSessionManager
16+
SESSION_VECTOR_FIELD_NAME: str = "vector_field"
17+
18+
# SemanticCache
19+
REDIS_KEY_FIELD_NAME: str = "key"
20+
ENTRY_ID_FIELD_NAME: str = "entry_id"
21+
PROMPT_FIELD_NAME: str = "prompt"
22+
RESPONSE_FIELD_NAME: str = "response"
23+
CACHE_VECTOR_FIELD_NAME: str = "prompt_vector"
24+
INSERTED_AT_FIELD_NAME: str = "inserted_at"
25+
UPDATED_AT_FIELD_NAME: str = "updated_at"
26+
METADATA_FIELD_NAME: str = "metadata"
27+
28+
# SemanticRouter
29+
ROUTE_VECTOR_FIELD_NAME: str = "vector"

redisvl/extensions/llmcache/schema.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
from pydantic.v1 import BaseModel, Field, root_validator, validator
44

5+
from redisvl.extensions.constants import (
6+
CACHE_VECTOR_FIELD_NAME,
7+
INSERTED_AT_FIELD_NAME,
8+
PROMPT_FIELD_NAME,
9+
RESPONSE_FIELD_NAME,
10+
UPDATED_AT_FIELD_NAME,
11+
)
512
from redisvl.redis.utils import array_to_buffer, hashify
613
from redisvl.schema import IndexSchema
714
from redisvl.utils.utils import current_timestamp, deserialize, serialize
@@ -110,12 +117,12 @@ def from_params(cls, name: str, prefix: str, vector_dims: int):
110117
return cls(
111118
index={"name": name, "prefix": prefix}, # type: ignore
112119
fields=[ # type: ignore
113-
{"name": "prompt", "type": "text"},
114-
{"name": "response", "type": "text"},
115-
{"name": "inserted_at", "type": "numeric"},
116-
{"name": "updated_at", "type": "numeric"},
120+
{"name": PROMPT_FIELD_NAME, "type": "text"},
121+
{"name": RESPONSE_FIELD_NAME, "type": "text"},
122+
{"name": INSERTED_AT_FIELD_NAME, "type": "numeric"},
123+
{"name": UPDATED_AT_FIELD_NAME, "type": "numeric"},
117124
{
118-
"name": "prompt_vector",
125+
"name": CACHE_VECTOR_FIELD_NAME,
119126
"type": "vector",
120127
"attrs": {
121128
"dims": vector_dims,

redisvl/extensions/llmcache/semantic.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33

44
from redis import Redis
55

6+
from redisvl.extensions.constants import (
7+
CACHE_VECTOR_FIELD_NAME,
8+
ENTRY_ID_FIELD_NAME,
9+
INSERTED_AT_FIELD_NAME,
10+
METADATA_FIELD_NAME,
11+
PROMPT_FIELD_NAME,
12+
REDIS_KEY_FIELD_NAME,
13+
RESPONSE_FIELD_NAME,
14+
UPDATED_AT_FIELD_NAME,
15+
)
616
from redisvl.extensions.llmcache.base import BaseLLMCache
717
from redisvl.extensions.llmcache.schema import (
818
CacheEntry,
@@ -19,15 +29,6 @@
1929
class SemanticCache(BaseLLMCache):
2030
"""Semantic Cache for Large Language Models."""
2131

22-
redis_key_field_name: str = "key"
23-
entry_id_field_name: str = "entry_id"
24-
prompt_field_name: str = "prompt"
25-
response_field_name: str = "response"
26-
vector_field_name: str = "prompt_vector"
27-
inserted_at_field_name: str = "inserted_at"
28-
updated_at_field_name: str = "updated_at"
29-
metadata_field_name: str = "metadata"
30-
3132
_index: SearchIndex
3233
_aindex: Optional[AsyncSearchIndex] = None
3334

@@ -94,12 +95,12 @@ def __init__(
9495
# Process fields and other settings
9596
self.set_threshold(distance_threshold)
9697
self.return_fields = [
97-
self.entry_id_field_name,
98-
self.prompt_field_name,
99-
self.response_field_name,
100-
self.inserted_at_field_name,
101-
self.updated_at_field_name,
102-
self.metadata_field_name,
98+
ENTRY_ID_FIELD_NAME,
99+
PROMPT_FIELD_NAME,
100+
RESPONSE_FIELD_NAME,
101+
INSERTED_AT_FIELD_NAME,
102+
UPDATED_AT_FIELD_NAME,
103+
METADATA_FIELD_NAME,
103104
]
104105

105106
# Create semantic cache schema and index
@@ -133,7 +134,7 @@ def __init__(
133134

134135
validate_vector_dims(
135136
vectorizer.dims,
136-
self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore
137+
self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore
137138
)
138139
self._vectorizer = vectorizer
139140

@@ -145,9 +146,7 @@ def _modify_schema(
145146
"""Modify the base cache schema using the provided filterable fields"""
146147

147148
if filterable_fields is not None:
148-
protected_field_names = set(
149-
self.return_fields + [self.redis_key_field_name]
150-
)
149+
protected_field_names = set(self.return_fields + [REDIS_KEY_FIELD_NAME])
151150
for filter_field in filterable_fields:
152151
field_name = filter_field["name"]
153152
if field_name in protected_field_names:
@@ -300,7 +299,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
300299
def _check_vector_dims(self, vector: List[float]):
301300
"""Checks the size of the provided vector and raises an error if it
302301
doesn't match the search index vector dimensions."""
303-
schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore
302+
schema_vector_dims = self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims # type: ignore
304303
validate_vector_dims(len(vector), schema_vector_dims)
305304

306305
def check(
@@ -363,7 +362,7 @@ def check(
363362

364363
query = RangeQuery(
365364
vector=vector,
366-
vector_field_name=self.vector_field_name,
365+
vector_field_name=CACHE_VECTOR_FIELD_NAME,
367366
return_fields=self.return_fields,
368367
distance_threshold=distance_threshold,
369368
num_results=num_results,
@@ -444,7 +443,7 @@ async def acheck(
444443

445444
query = RangeQuery(
446445
vector=vector,
447-
vector_field_name=self.vector_field_name,
446+
vector_field_name=CACHE_VECTOR_FIELD_NAME,
448447
return_fields=self.return_fields,
449448
distance_threshold=distance_threshold,
450449
num_results=num_results,
@@ -479,7 +478,7 @@ def _process_cache_results(
479478
cache_hit_dict = {
480479
k: v for k, v in cache_hit_dict.items() if k in return_fields
481480
}
482-
cache_hit_dict[self.redis_key_field_name] = redis_key
481+
cache_hit_dict[REDIS_KEY_FIELD_NAME] = redis_key
483482
cache_hits.append(cache_hit_dict)
484483
return redis_keys, cache_hits
485484

@@ -541,7 +540,7 @@ def store(
541540
keys = self._index.load(
542541
data=[cache_entry.to_dict()],
543542
ttl=ttl,
544-
id_field=self.entry_id_field_name,
543+
id_field=ENTRY_ID_FIELD_NAME,
545544
)
546545
return keys[0]
547546

@@ -605,7 +604,7 @@ async def astore(
605604
keys = await aindex.load(
606605
data=[cache_entry.to_dict()],
607606
ttl=ttl,
608-
id_field=self.entry_id_field_name,
607+
id_field=ENTRY_ID_FIELD_NAME,
609608
)
610609
return keys[0]
611610

@@ -629,21 +628,19 @@ def update(self, key: str, **kwargs) -> None:
629628
for k, v in kwargs.items():
630629

631630
# Make sure the item is in the index schema
632-
if k not in set(
633-
self._index.schema.field_names + [self.metadata_field_name]
634-
):
631+
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
635632
raise ValueError(f"{k} is not a valid field within the cache entry")
636633

637634
# Check for metadata and deserialize
638-
if k == self.metadata_field_name:
635+
if k == METADATA_FIELD_NAME:
639636
if isinstance(v, dict):
640637
kwargs[k] = serialize(v)
641638
else:
642639
raise TypeError(
643640
"If specified, cached metadata must be a dictionary."
644641
)
645642

646-
kwargs.update({self.updated_at_field_name: current_timestamp()})
643+
kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()})
647644

648645
self._index.client.hset(key, mapping=kwargs) # type: ignore
649646

@@ -674,21 +671,19 @@ async def aupdate(self, key: str, **kwargs) -> None:
674671
for k, v in kwargs.items():
675672

676673
# Make sure the item is in the index schema
677-
if k not in set(
678-
self._index.schema.field_names + [self.metadata_field_name]
679-
):
674+
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
680675
raise ValueError(f"{k} is not a valid field within the cache entry")
681676

682677
# Check for metadata and deserialize
683-
if k == self.metadata_field_name:
678+
if k == METADATA_FIELD_NAME:
684679
if isinstance(v, dict):
685680
kwargs[k] = serialize(v)
686681
else:
687682
raise TypeError(
688683
"If specified, cached metadata must be a dictionary."
689684
)
690685

691-
kwargs.update({self.updated_at_field_name: current_timestamp()})
686+
kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()})
692687

693688
await aindex.load(data=[kwargs], keys=[key])
694689

redisvl/extensions/router/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pydantic.v1 import BaseModel, Field, validator
55

6+
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
67
from redisvl.schema import IndexInfo, IndexSchema
78

89

@@ -104,7 +105,7 @@ def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema"
104105
{"name": "route_name", "type": "tag"},
105106
{"name": "reference", "type": "text"},
106107
{
107-
"name": "vector",
108+
"name": ROUTE_VECTOR_FIELD_NAME,
108109
"type": "vector",
109110
"attrs": {
110111
"algorithm": "flat",

redisvl/extensions/router/semantic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
99
from redis.exceptions import ResponseError
1010

11+
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
1112
from redisvl.extensions.router.schema import (
1213
DistanceAggregationMethod,
1314
Route,
@@ -226,7 +227,7 @@ def _classify_route(
226227
"""Classify to a single route using a vector."""
227228
vector_range_query = RangeQuery(
228229
vector=vector,
229-
vector_field_name="vector",
230+
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
230231
distance_threshold=distance_threshold,
231232
return_fields=["route_name"],
232233
)
@@ -278,7 +279,7 @@ def _classify_multi_route(
278279
"""Classify to multiple routes, up to max_k (int), using a vector."""
279280
vector_range_query = RangeQuery(
280281
vector=vector,
281-
vector_field_name="vector",
282+
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
282283
distance_threshold=distance_threshold,
283284
return_fields=["route_name"],
284285
)

redisvl/extensions/session_manager/base_session.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
from typing import Any, Dict, List, Optional, Union
22

3+
from redisvl.extensions.constants import (
4+
CONTENT_FIELD_NAME,
5+
ROLE_FIELD_NAME,
6+
TOOL_FIELD_NAME,
7+
)
38
from redisvl.extensions.session_manager.schema import ChatMessage
49
from redisvl.utils.utils import create_uuid
510

611

712
class BaseSessionManager:
8-
id_field_name: str = "entry_id"
9-
role_field_name: str = "role"
10-
content_field_name: str = "content"
11-
tool_field_name: str = "tool_call_id"
12-
timestamp_field_name: str = "timestamp"
13-
session_field_name: str = "session_tag"
1413

1514
def __init__(
1615
self,
@@ -107,11 +106,11 @@ def _format_context(
107106
context.append(chat_message.content)
108107
else:
109108
chat_message_dict = {
110-
self.role_field_name: chat_message.role,
111-
self.content_field_name: chat_message.content,
109+
ROLE_FIELD_NAME: chat_message.role,
110+
CONTENT_FIELD_NAME: chat_message.content,
112111
}
113112
if chat_message.tool_call_id is not None:
114-
chat_message_dict[self.tool_field_name] = chat_message.tool_call_id
113+
chat_message_dict[TOOL_FIELD_NAME] = chat_message.tool_call_id
115114

116115
context.append(chat_message_dict) # type: ignore
117116

redisvl/extensions/session_manager/schema.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22

33
from pydantic.v1 import BaseModel, Field, root_validator
44

5+
from redisvl.extensions.constants import (
6+
CONTENT_FIELD_NAME,
7+
ID_FIELD_NAME,
8+
ROLE_FIELD_NAME,
9+
SESSION_FIELD_NAME,
10+
SESSION_VECTOR_FIELD_NAME,
11+
TIMESTAMP_FIELD_NAME,
12+
TOOL_FIELD_NAME,
13+
)
514
from redisvl.redis.utils import array_to_buffer
615
from redisvl.schema import IndexSchema
716
from redisvl.utils.utils import current_timestamp
@@ -31,18 +40,22 @@ class Config:
3140
@root_validator(pre=True)
3241
@classmethod
3342
def generate_id(cls, values):
34-
if "timestamp" not in values:
35-
values["timestamp"] = current_timestamp()
36-
if "entry_id" not in values:
37-
values["entry_id"] = f'{values["session_tag"]}:{values["timestamp"]}'
43+
if TIMESTAMP_FIELD_NAME not in values:
44+
values[TIMESTAMP_FIELD_NAME] = current_timestamp()
45+
if ID_FIELD_NAME not in values:
46+
values[ID_FIELD_NAME] = (
47+
f"{values[SESSION_FIELD_NAME]}:{values[TIMESTAMP_FIELD_NAME]}"
48+
)
3849
return values
3950

4051
def to_dict(self) -> Dict:
4152
data = self.dict(exclude_none=True)
4253

4354
# handle optional fields
44-
if "vector_field" in data:
45-
data["vector_field"] = array_to_buffer(data["vector_field"])
55+
if SESSION_VECTOR_FIELD_NAME in data:
56+
data[SESSION_VECTOR_FIELD_NAME] = array_to_buffer(
57+
data[SESSION_VECTOR_FIELD_NAME]
58+
)
4659

4760
return data
4861

@@ -55,11 +68,11 @@ def from_params(cls, name: str, prefix: str):
5568
return cls(
5669
index={"name": name, "prefix": prefix}, # type: ignore
5770
fields=[ # type: ignore
58-
{"name": "role", "type": "tag"},
59-
{"name": "content", "type": "text"},
60-
{"name": "tool_call_id", "type": "tag"},
61-
{"name": "timestamp", "type": "numeric"},
62-
{"name": "session_tag", "type": "tag"},
71+
{"name": ROLE_FIELD_NAME, "type": "tag"},
72+
{"name": CONTENT_FIELD_NAME, "type": "text"},
73+
{"name": TOOL_FIELD_NAME, "type": "tag"},
74+
{"name": TIMESTAMP_FIELD_NAME, "type": "numeric"},
75+
{"name": SESSION_FIELD_NAME, "type": "tag"},
6376
],
6477
)
6578

@@ -72,13 +85,13 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int):
7285
return cls(
7386
index={"name": name, "prefix": prefix}, # type: ignore
7487
fields=[ # type: ignore
75-
{"name": "role", "type": "tag"},
76-
{"name": "content", "type": "text"},
77-
{"name": "tool_call_id", "type": "tag"},
78-
{"name": "timestamp", "type": "numeric"},
79-
{"name": "session_tag", "type": "tag"},
88+
{"name": ROLE_FIELD_NAME, "type": "tag"},
89+
{"name": CONTENT_FIELD_NAME, "type": "text"},
90+
{"name": TOOL_FIELD_NAME, "type": "tag"},
91+
{"name": TIMESTAMP_FIELD_NAME, "type": "numeric"},
92+
{"name": SESSION_FIELD_NAME, "type": "tag"},
8093
{
81-
"name": "vector_field",
94+
"name": SESSION_VECTOR_FIELD_NAME,
8295
"type": "vector",
8396
"attrs": {
8497
"dims": vectorizer_dims,

0 commit comments

Comments
 (0)