10
10
from re import RegexFlag
11
11
12
12
import aiohttp
13
+ import numpy as np
13
14
import openai
14
15
from behave import step
15
16
from behave .api .async_step import async_run_until_complete
@@ -34,6 +35,7 @@ def step_server_config(context, server_fqdn, server_port):
34
35
context .n_ga_w = None
35
36
context .n_gpu_layer = None
36
37
context .n_predict = None
38
+ context .n_prompts = 0
37
39
context .n_server_predict = None
38
40
context .n_slots = None
39
41
context .prompt_prefix = None
@@ -202,6 +204,7 @@ def step_n_tokens_predicted(context, predicted_n):
202
204
@step (u'a user prompt {user_prompt}' )
203
205
def step_user_prompt (context , user_prompt ):
204
206
context .prompts .append (user_prompt )
207
+ context .n_prompts = len (context .prompts )
205
208
206
209
207
210
@step (u'a system prompt {system_prompt}' )
@@ -289,6 +292,11 @@ def step_impl(context, n_ga_w):
289
292
def step_prompt_passkey (context ):
290
293
context .prompt_passkey = context .text
291
294
295
+ @step (u'{n_prompts:d} fixed prompts' )
296
+ def step_fixed_prompts (context , n_prompts ):
297
+ context .prompts .extend ([str (0 )* 1024 for i in range (n_prompts )])
298
+ context .n_prompts = n_prompts
299
+
292
300
293
301
@step (u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk' )
294
302
def step_prompt_passkey (context , passkey , i_pos ):
@@ -301,6 +309,7 @@ def step_prompt_passkey(context, passkey, i_pos):
301
309
passkey_highlight = "\x1b [33m" + passkey + "\x1b [0m"
302
310
print (f"Passkey challenge:\n ```{ prompt .replace (passkey , passkey_highlight )} ```\n " )
303
311
context .prompts .append (context .prompt_prefix + prompt + context .prompt_suffix )
312
+ context .n_prompts = len (context .prompts )
304
313
305
314
306
315
@step (u'an OAI compatible chat completions request with {api_error} api error' )
@@ -341,11 +350,13 @@ async def step_oai_chat_completions(context, api_error):
341
350
@step (u'a prompt' )
342
351
def step_a_prompt (context ):
343
352
context .prompts .append (context .text )
353
+ context .n_prompts = len (context .prompts )
344
354
345
355
346
356
@step (u'a prompt {prompt}' )
347
357
def step_a_prompt_prompt (context , prompt ):
348
358
context .prompts .append (prompt )
359
+ context .n_prompts = len (context .prompts )
349
360
350
361
351
362
@step (u'concurrent completion requests' )
@@ -430,25 +441,47 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
430
441
@step (u'embeddings are computed for' )
431
442
@async_run_until_complete
432
443
async def step_compute_embedding (context ):
444
+ context .n_prompts = 1
433
445
context .embeddings = await request_embedding (context .text , base_url = context .base_url )
434
446
435
447
448
+ @step (u'all embeddings are the same' )
449
+ @async_run_until_complete
450
+ async def step_all_embeddings_are_the_same (context ):
451
+ n_embedding_requests = await gather_tasks_results (context )
452
+ assert n_embedding_requests > 0
453
+ embeddings = []
454
+ for i in range (n_embedding_requests ):
455
+ embedding = context .tasks_result .pop ().pop ()
456
+ embeddings .append (embedding )
457
+ assert_embeddings (embedding )
458
+ n = len (embeddings )
459
+ for i in range (n - 1 ):
460
+ for j in range (i + 1 , n ):
461
+ embedding1 = np .array (embeddings [i ])
462
+ embedding2 = np .array (embeddings [j ])
463
+ if context .debug :
464
+ print (f"embedding1: { embedding1 [- 8 :]} \n " )
465
+ print (f"embedding2: { embedding2 [- 8 :]} \n " )
466
+ similarity = np .dot (embedding1 , embedding2 ) / (np .linalg .norm (embedding1 ) * np .linalg .norm (embedding2 ))
467
+ msg = f"Similarity between { i } and { j } : { similarity :.10f} "
468
+ if context .debug :
469
+ print (f"{ msg } \n " )
470
+ assert np .isclose (similarity , 1.0 , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ), msg
471
+
436
472
@step (u'embeddings are generated' )
437
473
def step_assert_embeddings (context ):
438
- if len (context .prompts ) == 0 :
439
- assert_embeddings (context .embeddings )
440
- else :
441
- assert len (context .embeddings ) == len (context .prompts ), (f"unexpected response:\n "
442
- f"context.prompts={ context .prompts } \n "
443
- f"context.embeddings={ context .embeddings } " )
444
- for embedding in context .embeddings :
445
- context .prompts .pop ()
446
- assert_embeddings (embedding )
474
+ assert context .n_prompts == len (context .embeddings ), (f"unexpected response:\n "
475
+ f"context.n_prompts={ context .n_prompts } \n "
476
+ f"context.embeddings={ context .embeddings } " )
477
+ for embedding in context .embeddings :
478
+ assert_embeddings (embedding )
447
479
448
480
449
481
@step (u'an OAI compatible embeddings computation request for' )
450
482
@async_run_until_complete
451
483
async def step_oai_compute_embeddings (context ):
484
+ context .n_prompts = 1
452
485
context .embeddings = await request_oai_embeddings (context .text ,
453
486
base_url = context .base_url ,
454
487
user_api_key = context .user_api_key ,
@@ -462,6 +495,7 @@ async def step_oai_compute_embeddings_multiple_inputs(context):
462
495
base_url = context .base_url ,
463
496
user_api_key = context .user_api_key ,
464
497
model = context .model )
498
+ context .prompts .clear ()
465
499
466
500
467
501
@step (u'concurrent embedding requests' )
@@ -488,9 +522,9 @@ async def step_concurrent_oai_embedding_requests(context):
488
522
@async_run_until_complete ()
489
523
async def all_embeddings_are_generated (context ):
490
524
n_embedding_requests = await gather_tasks_results (context )
491
- assert n_embedding_requests > 0
525
+ assert n_embedding_requests == context . n_prompts
492
526
for i in range (n_embedding_requests ):
493
- assert_embeddings (context .tasks_result .pop ())
527
+ assert_embeddings (context .tasks_result .pop (). pop () )
494
528
495
529
496
530
@step (u'tokenizing' )
@@ -588,11 +622,11 @@ def step_supported_models(context, i_model, param, preposition, param_value):
588
622
589
623
590
624
async def concurrent_requests (context , f_completion , * args , ** kwargs ):
591
- n_prompts = len (context .prompts )
625
+ context . n_prompts = len (context .prompts )
592
626
if context .debug :
593
- print (f"starting { n_prompts } concurrent completion requests..." )
594
- assert n_prompts > 0
595
- for prompt_no in range (n_prompts ):
627
+ print (f"starting { context . n_prompts } concurrent completion requests..." )
628
+ assert context . n_prompts > 0
629
+ for prompt_no in range (context . n_prompts ):
596
630
shifted_args = [context .prompts .pop (), * args ]
597
631
context .concurrent_tasks .append (asyncio .create_task (f_completion (* shifted_args , ** kwargs )))
598
632
await asyncio .sleep (0.1 )
@@ -765,7 +799,7 @@ async def request_embedding(content, base_url=None):
765
799
}) as response :
766
800
assert response .status == 200
767
801
response_json = await response .json ()
768
- return response_json ['embedding' ]
802
+ return [ response_json ['embedding' ] ]
769
803
770
804
771
805
async def request_oai_embeddings (input ,
@@ -775,6 +809,7 @@ async def request_oai_embeddings(input,
775
809
user_api_key = user_api_key if user_api_key is not None else 'nope'
776
810
if async_client :
777
811
origin = 'llama.cpp'
812
+ headers = []
778
813
if user_api_key is not None :
779
814
headers = {'Authorization' : f'Bearer { user_api_key } ' , 'Origin' : origin }
780
815
async with aiohttp .ClientSession () as session :
@@ -790,7 +825,13 @@ async def request_oai_embeddings(input,
790
825
response_json = await response .json ()
791
826
assert response_json ['model' ] == model , f"invalid model received: { response_json ['model' ]} "
792
827
assert response_json ['object' ] == 'list'
793
- return response_json ['data' ]
828
+ if isinstance (input , collections .abc .Sequence ):
829
+ embeddings = []
830
+ for an_oai_embeddings in response_json ['data' ]:
831
+ embeddings .append (an_oai_embeddings ['embedding' ])
832
+ else :
833
+ embeddings = [response_json ['data' ]['embedding' ]]
834
+ return embeddings
794
835
else :
795
836
openai .api_key = user_api_key
796
837
openai .api_base = f'{ base_url } /v1'
@@ -804,7 +845,7 @@ async def request_oai_embeddings(input,
804
845
for an_oai_embeddings in oai_embeddings .data :
805
846
embeddings .append (an_oai_embeddings .embedding )
806
847
else :
807
- embeddings = oai_embeddings .data .embedding
848
+ embeddings = [ oai_embeddings .data .embedding ]
808
849
return embeddings
809
850
810
851
@@ -899,6 +940,8 @@ def assert_embeddings(embeddings):
899
940
assert len (embeddings ) > 0
900
941
embeddings_computed = False
901
942
for emb in embeddings :
943
+ if not isinstance (emb , float ):
944
+ assert False , f"Bad embeddings: { embeddings } "
902
945
if emb != 0 :
903
946
embeddings_computed = True
904
947
assert embeddings_computed , f"Embeddings: { embeddings } "
0 commit comments