Skip to content

Commit aa05797

Browse files
Check to ensure correct embedding vector dimensions are used (#177)
Currently our semantic cache allows for specifying the vector in calls to store() and check(), but if the vector dimension does not match the schema dimensions this fails silently. This PR adds a check to verify correct vector dimensions and raises an error if they do not match.
1 parent ccc039f commit aa05797

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

redisvl/extensions/llmcache/semantic.py

+16
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,17 @@ def _search_cache(
234234
)
235235
return cache_hits
236236

237+
def _check_vector_dims(self, vector: List[float]):
238+
"""Checks the size of the provided vector and raises an error if it
239+
doesn't match the search index vector dimensions."""
240+
schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore
241+
if schema_vector_dims != len(vector):
242+
raise ValueError(
243+
"Invalid vector dimensions! "
244+
f"Vector has dims defined as {len(vector)}",
245+
f"Vector field has dims defined as {schema_vector_dims}",
246+
)
247+
237248
def check(
238249
self,
239250
prompt: Optional[str] = None,
@@ -266,6 +277,7 @@ def check(
266277
267278
Raises:
268279
ValueError: If neither a `prompt` nor a `vector` is specified.
280+
ValueError: if 'vector' has incorrect dimensions.
269281
TypeError: If `return_fields` is not a list when provided.
270282
271283
.. code-block:: python
@@ -279,6 +291,7 @@ def check(
279291

280292
# Use provided vector or create from prompt
281293
vector = vector or self._vectorize_prompt(prompt)
294+
self._check_vector_dims(vector)
282295

283296
# Check for cache hits by searching the cache
284297
cache_hits = self._search_cache(vector, num_results, return_fields)
@@ -307,6 +320,7 @@ def store(
307320
308321
Raises:
309322
ValueError: If neither prompt nor vector is specified.
323+
ValueError: if vector has incorrect dimensions.
310324
TypeError: If provided metadata is not a dictionary.
311325
312326
.. code-block:: python
@@ -319,6 +333,8 @@ def store(
319333
"""
320334
# Vectorize prompt if necessary and create cache payload
321335
vector = vector or self._vectorize_prompt(prompt)
336+
self._check_vector_dims(vector)
337+
322338
# Construct semantic cache payload
323339
id_field = self.entry_id_field_name
324340
payload = {

tests/integration/test_llmcache.py

+28
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,31 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize
210210
def test_delete(cache_no_cleanup):
211211
cache_no_cleanup.delete()
212212
assert not cache_no_cleanup.index.exists()
213+
214+
215+
# Test we can only store and check vectors of correct dimensions
216+
def test_vector_size(cache, vectorizer):
217+
prompt = "This is test prompt."
218+
response = "This is a test response."
219+
220+
vector = vectorizer.embed(prompt)
221+
cache.store(prompt=prompt, response=response, vector=vector)
222+
223+
# Test we can query with modified embeddings of correct size
224+
vector_2 = [v * 0.99 for v in vector] # same dimensions
225+
check_result = cache.check(vector=vector_2)
226+
assert check_result[0]["prompt"] == prompt
227+
228+
# Test that error is raised when we try to load wrong size vectors
229+
with pytest.raises(ValueError):
230+
cache.store(prompt=prompt, response=response, vector=vector[0:-1])
231+
232+
with pytest.raises(ValueError):
233+
cache.store(prompt=prompt, response=response, vector=[1, 2, 3])
234+
235+
# Test that error is raised when we try to query with wrong size vector
236+
with pytest.raises(ValueError):
237+
cache.check(vector=vector[0:-1])
238+
239+
with pytest.raises(ValueError):
240+
cache.check(vector=[1, 2, 3])

0 commit comments

Comments
 (0)