@@ -243,48 +243,58 @@ def bad_return_type(text: str) -> str:
243243 )
244244
245245
246- # @pytest.mark.requires_api_keys
247- # def test_dtypes(vectorizer):
248- # # # test dtype defaults to float32
249- # # if issubclass(vectorizer, CustomTextVectorizer):
250- # # vectorizer = vectorizer(embed=lambda x, input_type=None: [1.0, 2.0, 3.0])
251- # # elif issubclass(vectorizer, AzureOpenAITextVectorizer):
252- # # vectorizer = vectorizer(
253- # # model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002")
254- # # )
255- # # else:
256- # # vectorizer = vector_class()
257-
258- # assert vectorizer.dtype == "float32"
259-
260- # # test initializing dtype in constructor
261- # for dtype in ["float16", "float32", "float64", "bfloat16"]:
262- # if issubclass(vectorizer, CustomTextVectorizer):
263- # vectorizer = vectorizer(embed=lambda x: [1.0, 2.0, 3.0], dtype=dtype)
264- # elif issubclass(vectorizer, AzureOpenAITextVectorizer):
265- # vectorizer = vectorizer(
266- # model=os.getenv(
267- # "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002"
268- # ),
269- # dtype=dtype,
270- # )
271- # else:
272- # vectorizer = vectorizer(dtype=dtype)
273-
274- # assert vectorizer.dtype == dtype
275-
276- # # test validation of dtype on init
277- # if issubclass(vectorizer, CustomTextVectorizer):
278- # pytest.skip("skipping custom text vectorizer")
279-
280- # with pytest.raises(ValueError):
281- # vectorizer = vectorizer(dtype="float25")
282-
283- # with pytest.raises(ValueError):
284- # vectorizer = vectorizer(dtype=7)
285-
286- # with pytest.raises(ValueError):
287- # vectorizer = vectorizer(dtype=None)
246+ @pytest .mark .requires_api_keys
247+ @pytest .mark .parametrize (
248+ "vectorizer_" ,
249+ [
250+ AzureOpenAITextVectorizer ,
251+ BedrockTextVectorizer ,
252+ CohereTextVectorizer ,
253+ CustomTextVectorizer ,
254+ HFTextVectorizer ,
255+ MistralAITextVectorizer ,
256+ OpenAITextVectorizer ,
257+ VertexAITextVectorizer ,
258+ VoyageAITextVectorizer ,
259+ ],
260+ )
261+ def test_dtypes (vectorizer_ ):
262+ # test dtype defaults to float32
263+ if issubclass (vectorizer_ , CustomTextVectorizer ):
264+ vectorizer = vectorizer_ (embed = lambda x , input_type = None : [1.0 , 2.0 , 3.0 ])
265+ elif issubclass (vectorizer , AzureOpenAITextVectorizer ):
266+ vectorizer = vectorizer_ (
267+ model = os .getenv ("AZURE_OPENAI_DEPLOYMENT_NAME" , "text-embedding-ada-002" )
268+ )
269+ else :
270+ vectorizer = vectorizer_ ()
271+
272+ assert vectorizer .dtype == "float32"
273+
274+ # test initializing dtype in constructor
275+ for dtype in ["float16" , "float32" , "float64" , "bfloat16" ]:
276+ if issubclass (vectorizer_ , CustomTextVectorizer ):
277+ vectorizer = vectorizer_ (embed = lambda x : [1.0 , 2.0 , 3.0 ], dtype = dtype )
278+ elif issubclass (vectorizer_ , AzureOpenAITextVectorizer ):
279+ vectorizer = vectorizer_ (
280+ model = os .getenv (
281+ "AZURE_OPENAI_DEPLOYMENT_NAME" , "text-embedding-ada-002"
282+ ),
283+ dtype = dtype ,
284+ )
285+ else :
286+ vectorizer = vectorizer_ (dtype = dtype )
287+
288+ assert vectorizer .dtype == dtype
289+
290+ with pytest .raises (ValueError ):
291+ vectorizer = vectorizer (dtype = "float25" )
292+
293+ with pytest .raises (ValueError ):
294+ vectorizer = vectorizer (dtype = 7 )
295+
296+ with pytest .raises (ValueError ):
297+ vectorizer = vectorizer (dtype = None )
288298
289299
290300@pytest .fixture (
0 commit comments