1
- from typing import Any , Dict , List , Optional , Union
1
+ from typing import Any , Dict , List , Optional
2
2
3
3
from redis import Redis
4
4
10
10
)
11
11
from redisvl .index import SearchIndex
12
12
from redisvl .query import RangeQuery
13
- from redisvl .query .filter import FilterExpression , Tag
14
- from redisvl .redis .utils import array_to_buffer
15
- from redisvl .utils .utils import (
16
- current_timestamp ,
17
- deserialize ,
18
- serialize ,
19
- validate_vector_dims ,
20
- )
13
+ from redisvl .query .filter import FilterExpression
14
+ from redisvl .utils .utils import current_timestamp , serialize , validate_vector_dims
21
15
from redisvl .utils .vectorize import BaseVectorizer , HFTextVectorizer
22
16
23
17
24
18
class SemanticCache (BaseLLMCache ):
25
19
"""Semantic Cache for Large Language Models."""
26
20
21
+ redis_key_field_name : str = "key"
27
22
entry_id_field_name : str = "entry_id"
28
23
prompt_field_name : str = "prompt"
29
24
response_field_name : str = "response"
@@ -55,6 +50,8 @@ def __init__(
55
50
in Redis. Defaults to None.
56
51
vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache.
57
52
Defaults to HFTextVectorizer.
53
+ filterable_fields (Optional[List[Dict[str, Any]]]): An optional list of RedisVL fields
54
+ that can be used to customize cache retrieval with filters.
58
55
redis_client(Optional[Redis], optional): A redis client connection instance.
59
56
Defaults to None.
60
57
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
@@ -81,9 +78,6 @@ def __init__(
81
78
model = "sentence-transformers/all-mpnet-base-v2"
82
79
)
83
80
84
- # Create semantic cache schema
85
- schema = SemanticCacheIndexSchema .from_params (name , prefix , vectorizer .dims )
86
-
87
81
# Process fields
88
82
self .return_fields = [
89
83
self .entry_id_field_name ,
@@ -94,18 +88,9 @@ def __init__(
94
88
self .metadata_field_name ,
95
89
]
96
90
97
- if filterable_fields is not None :
98
- for filter_field in filterable_fields :
99
- if (
100
- filter_field ["name" ] in self .return_fields
101
- or filter_field ["name" ] == "key"
102
- ):
103
- raise ValueError (
104
- f'{ filter_field ["name" ]} is a reserved field name for the semantic cache schema'
105
- )
106
- schema .add_field (filter_field )
107
- # Add to return fields too
108
- self .return_fields .append (filter_field ["name" ])
91
+ # Create semantic cache schema and index
92
+ schema = SemanticCacheIndexSchema .from_params (name , prefix , vectorizer .dims )
93
+ schema = self ._modify_schema (schema , filterable_fields )
109
94
110
95
self ._index = SearchIndex (schema = schema )
111
96
@@ -120,6 +105,30 @@ def __init__(
120
105
self .set_threshold (distance_threshold )
121
106
self ._index .create (overwrite = False )
122
107
108
+ def _modify_schema (
109
+ self ,
110
+ schema : SemanticCacheIndexSchema ,
111
+ filterable_fields : Optional [List [Dict [str , Any ]]] = None ,
112
+ ) -> SemanticCacheIndexSchema :
113
+ """Modify the base cache schema using the provided filterable fields"""
114
+
115
+ if filterable_fields is not None :
116
+ protected_field_names = set (
117
+ self .return_fields + [self .redis_key_field_name ]
118
+ )
119
+ for filter_field in filterable_fields :
120
+ field_name = filter_field ["name" ]
121
+ if field_name in protected_field_names :
122
+ raise ValueError (
123
+ f"{ field_name } is a reserved field name for the semantic cache schema"
124
+ )
125
+ # Add to schema
126
+ schema .add_field (filter_field )
127
+ # Add to return fields too
128
+ self .return_fields .append (field_name )
129
+
130
+ return schema
131
+
123
132
@property
124
133
def index (self ) -> SearchIndex :
125
134
"""The underlying SearchIndex for the cache.
0 commit comments