Skip to content

Fix checking data_type before converting queries to bytes #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion engine/base_client/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import numpy as np
import tqdm
import os
from ml_dtypes import bfloat16

from dataset_reader.base_reader import Query
from engine.base_client.utils import check_data_type

DEFAULT_TOP = 10
MAX_QUERIES = int(os.getenv("MAX_QUERIES", -1))
Expand Down Expand Up @@ -66,6 +68,11 @@ def search_all(
):
parallel = self.search_params.get("parallel", 1)
top = self.search_params.get("top", None)
single_search_params = self.search_params.get("search_params", None)
if single_search_params:
data_type = check_data_type(single_search_params.get("data_type", "FLOAT32").upper())
else:
data_type = np.float32 # Default data type if not specified
# setup_search may require initialized client
self.init_client(
self.host, distance, self.connection_params, self.search_params
Expand All @@ -78,7 +85,7 @@ def search_all(
# Also, converts query vectors to bytes beforehand, preparing them for sending to client without affecting search time measurements
queries_list = []
for query in queries:
query.vector = np.array(query.vector).astype(np.float32).tobytes()
query.vector = np.array(query.vector).astype(data_type).tobytes()
queries_list.append(query)

# Handle MAX_QUERIES environment variable
Expand Down
19 changes: 19 additions & 0 deletions engine/base_client/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Iterable

from ml_dtypes import bfloat16
import numpy as np

from dataset_reader.base_reader import Record


Expand All @@ -18,3 +21,19 @@ def iter_batches(records: Iterable[Record], n: int) -> Iterable[Any]:
ids, vectors, metadata = [], [], []
if len(ids) > 0:
yield [ids, vectors, metadata]


def check_data_type(data_type: str):
valid_data_types = ["FLOAT32", "FLOAT64", "FLOAT16", "BFLOAT16"]
if data_type.upper() not in valid_data_types:
raise ValueError(
f"Invalid data type: {data_type}. Valid options are: {valid_data_types}"
)
if data_type == "FLOAT32":
return np.float32
if data_type == "FLOAT64":
return np.float64
if data_type == "FLOAT16":
return np.float16
if data_type == "BFLOAT16":
return bfloat16
10 changes: 2 additions & 8 deletions engine/clients/redis/search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import random
from typing import List, Tuple
from ml_dtypes import bfloat16
import numpy as np
from redis import Redis, RedisCluster
from redis.commands.search.query import Query
from engine.base_client.utils import check_data_type
from engine.base_client.search import BaseSearcher
from engine.clients.redis.config import (
REDIS_PORT,
Expand Down Expand Up @@ -44,13 +44,7 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
cls.data_type = (
cls.search_params["search_params"].get("data_type", "FLOAT32").upper()
)
cls.np_data_type = np.float32
if cls.data_type == "FLOAT64":
cls.np_data_type = np.float64
if cls.data_type == "FLOAT16":
cls.np_data_type = np.float16
if cls.data_type == "BFLOAT16":
cls.np_data_type = bfloat16
cls.np_data_type = check_data_type(cls.data_type)
cls._is_cluster = True if REDIS_CLUSTER else False

# In the case of CLUSTER API enabled we randomly select the starting primary shard
Expand Down
9 changes: 2 additions & 7 deletions engine/clients/redis/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from redis import Redis, RedisCluster
from engine.base_client.upload import BaseUploader
from engine.base_client.utils import check_data_type
from engine.clients.redis.config import (
REDIS_PORT,
REDIS_AUTH,
Expand Down Expand Up @@ -42,13 +43,7 @@ def init_client(cls, host, distance, connection_params, upload_params):
cls.upload_params = upload_params
cls.algorithm = cls.upload_params.get("algorithm", "hnsw").upper()
cls.data_type = cls.upload_params.get("data_type", "FLOAT32").upper()
cls.np_data_type = np.float32
if cls.data_type == "FLOAT64":
cls.np_data_type = np.float64
if cls.data_type == "FLOAT16":
cls.np_data_type = np.float16
if cls.data_type == "BFLOAT16":
cls.np_data_type = bfloat16
cls.np_data_type = check_data_type(cls.data_type)
cls._is_cluster = True if REDIS_CLUSTER else False

@classmethod
Expand Down