Skip to content

Commit 51b6af3

Browse files
Use pydantic for cache entries and hits
1 parent d1bd692 commit 51b6af3

File tree

7 files changed

+453
-224
lines changed

7 files changed

+453
-224
lines changed

docs/user_guide/llmcache_03.ipynb

+18-17
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@
8383
"\n",
8484
"llmcache = SemanticCache(\n",
8585
" name=\"llmcache\", # underlying search index name\n",
86-
" prefix=\"llmcache\", # redis key prefix for hash entries\n",
8786
" redis_url=\"redis://localhost:6379\", # redis connection url string\n",
8887
" distance_threshold=0.1 # semantic cache distance threshold\n",
8988
")"
@@ -107,13 +106,15 @@
107106
"│ llmcache │ HASH │ ['llmcache'] │ [] │ 0 │\n",
108107
"╰──────────────┴────────────────┴──────────────┴─────────────────┴────────────╯\n",
109108
"Index Fields:\n",
110-
"╭───────────────┬───────────────┬────────┬────────────────┬────────────────╮\n",
111-
"│ Name │ Attribute │ Type │ Field Option │ Option Value │\n",
112-
"├───────────────┼───────────────┼────────┼────────────────┼────────────────┤\n",
113-
"│ prompt │ prompt │ TEXT │ WEIGHT │ 1 │\n",
114-
"│ response │ response │ TEXT │ WEIGHT │ 1 │\n",
115-
"│ prompt_vector │ prompt_vector │ VECTOR │ │ │\n",
116-
"╰───────────────┴───────────────┴────────┴────────────────┴────────────────╯\n"
109+
"╭───────────────┬───────────────┬─────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n",
110+
"│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n",
111+
"├───────────────┼───────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n",
112+
"│ prompt │ prompt │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n",
113+
"│ response │ response │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n",
114+
"│ inserted_at │ inserted_at │ NUMERIC │ │ │ │ │ │ │ │ │\n",
115+
"│ updated_at │ updated_at │ NUMERIC │ │ │ │ │ │ │ │ │\n",
116+
"│ prompt_vector │ prompt_vector │ VECTOR │ algorithm │ FLAT │ data_type │ FLOAT32 │ dim │ 768 │ distance_metric │ COSINE │\n",
117+
"╰───────────────┴───────────────┴─────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n"
117118
]
118119
}
119120
],
@@ -208,7 +209,7 @@
208209
"name": "stdout",
209210
"output_type": "stream",
210211
"text": [
211-
"[{'id': 'llmcache:115049a298532be2f181edb03f766770c0db84c22aff39003fec340deaec7545', 'vector_distance': '9.53674316406e-07', 'prompt': 'What is the capital of France?', 'response': 'Paris', 'metadata': {'city': 'Paris', 'country': 'france'}}]\n"
212+
"[{'prompt': 'What is the capital of France?', 'response': 'Paris', 'metadata': {'city': 'Paris', 'country': 'france'}, 'key': 'llmcache:115049a298532be2f181edb03f766770c0db84c22aff39003fec340deaec7545'}]\n"
212213
]
213214
}
214215
],
@@ -384,7 +385,7 @@
384385
},
385386
{
386387
"cell_type": "code",
387-
"execution_count": 17,
388+
"execution_count": 16,
388389
"metadata": {},
389390
"outputs": [],
390391
"source": [
@@ -408,14 +409,14 @@
408409
},
409410
{
410411
"cell_type": "code",
411-
"execution_count": 18,
412+
"execution_count": 17,
412413
"metadata": {},
413414
"outputs": [
414415
{
415416
"name": "stdout",
416417
"output_type": "stream",
417418
"text": [
418-
"Without caching, a call to openAI to answer this simple question took 1.460299015045166 seconds.\n"
419+
"Without caching, a call to openAI to answer this simple question took 0.9312698841094971 seconds.\n"
419420
]
420421
},
421422
{
@@ -424,7 +425,7 @@
424425
"'llmcache:67e0f6e28fe2a61c0022fd42bf734bb8ffe49d3e375fd69d692574295a20fc1a'"
425426
]
426427
},
427-
"execution_count": 18,
428+
"execution_count": 17,
428429
"metadata": {},
429430
"output_type": "execute_result"
430431
}
@@ -451,8 +452,8 @@
451452
"name": "stdout",
452453
"output_type": "stream",
453454
"text": [
454-
"Avg time taken with LLM cache enabled: 0.2560166358947754\n",
455-
"Percentage of time saved: 82.47%\n"
455+
"Avg time taken with LLM cache enabled: 0.4896167993545532\n",
456+
"Percentage of time saved: 47.42%\n"
456457
]
457458
}
458459
],
@@ -515,7 +516,7 @@
515516
},
516517
{
517518
"cell_type": "code",
518-
"execution_count": 21,
519+
"execution_count": 20,
519520
"metadata": {},
520521
"outputs": [],
521522
"source": [
@@ -540,7 +541,7 @@
540541
"name": "python",
541542
"nbconvert_exporter": "python",
542543
"pygments_lexer": "ipython3",
543-
"version": "3.9.12"
544+
"version": "3.10.14"
544545
},
545546
"orig_nbformat": 4
546547
},

redisvl/extensions/llmcache/base.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
from typing import Any, Dict, List, Optional
32

43
from redisvl.redis.utils import hashify

redisvl/extensions/llmcache/schema.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from typing import Any, Dict, List, Optional
2+
from pydantic.v1 import BaseModel, Field, root_validator, validator
3+
from redisvl.redis.utils import array_to_buffer, hashify
4+
from redisvl.utils.utils import current_timestamp, deserialize, serialize
5+
from redisvl.schema import IndexSchema
6+
7+
8+
class CacheEntry(BaseModel):
9+
entry_id: str
10+
prompt: str
11+
response: str
12+
prompt_vector: List[float]
13+
inserted_at: float = Field(default_factory=current_timestamp)
14+
updated_at: float = Field(default_factory=current_timestamp)
15+
metadata: Optional[Dict[str, Any]] = Field(default=None)
16+
filters: Optional[Dict[str, Any]] = Field(default=None)
17+
18+
@root_validator(pre=True)
19+
@classmethod
20+
def generate_id(cls, values):
21+
# Ensure entry_id is set
22+
if not values.get("entry_id"):
23+
values["entry_id"] = hashify(values["prompt"])
24+
return values
25+
26+
@validator("metadata")
27+
def non_empty_metadata(cls, v):
28+
if v is not None and not isinstance(v, dict):
29+
raise TypeError("Metadata must be a dictionary.")
30+
return v
31+
32+
def to_dict(self) -> Dict:
33+
data = self.dict(exclude_none=True)
34+
data["prompt_vector"] = array_to_buffer(self.prompt_vector)
35+
if self.metadata:
36+
data["metadata"] = serialize(self.metadata)
37+
if self.filters:
38+
data.update(self.filters)
39+
del data["filters"]
40+
return data
41+
42+
43+
class CacheHit(BaseModel):
44+
entry_id: str
45+
prompt: str
46+
response: str
47+
vector_distance: float
48+
inserted_at: float
49+
updated_at: float
50+
metadata: Optional[Dict[str, Any]] = Field(default=None)
51+
filters: Optional[Dict[str, Any]] = Field(default=None)
52+
53+
@root_validator(pre=True)
54+
@classmethod
55+
def validate_cache_hit(cls, values):
56+
# Deserialize metadata if necessary
57+
if "metadata" in values and isinstance(values["metadata"], str):
58+
values["metadata"] = deserialize(values["metadata"])
59+
60+
# Separate filters from other fields
61+
known_fields = set(cls.__fields__.keys())
62+
filters = {k: v for k, v in values.items() if k not in known_fields}
63+
64+
# Add filters to values
65+
if filters:
66+
values["filters"] = filters
67+
68+
# Remove filter fields from the main values
69+
for k in filters:
70+
values.pop(k)
71+
72+
return values
73+
74+
def to_dict(self) -> Dict:
75+
data = self.dict(exclude_none=True)
76+
if self.filters:
77+
data.update(self.filters)
78+
del data["filters"]
79+
80+
return data
81+
82+
83+
class SemanticCacheIndexSchema(IndexSchema):
84+
85+
@classmethod
86+
def from_params(cls, name: str, prefix: str, vector_dims: int):
87+
88+
return cls(
89+
index={"name": name, "prefix": prefix}, # type: ignore
90+
fields=[ # type: ignore
91+
{"name": "prompt", "type": "text"},
92+
{"name": "response", "type": "text"},
93+
{"name": "inserted_at", "type": "numeric"},
94+
{"name": "updated_at", "type": "numeric"},
95+
{
96+
"name": "prompt_vector",
97+
"type": "vector",
98+
"attrs": {
99+
"dims": vector_dims,
100+
"datatype": "float32",
101+
"distance_metric": "cosine",
102+
"algorithm": "flat",
103+
},
104+
},
105+
],
106+
)

0 commit comments

Comments
 (0)