Skip to content

Commit 6379951

Browse files
sets the default datatype in our vectorizers to float32 if not specified by users (#253)
1 parent 0287885 commit 6379951

File tree

11 files changed

+202
-49
lines changed

11 files changed

+202
-49
lines changed

docs/user_guide/vectorizers_04.ipynb

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,43 @@
823823
" print(doc[\"text\"], doc[\"vector_distance\"])"
824824
]
825825
},
826+
{
827+
"cell_type": "markdown",
828+
"metadata": {},
829+
"source": [
830+
"## Selecting your float data type\n",
831+
"When embedding text as byte arrays RedisVL supports 4 different floating point data types, `float16`, `float32`, `float64` and `bfloat16`.\n",
832+
"Your dtype set for your vectorizer must match what is defined in your search index. If one is not explicitly set the default is `float32`."
833+
]
834+
},
835+
{
836+
"cell_type": "code",
837+
"execution_count": null,
838+
"metadata": {},
839+
"outputs": [
840+
{
841+
"data": {
842+
"text/plain": [
843+
"True"
844+
]
845+
},
846+
"execution_count": 4,
847+
"metadata": {},
848+
"output_type": "execute_result"
849+
}
850+
],
851+
"source": [
852+
"vectorizer = HFTextVectorizer(dtype=\"float16\")\n",
853+
"\n",
854+
"# subsequent calls to embed('', as_buffer=True) and embed_many('', as_buffer=True) will now encode as float16\n",
855+
"float16_bytes = vectorizer.embed('test sentence', as_buffer=True)\n",
856+
"\n",
857+
"# you can override this setting on each individual method call\n",
858+
"float64_bytes = vectorizer.embed('test sentence', as_buffer=True, dtype=\"float64\")\n",
859+
"\n",
860+
"float16_bytes != float64_bytes"
861+
]
862+
},
826863
{
827864
"cell_type": "code",
828865
"execution_count": null,

redisvl/utils/vectorize/base.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pydantic.v1 import BaseModel, validator
66

77
from redisvl.redis.utils import array_to_buffer
8+
from redisvl.schema.fields import VectorDataType
89

910

1011
class Vectorizers(Enum):
@@ -19,11 +20,22 @@ class Vectorizers(Enum):
1920
class BaseVectorizer(BaseModel, ABC):
2021
model: str
2122
dims: int
23+
dtype: str
2224

2325
@property
2426
def type(self) -> str:
2527
return "base"
2628

29+
@validator("dtype")
30+
def check_dtype(dtype):
31+
try:
32+
VectorDataType(dtype.upper())
33+
except ValueError:
34+
raise ValueError(
35+
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
36+
)
37+
return dtype
38+
2739
@validator("dims")
2840
@classmethod
2941
def check_dims(cls, value):
@@ -81,13 +93,7 @@ def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None):
8193
else:
8294
yield seq[pos : pos + size]
8395

84-
def _process_embedding(
85-
self, embedding: List[float], as_buffer: bool, dtype: Optional[str]
86-
):
96+
def _process_embedding(self, embedding: List[float], as_buffer: bool, dtype: str):
8797
if as_buffer:
88-
if not dtype:
89-
raise RuntimeError(
90-
"dtype is required if converting from float to byte string."
91-
)
9298
return array_to_buffer(embedding, dtype)
9399
return embedding

redisvl/utils/vectorize/text/azureopenai.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ class AzureOpenAITextVectorizer(BaseVectorizer):
5252
_aclient: Any = PrivateAttr()
5353

5454
def __init__(
55-
self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None
55+
self,
56+
model: str = "text-embedding-ada-002",
57+
api_config: Optional[Dict] = None,
58+
dtype: str = "float32",
5659
):
5760
"""Initialize the AzureOpenAI vectorizer.
5861
@@ -63,13 +66,17 @@ def __init__(
6366
api_config (Optional[Dict], optional): Dictionary containing the
6467
API key, API version, Azure endpoint, and any other API options.
6568
Defaults to None.
69+
dtype (str): the default datatype to use when embedding text as byte arrays.
70+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
71+
Defaults to 'float32'.
6672
6773
Raises:
6874
ImportError: If the openai library is not installed.
6975
ValueError: If the AzureOpenAI API key, version, or endpoint are not provided.
76+
ValueError: If an invalid dtype is provided.
7077
"""
7178
self._initialize_clients(api_config)
72-
super().__init__(model=model, dims=self._set_model_dims(model))
79+
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
7380

7481
def _initialize_clients(self, api_config: Optional[Dict]):
7582
"""
@@ -190,7 +197,7 @@ def embed_many(
190197
if len(texts) > 0 and not isinstance(texts[0], str):
191198
raise TypeError("Must pass in a list of str values to embed.")
192199

193-
dtype = kwargs.pop("dtype", None)
200+
dtype = kwargs.pop("dtype", self.dtype)
194201

195202
embeddings: List = []
196203
for batch in self.batchify(texts, batch_size, preprocess):
@@ -234,7 +241,7 @@ def embed(
234241
if preprocess:
235242
text = preprocess(text)
236243

237-
dtype = kwargs.pop("dtype", None)
244+
dtype = kwargs.pop("dtype", self.dtype)
238245

239246
result = self._client.embeddings.create(input=[text], model=self.model)
240247
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@@ -274,7 +281,7 @@ async def aembed_many(
274281
if len(texts) > 0 and not isinstance(texts[0], str):
275282
raise TypeError("Must pass in a list of str values to embed.")
276283

277-
dtype = kwargs.pop("dtype", None)
284+
dtype = kwargs.pop("dtype", self.dtype)
278285

279286
embeddings: List = []
280287
for batch in self.batchify(texts, batch_size, preprocess):
@@ -320,7 +327,7 @@ async def aembed(
320327
if preprocess:
321328
text = preprocess(text)
322329

323-
dtype = kwargs.pop("dtype", None)
330+
dtype = kwargs.pop("dtype", self.dtype)
324331

325332
result = await self._aclient.embeddings.create(input=[text], model=self.model)
326333
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)

redisvl/utils/vectorize/text/bedrock.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
self,
5050
model: str = "amazon.titan-embed-text-v2:0",
5151
api_config: Optional[Dict[str, str]] = None,
52+
dtype: str = "float32",
5253
) -> None:
5354
"""Initialize the AWS Bedrock Vectorizer.
5455
@@ -57,10 +58,14 @@ def __init__(
5758
api_config (Optional[Dict[str, str]]): AWS credentials and config.
5859
Can include: aws_access_key_id, aws_secret_access_key, aws_region
5960
If not provided, will use environment variables.
61+
dtype (str): the default datatype to use when embedding text as byte arrays.
62+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
63+
Defaults to 'float32'.
6064
6165
Raises:
6266
ValueError: If credentials are not provided in config or environment.
6367
ImportError: If boto3 is not installed.
68+
ValueError: If an invalid dtype is provided.
6469
"""
6570
try:
6671
import boto3 # type: ignore
@@ -94,7 +99,7 @@ def __init__(
9499
region_name=aws_region,
95100
)
96101

97-
super().__init__(model=model, dims=self._set_model_dims(model))
102+
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
98103

99104
def _set_model_dims(self, model: str) -> int:
100105
"""Initialize model and determine embedding dimensions."""
@@ -145,7 +150,7 @@ def embed(
145150
response_body = json.loads(response["body"].read())
146151
embedding = response_body["embedding"]
147152

148-
dtype = kwargs.pop("dtype", None)
153+
dtype = kwargs.pop("dtype", self.dtype)
149154
return self._process_embedding(embedding, as_buffer, dtype)
150155

151156
@retry(
@@ -181,7 +186,7 @@ def embed_many(
181186
raise TypeError("Texts must be a list of strings")
182187

183188
embeddings: List[List[float]] = []
184-
dtype = kwargs.pop("dtype", None)
189+
dtype = kwargs.pop("dtype", self.dtype)
185190

186191
for batch in self.batchify(texts, batch_size, preprocess):
187192
# Process each text in the batch individually since Bedrock

redisvl/utils/vectorize/text/cohere.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ class CohereTextVectorizer(BaseVectorizer):
4747
_client: Any = PrivateAttr()
4848

4949
def __init__(
50-
self, model: str = "embed-english-v3.0", api_config: Optional[Dict] = None
50+
self,
51+
model: str = "embed-english-v3.0",
52+
api_config: Optional[Dict] = None,
53+
dtype: str = "float32",
5154
):
5255
"""Initialize the Cohere vectorizer.
5356
@@ -57,14 +60,17 @@ def __init__(
5760
model (str): Model to use for embedding. Defaults to 'embed-english-v3.0'.
5861
api_config (Optional[Dict], optional): Dictionary containing the API key.
5962
Defaults to None.
63+
dtype (str): the default datatype to use when embedding text as byte arrays.
64+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
65+
Defaults to 'float32'.
6066
6167
Raises:
6268
ImportError: If the cohere library is not installed.
6369
ValueError: If the API key is not provided.
64-
70+
ValueError: If an invalid dtype is provided.
6571
"""
6672
self._initialize_client(api_config)
67-
super().__init__(model=model, dims=self._set_model_dims(model))
73+
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
6874

6975
def _initialize_client(self, api_config: Optional[Dict]):
7076
"""
@@ -159,7 +165,7 @@ def embed(
159165
if preprocess:
160166
text = preprocess(text)
161167

162-
dtype = kwargs.pop("dtype", None)
168+
dtype = kwargs.pop("dtype", self.dtype)
163169

164170
embedding = self._client.embed(
165171
texts=[text], model=self.model, input_type=input_type
@@ -228,7 +234,7 @@ def embed_many(
228234
See https://docs.cohere.com/reference/embed."
229235
)
230236

231-
dtype = kwargs.pop("dtype", None)
237+
dtype = kwargs.pop("dtype", self.dtype)
232238

233239
embeddings: List = []
234240
for batch in self.batchify(texts, batch_size, preprocess):

redisvl/utils/vectorize/text/custom.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class CustomTextVectorizer(BaseVectorizer):
10-
"""The CustomTextVectorizer class wraps user-defined embeding methods to create
10+
"""The CustomTextVectorizer class wraps user-defined embedding methods to create
1111
embeddings for text data.
1212
1313
This vectorizer is designed to accept a provided callable text vectorizer and
@@ -44,6 +44,7 @@ def __init__(
4444
embed_many: Optional[Callable] = None,
4545
aembed: Optional[Callable] = None,
4646
aembed_many: Optional[Callable] = None,
47+
dtype: str = "float32",
4748
):
4849
"""Initialize the Custom vectorizer.
4950
@@ -52,10 +53,14 @@ def __init__(
5253
embed_many (Optional[Callable)]: a Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None.
5354
aembed (Optional[Callable]): an asyncronous Callable function that accepts a string object and returns a lists of floats. Defaults to None.
5455
aembed_many (Optional[Callable]): an asyncronous Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None.
56+
dtype (str): the default datatype to use when embedding text as byte arrays.
57+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
58+
Defaults to 'float32'.
5559
5660
Raises:
57-
ValueError if any of the provided functions accept or return incorrect types.
58-
TypeError if any of the provided functions are not Callable objects.
61+
ValueError: if any of the provided functions accept or return incorrect types.
62+
TypeError: if any of the provided functions are not Callable objects.
63+
ValueError: If an invalid dtype is provided.
5964
"""
6065

6166
self._validate_embed(embed)
@@ -71,7 +76,7 @@ def __init__(
7176
self._validate_aembed_many(aembed_many)
7277
self._aembed_many_func = aembed_many
7378

74-
super().__init__(model=self.type, dims=self._set_model_dims())
79+
super().__init__(model=self.type, dims=self._set_model_dims(), dtype=dtype)
7580

7681
def _validate_embed(self, func: Callable):
7782
"""calls the func with dummy input and validates that it returns a vector"""
@@ -173,7 +178,7 @@ def embed(
173178
if preprocess:
174179
text = preprocess(text)
175180

176-
dtype = kwargs.pop("dtype", None)
181+
dtype = kwargs.pop("dtype", self.dtype)
177182

178183
result = self._embed_func(text, **kwargs)
179184
return self._process_embedding(result, as_buffer, dtype)
@@ -212,7 +217,7 @@ def embed_many(
212217
if not self._embed_many_func:
213218
raise NotImplementedError
214219

215-
dtype = kwargs.pop("dtype", None)
220+
dtype = kwargs.pop("dtype", self.dtype)
216221

217222
embeddings: List = []
218223
for batch in self.batchify(texts, batch_size, preprocess):
@@ -254,7 +259,7 @@ async def aembed(
254259
if preprocess:
255260
text = preprocess(text)
256261

257-
dtype = kwargs.pop("dtype", None)
262+
dtype = kwargs.pop("dtype", self.dtype)
258263

259264
result = await self._aembed_func(text, **kwargs)
260265
return self._process_embedding(result, as_buffer, dtype)
@@ -293,7 +298,7 @@ async def aembed_many(
293298
if not self._aembed_many_func:
294299
raise NotImplementedError
295300

296-
dtype = kwargs.pop("dtype", None)
301+
dtype = kwargs.pop("dtype", self.dtype)
297302

298303
embeddings: List = []
299304
for batch in self.batchify(texts, batch_size, preprocess):

redisvl/utils/vectorize/text/huggingface.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,28 @@ class HFTextVectorizer(BaseVectorizer):
3333
_client: Any = PrivateAttr()
3434

3535
def __init__(
36-
self, model: str = "sentence-transformers/all-mpnet-base-v2", **kwargs
36+
self,
37+
model: str = "sentence-transformers/all-mpnet-base-v2",
38+
dtype: str = "float32",
39+
**kwargs,
3740
):
3841
"""Initialize the Hugging Face text vectorizer.
3942
4043
Args:
4144
model (str): The pre-trained model from Hugging Face's Sentence
4245
Transformers to be used for embedding. Defaults to
4346
'sentence-transformers/all-mpnet-base-v2'.
47+
dtype (str): the default datatype to use when embedding text as byte arrays.
48+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
49+
Defaults to 'float32'.
4450
4551
Raises:
4652
ImportError: If the sentence-transformers library is not installed.
4753
ValueError: If there is an error setting the embedding model dimensions.
54+
ValueError: If an invalid dtype is provided.
4855
"""
4956
self._initialize_client(model)
50-
super().__init__(model=model, dims=self._set_model_dims())
57+
super().__init__(model=model, dims=self._set_model_dims(), dtype=dtype)
5158

5259
def _initialize_client(self, model: str):
5360
"""Setup the HuggingFace client"""
@@ -100,7 +107,7 @@ def embed(
100107
if preprocess:
101108
text = preprocess(text)
102109

103-
dtype = kwargs.pop("dtype", None)
110+
dtype = kwargs.pop("dtype", self.dtype)
104111

105112
embedding = self._client.encode([text], **kwargs)[0]
106113
return self._process_embedding(embedding.tolist(), as_buffer, dtype)
@@ -136,7 +143,7 @@ def embed_many(
136143
if len(texts) > 0 and not isinstance(texts[0], str):
137144
raise TypeError("Must pass in a list of str values to embed.")
138145

139-
dtype = kwargs.pop("dtype", None)
146+
dtype = kwargs.pop("dtype", self.dtype)
140147

141148
embeddings: List = []
142149
for batch in self.batchify(texts, batch_size, preprocess):

0 commit comments

Comments
 (0)