Skip to content

Commit d589c3c

Browse files
committed
server: tests: embeddings, add dedicated feature and real model, if KV cache size exceeds batch size, embeddings differs
1 parent 61b6370 commit d589c3c

File tree

5 files changed

+64
-93
lines changed

5 files changed

+64
-93
lines changed

.github/workflows/server.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ jobs:
5858
cmake \
5959
python3-pip \
6060
wget \
61-
psmisc
61+
psmisc \
62+
language-pack-en
6263
6364
- name: Build
6465
id: cmake_build

examples/server/tests/features/parallel.feature

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ Feature: Parallel
99
And 512 as batch size
1010
And 64 KV cache size
1111
And 2 slots
12-
And embeddings extraction
1312
And continuous batching
1413
Then the server is starting
1514
Then the server is healthy
@@ -99,48 +98,3 @@ Feature: Parallel
9998
Then the server is busy
10099
Then the server is idle
101100
Then all prompts are predicted
102-
103-
Scenario: Multi users embeddings
104-
Given a prompt:
105-
"""
106-
Write a very long story about AI.
107-
"""
108-
And a prompt:
109-
"""
110-
Write another very long music lyrics.
111-
"""
112-
And a prompt:
113-
"""
114-
Write a very long poem.
115-
"""
116-
And a prompt:
117-
"""
118-
Write a very long joke.
119-
"""
120-
Given concurrent embedding requests
121-
Then the server is busy
122-
Then the server is idle
123-
Then all embeddings are generated
124-
125-
Scenario: Multi users OAI compatibility embeddings
126-
Given a prompt:
127-
"""
128-
In which country Paris is located ?
129-
"""
130-
And a prompt:
131-
"""
132-
Is Madrid the capital of Spain ?
133-
"""
134-
And a prompt:
135-
"""
136-
What is the biggest US city ?
137-
"""
138-
And a prompt:
139-
"""
140-
What is the capital of Bulgaria ?
141-
"""
142-
And a model tinyllama-2
143-
Given concurrent OAI embedding requests
144-
Then the server is busy
145-
Then the server is idle
146-
Then all embeddings are generated

examples/server/tests/features/server.feature

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -49,34 +49,6 @@ Feature: llama.cpp server
4949
| llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled |
5050
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled |
5151

52-
Scenario: Embedding
53-
When embeddings are computed for:
54-
"""
55-
What is the capital of Bulgaria ?
56-
"""
57-
Then embeddings are generated
58-
59-
Scenario: OAI Embeddings compatibility
60-
Given a model tinyllama-2
61-
When an OAI compatible embeddings computation request for:
62-
"""
63-
What is the capital of Spain ?
64-
"""
65-
Then embeddings are generated
66-
67-
Scenario: OAI Embeddings compatibility with multiple inputs
68-
Given a model tinyllama-2
69-
Given a prompt:
70-
"""
71-
In which country Paris is located ?
72-
"""
73-
And a prompt:
74-
"""
75-
Is Madrid the capital of Spain ?
76-
"""
77-
When an OAI compatible embeddings computation request for multiple inputs
78-
Then embeddings are generated
79-
8052
Scenario: Tokenize / Detokenize
8153
When tokenizing:
8254
"""

examples/server/tests/features/steps/steps.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from re import RegexFlag
1111

1212
import aiohttp
13+
import numpy as np
1314
import openai
1415
from behave import step
1516
from behave.api.async_step import async_run_until_complete
@@ -34,6 +35,7 @@ def step_server_config(context, server_fqdn, server_port):
3435
context.n_ga_w = None
3536
context.n_gpu_layer = None
3637
context.n_predict = None
38+
context.n_prompts = 0
3739
context.n_server_predict = None
3840
context.n_slots = None
3941
context.prompt_prefix = None
@@ -202,6 +204,7 @@ def step_n_tokens_predicted(context, predicted_n):
202204
@step(u'a user prompt {user_prompt}')
203205
def step_user_prompt(context, user_prompt):
204206
context.prompts.append(user_prompt)
207+
context.n_prompts = len(context.prompts)
205208

206209

207210
@step(u'a system prompt {system_prompt}')
@@ -289,6 +292,11 @@ def step_impl(context, n_ga_w):
289292
def step_prompt_passkey(context):
290293
context.prompt_passkey = context.text
291294

295+
@step(u'{n_prompts:d} fixed prompts')
296+
def step_fixed_prompts(context, n_prompts):
297+
context.prompts.extend([str(0)*1024 for i in range(n_prompts)])
298+
context.n_prompts = n_prompts
299+
292300

293301
@step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
294302
def step_prompt_passkey(context, passkey, i_pos):
@@ -301,6 +309,7 @@ def step_prompt_passkey(context, passkey, i_pos):
301309
passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
302310
print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
303311
context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
312+
context.n_prompts = len(context.prompts)
304313

305314

306315
@step(u'an OAI compatible chat completions request with {api_error} api error')
@@ -341,11 +350,13 @@ async def step_oai_chat_completions(context, api_error):
341350
@step(u'a prompt')
342351
def step_a_prompt(context):
343352
context.prompts.append(context.text)
353+
context.n_prompts = len(context.prompts)
344354

345355

346356
@step(u'a prompt {prompt}')
347357
def step_a_prompt_prompt(context, prompt):
348358
context.prompts.append(prompt)
359+
context.n_prompts = len(context.prompts)
349360

350361

351362
@step(u'concurrent completion requests')
@@ -430,25 +441,47 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
430441
@step(u'embeddings are computed for')
431442
@async_run_until_complete
432443
async def step_compute_embedding(context):
444+
context.n_prompts = 1
433445
context.embeddings = await request_embedding(context.text, base_url=context.base_url)
434446

435447

448+
@step(u'all embeddings are the same')
449+
@async_run_until_complete
450+
async def step_all_embeddings_are_the_same(context):
451+
n_embedding_requests = await gather_tasks_results(context)
452+
assert n_embedding_requests > 0
453+
embeddings = []
454+
for i in range(n_embedding_requests):
455+
embedding = context.tasks_result.pop().pop()
456+
embeddings.append(embedding)
457+
assert_embeddings(embedding)
458+
n = len(embeddings)
459+
for i in range(n-1):
460+
for j in range(i+1, n):
461+
embedding1 = np.array(embeddings[i])
462+
embedding2 = np.array(embeddings[j])
463+
if context.debug:
464+
print(f"embedding1: {embedding1[-8:]}\n")
465+
print(f"embedding2: {embedding2[-8:]}\n")
466+
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
467+
msg = f"Similarity between {i} and {j}: {similarity:.10f}"
468+
if context.debug:
469+
print(f"{msg}\n")
470+
assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
471+
436472
@step(u'embeddings are generated')
437473
def step_assert_embeddings(context):
438-
if len(context.prompts) == 0:
439-
assert_embeddings(context.embeddings)
440-
else:
441-
assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
442-
f"context.prompts={context.prompts}\n"
443-
f"context.embeddings={context.embeddings}")
444-
for embedding in context.embeddings:
445-
context.prompts.pop()
446-
assert_embeddings(embedding)
474+
assert context.n_prompts == len(context.embeddings), (f"unexpected response:\n"
475+
f"context.n_prompts={context.n_prompts}\n"
476+
f"context.embeddings={context.embeddings}")
477+
for embedding in context.embeddings:
478+
assert_embeddings(embedding)
447479

448480

449481
@step(u'an OAI compatible embeddings computation request for')
450482
@async_run_until_complete
451483
async def step_oai_compute_embeddings(context):
484+
context.n_prompts = 1
452485
context.embeddings = await request_oai_embeddings(context.text,
453486
base_url=context.base_url,
454487
user_api_key=context.user_api_key,
@@ -462,6 +495,7 @@ async def step_oai_compute_embeddings_multiple_inputs(context):
462495
base_url=context.base_url,
463496
user_api_key=context.user_api_key,
464497
model=context.model)
498+
context.prompts.clear()
465499

466500

467501
@step(u'concurrent embedding requests')
@@ -488,9 +522,9 @@ async def step_concurrent_oai_embedding_requests(context):
488522
@async_run_until_complete()
489523
async def all_embeddings_are_generated(context):
490524
n_embedding_requests = await gather_tasks_results(context)
491-
assert n_embedding_requests > 0
525+
assert n_embedding_requests == context.n_prompts
492526
for i in range(n_embedding_requests):
493-
assert_embeddings(context.tasks_result.pop())
527+
assert_embeddings(context.tasks_result.pop().pop())
494528

495529

496530
@step(u'tokenizing')
@@ -588,11 +622,11 @@ def step_supported_models(context, i_model, param, preposition, param_value):
588622

589623

590624
async def concurrent_requests(context, f_completion, *args, **kwargs):
591-
n_prompts = len(context.prompts)
625+
context.n_prompts = len(context.prompts)
592626
if context.debug:
593-
print(f"starting {n_prompts} concurrent completion requests...")
594-
assert n_prompts > 0
595-
for prompt_no in range(n_prompts):
627+
print(f"starting {context.n_prompts} concurrent completion requests...")
628+
assert context.n_prompts > 0
629+
for prompt_no in range(context.n_prompts):
596630
shifted_args = [context.prompts.pop(), *args]
597631
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
598632
await asyncio.sleep(0.1)
@@ -765,7 +799,7 @@ async def request_embedding(content, base_url=None):
765799
}) as response:
766800
assert response.status == 200
767801
response_json = await response.json()
768-
return response_json['embedding']
802+
return [response_json['embedding']]
769803

770804

771805
async def request_oai_embeddings(input,
@@ -775,6 +809,7 @@ async def request_oai_embeddings(input,
775809
user_api_key = user_api_key if user_api_key is not None else 'nope'
776810
if async_client:
777811
origin = 'llama.cpp'
812+
headers=[]
778813
if user_api_key is not None:
779814
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
780815
async with aiohttp.ClientSession() as session:
@@ -790,7 +825,13 @@ async def request_oai_embeddings(input,
790825
response_json = await response.json()
791826
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
792827
assert response_json['object'] == 'list'
793-
return response_json['data']
828+
if isinstance(input, collections.abc.Sequence):
829+
embeddings = []
830+
for an_oai_embeddings in response_json['data']:
831+
embeddings.append(an_oai_embeddings['embedding'])
832+
else:
833+
embeddings = [response_json['data']['embedding']]
834+
return embeddings
794835
else:
795836
openai.api_key = user_api_key
796837
openai.api_base = f'{base_url}/v1'
@@ -804,7 +845,7 @@ async def request_oai_embeddings(input,
804845
for an_oai_embeddings in oai_embeddings.data:
805846
embeddings.append(an_oai_embeddings.embedding)
806847
else:
807-
embeddings = oai_embeddings.data.embedding
848+
embeddings = [oai_embeddings.data.embedding]
808849
return embeddings
809850

810851

@@ -899,6 +940,8 @@ def assert_embeddings(embeddings):
899940
assert len(embeddings) > 0
900941
embeddings_computed = False
901942
for emb in embeddings:
943+
if not isinstance(emb, float):
944+
assert False, f"Bad embeddings: {embeddings}"
902945
if emb != 0:
903946
embeddings_computed = True
904947
assert embeddings_computed, f"Embeddings: {embeddings}"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
aiohttp~=3.9.3
22
behave~=1.2.6
33
huggingface_hub~=0.20.3
4+
numpy~=1.24.4
45
openai~=0.25.0
56
prometheus-client~=0.20.0

0 commit comments

Comments
 (0)