Skip to content

Commit f808b80

Browse files
adds filter fields to cache key name hash (#224)
Prevent cache entries being overwritten when an identical prompt is used, but other cache entry data - like user id, or some other field - are different.
1 parent 51e58aa commit f808b80

File tree

5 files changed

+51
-18
lines changed

5 files changed

+51
-18
lines changed

redisvl/extensions/llmcache/base.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import Any, Dict, List, Optional
22

3-
from redisvl.redis.utils import hashify
4-
53

64
class BaseLLMCache:
75
def __init__(self, ttl: Optional[int] = None):
@@ -79,14 +77,3 @@ async def astore(
7977
"""Async store the specified key-value pair in the cache along with
8078
metadata."""
8179
raise NotImplementedError
82-
83-
def hash_input(self, prompt: str) -> str:
84-
"""Hashes the input prompt using SHA256.
85-
86-
Args:
87-
prompt (str): Input string to be hashed.
88-
89-
Returns:
90-
str: Hashed string.
91-
"""
92-
return hashify(prompt)

redisvl/extensions/llmcache/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class CacheEntry(BaseModel):
3939
def generate_id(cls, values):
4040
# Ensure entry_id is set
4141
if not values.get("entry_id"):
42-
values["entry_id"] = hashify(values["prompt"])
42+
values["entry_id"] = hashify(values["prompt"], values.get("filters"))
4343
return values
4444

4545
@validator("metadata")

redisvl/redis/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import hashlib
2-
from typing import Any, Dict, List
2+
from typing import Any, Dict, List, Optional
33

44
import numpy as np
55

@@ -40,6 +40,9 @@ def buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]:
4040
return np.frombuffer(buffer, dtype=dtype).tolist()
4141

4242

43-
def hashify(content: str) -> str:
44-
"""Create a secure hash of some arbitrary input text."""
43+
def hashify(content: str, extras: Optional[Dict[str, Any]] = None) -> str:
44+
"""Create a secure hash of some arbitrary input text and optional dictionary."""
45+
if extras:
46+
extra_string = " ".join([str(k) + str(v) for k, v in sorted(extras.items())])
47+
content = content + extra_string
4548
return hashlib.sha256(content.encode("utf-8")).hexdigest()

tests/integration/test_llmcache.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,3 +800,46 @@ def test_index_updating(redis_url):
800800
filter_expression=tag_filter,
801801
)
802802
assert len(response) == 1
803+
804+
805+
def test_no_key_collision_on_identical_prompts(redis_url):
806+
private_cache = SemanticCache(
807+
name="private_cache",
808+
redis_url=redis_url,
809+
filterable_fields=[
810+
{"name": "user_id", "type": "tag"},
811+
{"name": "zip_code", "type": "numeric"},
812+
],
813+
)
814+
815+
private_cache.store(
816+
prompt="What is the phone number linked to my account?",
817+
response="The number on file is 123-555-0000",
818+
filters={"user_id": "gabs"},
819+
)
820+
821+
private_cache.store(
822+
prompt="What's the phone number linked in my account?",
823+
response="The number on file is 123-555-9999",
824+
###filters={"user_id": "cerioni"},
825+
filters={"user_id": "cerioni", "zip_code": 90210},
826+
)
827+
828+
private_cache.store(
829+
prompt="What's the phone number linked in my account?",
830+
response="The number on file is 123-555-1111",
831+
filters={"user_id": "bart"},
832+
)
833+
834+
results = private_cache.check(
835+
"What's the phone number linked in my account?", num_results=5
836+
)
837+
assert len(results) == 3
838+
839+
zip_code_filter = Num("zip_code") != 90210
840+
filtered_results = private_cache.check(
841+
"what's the phone number linked in my account?",
842+
num_results=5,
843+
filter_expression=zip_code_filter,
844+
)
845+
assert len(filtered_results) == 2

tests/unit/test_llmcache_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_cache_entry_to_dict():
4848
filters={"category": "technology"},
4949
)
5050
result = entry.to_dict()
51-
assert result["entry_id"] == hashify("What is AI?")
51+
assert result["entry_id"] == hashify("What is AI?", {"category": "technology"})
5252
assert result["metadata"] == json.dumps({"author": "John"})
5353
assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3])
5454
assert result["category"] == "technology"

0 commit comments

Comments
 (0)