Skip to content

Commit 139eb64

Browse files
Vector object automatically converts vectors to byte strings (#411)
1 parent 6fb2dbd commit 139eb64

File tree

3 files changed

+55
-8
lines changed

3 files changed

+55
-8
lines changed

redisvl/query/aggregate.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Any, Dict, List, Optional, Set, Tuple, Union
22

3-
from pydantic import BaseModel, field_validator
3+
from pydantic import BaseModel, field_validator, model_validator
44
from redis.commands.search.aggregation import AggregateRequest, Desc
5+
from typing_extensions import Self
56

67
from redisvl.query.filter import FilterExpression
78
from redisvl.redis.utils import array_to_buffer
@@ -32,9 +33,16 @@ def validate_dtype(cls, dtype: str) -> str:
3233
raise ValueError(
3334
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
3435
)
35-
3636
return dtype
3737

38+
@model_validator(mode="after")
39+
def validate_vector(self) -> Self:
40+
"""If the vector passed in is an array of float convert it to a byte string."""
41+
if isinstance(self.vector, bytes):
42+
return self
43+
self.vector = array_to_buffer(self.vector, self.dtype)
44+
return self
45+
3846

3947
class AggregationQuery(AggregateRequest):
4048
"""
@@ -364,12 +372,8 @@ def params(self) -> Dict[str, Any]:
364372
Dict[str, Any]: The parameters for the aggregation.
365373
"""
366374
params = {}
367-
for i, (vector, dtype) in enumerate(
368-
[(v.vector, v.dtype) for v in self._vectors]
369-
):
370-
if isinstance(vector, list):
371-
vector = array_to_buffer(vector, dtype=dtype) # type: ignore
372-
params[f"vector_{i}"] = vector
375+
for i, v in enumerate(self._vectors):
376+
params[f"vector_{i}"] = v.vector
373377
return params
374378

375379
def _build_query_string(self) -> str:

tests/integration/test_aggregation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,31 @@ def test_multivector_query(index):
365365
)
366366

367367

368+
def test_multivector_query_accepts_bytes(index):
369+
skip_if_redis_version_below(index.client, "7.2.0")
370+
371+
vector_bytes = [
372+
array_to_buffer([0.1, 0.1, 0.5], "float32"),
373+
array_to_buffer([0.3, 0.4, 0.7, 0.2, -0.3, 0.25], "float64"),
374+
]
375+
vector_fields = ["user_embedding", "audio_embedding"]
376+
dtypes = ["float32", "float64"]
377+
vectors = []
378+
for vector, field, dtype in zip(vector_bytes, vector_fields, dtypes):
379+
vectors.append(Vector(vector=vector, field_name=field, dtype=dtype))
380+
381+
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
382+
383+
multi_query = MultiVectorQuery(
384+
vectors=vectors,
385+
return_fields=return_fields,
386+
)
387+
388+
results = index.query(multi_query)
389+
assert isinstance(results, list)
390+
assert len(results) == 7
391+
392+
368393
def test_multivector_query_with_filter(index):
369394
skip_if_redis_version_below(index.client, "7.2.0")
370395

tests/unit/test_aggregation_types.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from redisvl.index.index import process_results
77
from redisvl.query.aggregate import HybridQuery, MultiVectorQuery, Vector
88
from redisvl.query.filter import Tag
9+
from redisvl.redis.utils import array_to_buffer
910

1011
# Sample data for testing
1112
sample_vector = [0.1, 0.2, 0.3, 0.4]
@@ -314,3 +315,20 @@ def test_vector_object_validation():
314315
for dtype in ["bfloat16", "float16", "float32", "float64", "int8", "uint8"]:
315316
vec = Vector(vector=sample_vector, field_name="text embedding", dtype=dtype)
316317
assert isinstance(vec, Vector)
318+
319+
320+
def test_vector_object_handles_byte_conversion():
321+
# test that passing an array of floats gets converted to bytes
322+
vec = Vector(vector=sample_vector, field_name="field 1", dtype="float16")
323+
assert vec.vector == array_to_buffer(sample_vector, dtype="float16")
324+
325+
# test we can pass an array of floats and convert to all supported dtypes
326+
for datatype in ["bfloat16", "float16", "float32", "float64"]:
327+
vec = Vector(vector=sample_vector, field_name="field 1", dtype=datatype)
328+
assert vec.vector == array_to_buffer(sample_vector, dtype=datatype)
329+
330+
# test that passing in a byte string it is stored unchanged
331+
for datatype in ["bfloat16", "float16", "float32", "float64"]:
332+
byte_string = array_to_buffer(sample_vector, datatype)
333+
vec = Vector(vector=byte_string, field_name="field 1")
334+
assert vec.vector == byte_string

0 commit comments

Comments
 (0)