From 1aea7f1b55f2918b065b259ee0ecb98d624368df Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Tue, 2 Jul 2024 15:25:41 -0700 Subject: [PATCH 1/2] adds check to ensure correct embedding vector dimensions are used --- redisvl/extensions/llmcache/semantic.py | 14 +++++++++++++ tests/integration/test_llmcache.py | 28 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 07602ac7..3c55e7e5 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -234,6 +234,15 @@ def _search_cache( ) return cache_hits + def _check_vector_dims(self, vector: List[float]): + schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore + if schema_vector_dims != len(vector): + raise ValueError( + "Invalid vector dimensions! " + f"Vector has dims defined as {len(vector)}", + f"Vector field has dims defined as {schema_vector_dims}", + ) + def check( self, prompt: Optional[str] = None, @@ -266,6 +275,7 @@ def check( Raises: ValueError: If neither a `prompt` nor a `vector` is specified. + ValueError: if 'vector' has incorrect dimensions. TypeError: If `return_fields` is not a list when provided. .. code-block:: python @@ -279,6 +289,7 @@ def check( # Use provided vector or create from prompt vector = vector or self._vectorize_prompt(prompt) + self._check_vector_dims(vector) # Check for cache hits by searching the cache cache_hits = self._search_cache(vector, num_results, return_fields) @@ -307,6 +318,7 @@ def store( Raises: ValueError: If neither prompt nor vector is specified. + ValueError: if vector has incorrect dimensions. TypeError: If provided metadata is not a dictionary. .. code-block:: python @@ -319,6 +331,8 @@ def store( """ # Vectorize prompt if necessary and create cache payload vector = vector or self._vectorize_prompt(prompt) + self._check_vector_dims(vector) + # Construct semantic cache payload id_field = self.entry_id_field_name payload = { diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 373ca8da..2bb107fd 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -210,3 +210,31 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize def test_delete(cache_no_cleanup): cache_no_cleanup.delete() assert not cache_no_cleanup.index.exists() + + +# Test we can only store and check vectors of correct dimensions +def test_vector_size(cache, vectorizer): + prompt = "This is test prompt." + response = "This is a test response." + + vector = vectorizer.embed(prompt) + cache.store(prompt=prompt, response=response, vector=vector) + + # Test we can query with modified embeddings of correct size + vector_2 = [v * 0.99 for v in vector] # same dimensions + check_result = cache.check(vector=vector_2) + assert check_result[0]["prompt"] == prompt + + # Test that error is raised when we try to load wrong size vectors + with pytest.raises(ValueError): + cache.store(prompt=prompt, response=response, vector=vector[0:-1]) + + with pytest.raises(ValueError): + cache.store(prompt=prompt, response=response, vector=[1, 2, 3]) + + # Test that error is raised when we try to query with wrong size vector + with pytest.raises(ValueError): + cache.check(vector=vector[0:-1]) + + with pytest.raises(ValueError): + cache.check(vector=[1, 2, 3]) From 538757e21b72e6ab9f14ceddedb0e3a6d990d399 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Tue, 2 Jul 2024 16:26:12 -0700 Subject: [PATCH 2/2] adds doc string to vector dimension check --- redisvl/extensions/llmcache/semantic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 3c55e7e5..3956e7d6 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -235,6 +235,8 @@ def _search_cache( return cache_hits def _check_vector_dims(self, vector: List[float]): + """Checks the size of the provided vector and raises an error if it + doesn't match the search index vector dimensions.""" schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore if schema_vector_dims != len(vector): raise ValueError(