Skip to content

Commit 0a2e432

Browse files
Use pydantic for cache entries and hits (#195)
This PR handles 2 things: 1) Uses pydantic for both `CacheEntry` and `CacheHit` models. It ended up being better and cleaner to have two separate models.... one for what we write, and the other for what we load. 2) Filtering. @justin-cechmanek I want your feedback on this. I landed on a technique to allow for the user to define a list of `filterable_fields` as part of the semantic cache class init. What this does is give us a way to support arbitrary filters (scope, permissions, tags, numerics.... anything) within reason. Then you can create any `FilterExpression` and pass through at query time. This extends your initial implementation.
1 parent d1bd692 commit 0a2e432

File tree

7 files changed

+528
-227
lines changed

7 files changed

+528
-227
lines changed

docs/user_guide/llmcache_03.ipynb

+18-17
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@
8383
"\n",
8484
"llmcache = SemanticCache(\n",
8585
" name=\"llmcache\", # underlying search index name\n",
86-
" prefix=\"llmcache\", # redis key prefix for hash entries\n",
8786
" redis_url=\"redis://localhost:6379\", # redis connection url string\n",
8887
" distance_threshold=0.1 # semantic cache distance threshold\n",
8988
")"
@@ -107,13 +106,15 @@
107106
"│ llmcache │ HASH │ ['llmcache'] │ [] │ 0 │\n",
108107
"╰──────────────┴────────────────┴──────────────┴─────────────────┴────────────╯\n",
109108
"Index Fields:\n",
110-
"╭───────────────┬───────────────┬────────┬────────────────┬────────────────╮\n",
111-
"│ Name │ Attribute │ Type │ Field Option │ Option Value │\n",
112-
"├───────────────┼───────────────┼────────┼────────────────┼────────────────┤\n",
113-
"│ prompt │ prompt │ TEXT │ WEIGHT │ 1 │\n",
114-
"│ response │ response │ TEXT │ WEIGHT │ 1 │\n",
115-
"│ prompt_vector │ prompt_vector │ VECTOR │ │ │\n",
116-
"╰───────────────┴───────────────┴────────┴────────────────┴────────────────╯\n"
109+
"╭───────────────┬───────────────┬─────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n",
110+
"│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n",
111+
"├───────────────┼───────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n",
112+
"│ prompt │ prompt │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n",
113+
"│ response │ response │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n",
114+
"│ inserted_at │ inserted_at │ NUMERIC │ │ │ │ │ │ │ │ │\n",
115+
"│ updated_at │ updated_at │ NUMERIC │ │ │ │ │ │ │ │ │\n",
116+
"│ prompt_vector │ prompt_vector │ VECTOR │ algorithm │ FLAT │ data_type │ FLOAT32 │ dim │ 768 │ distance_metric │ COSINE │\n",
117+
"╰───────────────┴───────────────┴─────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n"
117118
]
118119
}
119120
],
@@ -208,7 +209,7 @@
208209
"name": "stdout",
209210
"output_type": "stream",
210211
"text": [
211-
"[{'id': 'llmcache:115049a298532be2f181edb03f766770c0db84c22aff39003fec340deaec7545', 'vector_distance': '9.53674316406e-07', 'prompt': 'What is the capital of France?', 'response': 'Paris', 'metadata': {'city': 'Paris', 'country': 'france'}}]\n"
212+
"[{'prompt': 'What is the capital of France?', 'response': 'Paris', 'metadata': {'city': 'Paris', 'country': 'france'}, 'key': 'llmcache:115049a298532be2f181edb03f766770c0db84c22aff39003fec340deaec7545'}]\n"
212213
]
213214
}
214215
],
@@ -384,7 +385,7 @@
384385
},
385386
{
386387
"cell_type": "code",
387-
"execution_count": 17,
388+
"execution_count": 16,
388389
"metadata": {},
389390
"outputs": [],
390391
"source": [
@@ -408,14 +409,14 @@
408409
},
409410
{
410411
"cell_type": "code",
411-
"execution_count": 18,
412+
"execution_count": 17,
412413
"metadata": {},
413414
"outputs": [
414415
{
415416
"name": "stdout",
416417
"output_type": "stream",
417418
"text": [
418-
"Without caching, a call to openAI to answer this simple question took 1.460299015045166 seconds.\n"
419+
"Without caching, a call to openAI to answer this simple question took 0.9312698841094971 seconds.\n"
419420
]
420421
},
421422
{
@@ -424,7 +425,7 @@
424425
"'llmcache:67e0f6e28fe2a61c0022fd42bf734bb8ffe49d3e375fd69d692574295a20fc1a'"
425426
]
426427
},
427-
"execution_count": 18,
428+
"execution_count": 17,
428429
"metadata": {},
429430
"output_type": "execute_result"
430431
}
@@ -451,8 +452,8 @@
451452
"name": "stdout",
452453
"output_type": "stream",
453454
"text": [
454-
"Avg time taken with LLM cache enabled: 0.2560166358947754\n",
455-
"Percentage of time saved: 82.47%\n"
455+
"Avg time taken with LLM cache enabled: 0.4896167993545532\n",
456+
"Percentage of time saved: 47.42%\n"
456457
]
457458
}
458459
],
@@ -515,7 +516,7 @@
515516
},
516517
{
517518
"cell_type": "code",
518-
"execution_count": 21,
519+
"execution_count": 20,
519520
"metadata": {},
520521
"outputs": [],
521522
"source": [
@@ -540,7 +541,7 @@
540541
"name": "python",
541542
"nbconvert_exporter": "python",
542543
"pygments_lexer": "ipython3",
543-
"version": "3.9.12"
544+
"version": "3.10.14"
544545
},
545546
"orig_nbformat": 4
546547
},

redisvl/extensions/llmcache/base.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
from typing import Any, Dict, List, Optional
32

43
from redisvl.redis.utils import hashify

redisvl/extensions/llmcache/schema.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from typing import Any, Dict, List, Optional
2+
3+
from pydantic.v1 import BaseModel, Field, root_validator, validator
4+
5+
from redisvl.redis.utils import array_to_buffer, hashify
6+
from redisvl.schema import IndexSchema
7+
from redisvl.utils.utils import current_timestamp, deserialize, serialize
8+
9+
10+
class CacheEntry(BaseModel):
11+
"""A single cache entry in Redis"""
12+
13+
entry_id: Optional[str] = Field(default=None)
14+
"""Cache entry identifier"""
15+
prompt: str
16+
"""Input prompt or question cached in Redis"""
17+
response: str
18+
"""Response or answer to the question, cached in Redis"""
19+
prompt_vector: List[float]
20+
"""Text embedding representation of the prompt"""
21+
inserted_at: float = Field(default_factory=current_timestamp)
22+
"""Timestamp of when the entry was added to the cache"""
23+
updated_at: float = Field(default_factory=current_timestamp)
24+
"""Timestamp of when the entry was updated in the cache"""
25+
metadata: Optional[Dict[str, Any]] = Field(default=None)
26+
"""Optional metadata stored on the cache entry"""
27+
filters: Optional[Dict[str, Any]] = Field(default=None)
28+
"""Optional filter data stored on the cache entry for customizing retrieval"""
29+
30+
@root_validator(pre=True)
31+
@classmethod
32+
def generate_id(cls, values):
33+
# Ensure entry_id is set
34+
if not values.get("entry_id"):
35+
values["entry_id"] = hashify(values["prompt"])
36+
return values
37+
38+
@validator("metadata")
39+
def non_empty_metadata(cls, v):
40+
if v is not None and not isinstance(v, dict):
41+
raise TypeError("Metadata must be a dictionary.")
42+
return v
43+
44+
def to_dict(self) -> Dict:
45+
data = self.dict(exclude_none=True)
46+
data["prompt_vector"] = array_to_buffer(self.prompt_vector)
47+
if self.metadata:
48+
data["metadata"] = serialize(self.metadata)
49+
if self.filters:
50+
data.update(self.filters)
51+
del data["filters"]
52+
return data
53+
54+
55+
class CacheHit(BaseModel):
56+
"""A cache hit based on some input query"""
57+
58+
entry_id: str
59+
"""Cache entry identifier"""
60+
prompt: str
61+
"""Input prompt or question cached in Redis"""
62+
response: str
63+
"""Response or answer to the question, cached in Redis"""
64+
vector_distance: float
65+
"""The semantic distance between the query vector and the stored prompt vector"""
66+
inserted_at: float
67+
"""Timestamp of when the entry was added to the cache"""
68+
updated_at: float
69+
"""Timestamp of when the entry was updated in the cache"""
70+
metadata: Optional[Dict[str, Any]] = Field(default=None)
71+
"""Optional metadata stored on the cache entry"""
72+
filters: Optional[Dict[str, Any]] = Field(default=None)
73+
"""Optional filter data stored on the cache entry for customizing retrieval"""
74+
75+
@root_validator(pre=True)
76+
@classmethod
77+
def validate_cache_hit(cls, values):
78+
# Deserialize metadata if necessary
79+
if "metadata" in values and isinstance(values["metadata"], str):
80+
values["metadata"] = deserialize(values["metadata"])
81+
82+
# Separate filters from other fields
83+
known_fields = set(cls.__fields__.keys())
84+
filters = {k: v for k, v in values.items() if k not in known_fields}
85+
86+
# Add filters to values
87+
if filters:
88+
values["filters"] = filters
89+
90+
# Remove filter fields from the main values
91+
for k in filters:
92+
values.pop(k)
93+
94+
return values
95+
96+
def to_dict(self) -> Dict:
97+
data = self.dict(exclude_none=True)
98+
if self.filters:
99+
data.update(self.filters)
100+
del data["filters"]
101+
102+
return data
103+
104+
105+
class SemanticCacheIndexSchema(IndexSchema):
106+
107+
@classmethod
108+
def from_params(cls, name: str, prefix: str, vector_dims: int):
109+
110+
return cls(
111+
index={"name": name, "prefix": prefix}, # type: ignore
112+
fields=[ # type: ignore
113+
{"name": "prompt", "type": "text"},
114+
{"name": "response", "type": "text"},
115+
{"name": "inserted_at", "type": "numeric"},
116+
{"name": "updated_at", "type": "numeric"},
117+
{
118+
"name": "prompt_vector",
119+
"type": "vector",
120+
"attrs": {
121+
"dims": vector_dims,
122+
"datatype": "float32",
123+
"distance_metric": "cosine",
124+
"algorithm": "flat",
125+
},
126+
},
127+
],
128+
)

0 commit comments

Comments
 (0)