From 06d0d73fffbf9616adcf492039e796f58b624d4f Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 26 Jun 2024 14:32:26 -0700 Subject: [PATCH 01/10] adds inserted & updated_at fields --- redisvl/extensions/llmcache/semantic.py | 24 +++++++- tests/integration/test_llmcache.py | 75 +++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 07602ac7..99b40073 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Optional +from time import time +from typing import Any, Dict, List, Optional, Union from redis import Redis @@ -16,6 +17,8 @@ class SemanticCache(BaseLLMCache): entry_id_field_name: str = "id" prompt_field_name: str = "prompt" vector_field_name: str = "prompt_vector" + inserted_at_field_name: str = "inserted_at" + updated_at_field_name: str = "updated_at" response_field_name: str = "response" metadata_field_name: str = "metadata" @@ -77,6 +80,8 @@ def __init__( [ {"name": self.prompt_field_name, "type": "text"}, {"name": self.response_field_name, "type": "text"}, + {"name": self.inserted_at_field_name, "type": "numeric"}, + {"name": self.updated_at_field_name, "type": "numeric"}, { "name": self.vector_field_name, "type": "vector", @@ -186,6 +191,20 @@ def delete(self) -> None: index.""" self._index.delete(drop=True) + def drop(self, document_ids: Union[str, List[str]]) -> None: + """Remove a specific entry or entries from the cache by it's ID. + + Args: + document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache. + """ + if isinstance(document_ids, List): + with self._index.client.pipeline(transaction=False) as pipe: # type: ignore + for key in document_ids: # type: ignore + pipe.delete(key) + pipe.execute() + else: + self._index.client.delete(document_ids) # type: ignore + def _refresh_ttl(self, key: str) -> None: """Refresh the time-to-live for the specified key.""" if self.ttl: @@ -320,12 +339,15 @@ def store( # Vectorize prompt if necessary and create cache payload vector = vector or self._vectorize_prompt(prompt) # Construct semantic cache payload + now = time() id_field = self.entry_id_field_name payload = { id_field: self.hash_input(prompt), self.prompt_field_name: prompt, self.response_field_name: response, self.vector_field_name: array_to_buffer(vector), + self.inserted_at_field_name: now, + self.updated_at_field_name: now, } if metadata is not None: if not isinstance(metadata, dict): diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 373ca8da..af8e9bd4 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -69,6 +69,43 @@ def test_store_and_check(cache, vectorizer): assert "metadata" not in check_result[0] +def test_return_fields(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + cache.store(prompt, response, vector=vector) + + # check default return fields + check_result = cache.check(vector=vector) + assert set(check_result[0].keys()) == { + "id", + "prompt", + "response", + "prompt_vector", + "vector_distance", + } + + # check all return fields + fields = [ + "id", + "prompt", + "response", + "inserted_at", + "updated_at", + "prompt_vector", + "vector_distance", + ] + check_result = cache.check(vector=vector, return_fields=fields[:]) + assert set(check_result[0].keys()) == set(fields) + + # check only some return fields + fields = ["inserted_at", "updated_at"] + check_result = cache.check(vector=vector, return_fields=fields[:]) + fields.extend(["id", "vector_distance"]) # id and vector_distance always returned + assert set(check_result[0].keys()) == set(fields) + + # Test clearing the cache def test_clear(cache, vectorizer): prompt = "This is a test prompt." @@ -95,6 +132,44 @@ def test_ttl_expiration(cache_with_ttl, vectorizer): assert len(check_result) == 0 +# Test manual expiration of single document +def test_drop_document(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + cache.store(prompt, response, vector=vector) + check_result = cache.check(vector=vector) + + cache.drop(check_result[0]["id"]) + recheck_result = cache.check(vector=vector) + assert len(recheck_result) == 0 + + +# Test manual expiration of multiple documents +def test_drop_documents(cache, vectorizer): + prompts = [ + "This is a test prompt.", + "This is also test prompt.", + "This is another test prompt.", + ] + responses = [ + "This is a test response.", + "This is also test response.", + "This is a another test response.", + ] + for prompt, response in zip(prompts, responses): + vector = vectorizer.embed(prompt) + cache.store(prompt, response, vector=vector) + + check_result = cache.check(vector=vector, num_results=3) + keys = [r["id"] for r in check_result[0:2]] # drop first 2 entries + cache.drop(keys) + + recheck_result = cache.check(vector=vector, num_results=3) + assert len(recheck_result) == 1 + + # Test check behavior with no match def test_check_no_match(cache, vectorizer): vector = vectorizer.embed("Some random sentence.") From 3578bd4a3348d17d53d408436d4e2902bde07c36 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Thu, 27 Jun 2024 17:24:22 -0700 Subject: [PATCH 02/10] adds method to update cache entries --- redisvl/extensions/llmcache/semantic.py | 41 +++++++++++++++++++++++++ tests/integration/test_llmcache.py | 29 +++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 99b40073..a41ec5dc 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -358,3 +358,44 @@ def store( # Load LLMCache entry with TTL keys = self._index.load(data=[payload], ttl=self._ttl, id_field=id_field) return keys[0] + + def update(self, key: str, **kwargs) -> None: + """Update specific fields within an existing cache entry. If no fields + are passed, then only the document TTL is refreshed. + + Args: + key (str): the key of the document to update. + kwargs: + + Raises: + ValueError if an incorrect mapping is provided as a kwarg. + TypeError if metadata is provided and not of type dict. + + .. code-block:: python + key = cache.store('this is a prompt', 'this is a response') + cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"}) + ) + """ + if not kwargs: + self._refresh_ttl(key) + return + + for _key, val in kwargs.items(): + if _key not in { + self.prompt_field_name, + self.vector_field_name, + self.response_field_name, + self.metadata_field_name, + }: + raise ValueError(f" {key} is not a valid field within document") + + # Check for metadata and deserialize + if _key == self.metadata_field_name: + if isinstance(val, dict): + kwargs[_key] = self.serialize(val) + else: + raise TypeError( + "If specified, cached metadata must be a dictionary." + ) + kwargs.update({self.updated_at_field_name: time()}) + self._index.client.hset(key, mapping=kwargs) # type: ignore diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index af8e9bd4..5622f911 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -170,6 +170,35 @@ def test_drop_documents(cache, vectorizer): assert len(recheck_result) == 1 +# Test updating document fields +def test_updating_document(cache): + prompts = [ + "This is a test prompt.", + "This is also test prompt.", + ] + responses = [ + "This is a test response.", + "This is also test response.", + ] + for prompt, response in zip(prompts, responses): + cache.store(prompt, response) + + check_result = cache.check(prompt=prompt, return_fields=["updated_at"]) + key = check_result[0]["id"] + + sleep(1) + + metadata = {"foo": "bar"} + cache.update(key=key, metadata=metadata) + + updated_result = cache.check( + prompt=prompt, return_fields=["updated_at", "metadata"] + ) + assert updated_result[0]["id"] == check_result[0]["id"] + assert updated_result[0]["metadata"] == metadata + assert updated_result[0]["updated_at"] > check_result[0]["updated_at"] + + # Test check behavior with no match def test_check_no_match(cache, vectorizer): vector = vectorizer.embed("Some random sentence.") From 4e4212af0c942e65c5de1c7d9f81a0652c8f085a Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 28 Jun 2024 08:58:32 -0700 Subject: [PATCH 03/10] cleans up test --- tests/integration/test_llmcache.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 5622f911..1fc25d46 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -172,16 +172,9 @@ def test_drop_documents(cache, vectorizer): # Test updating document fields def test_updating_document(cache): - prompts = [ - "This is a test prompt.", - "This is also test prompt.", - ] - responses = [ - "This is a test response.", - "This is also test response.", - ] - for prompt, response in zip(prompts, responses): - cache.store(prompt, response) + prompt = "This is a test prompt." + response = "This is a test response." + cache.store(prompt=prompt, response=response) check_result = cache.check(prompt=prompt, return_fields=["updated_at"]) key = check_result[0]["id"] @@ -194,7 +187,6 @@ def test_updating_document(cache): updated_result = cache.check( prompt=prompt, return_fields=["updated_at", "metadata"] ) - assert updated_result[0]["id"] == check_result[0]["id"] assert updated_result[0]["metadata"] == metadata assert updated_result[0]["updated_at"] > check_result[0]["updated_at"] From b292d3bc559d30d40cdf9de5e927771ac1616e1f Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Tue, 9 Jul 2024 16:55:41 -0700 Subject: [PATCH 04/10] adds manual document removal, updating, and scoped access control --- redisvl/extensions/llmcache/semantic.py | 42 +++++++++++++++++++++++-- tests/integration/test_llmcache.py | 35 +++++++++++++++++++++ 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index b72493e8..5df8bf0d 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -6,6 +6,7 @@ from redisvl.extensions.llmcache.base import BaseLLMCache from redisvl.index import SearchIndex from redisvl.query import RangeQuery +from redisvl.query.filter import Tag from redisvl.redis.utils import array_to_buffer from redisvl.schema.schema import IndexSchema from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer @@ -19,6 +20,7 @@ class SemanticCache(BaseLLMCache): vector_field_name: str = "prompt_vector" inserted_at_field_name: str = "inserted_at" updated_at_field_name: str = "updated_at" + tag_field_name: str = "scope_tag" response_field_name: str = "response" metadata_field_name: str = "metadata" @@ -82,6 +84,7 @@ def __init__( {"name": self.response_field_name, "type": "text"}, {"name": self.inserted_at_field_name, "type": "numeric"}, {"name": self.updated_at_field_name, "type": "numeric"}, + {"name": self.tag_field_name, "type": "tag"}, { "name": self.vector_field_name, "type": "vector", @@ -109,12 +112,12 @@ def __init__( self.entry_id_field_name, self.prompt_field_name, self.response_field_name, + self.tag_field_name, self.vector_field_name, self.metadata_field_name, ] self.set_vectorizer(vectorizer) self.set_threshold(distance_threshold) - self._index.create(overwrite=False) @property @@ -218,7 +221,11 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: return self._vectorizer.embed(prompt) def _search_cache( - self, vector: List[float], num_results: int, return_fields: Optional[List[str]] + self, + vector: List[float], + num_results: int, + return_fields: Optional[List[str]], + tags: Optional[Union[List[str], str]], ) -> List[Dict[str, Any]]: """Searches the semantic cache for similar prompt vectors and returns the specified return fields for each cache hit.""" @@ -240,6 +247,8 @@ def _search_cache( num_results=num_results, return_score=True, ) + if tags: + query.set_filter(self.get_filter(tags)) # type: ignore # Gather and return the cache hits cache_hits: List[Dict[str, Any]] = self._index.query(query) @@ -270,6 +279,7 @@ def check( vector: Optional[List[float]] = None, num_results: int = 1, return_fields: Optional[List[str]] = None, + tags: Optional[Union[List[str], str]] = None, ) -> List[Dict[str, Any]]: """Checks the semantic cache for results similar to the specified prompt or vector. @@ -289,6 +299,8 @@ def check( return_fields (Optional[List[str]], optional): The fields to include in each returned result. If None, defaults to all available fields in the cached entry. + tags (Optional[Union[List[str], str]) : the tag or tags to filter + results by. Default is None and full cache is searched. Returns: List[Dict[str, Any]]: A list of dicts containing the requested @@ -313,7 +325,7 @@ def check( self._check_vector_dims(vector) # Check for cache hits by searching the cache - cache_hits = self._search_cache(vector, num_results, return_fields) + cache_hits = self._search_cache(vector, num_results, return_fields, tags) return cache_hits def store( @@ -322,6 +334,7 @@ def store( response: str, vector: Optional[List[float]] = None, metadata: Optional[dict] = None, + tag: Optional[str] = None, ) -> str: """Stores the specified key-value pair in the cache along with metadata. @@ -333,6 +346,8 @@ def store( demand. metadata (Optional[dict], optional): The optional metadata to cache alongside the prompt and response. Defaults to None. + tag (Optional[str]): The optional tag to assign to the cache entry. + Defaults to None. Returns: str: The Redis key for the entries added to the semantic cache. @@ -370,6 +385,8 @@ def store( raise TypeError("If specified, cached metadata must be a dictionary.") # Serialize the metadata dict and add to cache payload payload[self.metadata_field_name] = self.serialize(metadata) + if tag is not None: + payload[self.tag_field_name] = tag # Load LLMCache entry with TTL keys = self._index.load(data=[payload], ttl=self._ttl, id_field=id_field) @@ -401,6 +418,7 @@ def update(self, key: str, **kwargs) -> None: self.prompt_field_name, self.vector_field_name, self.response_field_name, + self.tag_field_name, self.metadata_field_name, }: raise ValueError(f" {key} is not a valid field within document") @@ -415,3 +433,21 @@ def update(self, key: str, **kwargs) -> None: ) kwargs.update({self.updated_at_field_name: time()}) self._index.client.hset(key, mapping=kwargs) # type: ignore + + def get_filter( + self, + tags: Optional[Union[List[str], str]] = None, + ) -> Tag: + """Set the tags filter to apply to querries based on the desired scope. + + Args: + tags (Optional[Union[List[str], str]]): name of the specific tag or + tags to filter to. Default is None, which means all cache data + will be in scope. + """ + default_filter = Tag(self.tag_field_name) == [] + if not (tags): + return default_filter + + tag_filter = Tag(self.tag_field_name) == tags + return tag_filter diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 663bd1e4..66db61d2 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -334,3 +334,38 @@ def test_vector_size(cache, vectorizer): with pytest.raises(ValueError): cache.check(vector=[1, 2, 3]) + + +# test we can pass a list of tags and we'll include all results that match +def test_multiple_tags(cache): + tag_1 = "group 0" + tag_2 = "group 1" + tag_3 = "group 2" + tag_4 = "group 3" + tags = [tag_1, tag_2, tag_3, tag_4] + + for i in range(4): + prompt = f"test prompt {i}" + response = f"test response {i}" + cache.store(prompt, response, tag=tags[i]) + + # test we can specify one specific tag + results = cache.check("test prompt 1", tags=tag_1, num_results=5) + assert len(results) == 1 + assert results[0]["prompt"] == "test prompt 0" + + # test we can pass a list of tags + results = cache.check("test prompt 1", tags=[tag_1, tag_2, tag_3], num_results=5) + assert len(results) == 3 + + # test that default tag param searches full cache + results = cache.check("test prompt 1", num_results=5) + assert len(results) == 4 + + # test we can get all matches with empty tag list + results = cache.check("test prompt 1", tags=[], num_results=5) + assert len(results) == 4 + + # test no results are returned if we pass a nonexistant tag + results = cache.check("test prompt 1", tags=["bad tag"], num_results=5) + assert len(results) == 0 From b9022d132502b38d8ed28045dec59a73849b3c7e Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Thu, 11 Jul 2024 17:31:33 -0700 Subject: [PATCH 05/10] moves schema creation to separate class --- redisvl/extensions/llmcache/semantic.py | 55 ++++++++++++++----------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index bbaaf777..06a4d7c3 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -8,10 +8,38 @@ from redisvl.query import RangeQuery from redisvl.query.filter import Tag from redisvl.redis.utils import array_to_buffer -from redisvl.schema.schema import IndexSchema +from redisvl.schema import IndexSchema from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +class SemanticCacheIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, vector_dims: int): + + return cls( + index={"name": name, "prefix": name}, + fields=[ + {"name": "cache_name", "type": "tag"}, + {"name": "prompt", "type": "text"}, + {"name": "response", "type": "text"}, + {"name": "inserted_at", "type": "numeric"}, + {"name": "updated_at", "type": "numeric"}, + {"name": "scope_tag", "type": "tag"}, + { + "name": "prompt_vector", + "type": "vector", + "attrs": { + "dims": vector_dims, + "datatype": "float32", + "distance_metric": "cosine", + "algorithm": "flat", + }, + }, + ], + ) + + class SemanticCache(BaseLLMCache): """Semantic Cache for Large Language Models.""" @@ -75,30 +103,7 @@ def __init__( model="sentence-transformers/all-mpnet-base-v2" ) - # build cache index schema - schema = IndexSchema.from_dict({"index": {"name": name, "prefix": prefix}}) - # add fields - schema.add_fields( - [ - {"name": self.prompt_field_name, "type": "text"}, - {"name": self.response_field_name, "type": "text"}, - {"name": self.inserted_at_field_name, "type": "numeric"}, - {"name": self.updated_at_field_name, "type": "numeric"}, - {"name": self.tag_field_name, "type": "tag"}, - { - "name": self.vector_field_name, - "type": "vector", - "attrs": { - "dims": vectorizer.dims, - "datatype": "float32", - "distance_metric": "cosine", - "algorithm": "flat", - }, - }, - ] - ) - - # build search index + schema = SemanticCacheIndexSchema.from_params(name, vectorizer.dims) self._index = SearchIndex(schema=schema) # handle redis connection From 4b91de49c3d90df4837c0be37a7e48cf9c2cb21a Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 12 Jul 2024 15:32:32 -0700 Subject: [PATCH 06/10] wip: adding support for filter expressions --- redisvl/extensions/llmcache/semantic.py | 21 +++++++++++++-------- tests/integration/test_llmcache.py | 15 +++++++++++---- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 06a4d7c3..6bc03f8c 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -6,7 +6,7 @@ from redisvl.extensions.llmcache.base import BaseLLMCache from redisvl.index import SearchIndex from redisvl.query import RangeQuery -from redisvl.query.filter import Tag +from redisvl.query.filter import Tag, FilterExpression from redisvl.redis.utils import array_to_buffer from redisvl.schema import IndexSchema from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer @@ -18,8 +18,8 @@ class SemanticCacheIndexSchema(IndexSchema): def from_params(cls, name: str, vector_dims: int): return cls( - index={"name": name, "prefix": name}, - fields=[ + index={"name": name, "prefix": name}, # type: ignore + fields=[ # type: ignore {"name": "cache_name", "type": "tag"}, {"name": "prompt", "type": "text"}, {"name": "response", "type": "text"}, @@ -227,7 +227,8 @@ def _search_cache( vector: List[float], num_results: int, return_fields: Optional[List[str]], - tags: Optional[Union[List[str], str]], + ##tags: Optional[Union[List[str], str]], + filters: Optional[FilterExpression], ) -> List[Dict[str, Any]]: """Searches the semantic cache for similar prompt vectors and returns the specified return fields for each cache hit.""" @@ -249,8 +250,10 @@ def _search_cache( num_results=num_results, return_score=True, ) - if tags: - query.set_filter(self.get_filter(tags)) # type: ignore + ##if tags: + ## query.set_filter(self.get_filter(tags)) # type: ignore + if filters: + query.set_filter(filters) # type: ignore # Gather and return the cache hits cache_hits: List[Dict[str, Any]] = self._index.query(query) @@ -281,7 +284,8 @@ def check( vector: Optional[List[float]] = None, num_results: int = 1, return_fields: Optional[List[str]] = None, - tags: Optional[Union[List[str], str]] = None, + ##tags: Optional[Union[List[str], str]] = None, + filters: Optional[FilterExpression] = None, ) -> List[Dict[str, Any]]: """Checks the semantic cache for results similar to the specified prompt or vector. @@ -327,7 +331,8 @@ def check( self._check_vector_dims(vector) # Check for cache hits by searching the cache - cache_hits = self._search_cache(vector, num_results, return_fields, tags) + ##cache_hits = self._search_cache(vector, num_results, return_fields, tags) + cache_hits = self._search_cache(vector, num_results, return_fields, filters) return cache_hits def store( diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 6b7adda6..d7a65561 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -6,6 +6,7 @@ from redisvl.extensions.llmcache import SemanticCache from redisvl.index.index import SearchIndex from redisvl.utils.vectorize import HFTextVectorizer +from redisvl.query.filter import Tag, FilterExpression @pytest.fixture @@ -369,18 +370,22 @@ def test_multiple_tags(cache): tag_4 = "group 3" tags = [tag_1, tag_2, tag_3, tag_4] + filter_1 = Tag('scope_tag') == tag_1 + for i in range(4): prompt = f"test prompt {i}" response = f"test response {i}" cache.store(prompt, response, tag=tags[i]) # test we can specify one specific tag - results = cache.check("test prompt 1", tags=tag_1, num_results=5) + ##results = cache.check("test prompt 1", tags=tag_1, num_results=5) + results = cache.check("test prompt 1", filters=filter_1, num_results=5) assert len(results) == 1 assert results[0]["prompt"] == "test prompt 0" # test we can pass a list of tags - results = cache.check("test prompt 1", tags=[tag_1, tag_2, tag_3], num_results=5) + ##results = cache.check("test prompt 1", tags=[tag_1, tag_2, tag_3], num_results=5) + results = cache.check("test prompt 1", filters=[filter_1, filter_2, filter_3], num_results=5) assert len(results) == 3 # test that default tag param searches full cache @@ -388,9 +393,11 @@ def test_multiple_tags(cache): assert len(results) == 4 # test we can get all matches with empty tag list - results = cache.check("test prompt 1", tags=[], num_results=5) + ##results = cache.check("test prompt 1", tags=[], num_results=5) + results = cache.check("test prompt 1", filters=[], num_results=5) assert len(results) == 4 # test no results are returned if we pass a nonexistant tag - results = cache.check("test prompt 1", tags=["bad tag"], num_results=5) + ##results = cache.check("test prompt 1", tags=["bad tag"], num_results=5) + results = cache.check("test prompt 1", filters=["bad tag"], num_results=5) assert len(results) == 0 From 5c7893fa3ecec7c9470230f00dcbdab7443c202a Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 26 Jul 2024 11:18:39 -0700 Subject: [PATCH 07/10] adds support for filtering on cache check calls --- redisvl/extensions/llmcache/semantic.py | 45 ++++++------------------- tests/integration/test_llmcache.py | 44 ++++++++++++++++-------- 2 files changed, 41 insertions(+), 48 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 6bc03f8c..1ac181ff 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -6,7 +6,7 @@ from redisvl.extensions.llmcache.base import BaseLLMCache from redisvl.index import SearchIndex from redisvl.query import RangeQuery -from redisvl.query.filter import Tag, FilterExpression +from redisvl.query.filter import FilterExpression, Tag from redisvl.redis.utils import array_to_buffer from redisvl.schema import IndexSchema from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer @@ -18,14 +18,14 @@ class SemanticCacheIndexSchema(IndexSchema): def from_params(cls, name: str, vector_dims: int): return cls( - index={"name": name, "prefix": name}, # type: ignore - fields=[ # type: ignore + index={"name": name, "prefix": name}, # type: ignore + fields=[ # type: ignore {"name": "cache_name", "type": "tag"}, {"name": "prompt", "type": "text"}, {"name": "response", "type": "text"}, {"name": "inserted_at", "type": "numeric"}, {"name": "updated_at", "type": "numeric"}, - {"name": "scope_tag", "type": "tag"}, + {"name": "label", "type": "tag"}, { "name": "prompt_vector", "type": "vector", @@ -48,7 +48,7 @@ class SemanticCache(BaseLLMCache): vector_field_name: str = "prompt_vector" inserted_at_field_name: str = "inserted_at" updated_at_field_name: str = "updated_at" - tag_field_name: str = "scope_tag" + tag_field_name: str = "label" response_field_name: str = "response" metadata_field_name: str = "metadata" @@ -227,8 +227,7 @@ def _search_cache( vector: List[float], num_results: int, return_fields: Optional[List[str]], - ##tags: Optional[Union[List[str], str]], - filters: Optional[FilterExpression], + tag_filter: Optional[FilterExpression], ) -> List[Dict[str, Any]]: """Searches the semantic cache for similar prompt vectors and returns the specified return fields for each cache hit.""" @@ -250,10 +249,8 @@ def _search_cache( num_results=num_results, return_score=True, ) - ##if tags: - ## query.set_filter(self.get_filter(tags)) # type: ignore - if filters: - query.set_filter(filters) # type: ignore + if tag_filter: + query.set_filter(tag_filter) # type: ignore # Gather and return the cache hits cache_hits: List[Dict[str, Any]] = self._index.query(query) @@ -284,8 +281,7 @@ def check( vector: Optional[List[float]] = None, num_results: int = 1, return_fields: Optional[List[str]] = None, - ##tags: Optional[Union[List[str], str]] = None, - filters: Optional[FilterExpression] = None, + tag_filter: Optional[FilterExpression] = None, ) -> List[Dict[str, Any]]: """Checks the semantic cache for results similar to the specified prompt or vector. @@ -305,7 +301,7 @@ def check( return_fields (Optional[List[str]], optional): The fields to include in each returned result. If None, defaults to all available fields in the cached entry. - tags (Optional[Union[List[str], str]) : the tag or tags to filter + tag_filter (Optional[FilterExpression]) : the tag filter to filter results by. Default is None and full cache is searched. Returns: @@ -331,8 +327,7 @@ def check( self._check_vector_dims(vector) # Check for cache hits by searching the cache - ##cache_hits = self._search_cache(vector, num_results, return_fields, tags) - cache_hits = self._search_cache(vector, num_results, return_fields, filters) + cache_hits = self._search_cache(vector, num_results, return_fields, tag_filter) return cache_hits def store( @@ -440,21 +435,3 @@ def update(self, key: str, **kwargs) -> None: ) kwargs.update({self.updated_at_field_name: time()}) self._index.client.hset(key, mapping=kwargs) # type: ignore - - def get_filter( - self, - tags: Optional[Union[List[str], str]] = None, - ) -> Tag: - """Set the tags filter to apply to querries based on the desired scope. - - Args: - tags (Optional[Union[List[str], str]]): name of the specific tag or - tags to filter to. Default is None, which means all cache data - will be in scope. - """ - default_filter = Tag(self.tag_field_name) == [] - if not (tags): - return default_filter - - tag_filter = Tag(self.tag_field_name) == tags - return tag_filter diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 9d55be50..0947fc58 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -1,12 +1,12 @@ from collections import namedtuple -from time import sleep +from time import sleep, time import pytest from redisvl.extensions.llmcache import SemanticCache from redisvl.index.index import SearchIndex +from redisvl.query.filter import Num, Tag, Text from redisvl.utils.vectorize import HFTextVectorizer -from redisvl.query.filter import Tag, FilterExpression @pytest.fixture @@ -376,7 +376,9 @@ def test_multiple_tags(cache): tag_4 = "group 3" tags = [tag_1, tag_2, tag_3, tag_4] - filter_1 = Tag('scope_tag') == tag_1 + filter_1 = Tag("label") == tag_1 + filter_2 = Tag("label") == tag_2 + filter_3 = Tag("label") == tag_3 for i in range(4): prompt = f"test prompt {i}" @@ -384,26 +386,40 @@ def test_multiple_tags(cache): cache.store(prompt, response, tag=tags[i]) # test we can specify one specific tag - ##results = cache.check("test prompt 1", tags=tag_1, num_results=5) - results = cache.check("test prompt 1", filters=filter_1, num_results=5) + results = cache.check("test prompt 1", tag_filter=filter_1, num_results=5) assert len(results) == 1 assert results[0]["prompt"] == "test prompt 0" # test we can pass a list of tags - ##results = cache.check("test prompt 1", tags=[tag_1, tag_2, tag_3], num_results=5) - results = cache.check("test prompt 1", filters=[filter_1, filter_2, filter_3], num_results=5) + combined_filter = filter_1 | filter_2 | filter_3 + results = cache.check("test prompt 1", tag_filter=combined_filter, num_results=5) assert len(results) == 3 # test that default tag param searches full cache results = cache.check("test prompt 1", num_results=5) assert len(results) == 4 - # test we can get all matches with empty tag list - ##results = cache.check("test prompt 1", tags=[], num_results=5) - results = cache.check("test prompt 1", filters=[], num_results=5) - assert len(results) == 4 - # test no results are returned if we pass a nonexistant tag - ##results = cache.check("test prompt 1", tags=["bad tag"], num_results=5) - results = cache.check("test prompt 1", filters=["bad tag"], num_results=5) + bad_filter = Tag("label") == "bad tag" + results = cache.check("test prompt 1", tag_filter=bad_filter, num_results=5) assert len(results) == 0 + + +def test_complex_filters(cache): + cache.store(prompt="prompt 1", response="response 1") + cache.store(prompt="prompt 2", response="response 2") + sleep(1) + current_timestamp = time() + cache.store(prompt="prompt 3", response="response 3") + + # test we can do range filters on inserted_at and updated_at fields + range_filter = Num("inserted_at") < current_timestamp + results = cache.check("prompt 1", tag_filter=range_filter, num_results=5) + assert len(results) == 2 + + # test we can combine range filters and text filters + prompt_filter = Text("prompt") % "*pt 1" + combined_filter = prompt_filter & range_filter + + results = cache.check("prompt 1", tag_filter=combined_filter, num_results=5) + assert len(results) == 1 From 7f471f23974b765d74a23bd1e6110c6b409b0d93 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 26 Jul 2024 11:21:34 -0700 Subject: [PATCH 08/10] simplifies drop() method --- redisvl/extensions/llmcache/semantic.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 1ac181ff..0a8a39f1 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -202,13 +202,7 @@ def drop(self, document_ids: Union[str, List[str]]) -> None: Args: document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache. """ - if isinstance(document_ids, List): - with self._index.client.pipeline(transaction=False) as pipe: # type: ignore - for key in document_ids: # type: ignore - pipe.delete(key) - pipe.execute() - else: - self._index.client.delete(document_ids) # type: ignore + self._index.drop_keys(document_ids) def _refresh_ttl(self, key: str) -> None: """Refresh the time-to-live for the specified key.""" From bb662e527bf40fd9786cd9ae832cd7a84056d930 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 31 Jul 2024 14:34:50 -0700 Subject: [PATCH 09/10] replaces class helper methods with util functions --- redisvl/extensions/llmcache/base.py | 8 -------- redisvl/extensions/llmcache/semantic.py | 13 ++++++------- redisvl/utils/utils.py | 2 +- 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/redisvl/extensions/llmcache/base.py b/redisvl/extensions/llmcache/base.py index 5ca7abb9..a1c88466 100644 --- a/redisvl/extensions/llmcache/base.py +++ b/redisvl/extensions/llmcache/base.py @@ -58,11 +58,3 @@ def store( def hash_input(self, prompt: str): """Hashes the input using SHA256.""" return hashify(prompt) - - def serialize(self, metadata: Dict[str, Any]) -> str: - """Serlize the input into a string.""" - return json.dumps(metadata) - - def deserialize(self, metadata: str) -> Dict[str, Any]: - """Deserialize the input from a string.""" - return json.loads(metadata) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 2ae4678f..0008b397 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -1,4 +1,3 @@ -from time import time from typing import Any, Dict, List, Optional, Union from redis import Redis @@ -9,6 +8,7 @@ from redisvl.query.filter import FilterExpression, Tag from redisvl.redis.utils import array_to_buffer from redisvl.schema import IndexSchema +from redisvl.utils.utils import current_timestamp, deserialize, serialize from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer @@ -20,7 +20,6 @@ def from_params(cls, name: str, vector_dims: int): return cls( index={"name": name, "prefix": name}, # type: ignore fields=[ # type: ignore - {"name": "cache_name", "type": "tag"}, {"name": "prompt", "type": "text"}, {"name": "response", "type": "text"}, {"name": "inserted_at", "type": "numeric"}, @@ -253,7 +252,7 @@ def _search_cache( self._refresh_ttl(key) # Check for metadata and deserialize if self.metadata_field_name in hit: - hit[self.metadata_field_name] = self.deserialize( + hit[self.metadata_field_name] = deserialize( hit[self.metadata_field_name] ) return cache_hits @@ -366,7 +365,7 @@ def store( self._check_vector_dims(vector) # Construct semantic cache payload - now = time() + now = current_timestamp() id_field = self.entry_id_field_name payload = { id_field: self.hash_input(prompt), @@ -380,7 +379,7 @@ def store( if not isinstance(metadata, dict): raise TypeError("If specified, cached metadata must be a dictionary.") # Serialize the metadata dict and add to cache payload - payload[self.metadata_field_name] = self.serialize(metadata) + payload[self.metadata_field_name] = serialize(metadata) if tag is not None: payload[self.tag_field_name] = tag @@ -422,10 +421,10 @@ def update(self, key: str, **kwargs) -> None: # Check for metadata and deserialize if _key == self.metadata_field_name: if isinstance(val, dict): - kwargs[_key] = self.serialize(val) + kwargs[_key] = serialize(val) else: raise TypeError( "If specified, cached metadata must be a dictionary." ) - kwargs.update({self.updated_at_field_name: time()}) + kwargs.update({self.updated_at_field_name: current_timestamp()}) self._index.client.hset(key, mapping=kwargs) # type: ignore diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 5f5cc882..eafb47ad 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -54,6 +54,6 @@ def serialize(data: Dict[str, Any]) -> str: return json.dumps(data) -def deserialize(self, data: str) -> Dict[str, Any]: +def deserialize(data: str) -> Dict[str, Any]: """Deserialize the input from a string.""" return json.loads(data) From 32aa30ce3fc6d8c7a3bd1252e6822a8229cba647 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 31 Jul 2024 14:42:27 -0700 Subject: [PATCH 10/10] resets ttl whenever cache entry is updated --- redisvl/extensions/llmcache/semantic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 0008b397..b17c18c9 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -428,3 +428,4 @@ def update(self, key: str, **kwargs) -> None: ) kwargs.update({self.updated_at_field_name: current_timestamp()}) self._index.client.hset(key, mapping=kwargs) # type: ignore + self._refresh_ttl(key)