Skip to content

Commit d1bd692

Browse files
Add scoped access and basic filter support to semantic cache (#180)
This PR is the first of a few PRs that adds to RedisVL’s existing semantic cache class more functionality around dropping and updating cache entries. It also adds scoped access control and meta data fields.
1 parent 8bbd1b0 commit d1bd692

File tree

4 files changed

+267
-38
lines changed

4 files changed

+267
-38
lines changed

redisvl/extensions/llmcache/base.py

-8
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,3 @@ def store(
5858
def hash_input(self, prompt: str):
5959
"""Hashes the input using SHA256."""
6060
return hashify(prompt)
61-
62-
def serialize(self, metadata: Dict[str, Any]) -> str:
63-
"""Serlize the input into a string."""
64-
return json.dumps(metadata)
65-
66-
def deserialize(self, metadata: str) -> Dict[str, Any]:
67-
"""Deserialize the input from a string."""
68-
return json.loads(metadata)

redisvl/extensions/llmcache/semantic.py

+108-28
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,53 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List, Optional, Union
22

33
from redis import Redis
44

55
from redisvl.extensions.llmcache.base import BaseLLMCache
66
from redisvl.index import SearchIndex
77
from redisvl.query import RangeQuery
8+
from redisvl.query.filter import FilterExpression, Tag
89
from redisvl.redis.utils import array_to_buffer
9-
from redisvl.schema.schema import IndexSchema
10+
from redisvl.schema import IndexSchema
11+
from redisvl.utils.utils import current_timestamp, deserialize, serialize
1012
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
1113

1214

15+
class SemanticCacheIndexSchema(IndexSchema):
16+
17+
@classmethod
18+
def from_params(cls, name: str, vector_dims: int):
19+
20+
return cls(
21+
index={"name": name, "prefix": name}, # type: ignore
22+
fields=[ # type: ignore
23+
{"name": "prompt", "type": "text"},
24+
{"name": "response", "type": "text"},
25+
{"name": "inserted_at", "type": "numeric"},
26+
{"name": "updated_at", "type": "numeric"},
27+
{"name": "label", "type": "tag"},
28+
{
29+
"name": "prompt_vector",
30+
"type": "vector",
31+
"attrs": {
32+
"dims": vector_dims,
33+
"datatype": "float32",
34+
"distance_metric": "cosine",
35+
"algorithm": "flat",
36+
},
37+
},
38+
],
39+
)
40+
41+
1342
class SemanticCache(BaseLLMCache):
1443
"""Semantic Cache for Large Language Models."""
1544

1645
entry_id_field_name: str = "_id"
1746
prompt_field_name: str = "prompt"
1847
vector_field_name: str = "prompt_vector"
48+
inserted_at_field_name: str = "inserted_at"
49+
updated_at_field_name: str = "updated_at"
50+
tag_field_name: str = "label"
1951
response_field_name: str = "response"
2052
metadata_field_name: str = "metadata"
2153

@@ -69,27 +101,7 @@ def __init__(
69101
model="sentence-transformers/all-mpnet-base-v2"
70102
)
71103

72-
# build cache index schema
73-
schema = IndexSchema.from_dict({"index": {"name": name, "prefix": prefix}})
74-
# add fields
75-
schema.add_fields(
76-
[
77-
{"name": self.prompt_field_name, "type": "text"},
78-
{"name": self.response_field_name, "type": "text"},
79-
{
80-
"name": self.vector_field_name,
81-
"type": "vector",
82-
"attrs": {
83-
"dims": vectorizer.dims,
84-
"datatype": "float32",
85-
"distance_metric": "cosine",
86-
"algorithm": "flat",
87-
},
88-
},
89-
]
90-
)
91-
92-
# build search index
104+
schema = SemanticCacheIndexSchema.from_params(name, vectorizer.dims)
93105
self._index = SearchIndex(schema=schema)
94106

95107
# handle redis connection
@@ -103,12 +115,12 @@ def __init__(
103115
self.entry_id_field_name,
104116
self.prompt_field_name,
105117
self.response_field_name,
118+
self.tag_field_name,
106119
self.vector_field_name,
107120
self.metadata_field_name,
108121
]
109122
self.set_vectorizer(vectorizer)
110123
self.set_threshold(distance_threshold)
111-
112124
self._index.create(overwrite=False)
113125

114126
@property
@@ -182,6 +194,14 @@ def delete(self) -> None:
182194
index."""
183195
self._index.delete(drop=True)
184196

197+
def drop(self, document_ids: Union[str, List[str]]) -> None:
198+
"""Remove a specific entry or entries from the cache by it's ID.
199+
200+
Args:
201+
document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache.
202+
"""
203+
self._index.drop_keys(document_ids)
204+
185205
def _refresh_ttl(self, key: str) -> None:
186206
"""Refresh the time-to-live for the specified key."""
187207
if self._ttl:
@@ -195,7 +215,11 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
195215
return self._vectorizer.embed(prompt)
196216

197217
def _search_cache(
198-
self, vector: List[float], num_results: int, return_fields: Optional[List[str]]
218+
self,
219+
vector: List[float],
220+
num_results: int,
221+
return_fields: Optional[List[str]],
222+
tag_filter: Optional[FilterExpression],
199223
) -> List[Dict[str, Any]]:
200224
"""Searches the semantic cache for similar prompt vectors and returns
201225
the specified return fields for each cache hit."""
@@ -217,6 +241,8 @@ def _search_cache(
217241
num_results=num_results,
218242
return_score=True,
219243
)
244+
if tag_filter:
245+
query.set_filter(tag_filter) # type: ignore
220246

221247
# Gather and return the cache hits
222248
cache_hits: List[Dict[str, Any]] = self._index.query(query)
@@ -226,7 +252,7 @@ def _search_cache(
226252
self._refresh_ttl(key)
227253
# Check for metadata and deserialize
228254
if self.metadata_field_name in hit:
229-
hit[self.metadata_field_name] = self.deserialize(
255+
hit[self.metadata_field_name] = deserialize(
230256
hit[self.metadata_field_name]
231257
)
232258
return cache_hits
@@ -248,6 +274,7 @@ def check(
248274
vector: Optional[List[float]] = None,
249275
num_results: int = 1,
250276
return_fields: Optional[List[str]] = None,
277+
tag_filter: Optional[FilterExpression] = None,
251278
) -> List[Dict[str, Any]]:
252279
"""Checks the semantic cache for results similar to the specified prompt
253280
or vector.
@@ -267,6 +294,8 @@ def check(
267294
return_fields (Optional[List[str]], optional): The fields to include
268295
in each returned result. If None, defaults to all available
269296
fields in the cached entry.
297+
tag_filter (Optional[FilterExpression]) : the tag filter to filter
298+
results by. Default is None and full cache is searched.
270299
271300
Returns:
272301
List[Dict[str, Any]]: A list of dicts containing the requested
@@ -291,7 +320,7 @@ def check(
291320
self._check_vector_dims(vector)
292321

293322
# Check for cache hits by searching the cache
294-
cache_hits = self._search_cache(vector, num_results, return_fields)
323+
cache_hits = self._search_cache(vector, num_results, return_fields, tag_filter)
295324
return cache_hits
296325

297326
def store(
@@ -300,6 +329,7 @@ def store(
300329
response: str,
301330
vector: Optional[List[float]] = None,
302331
metadata: Optional[dict] = None,
332+
tag: Optional[str] = None,
303333
) -> str:
304334
"""Stores the specified key-value pair in the cache along with metadata.
305335
@@ -311,6 +341,8 @@ def store(
311341
demand.
312342
metadata (Optional[dict], optional): The optional metadata to cache
313343
alongside the prompt and response. Defaults to None.
344+
tag (Optional[str]): The optional tag to assign to the cache entry.
345+
Defaults to None.
314346
315347
Returns:
316348
str: The Redis key for the entries added to the semantic cache.
@@ -333,19 +365,67 @@ def store(
333365
self._check_vector_dims(vector)
334366

335367
# Construct semantic cache payload
368+
now = current_timestamp()
336369
id_field = self.entry_id_field_name
337370
payload = {
338371
id_field: self.hash_input(prompt),
339372
self.prompt_field_name: prompt,
340373
self.response_field_name: response,
341374
self.vector_field_name: array_to_buffer(vector),
375+
self.inserted_at_field_name: now,
376+
self.updated_at_field_name: now,
342377
}
343378
if metadata is not None:
344379
if not isinstance(metadata, dict):
345380
raise TypeError("If specified, cached metadata must be a dictionary.")
346381
# Serialize the metadata dict and add to cache payload
347-
payload[self.metadata_field_name] = self.serialize(metadata)
382+
payload[self.metadata_field_name] = serialize(metadata)
383+
if tag is not None:
384+
payload[self.tag_field_name] = tag
348385

349386
# Load LLMCache entry with TTL
350387
keys = self._index.load(data=[payload], ttl=self._ttl, id_field=id_field)
351388
return keys[0]
389+
390+
def update(self, key: str, **kwargs) -> None:
391+
"""Update specific fields within an existing cache entry. If no fields
392+
are passed, then only the document TTL is refreshed.
393+
394+
Args:
395+
key (str): the key of the document to update.
396+
kwargs:
397+
398+
Raises:
399+
ValueError if an incorrect mapping is provided as a kwarg.
400+
TypeError if metadata is provided and not of type dict.
401+
402+
.. code-block:: python
403+
key = cache.store('this is a prompt', 'this is a response')
404+
cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"})
405+
)
406+
"""
407+
if not kwargs:
408+
self._refresh_ttl(key)
409+
return
410+
411+
for _key, val in kwargs.items():
412+
if _key not in {
413+
self.prompt_field_name,
414+
self.vector_field_name,
415+
self.response_field_name,
416+
self.tag_field_name,
417+
self.metadata_field_name,
418+
}:
419+
raise ValueError(f" {key} is not a valid field within document")
420+
421+
# Check for metadata and deserialize
422+
if _key == self.metadata_field_name:
423+
if isinstance(val, dict):
424+
kwargs[_key] = serialize(val)
425+
else:
426+
raise TypeError(
427+
"If specified, cached metadata must be a dictionary."
428+
)
429+
kwargs.update({self.updated_at_field_name: current_timestamp()})
430+
self._index.client.hset(key, mapping=kwargs) # type: ignore
431+
self._refresh_ttl(key)

redisvl/utils/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,6 @@ def serialize(data: Dict[str, Any]) -> str:
5454
return json.dumps(data)
5555

5656

57-
def deserialize(self, data: str) -> Dict[str, Any]:
57+
def deserialize(data: str) -> Dict[str, Any]:
5858
"""Deserialize the input from a string."""
5959
return json.loads(data)

0 commit comments

Comments
 (0)