From a3542207b50d420a4fd1fbafc5c9b8be3cc5245e Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 27 Feb 2025 09:54:10 -0500 Subject: [PATCH 1/5] add kwargs support to all vectorizer embed methods --- redisvl/utils/vectorize/text/azureopenai.py | 14 ++++++++++---- redisvl/utils/vectorize/text/bedrock.py | 4 ++-- redisvl/utils/vectorize/text/cohere.py | 8 ++++---- redisvl/utils/vectorize/text/mistral.py | 12 ++++++++---- redisvl/utils/vectorize/text/openai.py | 14 ++++++++++---- redisvl/utils/vectorize/text/vertexai.py | 4 ++-- redisvl/utils/vectorize/text/voyageai.py | 13 ++++++------- 7 files changed, 42 insertions(+), 27 deletions(-) diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 7b3b7d01..61ba7425 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -205,7 +205,9 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create(input=batch, model=self.model) + response = self._client.embeddings.create( + input=batch, model=self.model, **kwargs + ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) for r in response.data @@ -248,7 +250,9 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.embeddings.create(input=[text], model=self.model) + result = self._client.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @retry( @@ -292,7 +296,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._aclient.embeddings.create( - input=batch, model=self.model + input=batch, model=self.model, **kwargs ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) @@ -336,7 +340,9 @@ async def aembed( dtype = kwargs.pop("dtype", self.dtype) - result = await self._aclient.embeddings.create(input=[text], model=self.model) + result = await self._aclient.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @property diff --git a/redisvl/utils/vectorize/text/bedrock.py b/redisvl/utils/vectorize/text/bedrock.py index 5858aff8..c985909a 100644 --- a/redisvl/utils/vectorize/text/bedrock.py +++ b/redisvl/utils/vectorize/text/bedrock.py @@ -156,7 +156,7 @@ def embed( text = preprocess(text) response = self._client.invoke_model( - modelId=self.model, body=json.dumps({"inputText": text}) + modelId=self.model, body=json.dumps({"inputText": text}), **kwargs ) response_body = json.loads(response["body"].read()) embedding = response_body["embedding"] @@ -206,7 +206,7 @@ def embed_many( batch_embeddings = [] for text in batch: response = self._client.invoke_model( - modelId=self.model, body=json.dumps({"inputText": text}) + modelId=self.model, body=json.dumps({"inputText": text}), **kwargs ) response_body = json.loads(response["body"].read()) batch_embeddings.append(response_body["embedding"]) diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index bd6481fe..d298931c 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -156,7 +156,7 @@ def embed( TypeError: In an invalid input_type is provided. """ - input_type = kwargs.get("input_type") + input_type = kwargs.pop("input_type", None) if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") @@ -172,7 +172,7 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) embedding = self._client.embed( - texts=[text], model=self.model, input_type=input_type + texts=[text], model=self.model, input_type=input_type, **kwargs ).embeddings[0] return self._process_embedding(embedding, as_buffer, dtype) @@ -227,7 +227,7 @@ def embed_many( TypeError: In an invalid input_type is provided. """ - input_type = kwargs.get("input_type") + input_type = kwargs.pop("input_type", None) if not isinstance(texts, list): raise TypeError("Must pass in a list of str values to embed.") @@ -244,7 +244,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embed( - texts=batch, model=self.model, input_type=input_type + texts=batch, model=self.model, input_type=input_type, **kwargs ) embeddings += [ self._process_embedding(embedding, as_buffer, dtype) diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py index e930b3a4..aabd4234 100644 --- a/redisvl/utils/vectorize/text/mistral.py +++ b/redisvl/utils/vectorize/text/mistral.py @@ -155,7 +155,9 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create(model=self.model, inputs=batch) + response = self._client.embeddings.create( + model=self.model, inputs=batch, **kwargs + ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) for r in response.data @@ -198,7 +200,9 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.embeddings.create(model=self.model, inputs=[text]) + result = self._client.embeddings.create( + model=self.model, inputs=[text], **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @retry( @@ -242,7 +246,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._client.embeddings.create_async( - model=self.model, inputs=batch + model=self.model, inputs=batch, **kwargs ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) @@ -287,7 +291,7 @@ async def aembed( dtype = kwargs.pop("dtype", self.dtype) result = await self._client.embeddings.create_async( - model=self.model, inputs=[text] + model=self.model, inputs=[text], **kwargs ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index 25b21c67..da8c76d3 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -156,7 +156,9 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create(input=batch, model=self.model) + response = self._client.embeddings.create( + input=batch, model=self.model, **kwargs + ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) for r in response.data @@ -199,7 +201,9 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.embeddings.create(input=[text], model=self.model) + result = self._client.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @retry( @@ -243,7 +247,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._aclient.embeddings.create( - input=batch, model=self.model + input=batch, model=self.model, **kwargs ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) @@ -287,7 +291,9 @@ async def aembed( dtype = kwargs.pop("dtype", self.dtype) - result = await self._aclient.embeddings.create(input=[text], model=self.model) + result = await self._aclient.embeddings.create( + input=[text], model=self.model, **kwargs + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @property diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index 6d455c67..36962e87 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -168,7 +168,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.get_embeddings(batch) + response = self._client.get_embeddings(batch, **kwargs) embeddings += [ self._process_embedding(r.values, as_buffer, dtype) for r in response ] @@ -210,7 +210,7 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - result = self._client.get_embeddings([text]) + result = self._client.get_embeddings([text], **kwargs) return self._process_embedding(result[0].values, as_buffer, dtype) @property diff --git a/redisvl/utils/vectorize/text/voyageai.py b/redisvl/utils/vectorize/text/voyageai.py index fbcbfd9e..67a8015a 100644 --- a/redisvl/utils/vectorize/text/voyageai.py +++ b/redisvl/utils/vectorize/text/voyageai.py @@ -204,8 +204,8 @@ def embed_many( TypeError: If an invalid input_type is provided. """ - input_type = kwargs.get("input_type") - truncation = kwargs.get("truncation") + input_type = kwargs.pop("input_type", None) + truncation = kwargs.pop("truncation", None) dtype = kwargs.pop("dtype", self.dtype) if not isinstance(texts, list): @@ -235,7 +235,7 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embed( - texts=batch, model=self.model, input_type=input_type + texts=batch, model=self.model, input_type=input_type, **kwargs ) embeddings += [ self._process_embedding(embedding, as_buffer, dtype) @@ -284,8 +284,8 @@ async def aembed_many( TypeError: In an invalid input_type is provided. """ - input_type = kwargs.get("input_type") - truncation = kwargs.get("truncation") + input_type = kwargs.pop("input_type", None) + truncation = kwargs.pop("truncation", None) dtype = kwargs.pop("dtype", self.dtype) if not isinstance(texts, list): @@ -315,7 +315,7 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = await self._aclient.embed( - texts=batch, model=self.model, input_type=input_type + texts=batch, model=self.model, input_type=input_type, **kwargs ) embeddings += [ self._process_embedding(embedding, as_buffer, dtype) @@ -360,7 +360,6 @@ async def aembed( Raises: TypeError: In an invalid input_type is provided. """ - result = await self.aembed_many( texts=[text], preprocess=preprocess, as_buffer=as_buffer, **kwargs ) From d51e4f3f8f9fdb17eb9c5fbf35ba4857d0611f76 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 27 Feb 2025 14:13:33 -0500 Subject: [PATCH 2/5] vectorizer typing changes --- redisvl/extensions/llmcache/semantic.py | 6 +- redisvl/extensions/router/semantic.py | 8 +- .../session_manager/semantic_session.py | 2 +- redisvl/utils/vectorize/base.py | 66 +- redisvl/utils/vectorize/text/azureopenai.py | 22 +- redisvl/utils/vectorize/text/bedrock.py | 18 +- redisvl/utils/vectorize/text/cohere.py | 93 ++- redisvl/utils/vectorize/text/custom.py | 18 +- redisvl/utils/vectorize/text/huggingface.py | 14 +- redisvl/utils/vectorize/text/mistral.py | 12 +- redisvl/utils/vectorize/text/openai.py | 22 +- redisvl/utils/vectorize/text/vertexai.py | 16 +- redisvl/utils/vectorize/text/voyageai.py | 12 +- tests/integration/test_vectorizers.py | 778 ++++++++++-------- 14 files changed, 658 insertions(+), 429 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 41e6e214..c741ca45 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -310,7 +310,8 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - return self._vectorizer.embed(prompt) + result = self._vectorizer.embed(prompt) + return result # type: ignore async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: """Converts a text prompt to its vector representation using the @@ -318,7 +319,8 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - return await self._vectorizer.aembed(prompt) + result = await self._vectorizer.aembed(prompt) + return result # type: ignore def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 4a7e72c3..8aff7524 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -366,14 +366,14 @@ def __call__( if not vector: if not statement: raise ValueError("Must provide a vector or statement to the router") - vector = self.vectorizer.embed(statement) + vector = self.vectorizer.embed(statement) # type: ignore aggregation_method = ( aggregation_method or self.routing_config.aggregation_method ) # perform route classification - top_route_match = self._classify_route(vector, aggregation_method) + top_route_match = self._classify_route(vector, aggregation_method) # type: ignore return top_route_match @deprecated_argument("distance_threshold") @@ -400,7 +400,7 @@ def route_many( if not vector: if not statement: raise ValueError("Must provide a vector or statement to the router") - vector = self.vectorizer.embed(statement) + vector = self.vectorizer.embed(statement) # type: ignore max_k = max_k or self.routing_config.max_k aggregation_method = ( @@ -409,7 +409,7 @@ def route_many( # classify routes top_route_matches = self._classify_multi_route( - vector, max_k, aggregation_method + vector, max_k, aggregation_method # type: ignore ) return top_route_matches diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 6825afa9..1aa15315 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -349,7 +349,7 @@ def add_messages( role=message[ROLE_FIELD_NAME], content=message[CONTENT_FIELD_NAME], session_tag=session_tag, - vector_field=content_vector, + vector_field=content_vector, # type: ignore ) if TOOL_FIELD_NAME in message: diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index b3a63fa9..c8a80677 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -49,24 +49,47 @@ def check_dims(cls, value): return value @abstractmethod - def embed_many( + def embed( self, - texts: List[str], + text: str, preprocess: Optional[Callable] = None, - batch_size: int = 1000, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[float], bytes]: + """Embed a chunk of text. + + Args: + text: Text to embed + preprocess: Optional function to preprocess text + as_buffer: If True, returns a bytes object instead of a list + + Returns: + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True + """ raise NotImplementedError @abstractmethod - def embed( + def embed_many( self, - text: str, + texts: List[str], preprocess: Optional[Callable] = None, + batch_size: int = 1000, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[List[float]], List[bytes]]: + """Embed multiple chunks of text. + + Args: + texts: List of texts to embed + preprocess: Optional function to preprocess text + batch_size: Number of texts to process in each batch + as_buffer: If True, returns each embedding as a bytes object + + Returns: + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True + """ raise NotImplementedError async def aembed_many( @@ -76,7 +99,19 @@ async def aembed_many( batch_size: int = 1000, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: + """Asynchronously embed multiple chunks of text. + + Args: + texts: List of texts to embed + preprocess: Optional function to preprocess text + batch_size: Number of texts to process in each batch + as_buffer: If True, returns each embedding as a bytes object + + Returns: + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True + """ # Fallback to standard embedding call if no async support return self.embed_many(texts, preprocess, batch_size, as_buffer, **kwargs) @@ -86,7 +121,18 @@ async def aembed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: + """Asynchronously embed a chunk of text. + + Args: + text: Text to embed + preprocess: Optional function to preprocess text + as_buffer: If True, returns a bytes object instead of a list + + Returns: + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True + """ # Fallback to standard embedding call if no async support return self.embed(text, preprocess, as_buffer, **kwargs) diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 61ba7425..ad4b7210 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -178,7 +178,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of texts using the AzureOpenAI API. Args: @@ -191,7 +191,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -226,7 +227,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the AzureOpenAI API. Args: @@ -237,7 +238,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -268,7 +270,7 @@ async def aembed_many( batch_size: int = 1000, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Asynchronously embed many chunks of texts using the AzureOpenAI API. Args: @@ -281,7 +283,8 @@ async def aembed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -316,7 +319,7 @@ async def aembed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Asynchronously embed a chunk of text using the OpenAI API. Args: @@ -327,7 +330,8 @@ async def aembed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. diff --git a/redisvl/utils/vectorize/text/bedrock.py b/redisvl/utils/vectorize/text/bedrock.py index c985909a..2d40685d 100644 --- a/redisvl/utils/vectorize/text/bedrock.py +++ b/redisvl/utils/vectorize/text/bedrock.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -135,8 +135,8 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: - """Embed a chunk of text using Amazon Bedrock. + ) -> Union[List[float], bytes]: + """Embed a chunk of text using the AWS Bedrock Embeddings API. Args: text (str): Text to embed. @@ -144,7 +144,8 @@ def embed( as_buffer (bool): Whether to return as byte buffer. Returns: - List[float]: The embedding vector. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If text is not a string. @@ -177,17 +178,18 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: - """Embed multiple texts using Amazon Bedrock. + ) -> Union[List[List[float]], List[bytes]]: + """Embed many chunks of text using the AWS Bedrock Embeddings API. Args: texts (List[str]): List of texts to embed. preprocess (Optional[Callable]): Optional preprocessing function. - batch_size (int): Size of batches for processing. + batch_size (int): Size of batches for processing. Defaults to 10. as_buffer (bool): Whether to return as byte buffers. Returns: - List[List[float]]: List of embedding vectors. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If texts is not a list of strings. diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index d298931c..4e6192e2 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -1,5 +1,6 @@ import os -from typing import Any, Callable, Dict, List, Optional +import warnings +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -64,7 +65,8 @@ def __init__( Defaults to None. dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). - Defaults to 'float32'. + 'float32' will use Cohere's float embeddings, 'int8' and 'uint8' will map + to Cohere's corresponding embedding types. Defaults to 'float32'. Raises: ImportError: If the cohere library is not installed. @@ -114,6 +116,15 @@ def _set_model_dims(self) -> int: raise ValueError(f"Error setting embedding model dimensions: {str(e)}") return len(embedding) + def _get_cohere_embedding_type(self, dtype: str) -> List[str]: + """Map dtype to appropriate Cohere embedding_types value.""" + if dtype == "int8": + return ["int8"] + elif dtype == "uint8": + return ["uint8"] + else: + return ["float"] + @deprecated_argument("dtype") def embed( self, @@ -121,7 +132,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], List[int], bytes]: """Embed a chunk of text using the Cohere Embeddings API. Must provide the embedding `input_type` as a `kwarg` to this method @@ -150,7 +161,11 @@ def embed( Required for embedding models v3 and higher. Returns: - List[float]: Embedding. + Union[List[float], List[int], bytes]: + - If as_buffer=True: Returns a bytes object + - If as_buffer=False: + - For dtype="float32": Returns a list of floats + - For dtype="int8" or "uint8": Returns a list of integers Raises: TypeError: In an invalid input_type is provided. @@ -171,9 +186,34 @@ def embed( dtype = kwargs.pop("dtype", self.dtype) - embedding = self._client.embed( - texts=[text], model=self.model, input_type=input_type, **kwargs - ).embeddings[0] + # Check if embedding_types was provided and warn user + if "embedding_types" in kwargs: + warnings.warn( + "The 'embedding_types' parameter is not supported in CohereTextVectorizer. " + "Please use the 'dtype' parameter instead. Your 'embedding_types' value will be ignored.", + UserWarning, + stacklevel=2, + ) + kwargs.pop("embedding_types") + + # Map dtype to appropriate embedding_type + embedding_types = self._get_cohere_embedding_type(dtype) + + response = self._client.embed( + texts=[text], + model=self.model, + input_type=input_type, + embedding_types=embedding_types, + **kwargs, + ) + + # Extract the appropriate embedding based on embedding_types + embed_type = embedding_types[0] + if hasattr(response.embeddings, embed_type): + embedding = getattr(response.embeddings, embed_type)[0] + else: + embedding = response.embeddings[0] # Fallback for older API versions + return self._process_embedding(embedding, as_buffer, dtype) @retry( @@ -189,7 +229,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[List[int]], List[bytes]]: """Embed many chunks of text using the Cohere Embeddings API. Must provide the embedding `input_type` as a `kwarg` to this method @@ -221,7 +261,11 @@ def embed_many( Required for embedding models v3 and higher. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[List[int]], List[bytes]]: + - If as_buffer=True: Returns a list of bytes objects + - If as_buffer=False: + - For dtype="float32": Returns a list of lists of floats + - For dtype="int8" or "uint8": Returns a list of lists of integers Raises: TypeError: In an invalid input_type is provided. @@ -241,14 +285,41 @@ def embed_many( dtype = kwargs.pop("dtype", self.dtype) + # Check if embedding_types was provided and warn user + if "embedding_types" in kwargs: + warnings.warn( + "The 'embedding_types' parameter is not supported in CohereTextVectorizer. " + "Please use the 'dtype' parameter instead. Your 'embedding_types' value will be ignored.", + UserWarning, + stacklevel=2, + ) + kwargs.pop("embedding_types") + + # Map dtype to appropriate embedding_type + embedding_types = self._get_cohere_embedding_type(dtype) + embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embed( - texts=batch, model=self.model, input_type=input_type, **kwargs + texts=batch, + model=self.model, + input_type=input_type, + embedding_types=embedding_types, + **kwargs, ) + + # Extract the appropriate embeddings based on embedding_types + embed_type = embedding_types[0] + if hasattr(response.embeddings, embed_type): + batch_embeddings = getattr(response.embeddings, embed_type) + else: + batch_embeddings = ( + response.embeddings + ) # Fallback for older API versions + embeddings += [ self._process_embedding(embedding, as_buffer, dtype) - for embedding in response.embeddings + for embedding in batch_embeddings ] return embeddings diff --git a/redisvl/utils/vectorize/text/custom.py b/redisvl/utils/vectorize/text/custom.py index 4558d4d7..154468b9 100644 --- a/redisvl/utils/vectorize/text/custom.py +++ b/redisvl/utils/vectorize/text/custom.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Union from pydantic import PrivateAttr @@ -162,7 +162,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """ Generate an embedding for a single piece of text using your sync embed function. @@ -172,7 +172,7 @@ def embed( as_buffer (bool): If True, return the embedding as a byte buffer. Returns: - List[float]: The embedding of the input text. + Union[List[float], bytes]: The embedding of the input text. Raises: TypeError: If the input is not a string. @@ -200,7 +200,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """ Generate embeddings for multiple pieces of text in batches using your sync embed_many function. @@ -211,7 +211,7 @@ def embed_many( as_buffer (bool): If True, convert each embedding to a byte buffer. Returns: - List[List[float]]: A list of embeddings, where each embedding is a list of floats. + Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes. Raises: TypeError: If the input is not a list of strings. @@ -226,7 +226,7 @@ def embed_many( raise NotImplementedError("No embed_many function was provided.") dtype = kwargs.pop("dtype", self.dtype) - embeddings: List[List[float]] = [] + embeddings: Union[List[List[float]], List[bytes]] = [] try: for batch in self.batchify(texts, batch_size, preprocess): @@ -291,7 +291,7 @@ async def aembed_many( batch_size: int = 1000, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """ Asynchronously generate embeddings for multiple pieces of text in batches. @@ -302,7 +302,7 @@ async def aembed_many( as_buffer (bool): If True, convert each embedding to a byte buffer. Returns: - List[List[float]]: A list of embeddings, where each embedding is a list of floats. + Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes. Raises: TypeError: If the input is not a list of strings. @@ -317,7 +317,7 @@ async def aembed_many( raise NotImplementedError("No aembed_many function was provided.") dtype = kwargs.pop("dtype", self.dtype) - embeddings: List[List[float]] = [] + embeddings: Union[List[List[float]], List[bytes]] = [] try: for batch in self.batchify(texts, batch_size, preprocess): diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index 8f81b85c..e085cb9f 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Union from pydantic.v1 import PrivateAttr @@ -89,7 +89,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the Hugging Face sentence transformer. Args: @@ -100,7 +100,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -124,7 +125,7 @@ def embed_many( batch_size: int = 1000, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Asynchronously embed many chunks of texts using the Hugging Face sentence transformer. @@ -133,12 +134,13 @@ def embed_many( preprocess (Optional[Callable], optional): Optional preprocessing callable to perform before vectorization. Defaults to None. batch_size (int, optional): Batch size of texts to use when creating - embeddings. Defaults to 10. + embeddings. Defaults to 1000. as_buffer (bool, optional): Whether to convert the raw embedding to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py index aabd4234..db2c47ee 100644 --- a/redisvl/utils/vectorize/text/mistral.py +++ b/redisvl/utils/vectorize/text/mistral.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -128,7 +128,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of texts using the Mistral API. Args: @@ -141,7 +141,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -176,7 +177,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the Mistral API. Args: @@ -187,7 +188,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index da8c76d3..950993ce 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -129,7 +129,7 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of texts using the OpenAI API. Args: @@ -142,7 +142,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -177,7 +178,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the OpenAI API. Args: @@ -188,7 +189,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -219,7 +221,7 @@ async def aembed_many( batch_size: int = 1000, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Asynchronously embed many chunks of texts using the OpenAI API. Args: @@ -232,7 +234,8 @@ async def aembed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. @@ -267,7 +270,7 @@ async def aembed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Asynchronously embed a chunk of text using the OpenAI API. Args: @@ -278,7 +281,8 @@ async def aembed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the text. diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index 36962e87..ebe2a625 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -141,8 +141,8 @@ def embed_many( batch_size: int = 10, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: - """Embed many chunks of texts using the VertexAI API. + ) -> Union[List[List[float]], List[bytes]]: + """Embed many chunks of text using the VertexAI Embeddings API. Args: texts (List[str]): List of text chunks to embed. @@ -154,7 +154,8 @@ def embed_many( to a byte string. Defaults to False. Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. @@ -186,8 +187,8 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: - """Embed a chunk of text using the VertexAI API. + ) -> Union[List[float], bytes]: + """Embed a chunk of text using the VertexAI Embeddings API. Args: text (str): Chunk of text to embed. @@ -197,7 +198,8 @@ def embed( to a byte string. Defaults to False. Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If the wrong input type is passed in for the test. diff --git a/redisvl/utils/vectorize/text/voyageai.py b/redisvl/utils/vectorize/text/voyageai.py index 67a8015a..9d015a81 100644 --- a/redisvl/utils/vectorize/text/voyageai.py +++ b/redisvl/utils/vectorize/text/voyageai.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -124,7 +124,7 @@ def embed( preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, - ) -> List[float]: + ) -> Union[List[float], bytes]: """Embed a chunk of text using the VoyageAI Embeddings API. Can provide the embedding `input_type` as a `kwarg` to this method @@ -149,7 +149,8 @@ def embed( Check https://docs.voyageai.com/docs/embeddings Returns: - List[float]: Embedding. + Union[List[float], bytes]: Embedding as a list of floats, or as a bytes + object if as_buffer=True Raises: TypeError: If an invalid input_type is provided. @@ -171,7 +172,7 @@ def embed_many( batch_size: Optional[int] = None, as_buffer: bool = False, **kwargs, - ) -> List[List[float]]: + ) -> Union[List[List[float]], List[bytes]]: """Embed many chunks of text using the VoyageAI Embeddings API. Can provide the embedding `input_type` as a `kwarg` to this method @@ -198,7 +199,8 @@ def embed_many( Check https://docs.voyageai.com/docs/embeddings Returns: - List[List[float]]: List of embeddings. + Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, + or as bytes objects if as_buffer=True Raises: TypeError: If an invalid input_type is provided. diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index e1de4a46..7859f5e9 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -1,5 +1,6 @@ import os +import numpy as np import pytest from redisvl.utils.vectorize import ( @@ -14,381 +15,472 @@ VoyageAITextVectorizer, ) - -@pytest.fixture( - params=[ - HFTextVectorizer, - OpenAITextVectorizer, - VertexAITextVectorizer, - CohereTextVectorizer, - AzureOpenAITextVectorizer, - BedrockTextVectorizer, - MistralAITextVectorizer, - CustomTextVectorizer, - VoyageAITextVectorizer, - ] -) -def vectorizer(request): - if request.param == HFTextVectorizer: - return request.param() - elif request.param == OpenAITextVectorizer: - return request.param() - elif request.param == VertexAITextVectorizer: - return request.param() - elif request.param == CohereTextVectorizer: - return request.param() - elif request.param == MistralAITextVectorizer: - return request.param() - elif request.param == VoyageAITextVectorizer: - return request.param(model="voyage-large-2") - elif request.param == AzureOpenAITextVectorizer: - return request.param( - model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002") - ) - elif request.param == BedrockTextVectorizer: - return request.param( - model=os.getenv("BEDROCK_MODEL_ID", "amazon.titan-embed-text-v2:0") - ) - elif request.param == CustomTextVectorizer: - - def embed(text): - return [1.1, 2.2, 3.3, 4.4] - - def embed_many(texts): - return [[1.1, 2.2, 3.3, 4.4]] * len(texts) - - return request.param(embed=embed, embed_many=embed_many) +# @pytest.fixture( +# params=[ +# HFTextVectorizer, +# OpenAITextVectorizer, +# VertexAITextVectorizer, +# CohereTextVectorizer, +# AzureOpenAITextVectorizer, +# BedrockTextVectorizer, +# MistralAITextVectorizer, +# CustomTextVectorizer, +# VoyageAITextVectorizer, +# ] +# ) +# def vectorizer(request): +# if request.param == HFTextVectorizer: +# return request.param() +# elif request.param == OpenAITextVectorizer: +# return request.param() +# elif request.param == VertexAITextVectorizer: +# return request.param() +# elif request.param == CohereTextVectorizer: +# return request.param() +# elif request.param == MistralAITextVectorizer: +# return request.param() +# elif request.param == VoyageAITextVectorizer: +# return request.param(model="voyage-large-2") +# elif request.param == AzureOpenAITextVectorizer: +# return request.param( +# model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002") +# ) +# elif request.param == BedrockTextVectorizer: +# return request.param( +# model=os.getenv("BEDROCK_MODEL_ID", "amazon.titan-embed-text-v2:0") +# ) +# elif request.param == CustomTextVectorizer: + +# def embed(text): +# return [1.1, 2.2, 3.3, 4.4] + +# def embed_many(texts): +# return [[1.1, 2.2, 3.3, 4.4]] * len(texts) + +# return request.param(embed=embed, embed_many=embed_many) + + +# @pytest.fixture +# def bedrock_vectorizer(): +# return BedrockTextVectorizer( +# model=os.getenv("BEDROCK_MODEL_ID", "amazon.titan-embed-text-v2:0") +# ) + + +# @pytest.fixture +# def custom_embed_func(): +# def embed(text: str): +# return [1.1, 2.2, 3.3, 4.4] + +# return embed + + +# @pytest.fixture +# def custom_embed_class(): +# class MyEmbedder: +# def embed(self, text: str): +# return [1.1, 2.2, 3.3, 4.4] + +# def embed_with_args(self, text: str, max_len=None): +# return [1.1, 2.2, 3.3, 4.4][0:max_len] + +# def embed_many(self, text_list): +# return [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] + +# def embed_many_with_args(self, texts, param=True): +# if param: +# return [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] +# else: +# return [[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]] + +# return MyEmbedder + + +# @pytest.mark.requires_api_keys +# def test_vectorizer_embed(vectorizer): +# text = "This is a test sentence." +# if isinstance(vectorizer, CohereTextVectorizer): +# embedding = vectorizer.embed(text, input_type="search_document") +# elif isinstance(vectorizer, VoyageAITextVectorizer): +# embedding = vectorizer.embed(text, input_type="document") +# else: +# embedding = vectorizer.embed(text) + +# assert isinstance(embedding, list) +# assert len(embedding) == vectorizer.dims + + +# @pytest.mark.requires_api_keys +# def test_vectorizer_embed_many(vectorizer): +# texts = ["This is the first test sentence.", "This is the second test sentence."] +# if isinstance(vectorizer, CohereTextVectorizer): +# embeddings = vectorizer.embed_many(texts, input_type="search_document") +# elif isinstance(vectorizer, VoyageAITextVectorizer): +# embeddings = vectorizer.embed_many(texts, input_type="document") +# else: +# embeddings = vectorizer.embed_many(texts) + +# assert isinstance(embeddings, list) +# assert len(embeddings) == len(texts) +# assert all( +# isinstance(emb, list) and len(emb) == vectorizer.dims for emb in embeddings +# ) + + +# @pytest.mark.requires_api_keys +# def test_vectorizer_bad_input(vectorizer): +# with pytest.raises(TypeError): +# vectorizer.embed(1) + +# with pytest.raises(TypeError): +# vectorizer.embed({"foo": "bar"}) + +# with pytest.raises(TypeError): +# vectorizer.embed_many(42) -@pytest.fixture -def bedrock_vectorizer(): - return BedrockTextVectorizer( - model=os.getenv("BEDROCK_MODEL_ID", "amazon.titan-embed-text-v2:0") - ) - - -@pytest.fixture -def custom_embed_func(): - def embed(text: str): - return [1.1, 2.2, 3.3, 4.4] - - return embed +# @pytest.mark.requires_api_keys +# def test_bedrock_bad_credentials(): +# with pytest.raises(ValueError): +# BedrockTextVectorizer( +# api_config={ +# "aws_access_key_id": "invalid", +# "aws_secret_access_key": "invalid", +# } +# ) -@pytest.fixture -def custom_embed_class(): - class MyEmbedder: - def embed(self, text: str): - return [1.1, 2.2, 3.3, 4.4] +# @pytest.mark.requires_api_keys +# def test_bedrock_invalid_model(bedrock_vectorizer): +# with pytest.raises(ValueError): +# bedrock = BedrockTextVectorizer(model="invalid-model") +# bedrock.embed("test") - def embed_with_args(self, text: str, max_len=None): - return [1.1, 2.2, 3.3, 4.4][0:max_len] - def embed_many(self, text_list): - return [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] - - def embed_many_with_args(self, texts, param=True): - if param: - return [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] - else: - return [[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]] - - return MyEmbedder +# def test_custom_vectorizer_embed(custom_embed_class, custom_embed_func): +# custom_wrapper = CustomTextVectorizer(embed=custom_embed_func) +# embedding = custom_wrapper.embed("This is a test sentence.") +# assert embedding == [1.1, 2.2, 3.3, 4.4] + +# custom_wrapper = CustomTextVectorizer(embed=custom_embed_class().embed) +# embedding = custom_wrapper.embed("This is a test sentence.") +# assert embedding == [1.1, 2.2, 3.3, 4.4] + +# custom_wrapper = CustomTextVectorizer(embed=custom_embed_class().embed_with_args) +# embedding = custom_wrapper.embed("This is a test sentence.", max_len=4) +# assert embedding == [1.1, 2.2, 3.3, 4.4] +# embedding = custom_wrapper.embed("This is a test sentence.", max_len=2) +# assert embedding == [1.1, 2.2] + +# with pytest.raises(ValueError): +# invalid_vectorizer = CustomTextVectorizer(embed="hello") + +# with pytest.raises(ValueError): +# invalid_vectorizer = CustomTextVectorizer(embed=42) + +# with pytest.raises(ValueError): +# invalid_vectorizer = CustomTextVectorizer(embed={"foo": "bar"}) + +# def bad_arg_type(value: int): +# return [value] + +# with pytest.raises(ValueError): +# invalid_vectorizer = CustomTextVectorizer(embed=bad_arg_type) + +# def bad_return_type(text: str) -> str: +# return text + +# with pytest.raises(ValueError): +# invalid_vectorizer = CustomTextVectorizer(embed=bad_return_type) + + +# def test_custom_vectorizer_embed_many(custom_embed_class, custom_embed_func): +# custom_wrapper = CustomTextVectorizer( +# custom_embed_func, embed_many=custom_embed_class().embed_many +# ) +# embeddings = custom_wrapper.embed_many(["test one.", "test two"]) +# assert embeddings == [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] + +# custom_wrapper = CustomTextVectorizer( +# custom_embed_func, embed_many=custom_embed_class().embed_many +# ) +# embeddings = custom_wrapper.embed_many(["test one.", "test two"]) +# assert embeddings == [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] + +# custom_wrapper = CustomTextVectorizer( +# custom_embed_func, embed_many=custom_embed_class().embed_many_with_args +# ) +# embeddings = custom_wrapper.embed_many(["test one.", "test two"], param=True) +# assert embeddings == [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] +# embeddings = custom_wrapper.embed_many(["test one.", "test two"], param=False) +# assert embeddings == [[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]] + +# with pytest.raises(ValueError): +# invalid_vectorizer = CustomTextVectorizer(custom_embed_func, embed_many="hello") + +# with pytest.raises(ValueError): +# invalid_vectorizer = CustomTextVectorizer(custom_embed_func, embed_many=42) + +# with pytest.raises(ValueError): +# invalid_vectorizer = CustomTextVectorizer( +# custom_embed_func, embed_many={"foo": "bar"} +# ) + +# def bad_arg_type(value: int): +# return [value] + +# with pytest.raises(ValueError): +# invalid_vectorizer = CustomTextVectorizer( +# custom_embed_func, embed_many=bad_arg_type +# ) + +# def bad_return_type(text: str) -> str: +# return text + +# with pytest.raises(ValueError): +# invalid_vectorizer = CustomTextVectorizer( +# custom_embed_func, embed_many=bad_return_type +# ) + + +# @pytest.mark.requires_api_keys +# @pytest.mark.parametrize( +# "vectorizer_", +# [ +# AzureOpenAITextVectorizer, +# BedrockTextVectorizer, +# CohereTextVectorizer, +# CustomTextVectorizer, +# HFTextVectorizer, +# MistralAITextVectorizer, +# OpenAITextVectorizer, +# VertexAITextVectorizer, +# VoyageAITextVectorizer, +# ], +# ) +# def test_default_dtype(vectorizer_): +# # test dtype defaults to float32 +# if isinstance(vectorizer_, CustomTextVectorizer): +# vectorizer = vectorizer_(embed=lambda x, input_type=None: [1.0, 2.0, 3.0]) +# elif isinstance(vectorizer_, AzureOpenAITextVectorizer): +# vectorizer = vectorizer_( +# model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002") +# ) +# else: +# vectorizer = vectorizer_() + +# assert vectorizer.dtype == "float32" + + +# @pytest.mark.requires_api_keys +# @pytest.mark.parametrize( +# "vectorizer_", +# [ +# AzureOpenAITextVectorizer, +# BedrockTextVectorizer, +# CohereTextVectorizer, +# CustomTextVectorizer, +# HFTextVectorizer, +# MistralAITextVectorizer, +# OpenAITextVectorizer, +# VertexAITextVectorizer, +# VoyageAITextVectorizer, +# ], +# ) +# def test_vectorizer_dtype_assignment(vectorizer_): +# # test initializing dtype in constructor +# for dtype in ["float16", "float32", "float64", "bfloat16", "int8", "uint8"]: +# if isinstance(vectorizer_, CustomTextVectorizer): +# vectorizer = vectorizer_(embed=lambda x: [1.0, 2.0, 3.0], dtype=dtype) +# elif isinstance(vectorizer_, AzureOpenAITextVectorizer): +# vectorizer = vectorizer_( +# model=os.getenv( +# "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002" +# ), +# dtype=dtype, +# ) +# else: +# vectorizer = vectorizer_(dtype=dtype) + +# assert vectorizer.dtype == dtype + + +# @pytest.mark.requires_api_keys +# @pytest.mark.parametrize( +# "vectorizer_", +# [ +# AzureOpenAITextVectorizer, +# BedrockTextVectorizer, +# CohereTextVectorizer, +# HFTextVectorizer, +# MistralAITextVectorizer, +# OpenAITextVectorizer, +# VertexAITextVectorizer, +# VoyageAITextVectorizer, +# ], +# ) +# def test_non_supported_dtypes(vectorizer_): +# with pytest.raises(ValueError): +# vectorizer_(dtype="float25") + +# with pytest.raises(ValueError): +# vectorizer_(dtype=7) + +# with pytest.raises(ValueError): +# vectorizer_(dtype=None) + + +# @pytest.fixture( +# params=[ +# OpenAITextVectorizer, +# BedrockTextVectorizer, +# MistralAITextVectorizer, +# CustomTextVectorizer, +# VoyageAITextVectorizer, +# ] +# ) +# def avectorizer(request): +# if request.param == CustomTextVectorizer: + +# def embed_func(text): +# return [1.1, 2.2, 3.3, 4.4] + +# async def aembed_func(text): +# return [1.1, 2.2, 3.3, 4.4] + +# async def aembed_many_func(texts): +# return [[1.1, 2.2, 3.3, 4.4]] * len(texts) + +# return request.param( +# embed=embed_func, aembed=aembed_func, aembed_many=aembed_many_func +# ) +# else: +# return request.param() + + +# @pytest.mark.requires_api_keys +# @pytest.mark.asyncio +# async def test_vectorizer_aembed(avectorizer): +# text = "This is a test sentence." +# embedding = await avectorizer.aembed(text) + +# assert isinstance(embedding, list) +# assert len(embedding) == avectorizer.dims + + +# @pytest.mark.requires_api_keys +# @pytest.mark.asyncio +# async def test_vectorizer_aembed_many(avectorizer): +# texts = ["This is the first test sentence.", "This is the second test sentence."] +# embeddings = await avectorizer.aembed_many(texts) + +# assert isinstance(embeddings, list) +# assert len(embeddings) == len(texts) +# assert all( +# isinstance(emb, list) and len(emb) == avectorizer.dims for emb in embeddings +# ) + + +# @pytest.mark.requires_api_keys +# @pytest.mark.asyncio +# async def test_avectorizer_bad_input(avectorizer): +# with pytest.raises(TypeError): +# avectorizer.embed(1) + +# with pytest.raises(TypeError): +# avectorizer.embed({"foo": "bar"}) + +# with pytest.raises(TypeError): +# avectorizer.embed_many(42) @pytest.mark.requires_api_keys -def test_vectorizer_embed(vectorizer): +@pytest.mark.parametrize( + "dtype,expected_type", + [ + ("float32", float), # Float dtype should return floats + ("int8", int), # Int8 dtype should return ints + ("uint8", int), # Uint8 dtype should return ints + ], +) +def test_cohere_dtype_support(dtype, expected_type): + """Test that CohereTextVectorizer properly handles different dtypes for embeddings.""" text = "This is a test sentence." - if isinstance(vectorizer, CohereTextVectorizer): - embedding = vectorizer.embed(text, input_type="search_document") - elif isinstance(vectorizer, VoyageAITextVectorizer): - embedding = vectorizer.embed(text, input_type="document") + texts = ["First test sentence.", "Second test sentence."] + + # Create vectorizer with specified dtype + vectorizer = CohereTextVectorizer(dtype=dtype) + + # Verify the correct mapping of dtype to Cohere embedding_types + if dtype == "int8": + assert vectorizer._get_cohere_embedding_type(dtype) == ["int8"] + elif dtype == "uint8": + assert vectorizer._get_cohere_embedding_type(dtype) == ["uint8"] else: - embedding = vectorizer.embed(text) + # All other dtypes should map to float + assert vectorizer._get_cohere_embedding_type(dtype) == ["float"] + # Test single embedding + embedding = vectorizer.embed(text, input_type="search_document") assert isinstance(embedding, list) assert len(embedding) == vectorizer.dims + # Check that all elements are of the expected type + assert all( + isinstance(val, expected_type) for val in embedding + ), f"Expected all elements to be {expected_type.__name__} for dtype {dtype}" -@pytest.mark.requires_api_keys -def test_vectorizer_embed_many(vectorizer): - texts = ["This is the first test sentence.", "This is the second test sentence."] - if isinstance(vectorizer, CohereTextVectorizer): - embeddings = vectorizer.embed_many(texts, input_type="search_document") - elif isinstance(vectorizer, VoyageAITextVectorizer): - embeddings = vectorizer.embed_many(texts, input_type="document") - else: - embeddings = vectorizer.embed_many(texts) - + # Test multiple embeddings + embeddings = vectorizer.embed_many(texts, input_type="search_document") assert isinstance(embeddings, list) assert len(embeddings) == len(texts) assert all( isinstance(emb, list) and len(emb) == vectorizer.dims for emb in embeddings ) + # Check that all elements in all embeddings are of the expected type + for emb in embeddings: + assert all( + isinstance(val, expected_type) for val in emb + ), f"Expected all elements to be {expected_type.__name__} for dtype {dtype}" -@pytest.mark.requires_api_keys -def test_vectorizer_bad_input(vectorizer): - with pytest.raises(TypeError): - vectorizer.embed(1) - - with pytest.raises(TypeError): - vectorizer.embed({"foo": "bar"}) - - with pytest.raises(TypeError): - vectorizer.embed_many(42) - - -@pytest.mark.requires_api_keys -def test_bedrock_bad_credentials(): - with pytest.raises(ValueError): - BedrockTextVectorizer( - api_config={ - "aws_access_key_id": "invalid", - "aws_secret_access_key": "invalid", - } - ) - - -@pytest.mark.requires_api_keys -def test_bedrock_invalid_model(bedrock_vectorizer): - with pytest.raises(ValueError): - bedrock = BedrockTextVectorizer(model="invalid-model") - bedrock.embed("test") - - -def test_custom_vectorizer_embed(custom_embed_class, custom_embed_func): - custom_wrapper = CustomTextVectorizer(embed=custom_embed_func) - embedding = custom_wrapper.embed("This is a test sentence.") - assert embedding == [1.1, 2.2, 3.3, 4.4] - - custom_wrapper = CustomTextVectorizer(embed=custom_embed_class().embed) - embedding = custom_wrapper.embed("This is a test sentence.") - assert embedding == [1.1, 2.2, 3.3, 4.4] - - custom_wrapper = CustomTextVectorizer(embed=custom_embed_class().embed_with_args) - embedding = custom_wrapper.embed("This is a test sentence.", max_len=4) - assert embedding == [1.1, 2.2, 3.3, 4.4] - embedding = custom_wrapper.embed("This is a test sentence.", max_len=2) - assert embedding == [1.1, 2.2] - - with pytest.raises(ValueError): - invalid_vectorizer = CustomTextVectorizer(embed="hello") - - with pytest.raises(ValueError): - invalid_vectorizer = CustomTextVectorizer(embed=42) - - with pytest.raises(ValueError): - invalid_vectorizer = CustomTextVectorizer(embed={"foo": "bar"}) - - def bad_arg_type(value: int): - return [value] - - with pytest.raises(ValueError): - invalid_vectorizer = CustomTextVectorizer(embed=bad_arg_type) - - def bad_return_type(text: str) -> str: - return text - - with pytest.raises(ValueError): - invalid_vectorizer = CustomTextVectorizer(embed=bad_return_type) - - -def test_custom_vectorizer_embed_many(custom_embed_class, custom_embed_func): - custom_wrapper = CustomTextVectorizer( - custom_embed_func, embed_many=custom_embed_class().embed_many + # Test as_buffer output format + embedding_buffer = vectorizer.embed( + text, input_type="search_document", as_buffer=True ) - embeddings = custom_wrapper.embed_many(["test one.", "test two"]) - assert embeddings == [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] + assert isinstance(embedding_buffer, bytes) - custom_wrapper = CustomTextVectorizer( - custom_embed_func, embed_many=custom_embed_class().embed_many + # Test embed_many with as_buffer=True + buffer_embeddings = vectorizer.embed_many( + texts, input_type="search_document", as_buffer=True ) - embeddings = custom_wrapper.embed_many(["test one.", "test two"]) - assert embeddings == [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] - - custom_wrapper = CustomTextVectorizer( - custom_embed_func, embed_many=custom_embed_class().embed_many_with_args - ) - embeddings = custom_wrapper.embed_many(["test one.", "test two"], param=True) - assert embeddings == [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] - embeddings = custom_wrapper.embed_many(["test one.", "test two"], param=False) - assert embeddings == [[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]] - - with pytest.raises(ValueError): - invalid_vectorizer = CustomTextVectorizer(custom_embed_func, embed_many="hello") - - with pytest.raises(ValueError): - invalid_vectorizer = CustomTextVectorizer(custom_embed_func, embed_many=42) - - with pytest.raises(ValueError): - invalid_vectorizer = CustomTextVectorizer( - custom_embed_func, embed_many={"foo": "bar"} - ) - - def bad_arg_type(value: int): - return [value] + assert all(isinstance(emb, bytes) for emb in buffer_embeddings) - with pytest.raises(ValueError): - invalid_vectorizer = CustomTextVectorizer( - custom_embed_func, embed_many=bad_arg_type - ) - - def bad_return_type(text: str) -> str: - return text - - with pytest.raises(ValueError): - invalid_vectorizer = CustomTextVectorizer( - custom_embed_func, embed_many=bad_return_type - ) - - -@pytest.mark.requires_api_keys -@pytest.mark.parametrize( - "vectorizer_", - [ - AzureOpenAITextVectorizer, - BedrockTextVectorizer, - CohereTextVectorizer, - CustomTextVectorizer, - HFTextVectorizer, - MistralAITextVectorizer, - OpenAITextVectorizer, - VertexAITextVectorizer, - VoyageAITextVectorizer, - ], -) -def test_default_dtype(vectorizer_): - # test dtype defaults to float32 - if issubclass(vectorizer_, CustomTextVectorizer): - vectorizer = vectorizer_(embed=lambda x, input_type=None: [1.0, 2.0, 3.0]) - elif issubclass(vectorizer_, AzureOpenAITextVectorizer): - vectorizer = vectorizer_( - model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002") - ) - else: - vectorizer = vectorizer_() - - assert vectorizer.dtype == "float32" + # Compare dimensions between buffer and list formats + assert len(np.frombuffer(embedding_buffer, dtype=dtype)) == len(embedding) @pytest.mark.requires_api_keys -@pytest.mark.parametrize( - "vectorizer_", - [ - AzureOpenAITextVectorizer, - BedrockTextVectorizer, - CohereTextVectorizer, - CustomTextVectorizer, - HFTextVectorizer, - MistralAITextVectorizer, - OpenAITextVectorizer, - VertexAITextVectorizer, - VoyageAITextVectorizer, - ], -) -def test_other_dtypes(vectorizer_): - # test initializing dtype in constructor - for dtype in ["float16", "float32", "float64", "bfloat16", "int8", "uint8"]: - if issubclass(vectorizer_, CustomTextVectorizer): - vectorizer = vectorizer_(embed=lambda x: [1.0, 2.0, 3.0], dtype=dtype) - elif issubclass(vectorizer_, AzureOpenAITextVectorizer): - vectorizer = vectorizer_( - model=os.getenv( - "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002" - ), - dtype=dtype, - ) - else: - vectorizer = vectorizer_(dtype=dtype) - - assert vectorizer.dtype == dtype - - -@pytest.mark.requires_api_keys -@pytest.mark.parametrize( - "vectorizer_", - [ - AzureOpenAITextVectorizer, - BedrockTextVectorizer, - CohereTextVectorizer, - HFTextVectorizer, - MistralAITextVectorizer, - OpenAITextVectorizer, - VertexAITextVectorizer, - VoyageAITextVectorizer, - ], -) -def test_bad_dtypes(vectorizer_): - with pytest.raises(ValueError): - vectorizer_(dtype="float25") - - with pytest.raises(ValueError): - vectorizer_(dtype=7) - - with pytest.raises(ValueError): - vectorizer_(dtype=None) - - -@pytest.fixture( - params=[ - OpenAITextVectorizer, - BedrockTextVectorizer, - MistralAITextVectorizer, - CustomTextVectorizer, - VoyageAITextVectorizer, - ] -) -def avectorizer(request): - if request.param == CustomTextVectorizer: - - def embed_func(text): - return [1.1, 2.2, 3.3, 4.4] - - async def aembed_func(text): - return [1.1, 2.2, 3.3, 4.4] - - async def aembed_many_func(texts): - return [[1.1, 2.2, 3.3, 4.4]] * len(texts) - - return request.param( - embed=embed_func, aembed=aembed_func, aembed_many=aembed_many_func - ) - else: - return request.param() - - -@pytest.mark.requires_api_keys -@pytest.mark.asyncio -async def test_vectorizer_aembed(avectorizer): +def test_cohere_embedding_types_warning(): + """Test that a warning is raised when embedding_types parameter is passed.""" text = "This is a test sentence." - embedding = await avectorizer.aembed(text) - + texts = ["First test sentence.", "Second test sentence."] + vectorizer = CohereTextVectorizer() + + # Test warning for single embedding + with pytest.warns(UserWarning, match="embedding_types.*not supported"): + embedding = vectorizer.embed( + text, + input_type="search_document", + embedding_types=["uint8"], # explicitly testing the anti-pattern here + ) assert isinstance(embedding, list) - assert len(embedding) == avectorizer.dims - - -@pytest.mark.requires_api_keys -@pytest.mark.asyncio -async def test_vectorizer_aembed_many(avectorizer): - texts = ["This is the first test sentence.", "This is the second test sentence."] - embeddings = await avectorizer.aembed_many(texts) + assert len(embedding) == vectorizer.dims + # Test warning for multiple embeddings + with pytest.warns(UserWarning, match="embedding_types.*not supported"): + embeddings = vectorizer.embed_many( + texts, input_type="search_document", embedding_types=["uint8"] + ) assert isinstance(embeddings, list) assert len(embeddings) == len(texts) - assert all( - isinstance(emb, list) and len(emb) == avectorizer.dims for emb in embeddings - ) - - -@pytest.mark.requires_api_keys -@pytest.mark.asyncio -async def test_avectorizer_bad_input(avectorizer): - with pytest.raises(TypeError): - avectorizer.embed(1) - - with pytest.raises(TypeError): - avectorizer.embed({"foo": "bar"}) - - with pytest.raises(TypeError): - avectorizer.embed_many(42) From 9a3eb08b582af33a8ad2d8c83a94809348d22337 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 27 Feb 2025 14:15:10 -0500 Subject: [PATCH 3/5] bring back vectorizer tests --- tests/integration/test_vectorizers.py | 743 +++++++++++++------------- 1 file changed, 372 insertions(+), 371 deletions(-) diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index 7859f5e9..330ffe1d 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -15,383 +15,384 @@ VoyageAITextVectorizer, ) -# @pytest.fixture( -# params=[ -# HFTextVectorizer, -# OpenAITextVectorizer, -# VertexAITextVectorizer, -# CohereTextVectorizer, -# AzureOpenAITextVectorizer, -# BedrockTextVectorizer, -# MistralAITextVectorizer, -# CustomTextVectorizer, -# VoyageAITextVectorizer, -# ] -# ) -# def vectorizer(request): -# if request.param == HFTextVectorizer: -# return request.param() -# elif request.param == OpenAITextVectorizer: -# return request.param() -# elif request.param == VertexAITextVectorizer: -# return request.param() -# elif request.param == CohereTextVectorizer: -# return request.param() -# elif request.param == MistralAITextVectorizer: -# return request.param() -# elif request.param == VoyageAITextVectorizer: -# return request.param(model="voyage-large-2") -# elif request.param == AzureOpenAITextVectorizer: -# return request.param( -# model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002") -# ) -# elif request.param == BedrockTextVectorizer: -# return request.param( -# model=os.getenv("BEDROCK_MODEL_ID", "amazon.titan-embed-text-v2:0") -# ) -# elif request.param == CustomTextVectorizer: - -# def embed(text): -# return [1.1, 2.2, 3.3, 4.4] - -# def embed_many(texts): -# return [[1.1, 2.2, 3.3, 4.4]] * len(texts) - -# return request.param(embed=embed, embed_many=embed_many) - - -# @pytest.fixture -# def bedrock_vectorizer(): -# return BedrockTextVectorizer( -# model=os.getenv("BEDROCK_MODEL_ID", "amazon.titan-embed-text-v2:0") -# ) - - -# @pytest.fixture -# def custom_embed_func(): -# def embed(text: str): -# return [1.1, 2.2, 3.3, 4.4] - -# return embed - - -# @pytest.fixture -# def custom_embed_class(): -# class MyEmbedder: -# def embed(self, text: str): -# return [1.1, 2.2, 3.3, 4.4] - -# def embed_with_args(self, text: str, max_len=None): -# return [1.1, 2.2, 3.3, 4.4][0:max_len] - -# def embed_many(self, text_list): -# return [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] - -# def embed_many_with_args(self, texts, param=True): -# if param: -# return [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] -# else: -# return [[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]] - -# return MyEmbedder - - -# @pytest.mark.requires_api_keys -# def test_vectorizer_embed(vectorizer): -# text = "This is a test sentence." -# if isinstance(vectorizer, CohereTextVectorizer): -# embedding = vectorizer.embed(text, input_type="search_document") -# elif isinstance(vectorizer, VoyageAITextVectorizer): -# embedding = vectorizer.embed(text, input_type="document") -# else: -# embedding = vectorizer.embed(text) - -# assert isinstance(embedding, list) -# assert len(embedding) == vectorizer.dims - - -# @pytest.mark.requires_api_keys -# def test_vectorizer_embed_many(vectorizer): -# texts = ["This is the first test sentence.", "This is the second test sentence."] -# if isinstance(vectorizer, CohereTextVectorizer): -# embeddings = vectorizer.embed_many(texts, input_type="search_document") -# elif isinstance(vectorizer, VoyageAITextVectorizer): -# embeddings = vectorizer.embed_many(texts, input_type="document") -# else: -# embeddings = vectorizer.embed_many(texts) - -# assert isinstance(embeddings, list) -# assert len(embeddings) == len(texts) -# assert all( -# isinstance(emb, list) and len(emb) == vectorizer.dims for emb in embeddings -# ) - - -# @pytest.mark.requires_api_keys -# def test_vectorizer_bad_input(vectorizer): -# with pytest.raises(TypeError): -# vectorizer.embed(1) - -# with pytest.raises(TypeError): -# vectorizer.embed({"foo": "bar"}) - -# with pytest.raises(TypeError): -# vectorizer.embed_many(42) +@pytest.fixture( + params=[ + HFTextVectorizer, + OpenAITextVectorizer, + VertexAITextVectorizer, + CohereTextVectorizer, + AzureOpenAITextVectorizer, + BedrockTextVectorizer, + MistralAITextVectorizer, + CustomTextVectorizer, + VoyageAITextVectorizer, + ] +) +def vectorizer(request): + if request.param == HFTextVectorizer: + return request.param() + elif request.param == OpenAITextVectorizer: + return request.param() + elif request.param == VertexAITextVectorizer: + return request.param() + elif request.param == CohereTextVectorizer: + return request.param() + elif request.param == MistralAITextVectorizer: + return request.param() + elif request.param == VoyageAITextVectorizer: + return request.param(model="voyage-large-2") + elif request.param == AzureOpenAITextVectorizer: + return request.param( + model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002") + ) + elif request.param == BedrockTextVectorizer: + return request.param( + model=os.getenv("BEDROCK_MODEL_ID", "amazon.titan-embed-text-v2:0") + ) + elif request.param == CustomTextVectorizer: + + def embed(text): + return [1.1, 2.2, 3.3, 4.4] + + def embed_many(texts): + return [[1.1, 2.2, 3.3, 4.4]] * len(texts) + + return request.param(embed=embed, embed_many=embed_many) + + +@pytest.fixture +def bedrock_vectorizer(): + return BedrockTextVectorizer( + model=os.getenv("BEDROCK_MODEL_ID", "amazon.titan-embed-text-v2:0") + ) + + +@pytest.fixture +def custom_embed_func(): + def embed(text: str): + return [1.1, 2.2, 3.3, 4.4] + + return embed + + +@pytest.fixture +def custom_embed_class(): + class MyEmbedder: + def embed(self, text: str): + return [1.1, 2.2, 3.3, 4.4] + + def embed_with_args(self, text: str, max_len=None): + return [1.1, 2.2, 3.3, 4.4][0:max_len] + + def embed_many(self, text_list): + return [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] -# @pytest.mark.requires_api_keys -# def test_bedrock_bad_credentials(): -# with pytest.raises(ValueError): -# BedrockTextVectorizer( -# api_config={ -# "aws_access_key_id": "invalid", -# "aws_secret_access_key": "invalid", -# } -# ) + def embed_many_with_args(self, texts, param=True): + if param: + return [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + else: + return [[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]] + return MyEmbedder -# @pytest.mark.requires_api_keys -# def test_bedrock_invalid_model(bedrock_vectorizer): -# with pytest.raises(ValueError): -# bedrock = BedrockTextVectorizer(model="invalid-model") -# bedrock.embed("test") +@pytest.mark.requires_api_keys +def test_vectorizer_embed(vectorizer): + text = "This is a test sentence." + if isinstance(vectorizer, CohereTextVectorizer): + embedding = vectorizer.embed(text, input_type="search_document") + elif isinstance(vectorizer, VoyageAITextVectorizer): + embedding = vectorizer.embed(text, input_type="document") + else: + embedding = vectorizer.embed(text) + + assert isinstance(embedding, list) + assert len(embedding) == vectorizer.dims + + +@pytest.mark.requires_api_keys +def test_vectorizer_embed_many(vectorizer): + texts = ["This is the first test sentence.", "This is the second test sentence."] + if isinstance(vectorizer, CohereTextVectorizer): + embeddings = vectorizer.embed_many(texts, input_type="search_document") + elif isinstance(vectorizer, VoyageAITextVectorizer): + embeddings = vectorizer.embed_many(texts, input_type="document") + else: + embeddings = vectorizer.embed_many(texts) + + assert isinstance(embeddings, list) + assert len(embeddings) == len(texts) + assert all( + isinstance(emb, list) and len(emb) == vectorizer.dims for emb in embeddings + ) + + +@pytest.mark.requires_api_keys +def test_vectorizer_bad_input(vectorizer): + with pytest.raises(TypeError): + vectorizer.embed(1) + + with pytest.raises(TypeError): + vectorizer.embed({"foo": "bar"}) + + with pytest.raises(TypeError): + vectorizer.embed_many(42) + + +@pytest.mark.requires_api_keys +def test_bedrock_bad_credentials(): + with pytest.raises(ValueError): + BedrockTextVectorizer( + api_config={ + "aws_access_key_id": "invalid", + "aws_secret_access_key": "invalid", + } + ) + + +@pytest.mark.requires_api_keys +def test_bedrock_invalid_model(bedrock_vectorizer): + with pytest.raises(ValueError): + bedrock = BedrockTextVectorizer(model="invalid-model") + bedrock.embed("test") + + +def test_custom_vectorizer_embed(custom_embed_class, custom_embed_func): + custom_wrapper = CustomTextVectorizer(embed=custom_embed_func) + embedding = custom_wrapper.embed("This is a test sentence.") + assert embedding == [1.1, 2.2, 3.3, 4.4] + + custom_wrapper = CustomTextVectorizer(embed=custom_embed_class().embed) + embedding = custom_wrapper.embed("This is a test sentence.") + assert embedding == [1.1, 2.2, 3.3, 4.4] + + custom_wrapper = CustomTextVectorizer(embed=custom_embed_class().embed_with_args) + embedding = custom_wrapper.embed("This is a test sentence.", max_len=4) + assert embedding == [1.1, 2.2, 3.3, 4.4] + embedding = custom_wrapper.embed("This is a test sentence.", max_len=2) + assert embedding == [1.1, 2.2] + + with pytest.raises(ValueError): + invalid_vectorizer = CustomTextVectorizer(embed="hello") + + with pytest.raises(ValueError): + invalid_vectorizer = CustomTextVectorizer(embed=42) + + with pytest.raises(ValueError): + invalid_vectorizer = CustomTextVectorizer(embed={"foo": "bar"}) + + def bad_arg_type(value: int): + return [value] + + with pytest.raises(ValueError): + invalid_vectorizer = CustomTextVectorizer(embed=bad_arg_type) + + def bad_return_type(text: str) -> str: + return text + + with pytest.raises(ValueError): + invalid_vectorizer = CustomTextVectorizer(embed=bad_return_type) + + +def test_custom_vectorizer_embed_many(custom_embed_class, custom_embed_func): + custom_wrapper = CustomTextVectorizer( + custom_embed_func, embed_many=custom_embed_class().embed_many + ) + embeddings = custom_wrapper.embed_many(["test one.", "test two"]) + assert embeddings == [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] + + custom_wrapper = CustomTextVectorizer( + custom_embed_func, embed_many=custom_embed_class().embed_many + ) + embeddings = custom_wrapper.embed_many(["test one.", "test two"]) + assert embeddings == [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] + + custom_wrapper = CustomTextVectorizer( + custom_embed_func, embed_many=custom_embed_class().embed_many_with_args + ) + embeddings = custom_wrapper.embed_many(["test one.", "test two"], param=True) + assert embeddings == [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + embeddings = custom_wrapper.embed_many(["test one.", "test two"], param=False) + assert embeddings == [[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]] + + with pytest.raises(ValueError): + invalid_vectorizer = CustomTextVectorizer(custom_embed_func, embed_many="hello") + + with pytest.raises(ValueError): + invalid_vectorizer = CustomTextVectorizer(custom_embed_func, embed_many=42) + + with pytest.raises(ValueError): + invalid_vectorizer = CustomTextVectorizer( + custom_embed_func, embed_many={"foo": "bar"} + ) + + def bad_arg_type(value: int): + return [value] + + with pytest.raises(ValueError): + invalid_vectorizer = CustomTextVectorizer( + custom_embed_func, embed_many=bad_arg_type + ) + + def bad_return_type(text: str) -> str: + return text + + with pytest.raises(ValueError): + invalid_vectorizer = CustomTextVectorizer( + custom_embed_func, embed_many=bad_return_type + ) + + +@pytest.mark.requires_api_keys +@pytest.mark.parametrize( + "vectorizer_", + [ + AzureOpenAITextVectorizer, + BedrockTextVectorizer, + CohereTextVectorizer, + CustomTextVectorizer, + HFTextVectorizer, + MistralAITextVectorizer, + OpenAITextVectorizer, + VertexAITextVectorizer, + VoyageAITextVectorizer, + ], +) +def test_default_dtype(vectorizer_): + # test dtype defaults to float32 + if isinstance(vectorizer_, CustomTextVectorizer): + vectorizer = vectorizer_(embed=lambda x, input_type=None: [1.0, 2.0, 3.0]) + elif isinstance(vectorizer_, AzureOpenAITextVectorizer): + vectorizer = vectorizer_( + model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002") + ) + else: + vectorizer = vectorizer_() + + assert vectorizer.dtype == "float32" + + +@pytest.mark.requires_api_keys +@pytest.mark.parametrize( + "vectorizer_", + [ + AzureOpenAITextVectorizer, + BedrockTextVectorizer, + CohereTextVectorizer, + CustomTextVectorizer, + HFTextVectorizer, + MistralAITextVectorizer, + OpenAITextVectorizer, + VertexAITextVectorizer, + VoyageAITextVectorizer, + ], +) +def test_vectorizer_dtype_assignment(vectorizer_): + # test initializing dtype in constructor + for dtype in ["float16", "float32", "float64", "bfloat16", "int8", "uint8"]: + if isinstance(vectorizer_, CustomTextVectorizer): + vectorizer = vectorizer_(embed=lambda x: [1.0, 2.0, 3.0], dtype=dtype) + elif isinstance(vectorizer_, AzureOpenAITextVectorizer): + vectorizer = vectorizer_( + model=os.getenv( + "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002" + ), + dtype=dtype, + ) + else: + vectorizer = vectorizer_(dtype=dtype) + + assert vectorizer.dtype == dtype + + +@pytest.mark.requires_api_keys +@pytest.mark.parametrize( + "vectorizer_", + [ + AzureOpenAITextVectorizer, + BedrockTextVectorizer, + CohereTextVectorizer, + HFTextVectorizer, + MistralAITextVectorizer, + OpenAITextVectorizer, + VertexAITextVectorizer, + VoyageAITextVectorizer, + ], +) +def test_non_supported_dtypes(vectorizer_): + with pytest.raises(ValueError): + vectorizer_(dtype="float25") + + with pytest.raises(ValueError): + vectorizer_(dtype=7) + + with pytest.raises(ValueError): + vectorizer_(dtype=None) + + +@pytest.fixture( + params=[ + OpenAITextVectorizer, + BedrockTextVectorizer, + MistralAITextVectorizer, + CustomTextVectorizer, + VoyageAITextVectorizer, + ] +) +def avectorizer(request): + if request.param == CustomTextVectorizer: + + def embed_func(text): + return [1.1, 2.2, 3.3, 4.4] + + async def aembed_func(text): + return [1.1, 2.2, 3.3, 4.4] + + async def aembed_many_func(texts): + return [[1.1, 2.2, 3.3, 4.4]] * len(texts) + + return request.param( + embed=embed_func, aembed=aembed_func, aembed_many=aembed_many_func + ) + else: + return request.param() + + +@pytest.mark.requires_api_keys +@pytest.mark.asyncio +async def test_vectorizer_aembed(avectorizer): + text = "This is a test sentence." + embedding = await avectorizer.aembed(text) + + assert isinstance(embedding, list) + assert len(embedding) == avectorizer.dims + + +@pytest.mark.requires_api_keys +@pytest.mark.asyncio +async def test_vectorizer_aembed_many(avectorizer): + texts = ["This is the first test sentence.", "This is the second test sentence."] + embeddings = await avectorizer.aembed_many(texts) + + assert isinstance(embeddings, list) + assert len(embeddings) == len(texts) + assert all( + isinstance(emb, list) and len(emb) == avectorizer.dims for emb in embeddings + ) + + +@pytest.mark.requires_api_keys +@pytest.mark.asyncio +async def test_avectorizer_bad_input(avectorizer): + with pytest.raises(TypeError): + avectorizer.embed(1) + + with pytest.raises(TypeError): + avectorizer.embed({"foo": "bar"}) -# def test_custom_vectorizer_embed(custom_embed_class, custom_embed_func): -# custom_wrapper = CustomTextVectorizer(embed=custom_embed_func) -# embedding = custom_wrapper.embed("This is a test sentence.") -# assert embedding == [1.1, 2.2, 3.3, 4.4] - -# custom_wrapper = CustomTextVectorizer(embed=custom_embed_class().embed) -# embedding = custom_wrapper.embed("This is a test sentence.") -# assert embedding == [1.1, 2.2, 3.3, 4.4] - -# custom_wrapper = CustomTextVectorizer(embed=custom_embed_class().embed_with_args) -# embedding = custom_wrapper.embed("This is a test sentence.", max_len=4) -# assert embedding == [1.1, 2.2, 3.3, 4.4] -# embedding = custom_wrapper.embed("This is a test sentence.", max_len=2) -# assert embedding == [1.1, 2.2] - -# with pytest.raises(ValueError): -# invalid_vectorizer = CustomTextVectorizer(embed="hello") - -# with pytest.raises(ValueError): -# invalid_vectorizer = CustomTextVectorizer(embed=42) - -# with pytest.raises(ValueError): -# invalid_vectorizer = CustomTextVectorizer(embed={"foo": "bar"}) - -# def bad_arg_type(value: int): -# return [value] - -# with pytest.raises(ValueError): -# invalid_vectorizer = CustomTextVectorizer(embed=bad_arg_type) - -# def bad_return_type(text: str) -> str: -# return text - -# with pytest.raises(ValueError): -# invalid_vectorizer = CustomTextVectorizer(embed=bad_return_type) - - -# def test_custom_vectorizer_embed_many(custom_embed_class, custom_embed_func): -# custom_wrapper = CustomTextVectorizer( -# custom_embed_func, embed_many=custom_embed_class().embed_many -# ) -# embeddings = custom_wrapper.embed_many(["test one.", "test two"]) -# assert embeddings == [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] - -# custom_wrapper = CustomTextVectorizer( -# custom_embed_func, embed_many=custom_embed_class().embed_many -# ) -# embeddings = custom_wrapper.embed_many(["test one.", "test two"]) -# assert embeddings == [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] - -# custom_wrapper = CustomTextVectorizer( -# custom_embed_func, embed_many=custom_embed_class().embed_many_with_args -# ) -# embeddings = custom_wrapper.embed_many(["test one.", "test two"], param=True) -# assert embeddings == [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] -# embeddings = custom_wrapper.embed_many(["test one.", "test two"], param=False) -# assert embeddings == [[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]] - -# with pytest.raises(ValueError): -# invalid_vectorizer = CustomTextVectorizer(custom_embed_func, embed_many="hello") - -# with pytest.raises(ValueError): -# invalid_vectorizer = CustomTextVectorizer(custom_embed_func, embed_many=42) - -# with pytest.raises(ValueError): -# invalid_vectorizer = CustomTextVectorizer( -# custom_embed_func, embed_many={"foo": "bar"} -# ) - -# def bad_arg_type(value: int): -# return [value] - -# with pytest.raises(ValueError): -# invalid_vectorizer = CustomTextVectorizer( -# custom_embed_func, embed_many=bad_arg_type -# ) - -# def bad_return_type(text: str) -> str: -# return text - -# with pytest.raises(ValueError): -# invalid_vectorizer = CustomTextVectorizer( -# custom_embed_func, embed_many=bad_return_type -# ) - - -# @pytest.mark.requires_api_keys -# @pytest.mark.parametrize( -# "vectorizer_", -# [ -# AzureOpenAITextVectorizer, -# BedrockTextVectorizer, -# CohereTextVectorizer, -# CustomTextVectorizer, -# HFTextVectorizer, -# MistralAITextVectorizer, -# OpenAITextVectorizer, -# VertexAITextVectorizer, -# VoyageAITextVectorizer, -# ], -# ) -# def test_default_dtype(vectorizer_): -# # test dtype defaults to float32 -# if isinstance(vectorizer_, CustomTextVectorizer): -# vectorizer = vectorizer_(embed=lambda x, input_type=None: [1.0, 2.0, 3.0]) -# elif isinstance(vectorizer_, AzureOpenAITextVectorizer): -# vectorizer = vectorizer_( -# model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002") -# ) -# else: -# vectorizer = vectorizer_() - -# assert vectorizer.dtype == "float32" - - -# @pytest.mark.requires_api_keys -# @pytest.mark.parametrize( -# "vectorizer_", -# [ -# AzureOpenAITextVectorizer, -# BedrockTextVectorizer, -# CohereTextVectorizer, -# CustomTextVectorizer, -# HFTextVectorizer, -# MistralAITextVectorizer, -# OpenAITextVectorizer, -# VertexAITextVectorizer, -# VoyageAITextVectorizer, -# ], -# ) -# def test_vectorizer_dtype_assignment(vectorizer_): -# # test initializing dtype in constructor -# for dtype in ["float16", "float32", "float64", "bfloat16", "int8", "uint8"]: -# if isinstance(vectorizer_, CustomTextVectorizer): -# vectorizer = vectorizer_(embed=lambda x: [1.0, 2.0, 3.0], dtype=dtype) -# elif isinstance(vectorizer_, AzureOpenAITextVectorizer): -# vectorizer = vectorizer_( -# model=os.getenv( -# "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002" -# ), -# dtype=dtype, -# ) -# else: -# vectorizer = vectorizer_(dtype=dtype) - -# assert vectorizer.dtype == dtype - - -# @pytest.mark.requires_api_keys -# @pytest.mark.parametrize( -# "vectorizer_", -# [ -# AzureOpenAITextVectorizer, -# BedrockTextVectorizer, -# CohereTextVectorizer, -# HFTextVectorizer, -# MistralAITextVectorizer, -# OpenAITextVectorizer, -# VertexAITextVectorizer, -# VoyageAITextVectorizer, -# ], -# ) -# def test_non_supported_dtypes(vectorizer_): -# with pytest.raises(ValueError): -# vectorizer_(dtype="float25") - -# with pytest.raises(ValueError): -# vectorizer_(dtype=7) - -# with pytest.raises(ValueError): -# vectorizer_(dtype=None) - - -# @pytest.fixture( -# params=[ -# OpenAITextVectorizer, -# BedrockTextVectorizer, -# MistralAITextVectorizer, -# CustomTextVectorizer, -# VoyageAITextVectorizer, -# ] -# ) -# def avectorizer(request): -# if request.param == CustomTextVectorizer: - -# def embed_func(text): -# return [1.1, 2.2, 3.3, 4.4] - -# async def aembed_func(text): -# return [1.1, 2.2, 3.3, 4.4] - -# async def aembed_many_func(texts): -# return [[1.1, 2.2, 3.3, 4.4]] * len(texts) - -# return request.param( -# embed=embed_func, aembed=aembed_func, aembed_many=aembed_many_func -# ) -# else: -# return request.param() - - -# @pytest.mark.requires_api_keys -# @pytest.mark.asyncio -# async def test_vectorizer_aembed(avectorizer): -# text = "This is a test sentence." -# embedding = await avectorizer.aembed(text) - -# assert isinstance(embedding, list) -# assert len(embedding) == avectorizer.dims - - -# @pytest.mark.requires_api_keys -# @pytest.mark.asyncio -# async def test_vectorizer_aembed_many(avectorizer): -# texts = ["This is the first test sentence.", "This is the second test sentence."] -# embeddings = await avectorizer.aembed_many(texts) - -# assert isinstance(embeddings, list) -# assert len(embeddings) == len(texts) -# assert all( -# isinstance(emb, list) and len(emb) == avectorizer.dims for emb in embeddings -# ) - - -# @pytest.mark.requires_api_keys -# @pytest.mark.asyncio -# async def test_avectorizer_bad_input(avectorizer): -# with pytest.raises(TypeError): -# avectorizer.embed(1) - -# with pytest.raises(TypeError): -# avectorizer.embed({"foo": "bar"}) - -# with pytest.raises(TypeError): -# avectorizer.embed_many(42) + with pytest.raises(TypeError): + avectorizer.embed_many(42) @pytest.mark.requires_api_keys From 7f3a0fe34a3bff4643aa51f6682e54d64f44b51a Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 27 Feb 2025 14:16:52 -0500 Subject: [PATCH 4/5] reset default batch_size to 10 for all vectorizers --- redisvl/utils/vectorize/base.py | 4 ++-- redisvl/utils/vectorize/text/azureopenai.py | 2 +- redisvl/utils/vectorize/text/custom.py | 2 +- redisvl/utils/vectorize/text/huggingface.py | 4 ++-- redisvl/utils/vectorize/text/mistral.py | 2 +- redisvl/utils/vectorize/text/openai.py | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index c8a80677..189b6e1a 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -74,7 +74,7 @@ def embed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> Union[List[List[float]], List[bytes]]: @@ -96,7 +96,7 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> Union[List[List[float]], List[bytes]]: diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index ad4b7210..410280e5 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -267,7 +267,7 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> Union[List[List[float]], List[bytes]]: diff --git a/redisvl/utils/vectorize/text/custom.py b/redisvl/utils/vectorize/text/custom.py index 154468b9..ed284d29 100644 --- a/redisvl/utils/vectorize/text/custom.py +++ b/redisvl/utils/vectorize/text/custom.py @@ -288,7 +288,7 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> Union[List[List[float]], List[bytes]]: diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index e085cb9f..bafba41d 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -122,7 +122,7 @@ def embed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> Union[List[List[float]], List[bytes]]: @@ -134,7 +134,7 @@ def embed_many( preprocess (Optional[Callable], optional): Optional preprocessing callable to perform before vectorization. Defaults to None. batch_size (int, optional): Batch size of texts to use when creating - embeddings. Defaults to 1000. + embeddings. Defaults to 10. as_buffer (bool, optional): Whether to convert the raw embedding to a byte string. Defaults to False. diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py index db2c47ee..05133b37 100644 --- a/redisvl/utils/vectorize/text/mistral.py +++ b/redisvl/utils/vectorize/text/mistral.py @@ -217,7 +217,7 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> List[List[float]]: diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index 950993ce..eee0764a 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -218,7 +218,7 @@ async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, - batch_size: int = 1000, + batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> Union[List[List[float]], List[bytes]]: From 56931ac2a9eb0b1bd7f7ca26c2976bed6f83ff8e Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 27 Feb 2025 15:47:24 -0500 Subject: [PATCH 5/5] fix test --- tests/integration/test_vectorizers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index 330ffe1d..36e444de 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -261,9 +261,9 @@ def bad_return_type(text: str) -> str: ) def test_default_dtype(vectorizer_): # test dtype defaults to float32 - if isinstance(vectorizer_, CustomTextVectorizer): + if issubclass(vectorizer_, CustomTextVectorizer): vectorizer = vectorizer_(embed=lambda x, input_type=None: [1.0, 2.0, 3.0]) - elif isinstance(vectorizer_, AzureOpenAITextVectorizer): + elif issubclass(vectorizer_, AzureOpenAITextVectorizer): vectorizer = vectorizer_( model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002") ) @@ -291,9 +291,9 @@ def test_default_dtype(vectorizer_): def test_vectorizer_dtype_assignment(vectorizer_): # test initializing dtype in constructor for dtype in ["float16", "float32", "float64", "bfloat16", "int8", "uint8"]: - if isinstance(vectorizer_, CustomTextVectorizer): + if issubclass(vectorizer_, CustomTextVectorizer): vectorizer = vectorizer_(embed=lambda x: [1.0, 2.0, 3.0], dtype=dtype) - elif isinstance(vectorizer_, AzureOpenAITextVectorizer): + elif issubclass(vectorizer_, AzureOpenAITextVectorizer): vectorizer = vectorizer_( model=os.getenv( "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002"