From 624175d8e32fc1eab03e085bcb04947759c14e63 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 20 Sep 2024 14:15:33 -0700 Subject: [PATCH 1/3] removes unused hash method --- redisvl/extensions/llmcache/base.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/redisvl/extensions/llmcache/base.py b/redisvl/extensions/llmcache/base.py index c3a1b269..8fce67ca 100644 --- a/redisvl/extensions/llmcache/base.py +++ b/redisvl/extensions/llmcache/base.py @@ -1,7 +1,5 @@ from typing import Any, Dict, List, Optional -from redisvl.redis.utils import hashify - class BaseLLMCache: def __init__(self, ttl: Optional[int] = None): @@ -79,14 +77,3 @@ async def astore( """Async store the specified key-value pair in the cache along with metadata.""" raise NotImplementedError - - def hash_input(self, prompt: str) -> str: - """Hashes the input prompt using SHA256. - - Args: - prompt (str): Input string to be hashed. - - Returns: - str: Hashed string. - """ - return hashify(prompt) From 40220c422c9d71d4e4d4e530fc8be61a6e34d2d2 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 20 Sep 2024 16:02:26 -0700 Subject: [PATCH 2/3] includes cache filter fields in key hash --- redisvl/extensions/llmcache/schema.py | 2 +- redisvl/redis/utils.py | 9 ++++-- tests/integration/test_llmcache.py | 43 +++++++++++++++++++++++++++ tests/unit/test_llmcache_schema.py | 2 +- 4 files changed, 51 insertions(+), 5 deletions(-) diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index 515b1421..95fe753a 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -32,7 +32,7 @@ class CacheEntry(BaseModel): def generate_id(cls, values): # Ensure entry_id is set if not values.get("entry_id"): - values["entry_id"] = hashify(values["prompt"]) + values["entry_id"] = hashify(values["prompt"], values.get("filters")) return values @validator("metadata") diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index a421022b..d6e4be06 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -1,5 +1,5 @@ import hashlib -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import numpy as np @@ -40,6 +40,9 @@ def buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]: return np.frombuffer(buffer, dtype=dtype).tolist() -def hashify(content: str) -> str: - """Create a secure hash of some arbitrary input text.""" +def hashify(content: str, extras: Optional[Dict[str, Any]] = None) -> str: + """Create a secure hash of some arbitrary input text and optional dictionary.""" + if extras: + extra_string = " ".join([str(k) + str(v) for k, v in extras.items()]) + content = content + extra_string return hashlib.sha256(content.encode("utf-8")).hexdigest() diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index e3aef3dc..6c106e87 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -800,3 +800,46 @@ def test_index_updating(redis_url): filter_expression=tag_filter, ) assert len(response) == 1 + + +def test_no_key_collision_on_identical_prompts(redis_url): + private_cache = SemanticCache( + name="private_cache", + redis_url=redis_url, + filterable_fields=[ + {"name": "user_id", "type": "tag"}, + {"name": "zip_code", "type": "numeric"}, + ], + ) + + private_cache.store( + prompt="What is the phone number linked to my account?", + response="The number on file is 123-555-0000", + filters={"user_id": "gabs"}, + ) + + private_cache.store( + prompt="What's the phone number linked in my account?", + response="The number on file is 123-555-9999", + ###filters={"user_id": "cerioni"}, + filters={"user_id": "cerioni", "zip_code": 90210}, + ) + + private_cache.store( + prompt="What's the phone number linked in my account?", + response="The number on file is 123-555-1111", + filters={"user_id": "bart"}, + ) + + results = private_cache.check( + "What's the phone number linked in my account?", num_results=5 + ) + assert len(results) == 3 + + zip_code_filter = Num("zip_code") != 90210 + filtered_results = private_cache.check( + "what's the phone number linked in my account?", + num_results=5, + filter_expression=zip_code_filter, + ) + assert len(filtered_results) == 2 diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py index e3961e6b..aa3a3add 100644 --- a/tests/unit/test_llmcache_schema.py +++ b/tests/unit/test_llmcache_schema.py @@ -48,7 +48,7 @@ def test_cache_entry_to_dict(): filters={"category": "technology"}, ) result = entry.to_dict() - assert result["entry_id"] == hashify("What is AI?") + assert result["entry_id"] == hashify("What is AI?", {"category": "technology"}) assert result["metadata"] == json.dumps({"author": "John"}) assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3]) assert result["category"] == "technology" From c3937eef155ab33984eeb369e728cf3313d80995 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 25 Sep 2024 13:56:27 -0700 Subject: [PATCH 3/3] sorts cache filter fields when hashing --- redisvl/redis/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index d6e4be06..28d15509 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -43,6 +43,6 @@ def buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]: def hashify(content: str, extras: Optional[Dict[str, Any]] = None) -> str: """Create a secure hash of some arbitrary input text and optional dictionary.""" if extras: - extra_string = " ".join([str(k) + str(v) for k, v in extras.items()]) + extra_string = " ".join([str(k) + str(v) for k, v in sorted(extras.items())]) content = content + extra_string return hashlib.sha256(content.encode("utf-8")).hexdigest()