diff --git a/docs/user_guide/llmcache_03.ipynb b/docs/user_guide/llmcache_03.ipynb index 8d326f8f..3403b287 100644 --- a/docs/user_guide/llmcache_03.ipynb +++ b/docs/user_guide/llmcache_03.ipynb @@ -83,7 +83,6 @@ "\n", "llmcache = SemanticCache(\n", " name=\"llmcache\", # underlying search index name\n", - " prefix=\"llmcache\", # redis key prefix for hash entries\n", " redis_url=\"redis://localhost:6379\", # redis connection url string\n", " distance_threshold=0.1 # semantic cache distance threshold\n", ")" @@ -107,13 +106,15 @@ "│ llmcache │ HASH │ ['llmcache'] │ [] │ 0 │\n", "╰──────────────┴────────────────┴──────────────┴─────────────────┴────────────╯\n", "Index Fields:\n", - "╭───────────────┬───────────────┬────────┬────────────────┬────────────────╮\n", - "│ Name │ Attribute │ Type │ Field Option │ Option Value │\n", - "├───────────────┼───────────────┼────────┼────────────────┼────────────────┤\n", - "│ prompt │ prompt │ TEXT │ WEIGHT │ 1 │\n", - "│ response │ response │ TEXT │ WEIGHT │ 1 │\n", - "│ prompt_vector │ prompt_vector │ VECTOR │ │ │\n", - "╰───────────────┴───────────────┴────────┴────────────────┴────────────────╯\n" + "╭───────────────┬───────────────┬─────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n", + "│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n", + "├───────────────┼───────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n", + "│ prompt │ prompt │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n", + "│ response │ response │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n", + "│ inserted_at │ inserted_at │ NUMERIC │ │ │ │ │ │ │ │ │\n", + "│ updated_at │ updated_at │ NUMERIC │ │ │ │ │ │ │ │ │\n", + "│ prompt_vector │ prompt_vector │ VECTOR │ algorithm │ FLAT │ data_type │ FLOAT32 │ dim │ 768 │ distance_metric │ COSINE │\n", + "╰───────────────┴───────────────┴─────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n" ] } ], @@ -208,7 +209,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[{'id': 'llmcache:115049a298532be2f181edb03f766770c0db84c22aff39003fec340deaec7545', 'vector_distance': '9.53674316406e-07', 'prompt': 'What is the capital of France?', 'response': 'Paris', 'metadata': {'city': 'Paris', 'country': 'france'}}]\n" + "[{'prompt': 'What is the capital of France?', 'response': 'Paris', 'metadata': {'city': 'Paris', 'country': 'france'}, 'key': 'llmcache:115049a298532be2f181edb03f766770c0db84c22aff39003fec340deaec7545'}]\n" ] } ], @@ -384,7 +385,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -408,14 +409,14 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Without caching, a call to openAI to answer this simple question took 1.460299015045166 seconds.\n" + "Without caching, a call to openAI to answer this simple question took 0.9312698841094971 seconds.\n" ] }, { @@ -424,7 +425,7 @@ "'llmcache:67e0f6e28fe2a61c0022fd42bf734bb8ffe49d3e375fd69d692574295a20fc1a'" ] }, - "execution_count": 18, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -451,8 +452,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Avg time taken with LLM cache enabled: 0.2560166358947754\n", - "Percentage of time saved: 82.47%\n" + "Avg time taken with LLM cache enabled: 0.4896167993545532\n", + "Percentage of time saved: 47.42%\n" ] } ], @@ -515,7 +516,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -540,7 +541,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.14" }, "orig_nbformat": 4 }, diff --git a/redisvl/extensions/llmcache/base.py b/redisvl/extensions/llmcache/base.py index a1c88466..d11a404f 100644 --- a/redisvl/extensions/llmcache/base.py +++ b/redisvl/extensions/llmcache/base.py @@ -1,4 +1,3 @@ -import json from typing import Any, Dict, List, Optional from redisvl.redis.utils import hashify diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py new file mode 100644 index 00000000..8075496b --- /dev/null +++ b/redisvl/extensions/llmcache/schema.py @@ -0,0 +1,128 @@ +from typing import Any, Dict, List, Optional + +from pydantic.v1 import BaseModel, Field, root_validator, validator + +from redisvl.redis.utils import array_to_buffer, hashify +from redisvl.schema import IndexSchema +from redisvl.utils.utils import current_timestamp, deserialize, serialize + + +class CacheEntry(BaseModel): + """A single cache entry in Redis""" + + entry_id: Optional[str] = Field(default=None) + """Cache entry identifier""" + prompt: str + """Input prompt or question cached in Redis""" + response: str + """Response or answer to the question, cached in Redis""" + prompt_vector: List[float] + """Text embedding representation of the prompt""" + inserted_at: float = Field(default_factory=current_timestamp) + """Timestamp of when the entry was added to the cache""" + updated_at: float = Field(default_factory=current_timestamp) + """Timestamp of when the entry was updated in the cache""" + metadata: Optional[Dict[str, Any]] = Field(default=None) + """Optional metadata stored on the cache entry""" + filters: Optional[Dict[str, Any]] = Field(default=None) + """Optional filter data stored on the cache entry for customizing retrieval""" + + @root_validator(pre=True) + @classmethod + def generate_id(cls, values): + # Ensure entry_id is set + if not values.get("entry_id"): + values["entry_id"] = hashify(values["prompt"]) + return values + + @validator("metadata") + def non_empty_metadata(cls, v): + if v is not None and not isinstance(v, dict): + raise TypeError("Metadata must be a dictionary.") + return v + + def to_dict(self) -> Dict: + data = self.dict(exclude_none=True) + data["prompt_vector"] = array_to_buffer(self.prompt_vector) + if self.metadata: + data["metadata"] = serialize(self.metadata) + if self.filters: + data.update(self.filters) + del data["filters"] + return data + + +class CacheHit(BaseModel): + """A cache hit based on some input query""" + + entry_id: str + """Cache entry identifier""" + prompt: str + """Input prompt or question cached in Redis""" + response: str + """Response or answer to the question, cached in Redis""" + vector_distance: float + """The semantic distance between the query vector and the stored prompt vector""" + inserted_at: float + """Timestamp of when the entry was added to the cache""" + updated_at: float + """Timestamp of when the entry was updated in the cache""" + metadata: Optional[Dict[str, Any]] = Field(default=None) + """Optional metadata stored on the cache entry""" + filters: Optional[Dict[str, Any]] = Field(default=None) + """Optional filter data stored on the cache entry for customizing retrieval""" + + @root_validator(pre=True) + @classmethod + def validate_cache_hit(cls, values): + # Deserialize metadata if necessary + if "metadata" in values and isinstance(values["metadata"], str): + values["metadata"] = deserialize(values["metadata"]) + + # Separate filters from other fields + known_fields = set(cls.__fields__.keys()) + filters = {k: v for k, v in values.items() if k not in known_fields} + + # Add filters to values + if filters: + values["filters"] = filters + + # Remove filter fields from the main values + for k in filters: + values.pop(k) + + return values + + def to_dict(self) -> Dict: + data = self.dict(exclude_none=True) + if self.filters: + data.update(self.filters) + del data["filters"] + + return data + + +class SemanticCacheIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, prefix: str, vector_dims: int): + + return cls( + index={"name": name, "prefix": prefix}, # type: ignore + fields=[ # type: ignore + {"name": "prompt", "type": "text"}, + {"name": "response", "type": "text"}, + {"name": "inserted_at", "type": "numeric"}, + {"name": "updated_at", "type": "numeric"}, + { + "name": "prompt_vector", + "type": "vector", + "attrs": { + "dims": vector_dims, + "datatype": "float32", + "distance_metric": "cosine", + "algorithm": "flat", + }, + }, + ], + ) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index b17c18c9..f9518614 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -1,63 +1,39 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from redis import Redis from redisvl.extensions.llmcache.base import BaseLLMCache +from redisvl.extensions.llmcache.schema import ( + CacheEntry, + CacheHit, + SemanticCacheIndexSchema, +) from redisvl.index import SearchIndex from redisvl.query import RangeQuery -from redisvl.query.filter import FilterExpression, Tag -from redisvl.redis.utils import array_to_buffer -from redisvl.schema import IndexSchema -from redisvl.utils.utils import current_timestamp, deserialize, serialize +from redisvl.query.filter import FilterExpression +from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer -class SemanticCacheIndexSchema(IndexSchema): - - @classmethod - def from_params(cls, name: str, vector_dims: int): - - return cls( - index={"name": name, "prefix": name}, # type: ignore - fields=[ # type: ignore - {"name": "prompt", "type": "text"}, - {"name": "response", "type": "text"}, - {"name": "inserted_at", "type": "numeric"}, - {"name": "updated_at", "type": "numeric"}, - {"name": "label", "type": "tag"}, - { - "name": "prompt_vector", - "type": "vector", - "attrs": { - "dims": vector_dims, - "datatype": "float32", - "distance_metric": "cosine", - "algorithm": "flat", - }, - }, - ], - ) - - class SemanticCache(BaseLLMCache): """Semantic Cache for Large Language Models.""" - entry_id_field_name: str = "_id" + redis_key_field_name: str = "key" + entry_id_field_name: str = "entry_id" prompt_field_name: str = "prompt" + response_field_name: str = "response" vector_field_name: str = "prompt_vector" inserted_at_field_name: str = "inserted_at" updated_at_field_name: str = "updated_at" - tag_field_name: str = "label" - response_field_name: str = "response" metadata_field_name: str = "metadata" def __init__( self, name: str = "llmcache", - prefix: Optional[str] = None, distance_threshold: float = 0.1, ttl: Optional[int] = None, vectorizer: Optional[BaseVectorizer] = None, + filterable_fields: Optional[List[Dict[str, Any]]] = None, redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", connection_kwargs: Dict[str, Any] = {}, @@ -68,15 +44,14 @@ def __init__( Args: name (str, optional): The name of the semantic cache search index. Defaults to "llmcache". - prefix (Optional[str], optional): The prefix for Redis keys - associated with the semantic cache search index. Defaults to - None, and the index name will be used as the key prefix. distance_threshold (float, optional): Semantic threshold for the cache. Defaults to 0.1. ttl (Optional[int], optional): The time-to-live for records cached in Redis. Defaults to None. vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache. Defaults to HFTextVectorizer. + filterable_fields (Optional[List[Dict[str, Any]]]): An optional list of RedisVL fields + that can be used to customize cache retrieval with filters. redis_client(Optional[Redis], optional): A redis client connection instance. Defaults to None. redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. @@ -87,12 +62,13 @@ def __init__( TypeError: If an invalid vectorizer is provided. TypeError: If the TTL value is not an int. ValueError: If the threshold is not between 0 and 1. - ValueError: If the index name is not provided """ super().__init__(ttl) # Use the index name as the key prefix by default - if prefix is None: + if "prefix" in kwargs: + prefix = kwargs["prefix"] + else: prefix = name # Set vectorizer default @@ -101,28 +77,57 @@ def __init__( model="sentence-transformers/all-mpnet-base-v2" ) - schema = SemanticCacheIndexSchema.from_params(name, vectorizer.dims) + # Process fields + self.return_fields = [ + self.entry_id_field_name, + self.prompt_field_name, + self.response_field_name, + self.inserted_at_field_name, + self.updated_at_field_name, + self.metadata_field_name, + ] + + # Create semantic cache schema and index + schema = SemanticCacheIndexSchema.from_params(name, prefix, vectorizer.dims) + schema = self._modify_schema(schema, filterable_fields) + self._index = SearchIndex(schema=schema) - # handle redis connection + # Handle redis connection if redis_client: self._index.set_client(redis_client) elif redis_url: self._index.connect(redis_url=redis_url, **connection_kwargs) - # initialize other components - self.default_return_fields = [ - self.entry_id_field_name, - self.prompt_field_name, - self.response_field_name, - self.tag_field_name, - self.vector_field_name, - self.metadata_field_name, - ] - self.set_vectorizer(vectorizer) + # Initialize other components + self._set_vectorizer(vectorizer) self.set_threshold(distance_threshold) self._index.create(overwrite=False) + def _modify_schema( + self, + schema: SemanticCacheIndexSchema, + filterable_fields: Optional[List[Dict[str, Any]]] = None, + ) -> SemanticCacheIndexSchema: + """Modify the base cache schema using the provided filterable fields""" + + if filterable_fields is not None: + protected_field_names = set( + self.return_fields + [self.redis_key_field_name] + ) + for filter_field in filterable_fields: + field_name = filter_field["name"] + if field_name in protected_field_names: + raise ValueError( + f"{field_name} is a reserved field name for the semantic cache schema" + ) + # Add to schema + schema.add_field(filter_field) + # Add to return fields too + self.return_fields.append(field_name) + + return schema + @property def index(self) -> SearchIndex: """The underlying SearchIndex for the cache. @@ -157,7 +162,7 @@ def set_threshold(self, distance_threshold: float) -> None: ) self._distance_threshold = float(distance_threshold) - def set_vectorizer(self, vectorizer: BaseVectorizer) -> None: + def _set_vectorizer(self, vectorizer: BaseVectorizer) -> None: """Sets the vectorizer for the LLM cache. Must be a valid subclass of BaseVectorizer and have equivalent @@ -175,14 +180,7 @@ def set_vectorizer(self, vectorizer: BaseVectorizer) -> None: raise TypeError("Must provide a valid redisvl.vectorizer class.") schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore - - if schema_vector_dims != vectorizer.dims: - raise ValueError( - "Invalid vector dimensions! " - f"Vectorizer has dims defined as {vectorizer.dims}", - f"Vector field has dims defined as {schema_vector_dims}", - ) - + validate_vector_dims(vectorizer.dims, schema_vector_dims) self._vectorizer = vectorizer def clear(self) -> None: @@ -194,13 +192,20 @@ def delete(self) -> None: index.""" self._index.delete(drop=True) - def drop(self, document_ids: Union[str, List[str]]) -> None: - """Remove a specific entry or entries from the cache by it's ID. + def drop( + self, ids: Optional[List[str]] = None, keys: Optional[List[str]] = None + ) -> None: + """Manually expire specific entries from the cache by id or specific + Redis key. Args: - document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache. + ids (Optional[str]): The document ID or IDs to remove from the cache. + keys (Optional[str]): """ - self._index.drop_keys(document_ids) + if ids is not None: + self._index.drop_keys([self._index.key(id) for id in ids]) + if keys is not None: + self._index.drop_keys(keys) def _refresh_ttl(self, key: str) -> None: """Refresh the time-to-live for the specified key.""" @@ -212,61 +217,14 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: configured vectorizer.""" if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - return self._vectorizer.embed(prompt) - - def _search_cache( - self, - vector: List[float], - num_results: int, - return_fields: Optional[List[str]], - tag_filter: Optional[FilterExpression], - ) -> List[Dict[str, Any]]: - """Searches the semantic cache for similar prompt vectors and returns - the specified return fields for each cache hit.""" - # Setup and type checks - if not isinstance(vector, list): - raise TypeError("Vector must be a list of floats") - - return_fields = return_fields or self.default_return_fields - if not isinstance(return_fields, list): - raise TypeError("return_fields must be a list of field names") - - # Construct vector RangeQuery for the cache check - query = RangeQuery( - vector=vector, - vector_field_name=self.vector_field_name, - return_fields=return_fields, - distance_threshold=self._distance_threshold, - num_results=num_results, - return_score=True, - ) - if tag_filter: - query.set_filter(tag_filter) # type: ignore - - # Gather and return the cache hits - cache_hits: List[Dict[str, Any]] = self._index.query(query) - # Process cache hits - for hit in cache_hits: - key = hit["id"] - self._refresh_ttl(key) - # Check for metadata and deserialize - if self.metadata_field_name in hit: - hit[self.metadata_field_name] = deserialize( - hit[self.metadata_field_name] - ) - return cache_hits + return self._vectorizer.embed(prompt) def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it doesn't match the search index vector dimensions.""" schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore - if schema_vector_dims != len(vector): - raise ValueError( - "Invalid vector dimensions! " - f"Vector has dims defined as {len(vector)}", - f"Vector field has dims defined as {schema_vector_dims}", - ) + validate_vector_dims(len(vector), schema_vector_dims) def check( self, @@ -274,7 +232,7 @@ def check( vector: Optional[List[float]] = None, num_results: int = 1, return_fields: Optional[List[str]] = None, - tag_filter: Optional[FilterExpression] = None, + filter_expression: Optional[FilterExpression] = None, ) -> List[Dict[str, Any]]: """Checks the semantic cache for results similar to the specified prompt or vector. @@ -294,8 +252,9 @@ def check( return_fields (Optional[List[str]], optional): The fields to include in each returned result. If None, defaults to all available fields in the cached entry. - tag_filter (Optional[FilterExpression]) : the tag filter to filter - results by. Default is None and full cache is searched. + filter_expression (Optional[FilterExpression]) : Optional filter expression + that can be used to filter cache results. Defaults to None and + the full cache will be searched. Returns: List[Dict[str, Any]]: A list of dicts containing the requested @@ -315,12 +274,40 @@ def check( if not (prompt or vector): raise ValueError("Either prompt or vector must be specified.") - # Use provided vector or create from prompt vector = vector or self._vectorize_prompt(prompt) self._check_vector_dims(vector) + return_fields = return_fields or self.return_fields + + if not isinstance(return_fields, list): + raise TypeError("return_fields must be a list of field names") + + query = RangeQuery( + vector=vector, + vector_field_name=self.vector_field_name, + return_fields=self.return_fields, + distance_threshold=self._distance_threshold, + num_results=num_results, + return_score=True, + filter_expression=filter_expression, + ) + + cache_hits: List[Dict[Any, str]] = [] + + # Search the cache! + cache_search_results = self._index.query(query) + + for cache_search_result in cache_search_results: + key = cache_search_result["id"] + self._refresh_ttl(key) + + # Create cache hit + cache_hit = CacheHit(**cache_search_result) + cache_hit_dict = { + k: v for k, v in cache_hit.to_dict().items() if k in return_fields + } + cache_hit_dict["key"] = key + cache_hits.append(cache_hit_dict) - # Check for cache hits by searching the cache - cache_hits = self._search_cache(vector, num_results, return_fields, tag_filter) return cache_hits def store( @@ -328,8 +315,8 @@ def store( prompt: str, response: str, vector: Optional[List[float]] = None, - metadata: Optional[dict] = None, - tag: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + filters: Optional[Dict[str, Any]] = None, ) -> str: """Stores the specified key-value pair in the cache along with metadata. @@ -339,9 +326,9 @@ def store( vector (Optional[List[float]], optional): The prompt vector to cache. Defaults to None, and the prompt vector is generated on demand. - metadata (Optional[dict], optional): The optional metadata to cache + metadata (Optional[Dict[str, Any]], optional): The optional metadata to cache alongside the prompt and response. Defaults to None. - tag (Optional[str]): The optional tag to assign to the cache entry. + filters (Optional[Dict[str, Any]]): The optional tag to assign to the cache entry. Defaults to None. Returns: @@ -362,29 +349,24 @@ def store( """ # Vectorize prompt if necessary and create cache payload vector = vector or self._vectorize_prompt(prompt) + self._check_vector_dims(vector) - # Construct semantic cache payload - now = current_timestamp() - id_field = self.entry_id_field_name - payload = { - id_field: self.hash_input(prompt), - self.prompt_field_name: prompt, - self.response_field_name: response, - self.vector_field_name: array_to_buffer(vector), - self.inserted_at_field_name: now, - self.updated_at_field_name: now, - } - if metadata is not None: - if not isinstance(metadata, dict): - raise TypeError("If specified, cached metadata must be a dictionary.") - # Serialize the metadata dict and add to cache payload - payload[self.metadata_field_name] = serialize(metadata) - if tag is not None: - payload[self.tag_field_name] = tag - - # Load LLMCache entry with TTL - keys = self._index.load(data=[payload], ttl=self._ttl, id_field=id_field) + # Build cache entry for the cache + cache_entry = CacheEntry( + prompt=prompt, + response=response, + prompt_vector=vector, + metadata=metadata, + filters=filters, + ) + + # Load cache entry with TTL + keys = self._index.load( + data=[cache_entry.to_dict()], + ttl=self._ttl, + id_field=self.entry_id_field_name, + ) return keys[0] def update(self, key: str, **kwargs) -> None: @@ -392,8 +374,7 @@ def update(self, key: str, **kwargs) -> None: are passed, then only the document TTL is refreshed. Args: - key (str): the key of the document to update. - kwargs: + key (str): the key of the document to update using kwargs. Raises: ValueError if an incorrect mapping is provided as a kwarg. @@ -404,28 +385,26 @@ def update(self, key: str, **kwargs) -> None: cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"}) ) """ - if not kwargs: - self._refresh_ttl(key) - return - - for _key, val in kwargs.items(): - if _key not in { - self.prompt_field_name, - self.vector_field_name, - self.response_field_name, - self.tag_field_name, - self.metadata_field_name, - }: - raise ValueError(f" {key} is not a valid field within document") - - # Check for metadata and deserialize - if _key == self.metadata_field_name: - if isinstance(val, dict): - kwargs[_key] = serialize(val) - else: - raise TypeError( - "If specified, cached metadata must be a dictionary." - ) - kwargs.update({self.updated_at_field_name: current_timestamp()}) - self._index.client.hset(key, mapping=kwargs) # type: ignore + if kwargs: + for k, v in kwargs.items(): + + # Make sure the item is in the index schema + if k not in set( + self._index.schema.field_names + [self.metadata_field_name] + ): + raise ValueError(f"{k} is not a valid field within the cache entry") + + # Check for metadata and deserialize + if k == self.metadata_field_name: + if isinstance(v, dict): + kwargs[k] = serialize(v) + else: + raise TypeError( + "If specified, cached metadata must be a dictionary." + ) + + kwargs.update({self.updated_at_field_name: current_timestamp()}) + + self._index.client.hset(key, mapping=kwargs) # type: ignore + self._refresh_ttl(key) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 6a01c1ce..f5e6b4a6 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -193,11 +193,6 @@ def storage_type(self) -> StorageType: hash or json.""" return self.schema.index.storage_type - @property - def client(self) -> Optional[Union[redis.Redis, aredis.Redis]]: - """The underlying redis-py client object.""" - return self._redis_client - @classmethod def from_yaml(cls, schema_path: str, **kwargs): """Create a SearchIndex from a YAML schema file. @@ -364,6 +359,11 @@ def from_existing( schema = IndexSchema.from_dict(schema_dict) return cls(schema, redis_client, **kwargs) + @property + def client(self) -> Optional[redis.Redis]: + """The underlying redis-py client object.""" + return self._redis_client + def connect(self, redis_url: Optional[str] = None, **kwargs): """Connect to a Redis instance using the provided `redis_url`, falling back to the `REDIS_URL` environment variable (if available). @@ -843,6 +843,11 @@ async def from_existing( await index.set_client(redis_client) return index + @property + def client(self) -> Optional[aredis.Redis]: + """The underlying redis-py client object.""" + return self._redis_client + async def connect(self, redis_url: Optional[str] = None, **kwargs): """Connect to a Redis instance using the provided `redis_url`, falling back to the `REDIS_URL` environment variable (if available). diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index b272ac30..34c15113 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -2,6 +2,7 @@ from time import sleep, time import pytest +from pydantic.v1 import ValidationError from redis.exceptions import ConnectionError from redisvl.extensions.llmcache import SemanticCache @@ -24,6 +25,18 @@ def cache(vectorizer, redis_url): cache_instance._index.delete(True) # Clean up index +@pytest.fixture +def cache_with_filters(vectorizer, redis_url): + cache_instance = SemanticCache( + vectorizer=vectorizer, + distance_threshold=0.2, + filterable_fields=[{"name": "label", "type": "tag"}], + redis_url=redis_url, + ) + yield cache_instance + cache_instance._index.delete(True) # Clean up index + + @pytest.fixture def cache_no_cleanup(vectorizer, redis_url): cache_instance = SemanticCache( @@ -100,32 +113,30 @@ def test_return_fields(cache, vectorizer): # check default return fields check_result = cache.check(vector=vector) assert set(check_result[0].keys()) == { - "id", - "_id", + "key", + "entry_id", "prompt", "response", - "prompt_vector", "vector_distance", + "inserted_at", + "updated_at", } - # check all return fields + # check specific return fields fields = [ - "id", - "_id", + "key", + "entry_id", "prompt", "response", - "inserted_at", - "updated_at", - "prompt_vector", "vector_distance", ] - check_result = cache.check(vector=vector, return_fields=fields[:]) + check_result = cache.check(vector=vector, return_fields=fields) assert set(check_result[0].keys()) == set(fields) # check only some return fields fields = ["inserted_at", "updated_at"] - check_result = cache.check(vector=vector, return_fields=fields[:]) - fields.extend(["id", "vector_distance"]) # id and vector_distance always returned + check_result = cache.check(vector=vector, return_fields=fields) + fields.append("key") assert set(check_result[0].keys()) == set(fields) @@ -178,7 +189,7 @@ def test_drop_document(cache, vectorizer): cache.store(prompt, response, vector=vector) check_result = cache.check(vector=vector) - cache.drop(check_result[0]["id"]) + cache.drop(ids=[check_result[0]["entry_id"]]) recheck_result = cache.check(vector=vector) assert len(recheck_result) == 0 @@ -200,8 +211,9 @@ def test_drop_documents(cache, vectorizer): cache.store(prompt, response, vector=vector) check_result = cache.check(vector=vector, num_results=3) - keys = [r["id"] for r in check_result[0:2]] # drop first 2 entries - cache.drop(keys) + print(check_result, flush=True) + ids = [r["entry_id"] for r in check_result[0:2]] # drop first 2 entries + cache.drop(ids=ids) recheck_result = cache.check(vector=vector, num_results=3) assert len(recheck_result) == 1 @@ -214,7 +226,7 @@ def test_updating_document(cache): cache.store(prompt=prompt, response=response) check_result = cache.check(prompt=prompt, return_fields=["updated_at"]) - key = check_result[0]["id"] + key = check_result[0]["key"] sleep(1) @@ -290,9 +302,7 @@ def test_store_with_invalid_metadata(cache, vectorizer): vector = vectorizer.embed(prompt) - with pytest.raises( - TypeError, match=r"If specified, cached metadata must be a dictionary." - ): + with pytest.raises(ValidationError): cache.store(prompt, response, vector=vector, metadata=metadata) @@ -381,8 +391,11 @@ def test_vector_size(cache, vectorizer): cache.check(vector=[1, 2, 3]) -# test we can pass a list of tags and we'll include all results that match -def test_multiple_tags(cache): +def test_cache_with_filters(cache_with_filters): + assert "label" in cache_with_filters._index.schema.fields + + +def test_cache_filtering(cache_with_filters): tag_1 = "group 0" tag_2 = "group 1" tag_3 = "group 2" @@ -396,43 +409,91 @@ def test_multiple_tags(cache): for i in range(4): prompt = f"test prompt {i}" response = f"test response {i}" - cache.store(prompt, response, tag=tags[i]) + cache_with_filters.store(prompt, response, filters={"label": tags[i]}) # test we can specify one specific tag - results = cache.check("test prompt 1", tag_filter=filter_1, num_results=5) + results = cache_with_filters.check( + "test prompt 1", filter_expression=filter_1, num_results=5 + ) assert len(results) == 1 assert results[0]["prompt"] == "test prompt 0" # test we can pass a list of tags combined_filter = filter_1 | filter_2 | filter_3 - results = cache.check("test prompt 1", tag_filter=combined_filter, num_results=5) + results = cache_with_filters.check( + "test prompt 1", filter_expression=combined_filter, num_results=5 + ) assert len(results) == 3 # test that default tag param searches full cache - results = cache.check("test prompt 1", num_results=5) + results = cache_with_filters.check("test prompt 1", num_results=5) assert len(results) == 4 # test no results are returned if we pass a nonexistant tag bad_filter = Tag("label") == "bad tag" - results = cache.check("test prompt 1", tag_filter=bad_filter, num_results=5) + results = cache_with_filters.check( + "test prompt 1", filter_expression=bad_filter, num_results=5 + ) assert len(results) == 0 -def test_complex_filters(cache): - cache.store(prompt="prompt 1", response="response 1") - cache.store(prompt="prompt 2", response="response 2") +def test_cache_bad_filters(vectorizer, redis_url): + with pytest.raises(ValueError): + cache_instance = SemanticCache( + vectorizer=vectorizer, + distance_threshold=0.2, + # invalid field type + filterable_fields=[ + {"name": "label", "type": "tag"}, + {"name": "test", "type": "nothing"}, + ], + redis_url=redis_url, + ) + + with pytest.raises(ValueError): + cache_instance = SemanticCache( + vectorizer=vectorizer, + distance_threshold=0.2, + # duplicate field type + filterable_fields=[ + {"name": "label", "type": "tag"}, + {"name": "label", "type": "tag"}, + ], + redis_url=redis_url, + ) + + with pytest.raises(ValueError): + cache_instance = SemanticCache( + vectorizer=vectorizer, + distance_threshold=0.2, + # reserved field name + filterable_fields=[ + {"name": "label", "type": "tag"}, + {"name": "metadata", "type": "tag"}, + ], + redis_url=redis_url, + ) + + +def test_complex_filters(cache_with_filters): + cache_with_filters.store(prompt="prompt 1", response="response 1") + cache_with_filters.store(prompt="prompt 2", response="response 2") sleep(1) current_timestamp = time() - cache.store(prompt="prompt 3", response="response 3") + cache_with_filters.store(prompt="prompt 3", response="response 3") # test we can do range filters on inserted_at and updated_at fields range_filter = Num("inserted_at") < current_timestamp - results = cache.check("prompt 1", tag_filter=range_filter, num_results=5) + results = cache_with_filters.check( + "prompt 1", filter_expression=range_filter, num_results=5 + ) assert len(results) == 2 # test we can combine range filters and text filters prompt_filter = Text("prompt") % "*pt 1" combined_filter = prompt_filter & range_filter - results = cache.check("prompt 1", tag_filter=combined_filter, num_results=5) + results = cache_with_filters.check( + "prompt 1", filter_expression=combined_filter, num_results=5 + ) assert len(results) == 1 diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py new file mode 100644 index 00000000..e3961e6b --- /dev/null +++ b/tests/unit/test_llmcache_schema.py @@ -0,0 +1,128 @@ +import json + +import pytest +from pydantic.v1 import ValidationError + +from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit +from redisvl.redis.utils import array_to_buffer, hashify + + +def test_valid_cache_entry_creation(): + entry = CacheEntry( + prompt="What is AI?", + response="AI is artificial intelligence.", + prompt_vector=[0.1, 0.2, 0.3], + ) + assert entry.entry_id == hashify("What is AI?") + assert entry.prompt == "What is AI?" + assert entry.response == "AI is artificial intelligence." + assert entry.prompt_vector == [0.1, 0.2, 0.3] + + +def test_cache_entry_with_given_entry_id(): + entry = CacheEntry( + entry_id="custom_id", + prompt="What is AI?", + response="AI is artificial intelligence.", + prompt_vector=[0.1, 0.2, 0.3], + ) + assert entry.entry_id == "custom_id" + + +def test_cache_entry_with_invalid_metadata(): + with pytest.raises(ValidationError): + CacheEntry( + prompt="What is AI?", + response="AI is artificial intelligence.", + prompt_vector=[0.1, 0.2, 0.3], + metadata="invalid_metadata", + ) + + +def test_cache_entry_to_dict(): + entry = CacheEntry( + prompt="What is AI?", + response="AI is artificial intelligence.", + prompt_vector=[0.1, 0.2, 0.3], + metadata={"author": "John"}, + filters={"category": "technology"}, + ) + result = entry.to_dict() + assert result["entry_id"] == hashify("What is AI?") + assert result["metadata"] == json.dumps({"author": "John"}) + assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3]) + assert result["category"] == "technology" + assert "filters" not in result + + +def test_valid_cache_hit_creation(): + hit = CacheHit( + entry_id="entry_1", + prompt="What is AI?", + response="AI is artificial intelligence.", + vector_distance=0.1, + inserted_at=1625819123.123, + updated_at=1625819123.123, + ) + assert hit.entry_id == "entry_1" + assert hit.prompt == "What is AI?" + assert hit.response == "AI is artificial intelligence." + assert hit.vector_distance == 0.1 + assert hit.inserted_at == hit.updated_at == 1625819123.123 + + +def test_cache_hit_with_serialized_metadata(): + hit = CacheHit( + entry_id="entry_1", + prompt="What is AI?", + response="AI is artificial intelligence.", + vector_distance=0.1, + inserted_at=1625819123.123, + updated_at=1625819123.123, + metadata=json.dumps({"author": "John"}), + ) + assert hit.metadata == {"author": "John"} + + +def test_cache_hit_to_dict(): + hit = CacheHit( + entry_id="entry_1", + prompt="What is AI?", + response="AI is artificial intelligence.", + vector_distance=0.1, + inserted_at=1625819123.123, + updated_at=1625819123.123, + filters={"category": "technology"}, + ) + result = hit.to_dict() + assert result["entry_id"] == "entry_1" + assert result["prompt"] == "What is AI?" + assert result["response"] == "AI is artificial intelligence." + assert result["vector_distance"] == 0.1 + assert result["category"] == "technology" + assert "filters" not in result + + +def test_cache_entry_with_empty_optional_fields(): + entry = CacheEntry( + prompt="What is AI?", + response="AI is artificial intelligence.", + prompt_vector=[0.1, 0.2, 0.3], + ) + result = entry.to_dict() + assert "metadata" not in result + assert "filters" not in result + + +def test_cache_hit_with_empty_optional_fields(): + hit = CacheHit( + entry_id="entry_1", + prompt="What is AI?", + response="AI is artificial intelligence.", + vector_distance=0.1, + inserted_at=1625819123.123, + updated_at=1625819123.123, + ) + result = hit.to_dict() + assert "metadata" not in result + assert "filters" not in result