-
Notifications
You must be signed in to change notification settings - Fork 56
adds semantic cache scoped access and additional functionality #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
06d0d73
3578bd4
4e4212a
7e286bd
b292d3b
2f28ec6
b9022d1
4b91de4
6684b0b
5c7893f
7f471f2
f112466
2443549
bb662e5
32aa30c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,53 @@ | ||
from typing import Any, Dict, List, Optional | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from redis import Redis | ||
|
||
from redisvl.extensions.llmcache.base import BaseLLMCache | ||
from redisvl.index import SearchIndex | ||
from redisvl.query import RangeQuery | ||
from redisvl.query.filter import FilterExpression, Tag | ||
from redisvl.redis.utils import array_to_buffer | ||
from redisvl.schema.schema import IndexSchema | ||
from redisvl.schema import IndexSchema | ||
from redisvl.utils.utils import current_timestamp, deserialize, serialize | ||
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer | ||
|
||
|
||
class SemanticCacheIndexSchema(IndexSchema): | ||
|
||
@classmethod | ||
def from_params(cls, name: str, vector_dims: int): | ||
|
||
return cls( | ||
index={"name": name, "prefix": name}, # type: ignore | ||
fields=[ # type: ignore | ||
{"name": "prompt", "type": "text"}, | ||
{"name": "response", "type": "text"}, | ||
{"name": "inserted_at", "type": "numeric"}, | ||
{"name": "updated_at", "type": "numeric"}, | ||
{"name": "label", "type": "tag"}, | ||
{ | ||
"name": "prompt_vector", | ||
"type": "vector", | ||
"attrs": { | ||
"dims": vector_dims, | ||
"datatype": "float32", | ||
"distance_metric": "cosine", | ||
"algorithm": "flat", | ||
}, | ||
}, | ||
], | ||
) | ||
|
||
|
||
class SemanticCache(BaseLLMCache): | ||
"""Semantic Cache for Large Language Models.""" | ||
|
||
entry_id_field_name: str = "_id" | ||
prompt_field_name: str = "prompt" | ||
vector_field_name: str = "prompt_vector" | ||
inserted_at_field_name: str = "inserted_at" | ||
updated_at_field_name: str = "updated_at" | ||
tag_field_name: str = "label" | ||
response_field_name: str = "response" | ||
metadata_field_name: str = "metadata" | ||
|
||
|
@@ -69,27 +101,7 @@ def __init__( | |
model="sentence-transformers/all-mpnet-base-v2" | ||
) | ||
|
||
# build cache index schema | ||
schema = IndexSchema.from_dict({"index": {"name": name, "prefix": prefix}}) | ||
# add fields | ||
schema.add_fields( | ||
[ | ||
{"name": self.prompt_field_name, "type": "text"}, | ||
{"name": self.response_field_name, "type": "text"}, | ||
{ | ||
"name": self.vector_field_name, | ||
"type": "vector", | ||
"attrs": { | ||
"dims": vectorizer.dims, | ||
"datatype": "float32", | ||
"distance_metric": "cosine", | ||
"algorithm": "flat", | ||
}, | ||
}, | ||
] | ||
) | ||
|
||
# build search index | ||
schema = SemanticCacheIndexSchema.from_params(name, vectorizer.dims) | ||
self._index = SearchIndex(schema=schema) | ||
|
||
# handle redis connection | ||
|
@@ -103,12 +115,12 @@ def __init__( | |
self.entry_id_field_name, | ||
self.prompt_field_name, | ||
self.response_field_name, | ||
self.tag_field_name, | ||
self.vector_field_name, | ||
self.metadata_field_name, | ||
] | ||
self.set_vectorizer(vectorizer) | ||
self.set_threshold(distance_threshold) | ||
|
||
self._index.create(overwrite=False) | ||
|
||
@property | ||
|
@@ -182,6 +194,14 @@ def delete(self) -> None: | |
index.""" | ||
self._index.delete(drop=True) | ||
|
||
def drop(self, document_ids: Union[str, List[str]]) -> None: | ||
"""Remove a specific entry or entries from the cache by it's ID. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for this, are we expecting users to pass the full redis key? or just the id portion (without prefix)? I think the terminology we try to use for this throughout the library is |
||
|
||
Args: | ||
document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache. | ||
""" | ||
self._index.drop_keys(document_ids) | ||
|
||
def _refresh_ttl(self, key: str) -> None: | ||
"""Refresh the time-to-live for the specified key.""" | ||
if self._ttl: | ||
|
@@ -195,7 +215,11 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: | |
return self._vectorizer.embed(prompt) | ||
|
||
def _search_cache( | ||
self, vector: List[float], num_results: int, return_fields: Optional[List[str]] | ||
self, | ||
vector: List[float], | ||
num_results: int, | ||
return_fields: Optional[List[str]], | ||
tag_filter: Optional[FilterExpression], | ||
) -> List[Dict[str, Any]]: | ||
"""Searches the semantic cache for similar prompt vectors and returns | ||
the specified return fields for each cache hit.""" | ||
|
@@ -217,6 +241,8 @@ def _search_cache( | |
num_results=num_results, | ||
return_score=True, | ||
) | ||
if tag_filter: | ||
query.set_filter(tag_filter) # type: ignore | ||
|
||
# Gather and return the cache hits | ||
cache_hits: List[Dict[str, Any]] = self._index.query(query) | ||
|
@@ -226,7 +252,7 @@ def _search_cache( | |
self._refresh_ttl(key) | ||
# Check for metadata and deserialize | ||
if self.metadata_field_name in hit: | ||
hit[self.metadata_field_name] = self.deserialize( | ||
hit[self.metadata_field_name] = deserialize( | ||
hit[self.metadata_field_name] | ||
) | ||
return cache_hits | ||
|
@@ -248,6 +274,7 @@ def check( | |
vector: Optional[List[float]] = None, | ||
num_results: int = 1, | ||
return_fields: Optional[List[str]] = None, | ||
tag_filter: Optional[FilterExpression] = None, | ||
) -> List[Dict[str, Any]]: | ||
"""Checks the semantic cache for results similar to the specified prompt | ||
or vector. | ||
|
@@ -267,6 +294,8 @@ def check( | |
return_fields (Optional[List[str]], optional): The fields to include | ||
in each returned result. If None, defaults to all available | ||
fields in the cached entry. | ||
tag_filter (Optional[FilterExpression]) : the tag filter to filter | ||
results by. Default is None and full cache is searched. | ||
|
||
Returns: | ||
List[Dict[str, Any]]: A list of dicts containing the requested | ||
|
@@ -291,7 +320,7 @@ def check( | |
self._check_vector_dims(vector) | ||
|
||
# Check for cache hits by searching the cache | ||
cache_hits = self._search_cache(vector, num_results, return_fields) | ||
cache_hits = self._search_cache(vector, num_results, return_fields, tag_filter) | ||
return cache_hits | ||
|
||
def store( | ||
|
@@ -300,6 +329,7 @@ def store( | |
response: str, | ||
vector: Optional[List[float]] = None, | ||
metadata: Optional[dict] = None, | ||
tag: Optional[str] = None, | ||
) -> str: | ||
"""Stores the specified key-value pair in the cache along with metadata. | ||
|
||
|
@@ -311,6 +341,8 @@ def store( | |
demand. | ||
metadata (Optional[dict], optional): The optional metadata to cache | ||
alongside the prompt and response. Defaults to None. | ||
tag (Optional[str]): The optional tag to assign to the cache entry. | ||
Defaults to None. | ||
|
||
Returns: | ||
str: The Redis key for the entries added to the semantic cache. | ||
|
@@ -333,19 +365,67 @@ def store( | |
self._check_vector_dims(vector) | ||
|
||
# Construct semantic cache payload | ||
now = current_timestamp() | ||
id_field = self.entry_id_field_name | ||
payload = { | ||
id_field: self.hash_input(prompt), | ||
self.prompt_field_name: prompt, | ||
self.response_field_name: response, | ||
self.vector_field_name: array_to_buffer(vector), | ||
self.inserted_at_field_name: now, | ||
self.updated_at_field_name: now, | ||
} | ||
if metadata is not None: | ||
if not isinstance(metadata, dict): | ||
raise TypeError("If specified, cached metadata must be a dictionary.") | ||
# Serialize the metadata dict and add to cache payload | ||
payload[self.metadata_field_name] = self.serialize(metadata) | ||
payload[self.metadata_field_name] = serialize(metadata) | ||
if tag is not None: | ||
payload[self.tag_field_name] = tag | ||
|
||
# Load LLMCache entry with TTL | ||
keys = self._index.load(data=[payload], ttl=self._ttl, id_field=id_field) | ||
return keys[0] | ||
|
||
def update(self, key: str, **kwargs) -> None: | ||
"""Update specific fields within an existing cache entry. If no fields | ||
are passed, then only the document TTL is refreshed. | ||
|
||
Args: | ||
key (str): the key of the document to update. | ||
kwargs: | ||
|
||
Raises: | ||
ValueError if an incorrect mapping is provided as a kwarg. | ||
TypeError if metadata is provided and not of type dict. | ||
|
||
.. code-block:: python | ||
key = cache.store('this is a prompt', 'this is a response') | ||
cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"}) | ||
) | ||
""" | ||
if not kwargs: | ||
self._refresh_ttl(key) | ||
return | ||
|
||
for _key, val in kwargs.items(): | ||
if _key not in { | ||
self.prompt_field_name, | ||
self.vector_field_name, | ||
self.response_field_name, | ||
self.tag_field_name, | ||
self.metadata_field_name, | ||
}: | ||
raise ValueError(f" {key} is not a valid field within document") | ||
|
||
# Check for metadata and deserialize | ||
if _key == self.metadata_field_name: | ||
if isinstance(val, dict): | ||
kwargs[_key] = serialize(val) | ||
else: | ||
raise TypeError( | ||
"If specified, cached metadata must be a dictionary." | ||
) | ||
kwargs.update({self.updated_at_field_name: current_timestamp()}) | ||
self._index.client.hset(key, mapping=kwargs) # type: ignore | ||
self._refresh_ttl(key) | ||
justin-cechmanek marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.