-
Notifications
You must be signed in to change notification settings - Fork 55
adds semantic cache scoped access and additional functionality #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
06d0d73
3578bd4
4e4212a
7e286bd
b292d3b
2f28ec6
b9022d1
4b91de4
6684b0b
5c7893f
7f471f2
f112466
2443549
bb662e5
32aa30c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
from typing import Any, Dict, List, Optional | ||
from time import time | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from redis import Redis | ||
|
||
from redisvl.extensions.llmcache.base import BaseLLMCache | ||
from redisvl.index import SearchIndex | ||
from redisvl.query import RangeQuery | ||
from redisvl.query.filter import Tag | ||
from redisvl.redis.utils import array_to_buffer | ||
from redisvl.schema.schema import IndexSchema | ||
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer | ||
|
@@ -16,6 +18,9 @@ class SemanticCache(BaseLLMCache): | |
entry_id_field_name: str = "id" | ||
prompt_field_name: str = "prompt" | ||
vector_field_name: str = "prompt_vector" | ||
inserted_at_field_name: str = "inserted_at" | ||
updated_at_field_name: str = "updated_at" | ||
tag_field_name: str = "scope_tag" | ||
response_field_name: str = "response" | ||
metadata_field_name: str = "metadata" | ||
|
||
|
@@ -77,6 +82,9 @@ def __init__( | |
[ | ||
{"name": self.prompt_field_name, "type": "text"}, | ||
{"name": self.response_field_name, "type": "text"}, | ||
{"name": self.inserted_at_field_name, "type": "numeric"}, | ||
{"name": self.updated_at_field_name, "type": "numeric"}, | ||
justin-cechmanek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{"name": self.tag_field_name, "type": "tag"}, | ||
{ | ||
"name": self.vector_field_name, | ||
"type": "vector", | ||
|
@@ -104,12 +112,12 @@ def __init__( | |
self.entry_id_field_name, | ||
self.prompt_field_name, | ||
self.response_field_name, | ||
self.tag_field_name, | ||
self.vector_field_name, | ||
self.metadata_field_name, | ||
] | ||
self.set_vectorizer(vectorizer) | ||
self.set_threshold(distance_threshold) | ||
|
||
self._index.create(overwrite=False) | ||
|
||
@property | ||
|
@@ -183,6 +191,20 @@ def delete(self) -> None: | |
index.""" | ||
self._index.delete(drop=True) | ||
|
||
def drop(self, document_ids: Union[str, List[str]]) -> None: | ||
"""Remove a specific entry or entries from the cache by it's ID. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for this, are we expecting users to pass the full redis key? or just the id portion (without prefix)? I think the terminology we try to use for this throughout the library is |
||
|
||
Args: | ||
document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache. | ||
""" | ||
if isinstance(document_ids, List): | ||
with self._index.client.pipeline(transaction=False) as pipe: # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a similar theme to the work around There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the idea of combining methods. What about something like
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will move this pipeline code into index. |
||
for key in document_ids: # type: ignore | ||
pipe.delete(key) | ||
pipe.execute() | ||
else: | ||
self._index.client.delete(document_ids) # type: ignore | ||
|
||
def _refresh_ttl(self, key: str) -> None: | ||
"""Refresh the time-to-live for the specified key.""" | ||
if self._ttl: | ||
|
@@ -196,7 +218,11 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: | |
return self._vectorizer.embed(prompt) | ||
|
||
def _search_cache( | ||
self, vector: List[float], num_results: int, return_fields: Optional[List[str]] | ||
self, | ||
vector: List[float], | ||
num_results: int, | ||
return_fields: Optional[List[str]], | ||
tags: Optional[Union[List[str], str]], | ||
) -> List[Dict[str, Any]]: | ||
"""Searches the semantic cache for similar prompt vectors and returns | ||
the specified return fields for each cache hit.""" | ||
|
@@ -218,6 +244,8 @@ def _search_cache( | |
num_results=num_results, | ||
return_score=True, | ||
) | ||
if tags: | ||
query.set_filter(self.get_filter(tags)) # type: ignore | ||
|
||
# Gather and return the cache hits | ||
cache_hits: List[Dict[str, Any]] = self._index.query(query) | ||
|
@@ -248,6 +276,7 @@ def check( | |
vector: Optional[List[float]] = None, | ||
num_results: int = 1, | ||
return_fields: Optional[List[str]] = None, | ||
tags: Optional[Union[List[str], str]] = None, | ||
) -> List[Dict[str, Any]]: | ||
justin-cechmanek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Checks the semantic cache for results similar to the specified prompt | ||
or vector. | ||
|
@@ -267,6 +296,8 @@ def check( | |
return_fields (Optional[List[str]], optional): The fields to include | ||
in each returned result. If None, defaults to all available | ||
fields in the cached entry. | ||
tags (Optional[Union[List[str], str]) : the tag or tags to filter | ||
results by. Default is None and full cache is searched. | ||
|
||
Returns: | ||
List[Dict[str, Any]]: A list of dicts containing the requested | ||
|
@@ -291,7 +322,7 @@ def check( | |
self._check_vector_dims(vector) | ||
|
||
# Check for cache hits by searching the cache | ||
cache_hits = self._search_cache(vector, num_results, return_fields) | ||
cache_hits = self._search_cache(vector, num_results, return_fields, tags) | ||
return cache_hits | ||
|
||
def store( | ||
|
@@ -300,6 +331,7 @@ def store( | |
response: str, | ||
vector: Optional[List[float]] = None, | ||
metadata: Optional[dict] = None, | ||
tag: Optional[str] = None, | ||
) -> str: | ||
"""Stores the specified key-value pair in the cache along with metadata. | ||
|
||
|
@@ -311,6 +343,8 @@ def store( | |
demand. | ||
metadata (Optional[dict], optional): The optional metadata to cache | ||
alongside the prompt and response. Defaults to None. | ||
tag (Optional[str]): The optional tag to assign to the cache entry. | ||
Defaults to None. | ||
|
||
Returns: | ||
str: The Redis key for the entries added to the semantic cache. | ||
|
@@ -333,19 +367,84 @@ def store( | |
self._check_vector_dims(vector) | ||
|
||
# Construct semantic cache payload | ||
now = time() | ||
id_field = self.entry_id_field_name | ||
justin-cechmanek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
payload = { | ||
id_field: self.hash_input(prompt), | ||
self.prompt_field_name: prompt, | ||
self.response_field_name: response, | ||
self.vector_field_name: array_to_buffer(vector), | ||
self.inserted_at_field_name: now, | ||
self.updated_at_field_name: now, | ||
} | ||
if metadata is not None: | ||
if not isinstance(metadata, dict): | ||
raise TypeError("If specified, cached metadata must be a dictionary.") | ||
# Serialize the metadata dict and add to cache payload | ||
payload[self.metadata_field_name] = self.serialize(metadata) | ||
if tag is not None: | ||
justin-cechmanek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
payload[self.tag_field_name] = tag | ||
|
||
# Load LLMCache entry with TTL | ||
keys = self._index.load(data=[payload], ttl=self._ttl, id_field=id_field) | ||
return keys[0] | ||
|
||
def update(self, key: str, **kwargs) -> None: | ||
"""Update specific fields within an existing cache entry. If no fields | ||
are passed, then only the document TTL is refreshed. | ||
|
||
Args: | ||
key (str): the key of the document to update. | ||
kwargs: | ||
|
||
Raises: | ||
ValueError if an incorrect mapping is provided as a kwarg. | ||
TypeError if metadata is provided and not of type dict. | ||
|
||
.. code-block:: python | ||
key = cache.store('this is a prompt', 'this is a response') | ||
cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"}) | ||
) | ||
""" | ||
if not kwargs: | ||
self._refresh_ttl(key) | ||
return | ||
|
||
for _key, val in kwargs.items(): | ||
if _key not in { | ||
self.prompt_field_name, | ||
self.vector_field_name, | ||
self.response_field_name, | ||
self.tag_field_name, | ||
self.metadata_field_name, | ||
}: | ||
raise ValueError(f" {key} is not a valid field within document") | ||
|
||
# Check for metadata and deserialize | ||
if _key == self.metadata_field_name: | ||
if isinstance(val, dict): | ||
kwargs[_key] = self.serialize(val) | ||
else: | ||
raise TypeError( | ||
"If specified, cached metadata must be a dictionary." | ||
) | ||
kwargs.update({self.updated_at_field_name: time()}) | ||
self._index.client.hset(key, mapping=kwargs) # type: ignore | ||
justin-cechmanek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
justin-cechmanek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def get_filter( | ||
self, | ||
tags: Optional[Union[List[str], str]] = None, | ||
) -> Tag: | ||
"""Set the tags filter to apply to querries based on the desired scope. | ||
|
||
Args: | ||
tags (Optional[Union[List[str], str]]): name of the specific tag or | ||
tags to filter to. Default is None, which means all cache data | ||
will be in scope. | ||
""" | ||
default_filter = Tag(self.tag_field_name) == [] | ||
if not (tags): | ||
return default_filter | ||
|
||
tag_filter = Tag(self.tag_field_name) == tags | ||
return tag_filter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just want to understand this data model a bit more. As it stands here, it looks like there would be 1 field for scoping (
scope_tag
) and filtering. What if I have both year, location, and user ID data that I'd like to filter my cache check on?I think we might need to generalize a bit more and support a set of optional/customizable filterable fields (providable to the class on init??)
So if staying with Hash, the schema could be something like:
id
prompt
prompt_vector
response
inserted_at
updated_at
metadata
scope_tag
???Just thinking out loud!