Skip to content

adds semantic cache scoped access and additional functionality #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 103 additions & 4 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Dict, List, Optional
from time import time
from typing import Any, Dict, List, Optional, Union

from redis import Redis

from redisvl.extensions.llmcache.base import BaseLLMCache
from redisvl.index import SearchIndex
from redisvl.query import RangeQuery
from redisvl.query.filter import Tag
from redisvl.redis.utils import array_to_buffer
from redisvl.schema.schema import IndexSchema
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
Expand All @@ -16,6 +18,9 @@ class SemanticCache(BaseLLMCache):
entry_id_field_name: str = "id"
prompt_field_name: str = "prompt"
vector_field_name: str = "prompt_vector"
inserted_at_field_name: str = "inserted_at"
updated_at_field_name: str = "updated_at"
tag_field_name: str = "scope_tag"
response_field_name: str = "response"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to understand this data model a bit more. As it stands here, it looks like there would be 1 field for scoping (scope_tag) and filtering. What if I have both year, location, and user ID data that I'd like to filter my cache check on?

I think we might need to generalize a bit more and support a set of optional/customizable filterable fields (providable to the class on init??)

So if staying with Hash, the schema could be something like:

  • id
  • prompt
  • prompt_vector
  • response
  • inserted_at
  • updated_at
  • metadata
  • scope_tag???
  • ?

Just thinking out loud!

metadata_field_name: str = "metadata"

Expand Down Expand Up @@ -77,6 +82,9 @@ def __init__(
[
{"name": self.prompt_field_name, "type": "text"},
{"name": self.response_field_name, "type": "text"},
{"name": self.inserted_at_field_name, "type": "numeric"},
{"name": self.updated_at_field_name, "type": "numeric"},
{"name": self.tag_field_name, "type": "tag"},
{
"name": self.vector_field_name,
"type": "vector",
Expand Down Expand Up @@ -104,12 +112,12 @@ def __init__(
self.entry_id_field_name,
self.prompt_field_name,
self.response_field_name,
self.tag_field_name,
self.vector_field_name,
self.metadata_field_name,
]
self.set_vectorizer(vectorizer)
self.set_threshold(distance_threshold)

self._index.create(overwrite=False)

@property
Expand Down Expand Up @@ -183,6 +191,20 @@ def delete(self) -> None:
index."""
self._index.delete(drop=True)

def drop(self, document_ids: Union[str, List[str]]) -> None:
"""Remove a specific entry or entries from the cache by it's ID.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for this, are we expecting users to pass the full redis key? or just the id portion (without prefix)? I think the terminology we try to use for this throughout the library is id when we are referring to the part without the prefix. So maybe we just use ids instead of document_ids?


Args:
document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache.
"""
if isinstance(document_ids, List):
with self._index.client.pipeline(transaction=False) as pipe: # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a similar theme to the work around clear(), to me this feels like an addition to that method by supporting an optional list of ids? by default clear removes everything. Maybe if this list is provided, instead of clearing everything it will just focus on using this pipeline code to clear the subset?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of combining methods. What about something like

clear(keys: Optional[List[str], str]):
"""Remove all or specific entries from the cache by it's ID.
    Args:
        document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache. If none are provided then all documents are removed.
"""

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will move this pipeline code into index.

for key in document_ids: # type: ignore
pipe.delete(key)
pipe.execute()
else:
self._index.client.delete(document_ids) # type: ignore

def _refresh_ttl(self, key: str) -> None:
"""Refresh the time-to-live for the specified key."""
if self._ttl:
Expand All @@ -196,7 +218,11 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
return self._vectorizer.embed(prompt)

def _search_cache(
self, vector: List[float], num_results: int, return_fields: Optional[List[str]]
self,
vector: List[float],
num_results: int,
return_fields: Optional[List[str]],
tags: Optional[Union[List[str], str]],
) -> List[Dict[str, Any]]:
"""Searches the semantic cache for similar prompt vectors and returns
the specified return fields for each cache hit."""
Expand All @@ -218,6 +244,8 @@ def _search_cache(
num_results=num_results,
return_score=True,
)
if tags:
query.set_filter(self.get_filter(tags)) # type: ignore

# Gather and return the cache hits
cache_hits: List[Dict[str, Any]] = self._index.query(query)
Expand Down Expand Up @@ -248,6 +276,7 @@ def check(
vector: Optional[List[float]] = None,
num_results: int = 1,
return_fields: Optional[List[str]] = None,
tags: Optional[Union[List[str], str]] = None,
) -> List[Dict[str, Any]]:
"""Checks the semantic cache for results similar to the specified prompt
or vector.
Expand All @@ -267,6 +296,8 @@ def check(
return_fields (Optional[List[str]], optional): The fields to include
in each returned result. If None, defaults to all available
fields in the cached entry.
tags (Optional[Union[List[str], str]) : the tag or tags to filter
results by. Default is None and full cache is searched.

Returns:
List[Dict[str, Any]]: A list of dicts containing the requested
Expand All @@ -291,7 +322,7 @@ def check(
self._check_vector_dims(vector)

# Check for cache hits by searching the cache
cache_hits = self._search_cache(vector, num_results, return_fields)
cache_hits = self._search_cache(vector, num_results, return_fields, tags)
return cache_hits

def store(
Expand All @@ -300,6 +331,7 @@ def store(
response: str,
vector: Optional[List[float]] = None,
metadata: Optional[dict] = None,
tag: Optional[str] = None,
) -> str:
"""Stores the specified key-value pair in the cache along with metadata.

Expand All @@ -311,6 +343,8 @@ def store(
demand.
metadata (Optional[dict], optional): The optional metadata to cache
alongside the prompt and response. Defaults to None.
tag (Optional[str]): The optional tag to assign to the cache entry.
Defaults to None.

Returns:
str: The Redis key for the entries added to the semantic cache.
Expand All @@ -333,19 +367,84 @@ def store(
self._check_vector_dims(vector)

# Construct semantic cache payload
now = time()
id_field = self.entry_id_field_name
payload = {
id_field: self.hash_input(prompt),
self.prompt_field_name: prompt,
self.response_field_name: response,
self.vector_field_name: array_to_buffer(vector),
self.inserted_at_field_name: now,
self.updated_at_field_name: now,
}
if metadata is not None:
if not isinstance(metadata, dict):
raise TypeError("If specified, cached metadata must be a dictionary.")
# Serialize the metadata dict and add to cache payload
payload[self.metadata_field_name] = self.serialize(metadata)
if tag is not None:
payload[self.tag_field_name] = tag

# Load LLMCache entry with TTL
keys = self._index.load(data=[payload], ttl=self._ttl, id_field=id_field)
return keys[0]

def update(self, key: str, **kwargs) -> None:
"""Update specific fields within an existing cache entry. If no fields
are passed, then only the document TTL is refreshed.

Args:
key (str): the key of the document to update.
kwargs:

Raises:
ValueError if an incorrect mapping is provided as a kwarg.
TypeError if metadata is provided and not of type dict.

.. code-block:: python
key = cache.store('this is a prompt', 'this is a response')
cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"})
)
"""
if not kwargs:
self._refresh_ttl(key)
return

for _key, val in kwargs.items():
if _key not in {
self.prompt_field_name,
self.vector_field_name,
self.response_field_name,
self.tag_field_name,
self.metadata_field_name,
}:
raise ValueError(f" {key} is not a valid field within document")

# Check for metadata and deserialize
if _key == self.metadata_field_name:
if isinstance(val, dict):
kwargs[_key] = self.serialize(val)
else:
raise TypeError(
"If specified, cached metadata must be a dictionary."
)
kwargs.update({self.updated_at_field_name: time()})
self._index.client.hset(key, mapping=kwargs) # type: ignore

def get_filter(
self,
tags: Optional[Union[List[str], str]] = None,
) -> Tag:
"""Set the tags filter to apply to querries based on the desired scope.

Args:
tags (Optional[Union[List[str], str]]): name of the specific tag or
tags to filter to. Default is None, which means all cache data
will be in scope.
"""
default_filter = Tag(self.tag_field_name) == []
if not (tags):
return default_filter

tag_filter = Tag(self.tag_field_name) == tags
return tag_filter
131 changes: 131 additions & 0 deletions tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,43 @@ def test_store_and_check(cache, vectorizer):
assert "metadata" not in check_result[0]


def test_return_fields(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)

cache.store(prompt, response, vector=vector)

# check default return fields
check_result = cache.check(vector=vector)
assert set(check_result[0].keys()) == {
"id",
"prompt",
"response",
"prompt_vector",
"vector_distance",
}

# check all return fields
fields = [
"id",
"prompt",
"response",
"inserted_at",
"updated_at",
"prompt_vector",
"vector_distance",
]
check_result = cache.check(vector=vector, return_fields=fields[:])
assert set(check_result[0].keys()) == set(fields)

# check only some return fields
fields = ["inserted_at", "updated_at"]
check_result = cache.check(vector=vector, return_fields=fields[:])
fields.extend(["id", "vector_distance"]) # id and vector_distance always returned
assert set(check_result[0].keys()) == set(fields)


# Test clearing the cache
def test_clear(cache, vectorizer):
prompt = "This is a test prompt."
Expand All @@ -111,6 +148,65 @@ def test_ttl_expiration(cache_with_ttl, vectorizer):
assert len(check_result) == 0


# Test manual expiration of single document
def test_drop_document(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)

cache.store(prompt, response, vector=vector)
check_result = cache.check(vector=vector)

cache.drop(check_result[0]["id"])
recheck_result = cache.check(vector=vector)
assert len(recheck_result) == 0


# Test manual expiration of multiple documents
def test_drop_documents(cache, vectorizer):
prompts = [
"This is a test prompt.",
"This is also test prompt.",
"This is another test prompt.",
]
responses = [
"This is a test response.",
"This is also test response.",
"This is a another test response.",
]
for prompt, response in zip(prompts, responses):
vector = vectorizer.embed(prompt)
cache.store(prompt, response, vector=vector)

check_result = cache.check(vector=vector, num_results=3)
keys = [r["id"] for r in check_result[0:2]] # drop first 2 entries
cache.drop(keys)

recheck_result = cache.check(vector=vector, num_results=3)
assert len(recheck_result) == 1


# Test updating document fields
def test_updating_document(cache):
prompt = "This is a test prompt."
response = "This is a test response."
cache.store(prompt=prompt, response=response)

check_result = cache.check(prompt=prompt, return_fields=["updated_at"])
key = check_result[0]["id"]

sleep(1)

metadata = {"foo": "bar"}
cache.update(key=key, metadata=metadata)

updated_result = cache.check(
prompt=prompt, return_fields=["updated_at", "metadata"]
)
assert updated_result[0]["metadata"] == metadata
assert updated_result[0]["updated_at"] > check_result[0]["updated_at"]


def test_ttl_expiration_after_update(cache_with_ttl, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand Down Expand Up @@ -263,3 +359,38 @@ def test_vector_size(cache, vectorizer):

with pytest.raises(ValueError):
cache.check(vector=[1, 2, 3])


# test we can pass a list of tags and we'll include all results that match
def test_multiple_tags(cache):
tag_1 = "group 0"
tag_2 = "group 1"
tag_3 = "group 2"
tag_4 = "group 3"
tags = [tag_1, tag_2, tag_3, tag_4]

for i in range(4):
prompt = f"test prompt {i}"
response = f"test response {i}"
cache.store(prompt, response, tag=tags[i])

# test we can specify one specific tag
results = cache.check("test prompt 1", tags=tag_1, num_results=5)
assert len(results) == 1
assert results[0]["prompt"] == "test prompt 0"

# test we can pass a list of tags
results = cache.check("test prompt 1", tags=[tag_1, tag_2, tag_3], num_results=5)
assert len(results) == 3

# test that default tag param searches full cache
results = cache.check("test prompt 1", num_results=5)
assert len(results) == 4

# test we can get all matches with empty tag list
results = cache.check("test prompt 1", tags=[], num_results=5)
assert len(results) == 4

# test no results are returned if we pass a nonexistant tag
results = cache.check("test prompt 1", tags=["bad tag"], num_results=5)
assert len(results) == 0
Loading