Skip to content

Commit 44df263

Browse files
uncomment test
1 parent bf12a61 commit 44df263

File tree

1 file changed

+52
-42
lines changed

1 file changed

+52
-42
lines changed

tests/integration/test_vectorizers.py

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -243,48 +243,58 @@ def bad_return_type(text: str) -> str:
243243
)
244244

245245

246-
# @pytest.mark.requires_api_keys
247-
# def test_dtypes(vectorizer):
248-
# # # test dtype defaults to float32
249-
# # if issubclass(vectorizer, CustomTextVectorizer):
250-
# # vectorizer = vectorizer(embed=lambda x, input_type=None: [1.0, 2.0, 3.0])
251-
# # elif issubclass(vectorizer, AzureOpenAITextVectorizer):
252-
# # vectorizer = vectorizer(
253-
# # model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002")
254-
# # )
255-
# # else:
256-
# # vectorizer = vector_class()
257-
258-
# assert vectorizer.dtype == "float32"
259-
260-
# # test initializing dtype in constructor
261-
# for dtype in ["float16", "float32", "float64", "bfloat16"]:
262-
# if issubclass(vectorizer, CustomTextVectorizer):
263-
# vectorizer = vectorizer(embed=lambda x: [1.0, 2.0, 3.0], dtype=dtype)
264-
# elif issubclass(vectorizer, AzureOpenAITextVectorizer):
265-
# vectorizer = vectorizer(
266-
# model=os.getenv(
267-
# "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002"
268-
# ),
269-
# dtype=dtype,
270-
# )
271-
# else:
272-
# vectorizer = vectorizer(dtype=dtype)
273-
274-
# assert vectorizer.dtype == dtype
275-
276-
# # test validation of dtype on init
277-
# if issubclass(vectorizer, CustomTextVectorizer):
278-
# pytest.skip("skipping custom text vectorizer")
279-
280-
# with pytest.raises(ValueError):
281-
# vectorizer = vectorizer(dtype="float25")
282-
283-
# with pytest.raises(ValueError):
284-
# vectorizer = vectorizer(dtype=7)
285-
286-
# with pytest.raises(ValueError):
287-
# vectorizer = vectorizer(dtype=None)
246+
@pytest.mark.requires_api_keys
247+
@pytest.mark.parametrize(
248+
"vectorizer_",
249+
[
250+
AzureOpenAITextVectorizer,
251+
BedrockTextVectorizer,
252+
CohereTextVectorizer,
253+
CustomTextVectorizer,
254+
HFTextVectorizer,
255+
MistralAITextVectorizer,
256+
OpenAITextVectorizer,
257+
VertexAITextVectorizer,
258+
VoyageAITextVectorizer,
259+
],
260+
)
261+
def test_dtypes(vectorizer_):
262+
# test dtype defaults to float32
263+
if issubclass(vectorizer_, CustomTextVectorizer):
264+
vectorizer = vectorizer_(embed=lambda x, input_type=None: [1.0, 2.0, 3.0])
265+
elif issubclass(vectorizer, AzureOpenAITextVectorizer):
266+
vectorizer = vectorizer_(
267+
model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002")
268+
)
269+
else:
270+
vectorizer = vectorizer_()
271+
272+
assert vectorizer.dtype == "float32"
273+
274+
# test initializing dtype in constructor
275+
for dtype in ["float16", "float32", "float64", "bfloat16"]:
276+
if issubclass(vectorizer_, CustomTextVectorizer):
277+
vectorizer = vectorizer_(embed=lambda x: [1.0, 2.0, 3.0], dtype=dtype)
278+
elif issubclass(vectorizer_, AzureOpenAITextVectorizer):
279+
vectorizer = vectorizer_(
280+
model=os.getenv(
281+
"AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002"
282+
),
283+
dtype=dtype,
284+
)
285+
else:
286+
vectorizer = vectorizer_(dtype=dtype)
287+
288+
assert vectorizer.dtype == dtype
289+
290+
with pytest.raises(ValueError):
291+
vectorizer = vectorizer(dtype="float25")
292+
293+
with pytest.raises(ValueError):
294+
vectorizer = vectorizer(dtype=7)
295+
296+
with pytest.raises(ValueError):
297+
vectorizer = vectorizer(dtype=None)
288298

289299

290300
@pytest.fixture(

0 commit comments

Comments
 (0)