1
- from typing import Any , Dict , List , Optional
1
+ from typing import Any , Dict , List , Optional , Union
2
2
3
3
from redis import Redis
4
4
5
5
from redisvl .extensions .llmcache .base import BaseLLMCache
6
6
from redisvl .index import SearchIndex
7
7
from redisvl .query import RangeQuery
8
+ from redisvl .query .filter import FilterExpression , Tag
8
9
from redisvl .redis .utils import array_to_buffer
9
- from redisvl .schema .schema import IndexSchema
10
+ from redisvl .schema import IndexSchema
11
+ from redisvl .utils .utils import current_timestamp , deserialize , serialize
10
12
from redisvl .utils .vectorize import BaseVectorizer , HFTextVectorizer
11
13
12
14
15
+ class SemanticCacheIndexSchema (IndexSchema ):
16
+
17
+ @classmethod
18
+ def from_params (cls , name : str , vector_dims : int ):
19
+
20
+ return cls (
21
+ index = {"name" : name , "prefix" : name }, # type: ignore
22
+ fields = [ # type: ignore
23
+ {"name" : "prompt" , "type" : "text" },
24
+ {"name" : "response" , "type" : "text" },
25
+ {"name" : "inserted_at" , "type" : "numeric" },
26
+ {"name" : "updated_at" , "type" : "numeric" },
27
+ {"name" : "label" , "type" : "tag" },
28
+ {
29
+ "name" : "prompt_vector" ,
30
+ "type" : "vector" ,
31
+ "attrs" : {
32
+ "dims" : vector_dims ,
33
+ "datatype" : "float32" ,
34
+ "distance_metric" : "cosine" ,
35
+ "algorithm" : "flat" ,
36
+ },
37
+ },
38
+ ],
39
+ )
40
+
41
+
13
42
class SemanticCache (BaseLLMCache ):
14
43
"""Semantic Cache for Large Language Models."""
15
44
16
45
entry_id_field_name : str = "_id"
17
46
prompt_field_name : str = "prompt"
18
47
vector_field_name : str = "prompt_vector"
48
+ inserted_at_field_name : str = "inserted_at"
49
+ updated_at_field_name : str = "updated_at"
50
+ tag_field_name : str = "label"
19
51
response_field_name : str = "response"
20
52
metadata_field_name : str = "metadata"
21
53
@@ -69,27 +101,7 @@ def __init__(
69
101
model = "sentence-transformers/all-mpnet-base-v2"
70
102
)
71
103
72
- # build cache index schema
73
- schema = IndexSchema .from_dict ({"index" : {"name" : name , "prefix" : prefix }})
74
- # add fields
75
- schema .add_fields (
76
- [
77
- {"name" : self .prompt_field_name , "type" : "text" },
78
- {"name" : self .response_field_name , "type" : "text" },
79
- {
80
- "name" : self .vector_field_name ,
81
- "type" : "vector" ,
82
- "attrs" : {
83
- "dims" : vectorizer .dims ,
84
- "datatype" : "float32" ,
85
- "distance_metric" : "cosine" ,
86
- "algorithm" : "flat" ,
87
- },
88
- },
89
- ]
90
- )
91
-
92
- # build search index
104
+ schema = SemanticCacheIndexSchema .from_params (name , vectorizer .dims )
93
105
self ._index = SearchIndex (schema = schema )
94
106
95
107
# handle redis connection
@@ -103,12 +115,12 @@ def __init__(
103
115
self .entry_id_field_name ,
104
116
self .prompt_field_name ,
105
117
self .response_field_name ,
118
+ self .tag_field_name ,
106
119
self .vector_field_name ,
107
120
self .metadata_field_name ,
108
121
]
109
122
self .set_vectorizer (vectorizer )
110
123
self .set_threshold (distance_threshold )
111
-
112
124
self ._index .create (overwrite = False )
113
125
114
126
@property
@@ -182,6 +194,14 @@ def delete(self) -> None:
182
194
index."""
183
195
self ._index .delete (drop = True )
184
196
197
+ def drop (self , document_ids : Union [str , List [str ]]) -> None :
198
+ """Remove a specific entry or entries from the cache by it's ID.
199
+
200
+ Args:
201
+ document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache.
202
+ """
203
+ self ._index .drop_keys (document_ids )
204
+
185
205
def _refresh_ttl (self , key : str ) -> None :
186
206
"""Refresh the time-to-live for the specified key."""
187
207
if self ._ttl :
@@ -195,7 +215,11 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
195
215
return self ._vectorizer .embed (prompt )
196
216
197
217
def _search_cache (
198
- self , vector : List [float ], num_results : int , return_fields : Optional [List [str ]]
218
+ self ,
219
+ vector : List [float ],
220
+ num_results : int ,
221
+ return_fields : Optional [List [str ]],
222
+ tag_filter : Optional [FilterExpression ],
199
223
) -> List [Dict [str , Any ]]:
200
224
"""Searches the semantic cache for similar prompt vectors and returns
201
225
the specified return fields for each cache hit."""
@@ -217,6 +241,8 @@ def _search_cache(
217
241
num_results = num_results ,
218
242
return_score = True ,
219
243
)
244
+ if tag_filter :
245
+ query .set_filter (tag_filter ) # type: ignore
220
246
221
247
# Gather and return the cache hits
222
248
cache_hits : List [Dict [str , Any ]] = self ._index .query (query )
@@ -226,7 +252,7 @@ def _search_cache(
226
252
self ._refresh_ttl (key )
227
253
# Check for metadata and deserialize
228
254
if self .metadata_field_name in hit :
229
- hit [self .metadata_field_name ] = self . deserialize (
255
+ hit [self .metadata_field_name ] = deserialize (
230
256
hit [self .metadata_field_name ]
231
257
)
232
258
return cache_hits
@@ -248,6 +274,7 @@ def check(
248
274
vector : Optional [List [float ]] = None ,
249
275
num_results : int = 1 ,
250
276
return_fields : Optional [List [str ]] = None ,
277
+ tag_filter : Optional [FilterExpression ] = None ,
251
278
) -> List [Dict [str , Any ]]:
252
279
"""Checks the semantic cache for results similar to the specified prompt
253
280
or vector.
@@ -267,6 +294,8 @@ def check(
267
294
return_fields (Optional[List[str]], optional): The fields to include
268
295
in each returned result. If None, defaults to all available
269
296
fields in the cached entry.
297
+ tag_filter (Optional[FilterExpression]) : the tag filter to filter
298
+ results by. Default is None and full cache is searched.
270
299
271
300
Returns:
272
301
List[Dict[str, Any]]: A list of dicts containing the requested
@@ -291,7 +320,7 @@ def check(
291
320
self ._check_vector_dims (vector )
292
321
293
322
# Check for cache hits by searching the cache
294
- cache_hits = self ._search_cache (vector , num_results , return_fields )
323
+ cache_hits = self ._search_cache (vector , num_results , return_fields , tag_filter )
295
324
return cache_hits
296
325
297
326
def store (
@@ -300,6 +329,7 @@ def store(
300
329
response : str ,
301
330
vector : Optional [List [float ]] = None ,
302
331
metadata : Optional [dict ] = None ,
332
+ tag : Optional [str ] = None ,
303
333
) -> str :
304
334
"""Stores the specified key-value pair in the cache along with metadata.
305
335
@@ -311,6 +341,8 @@ def store(
311
341
demand.
312
342
metadata (Optional[dict], optional): The optional metadata to cache
313
343
alongside the prompt and response. Defaults to None.
344
+ tag (Optional[str]): The optional tag to assign to the cache entry.
345
+ Defaults to None.
314
346
315
347
Returns:
316
348
str: The Redis key for the entries added to the semantic cache.
@@ -333,19 +365,67 @@ def store(
333
365
self ._check_vector_dims (vector )
334
366
335
367
# Construct semantic cache payload
368
+ now = current_timestamp ()
336
369
id_field = self .entry_id_field_name
337
370
payload = {
338
371
id_field : self .hash_input (prompt ),
339
372
self .prompt_field_name : prompt ,
340
373
self .response_field_name : response ,
341
374
self .vector_field_name : array_to_buffer (vector ),
375
+ self .inserted_at_field_name : now ,
376
+ self .updated_at_field_name : now ,
342
377
}
343
378
if metadata is not None :
344
379
if not isinstance (metadata , dict ):
345
380
raise TypeError ("If specified, cached metadata must be a dictionary." )
346
381
# Serialize the metadata dict and add to cache payload
347
- payload [self .metadata_field_name ] = self .serialize (metadata )
382
+ payload [self .metadata_field_name ] = serialize (metadata )
383
+ if tag is not None :
384
+ payload [self .tag_field_name ] = tag
348
385
349
386
# Load LLMCache entry with TTL
350
387
keys = self ._index .load (data = [payload ], ttl = self ._ttl , id_field = id_field )
351
388
return keys [0 ]
389
+
390
+ def update (self , key : str , ** kwargs ) -> None :
391
+ """Update specific fields within an existing cache entry. If no fields
392
+ are passed, then only the document TTL is refreshed.
393
+
394
+ Args:
395
+ key (str): the key of the document to update.
396
+ kwargs:
397
+
398
+ Raises:
399
+ ValueError if an incorrect mapping is provided as a kwarg.
400
+ TypeError if metadata is provided and not of type dict.
401
+
402
+ .. code-block:: python
403
+ key = cache.store('this is a prompt', 'this is a response')
404
+ cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"})
405
+ )
406
+ """
407
+ if not kwargs :
408
+ self ._refresh_ttl (key )
409
+ return
410
+
411
+ for _key , val in kwargs .items ():
412
+ if _key not in {
413
+ self .prompt_field_name ,
414
+ self .vector_field_name ,
415
+ self .response_field_name ,
416
+ self .tag_field_name ,
417
+ self .metadata_field_name ,
418
+ }:
419
+ raise ValueError (f" { key } is not a valid field within document" )
420
+
421
+ # Check for metadata and deserialize
422
+ if _key == self .metadata_field_name :
423
+ if isinstance (val , dict ):
424
+ kwargs [_key ] = serialize (val )
425
+ else :
426
+ raise TypeError (
427
+ "If specified, cached metadata must be a dictionary."
428
+ )
429
+ kwargs .update ({self .updated_at_field_name : current_timestamp ()})
430
+ self ._index .client .hset (key , mapping = kwargs ) # type: ignore
431
+ self ._refresh_ttl (key )
0 commit comments