From 84d41411483e08b005a0bc6bd5666a1a1d9399ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Po=C5=BAniak?= Date: Fri, 8 Aug 2025 03:07:17 -0700 Subject: [PATCH 1/2] Add checking data_type function to baseclient/utils, check data_type in search_params before converting queries to bytes --- engine/base_client/search.py | 9 ++++++++- engine/base_client/utils.py | 19 +++++++++++++++++++ engine/clients/redis/search.py | 9 ++------- engine/clients/redis/upload.py | 9 ++------- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/engine/base_client/search.py b/engine/base_client/search.py index f63047b1..73774187 100644 --- a/engine/base_client/search.py +++ b/engine/base_client/search.py @@ -7,8 +7,10 @@ import numpy as np import tqdm import os +from ml_dtypes import bfloat16 from dataset_reader.base_reader import Query +from engine.base_client.utils import check_data_type DEFAULT_TOP = 10 MAX_QUERIES = int(os.getenv("MAX_QUERIES", -1)) @@ -66,6 +68,11 @@ def search_all( ): parallel = self.search_params.get("parallel", 1) top = self.search_params.get("top", None) + single_search_params = self.search_params.get("search_params", None) + if single_search_params: + data_type = check_data_type(single_search_params.get("data_type", "FLOAT32").upper()) + else: + data_type = np.float32 # Default data type if not specified # setup_search may require initialized client self.init_client( self.host, distance, self.connection_params, self.search_params @@ -78,7 +85,7 @@ def search_all( # Also, converts query vectors to bytes beforehand, preparing them for sending to client without affecting search time measurements queries_list = [] for query in queries: - query.vector = np.array(query.vector).astype(np.float32).tobytes() + query.vector = np.array(query.vector).astype(data_type).tobytes() queries_list.append(query) # Handle MAX_QUERIES environment variable diff --git a/engine/base_client/utils.py b/engine/base_client/utils.py index 4b6b8ad5..e0517178 100644 --- a/engine/base_client/utils.py +++ b/engine/base_client/utils.py @@ -1,5 +1,8 @@ from typing import Any, Iterable +from ml_dtypes import bfloat16 +import numpy as np + from dataset_reader.base_reader import Record @@ -18,3 +21,19 @@ def iter_batches(records: Iterable[Record], n: int) -> Iterable[Any]: ids, vectors, metadata = [], [], [] if len(ids) > 0: yield [ids, vectors, metadata] + + +def check_data_type(data_type: str): + valid_data_types = ["FLOAT32", "FLOAT64", "FLOAT16", "BFLOAT16"] + if data_type.upper() not in valid_data_types: + raise ValueError( + f"Invalid data type: {data_type}. Valid options are: {valid_data_types}" + ) + if data_type == "FLOAT32": + return np.float32 + if data_type == "FLOAT64": + return np.float64 + if data_type == "FLOAT16": + return np.float16 + if data_type == "BFLOAT16": + return bfloat16 diff --git a/engine/clients/redis/search.py b/engine/clients/redis/search.py index 8a5f00e0..8acd44b6 100644 --- a/engine/clients/redis/search.py +++ b/engine/clients/redis/search.py @@ -4,6 +4,7 @@ import numpy as np from redis import Redis, RedisCluster from redis.commands.search.query import Query +from engine.base_client.utils import check_data_type from engine.base_client.search import BaseSearcher from engine.clients.redis.config import ( REDIS_PORT, @@ -44,13 +45,7 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic cls.data_type = ( cls.search_params["search_params"].get("data_type", "FLOAT32").upper() ) - cls.np_data_type = np.float32 - if cls.data_type == "FLOAT64": - cls.np_data_type = np.float64 - if cls.data_type == "FLOAT16": - cls.np_data_type = np.float16 - if cls.data_type == "BFLOAT16": - cls.np_data_type = bfloat16 + cls.np_data_type = check_data_type(cls.data_type) cls._is_cluster = True if REDIS_CLUSTER else False # In the case of CLUSTER API enabled we randomly select the starting primary shard diff --git a/engine/clients/redis/upload.py b/engine/clients/redis/upload.py index 2022290c..4d055abb 100644 --- a/engine/clients/redis/upload.py +++ b/engine/clients/redis/upload.py @@ -7,6 +7,7 @@ import numpy as np from redis import Redis, RedisCluster from engine.base_client.upload import BaseUploader +from engine.base_client.utils import check_data_type from engine.clients.redis.config import ( REDIS_PORT, REDIS_AUTH, @@ -42,13 +43,7 @@ def init_client(cls, host, distance, connection_params, upload_params): cls.upload_params = upload_params cls.algorithm = cls.upload_params.get("algorithm", "hnsw").upper() cls.data_type = cls.upload_params.get("data_type", "FLOAT32").upper() - cls.np_data_type = np.float32 - if cls.data_type == "FLOAT64": - cls.np_data_type = np.float64 - if cls.data_type == "FLOAT16": - cls.np_data_type = np.float16 - if cls.data_type == "BFLOAT16": - cls.np_data_type = bfloat16 + cls.np_data_type = check_data_type(cls.data_type) cls._is_cluster = True if REDIS_CLUSTER else False @classmethod From f007b79b8e56cedf9fd5231e545922ec5308ea47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Po=C5=BAniak?= Date: Fri, 8 Aug 2025 03:10:32 -0700 Subject: [PATCH 2/2] Remove unused import --- engine/clients/redis/search.py | 1 - 1 file changed, 1 deletion(-) diff --git a/engine/clients/redis/search.py b/engine/clients/redis/search.py index 8acd44b6..a39fcb84 100644 --- a/engine/clients/redis/search.py +++ b/engine/clients/redis/search.py @@ -1,6 +1,5 @@ import random from typing import List, Tuple -from ml_dtypes import bfloat16 import numpy as np from redis import Redis, RedisCluster from redis.commands.search.query import Query