Skip to content

Commit 81e04c7

Browse files
committed
Refactor test_openai_embeddings
1 parent 1d4d263 commit 81e04c7

File tree

1 file changed

+73
-184
lines changed

1 file changed

+73
-184
lines changed

tests/integration/inference/test_openai_embeddings.py

Lines changed: 73 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,6 @@
1212

1313
from llama_stack.core.library_client import LlamaStackAsLibraryClient
1414

15-
16-
def decode_base64_to_floats(base64_string: str) -> list[float]:
17-
"""Helper function to decode base64 string to list of float32 values."""
18-
embedding_bytes = base64.b64decode(base64_string)
19-
float_count = len(embedding_bytes) // 4 # 4 bytes per float32
20-
embedding_floats = struct.unpack(f"{float_count}f", embedding_bytes)
21-
return list(embedding_floats)
22-
23-
2415
ASYMMETRIC_EMBEDDING_MODELS_BY_PROVIDER = {
2516
"remote::nvidia": [
2617
"nvidia/llama-3.2-nv-embedqa-1b-v2",
@@ -31,6 +22,14 @@ def decode_base64_to_floats(base64_string: str) -> list[float]:
3122
}
3223

3324

25+
def decode_base64_to_floats(base64_string: str) -> list[float]:
26+
"""Helper function to decode base64 string to list of float32 values."""
27+
embedding_bytes = base64.b64decode(base64_string)
28+
float_count = len(embedding_bytes) // 4 # 4 bytes per float32
29+
embedding_floats = struct.unpack(f"{float_count}f", embedding_bytes)
30+
return list(embedding_floats)
31+
32+
3433
def provider_from_model(client_with_models, model_id):
3534
models = {m.identifier: m for m in client_with_models.models.list()}
3635
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
@@ -50,6 +49,9 @@ def is_asymmetric_model(client_with_models, model_id):
5049

5150

5251
def get_extra_body_for_model(client_with_models, model_id, input_type="query"):
52+
if not is_asymmetric_model(client_with_models, model_id):
53+
return None
54+
5355
provider = provider_from_model(client_with_models, model_id)
5456

5557
if provider.provider_type == "remote::nvidia":
@@ -149,27 +151,13 @@ def test_openai_embeddings_single_string(compat_client, client_with_models, embe
149151
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
150152

151153
input_text = "Hello, world!"
152-
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
153154

154-
# For asymmetric models, verify that calling without extra_body raises an error
155-
if is_asymmetric_model(client_with_models, embedding_model_id):
156-
kwargs_without_extra = {
157-
"model": embedding_model_id,
158-
"input": input_text,
159-
"encoding_format": "float",
160-
}
161-
with pytest.raises(Exception): # noqa: B017
162-
compat_client.embeddings.create(**kwargs_without_extra)
163-
164-
kwargs = {
165-
"model": embedding_model_id,
166-
"input": input_text,
167-
"encoding_format": "float",
168-
}
169-
if is_asymmetric_model(client_with_models, embedding_model_id):
170-
kwargs["extra_body"] = extra_body
171-
172-
response = compat_client.embeddings.create(**kwargs)
155+
response = compat_client.embeddings.create(
156+
model=embedding_model_id,
157+
input=input_text,
158+
encoding_format="float",
159+
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
160+
)
173161

174162
assert response.object == "list"
175163

@@ -188,26 +176,13 @@ def test_openai_embeddings_multiple_strings(compat_client, client_with_models, e
188176
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
189177

190178
input_texts = ["Hello, world!", "How are you today?", "This is a test."]
191-
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
192179

193-
if is_asymmetric_model(client_with_models, embedding_model_id):
194-
kwargs_without_extra = {
195-
"model": embedding_model_id,
196-
"input": input_texts,
197-
"encoding_format": "float",
198-
}
199-
with pytest.raises(Exception): # noqa: B017
200-
compat_client.embeddings.create(**kwargs_without_extra)
201-
202-
kwargs = {
203-
"model": embedding_model_id,
204-
"input": input_texts,
205-
"encoding_format": "float",
206-
}
207-
if is_asymmetric_model(client_with_models, embedding_model_id):
208-
kwargs["extra_body"] = extra_body
209-
210-
response = compat_client.embeddings.create(**kwargs)
180+
response = compat_client.embeddings.create(
181+
model=embedding_model_id,
182+
input=input_texts,
183+
encoding_format="float",
184+
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
185+
)
211186

212187
assert response.object == "list"
213188

@@ -228,26 +203,13 @@ def test_openai_embeddings_with_encoding_format_float(compat_client, client_with
228203
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
229204

230205
input_text = "Test encoding format"
231-
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
232206

233-
if is_asymmetric_model(client_with_models, embedding_model_id):
234-
kwargs_without_extra = {
235-
"model": embedding_model_id,
236-
"input": input_text,
237-
"encoding_format": "float",
238-
}
239-
with pytest.raises(Exception): # noqa: B017
240-
compat_client.embeddings.create(**kwargs_without_extra)
241-
242-
kwargs = {
243-
"model": embedding_model_id,
244-
"input": input_text,
245-
"encoding_format": "float",
246-
}
247-
if is_asymmetric_model(client_with_models, embedding_model_id):
248-
kwargs["extra_body"] = extra_body
249-
250-
response = compat_client.embeddings.create(**kwargs)
207+
response = compat_client.embeddings.create(
208+
model=embedding_model_id,
209+
input=input_text,
210+
encoding_format="float",
211+
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
212+
)
251213

252214
assert response.object == "list"
253215
assert len(response.data) == 1
@@ -262,26 +224,13 @@ def test_openai_embeddings_with_dimensions(compat_client, client_with_models, em
262224

263225
input_text = "Test dimensions parameter"
264226
dimensions = 16
265-
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
266227

267-
if is_asymmetric_model(client_with_models, embedding_model_id):
268-
kwargs_without_extra = {
269-
"model": embedding_model_id,
270-
"input": input_text,
271-
"dimensions": dimensions,
272-
}
273-
with pytest.raises(Exception): # noqa: B017
274-
compat_client.embeddings.create(**kwargs_without_extra)
275-
276-
kwargs = {
277-
"model": embedding_model_id,
278-
"input": input_text,
279-
"dimensions": dimensions,
280-
}
281-
if is_asymmetric_model(client_with_models, embedding_model_id):
282-
kwargs["extra_body"] = extra_body
283-
284-
response = compat_client.embeddings.create(**kwargs)
228+
response = compat_client.embeddings.create(
229+
model=embedding_model_id,
230+
input=input_text,
231+
dimensions=dimensions,
232+
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
233+
)
285234

286235
assert response.object == "list"
287236
assert len(response.data) == 1
@@ -297,26 +246,13 @@ def test_openai_embeddings_with_user_parameter(compat_client, client_with_models
297246

298247
input_text = "Test user parameter"
299248
user_id = "test-user-123"
300-
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
301249

302-
if is_asymmetric_model(client_with_models, embedding_model_id):
303-
kwargs_without_extra = {
304-
"model": embedding_model_id,
305-
"input": input_text,
306-
"user": user_id,
307-
}
308-
with pytest.raises(Exception): # noqa: B017
309-
compat_client.embeddings.create(**kwargs_without_extra)
310-
311-
kwargs = {
312-
"model": embedding_model_id,
313-
"input": input_text,
314-
"user": user_id,
315-
}
316-
if is_asymmetric_model(client_with_models, embedding_model_id):
317-
kwargs["extra_body"] = extra_body
318-
319-
response = compat_client.embeddings.create(**kwargs)
250+
response = compat_client.embeddings.create(
251+
model=embedding_model_id,
252+
input=input_text,
253+
user=user_id,
254+
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
255+
)
320256

321257
assert response.object == "list"
322258
assert len(response.data) == 1
@@ -328,17 +264,12 @@ def test_openai_embeddings_empty_list_error(compat_client, client_with_models, e
328264
"""Test that empty list input raises an appropriate error."""
329265
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
330266

331-
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
332-
333-
kwargs = {
334-
"model": embedding_model_id,
335-
"input": [],
336-
}
337-
if is_asymmetric_model(client_with_models, embedding_model_id):
338-
kwargs["extra_body"] = extra_body
339-
340267
with pytest.raises(Exception): # noqa: B017
341-
compat_client.embeddings.create(**kwargs)
268+
compat_client.embeddings.create(
269+
model=embedding_model_id,
270+
input=[],
271+
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
272+
)
342273

343274

344275
def test_openai_embeddings_invalid_model_error(compat_client, client_with_models, embedding_model_id):
@@ -349,6 +280,7 @@ def test_openai_embeddings_invalid_model_error(compat_client, client_with_models
349280
compat_client.embeddings.create(
350281
model="invalid-model-id",
351282
input="Test text",
283+
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
352284
)
353285

354286

@@ -358,35 +290,20 @@ def test_openai_embeddings_different_inputs_different_outputs(compat_client, cli
358290

359291
input_text1 = "This is the first text"
360292
input_text2 = "This is completely different content"
361-
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
362293

363-
if is_asymmetric_model(client_with_models, embedding_model_id):
364-
kwargs_without_extra = {
365-
"model": embedding_model_id,
366-
"input": input_text1,
367-
"encoding_format": "float",
368-
}
369-
with pytest.raises(Exception): # noqa: B017
370-
compat_client.embeddings.create(**kwargs_without_extra)
371-
372-
kwargs1 = {
373-
"model": embedding_model_id,
374-
"input": input_text1,
375-
"encoding_format": "float",
376-
}
377-
if is_asymmetric_model(client_with_models, embedding_model_id):
378-
kwargs1["extra_body"] = extra_body
379-
380-
kwargs2 = {
381-
"model": embedding_model_id,
382-
"input": input_text2,
383-
"encoding_format": "float",
384-
}
385-
if is_asymmetric_model(client_with_models, embedding_model_id):
386-
kwargs2["extra_body"] = extra_body
387-
388-
response1 = compat_client.embeddings.create(**kwargs1)
389-
response2 = compat_client.embeddings.create(**kwargs2)
294+
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
295+
response1 = compat_client.embeddings.create(
296+
model=embedding_model_id,
297+
input=input_text1,
298+
encoding_format="float",
299+
extra_body=extra_body,
300+
)
301+
response2 = compat_client.embeddings.create(
302+
model=embedding_model_id,
303+
input=input_text2,
304+
encoding_format="float",
305+
extra_body=extra_body,
306+
)
390307

391308
embedding1 = response1.data[0].embedding
392309
embedding2 = response2.data[0].embedding
@@ -404,28 +321,14 @@ def test_openai_embeddings_with_encoding_format_base64(compat_client, client_wit
404321

405322
input_text = "Test base64 encoding format"
406323
dimensions = 12
407-
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
408324

409-
if is_asymmetric_model(client_with_models, embedding_model_id):
410-
kwargs_without_extra = {
411-
"model": embedding_model_id,
412-
"input": input_text,
413-
"encoding_format": "base64",
414-
"dimensions": dimensions,
415-
}
416-
with pytest.raises(Exception): # noqa: B017
417-
compat_client.embeddings.create(**kwargs_without_extra)
418-
419-
kwargs = {
420-
"model": embedding_model_id,
421-
"input": input_text,
422-
"encoding_format": "base64",
423-
"dimensions": dimensions,
424-
}
425-
if is_asymmetric_model(client_with_models, embedding_model_id):
426-
kwargs["extra_body"] = extra_body
427-
428-
response = compat_client.embeddings.create(**kwargs)
325+
response = compat_client.embeddings.create(
326+
model=embedding_model_id,
327+
input=input_text,
328+
encoding_format="base64",
329+
dimensions=dimensions,
330+
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
331+
)
429332

430333
# Validate response structure
431334
assert response.object == "list"
@@ -451,27 +354,13 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
451354
skip_if_model_doesnt_support_encoding_format_base64(client_with_models, embedding_model_id)
452355

453356
input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]
454-
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
455-
456-
if is_asymmetric_model(client_with_models, embedding_model_id):
457-
kwargs_without_extra = {
458-
"model": embedding_model_id,
459-
"input": input_texts,
460-
"encoding_format": "base64",
461-
}
462-
with pytest.raises(Exception): # noqa: B017
463-
compat_client.embeddings.create(**kwargs_without_extra)
464-
465-
kwargs = {
466-
"model": embedding_model_id,
467-
"input": input_texts,
468-
"encoding_format": "base64",
469-
}
470-
if is_asymmetric_model(client_with_models, embedding_model_id):
471-
kwargs["extra_body"] = extra_body
472-
473-
response = compat_client.embeddings.create(**kwargs)
474357

358+
response = compat_client.embeddings.create(
359+
model=embedding_model_id,
360+
input=input_texts,
361+
encoding_format="base64",
362+
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
363+
)
475364
# Validate response structure
476365
assert response.object == "list"
477366

0 commit comments

Comments
 (0)