1010from re import RegexFlag
1111
1212import aiohttp
13+ import numpy as np
1314import openai
1415from behave import step
1516from behave .api .async_step import async_run_until_complete
@@ -34,6 +35,7 @@ def step_server_config(context, server_fqdn, server_port):
3435 context .n_ga_w = None
3536 context .n_gpu_layer = None
3637 context .n_predict = None
38+ context .n_prompts = 0
3739 context .n_server_predict = None
3840 context .n_slots = None
3941 context .prompt_prefix = None
@@ -202,6 +204,7 @@ def step_n_tokens_predicted(context, predicted_n):
202204@step (u'a user prompt {user_prompt}' )
203205def step_user_prompt (context , user_prompt ):
204206 context .prompts .append (user_prompt )
207+ context .n_prompts = len (context .prompts )
205208
206209
207210@step (u'a system prompt {system_prompt}' )
@@ -289,6 +292,11 @@ def step_impl(context, n_ga_w):
289292def step_prompt_passkey (context ):
290293 context .prompt_passkey = context .text
291294
295+ @step (u'{n_prompts:d} fixed prompts' )
296+ def step_fixed_prompts (context , n_prompts ):
297+ context .prompts .extend ([str (0 )* 1024 for i in range (n_prompts )])
298+ context .n_prompts = n_prompts
299+
292300
293301@step (u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk' )
294302def step_prompt_passkey (context , passkey , i_pos ):
@@ -301,6 +309,7 @@ def step_prompt_passkey(context, passkey, i_pos):
301309 passkey_highlight = "\x1b [33m" + passkey + "\x1b [0m"
302310 print (f"Passkey challenge:\n ```{ prompt .replace (passkey , passkey_highlight )} ```\n " )
303311 context .prompts .append (context .prompt_prefix + prompt + context .prompt_suffix )
312+ context .n_prompts = len (context .prompts )
304313
305314
306315@step (u'an OAI compatible chat completions request with {api_error} api error' )
@@ -341,11 +350,13 @@ async def step_oai_chat_completions(context, api_error):
341350@step (u'a prompt' )
342351def step_a_prompt (context ):
343352 context .prompts .append (context .text )
353+ context .n_prompts = len (context .prompts )
344354
345355
346356@step (u'a prompt {prompt}' )
347357def step_a_prompt_prompt (context , prompt ):
348358 context .prompts .append (prompt )
359+ context .n_prompts = len (context .prompts )
349360
350361
351362@step (u'concurrent completion requests' )
@@ -430,25 +441,47 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
430441@step (u'embeddings are computed for' )
431442@async_run_until_complete
432443async def step_compute_embedding (context ):
444+ context .n_prompts = 1
433445 context .embeddings = await request_embedding (context .text , base_url = context .base_url )
434446
435447
448+ @step (u'all embeddings are the same' )
449+ @async_run_until_complete
450+ async def step_all_embeddings_are_the_same (context ):
451+ n_embedding_requests = await gather_tasks_results (context )
452+ assert n_embedding_requests > 0
453+ embeddings = []
454+ for i in range (n_embedding_requests ):
455+ embedding = context .tasks_result .pop ().pop ()
456+ embeddings .append (embedding )
457+ assert_embeddings (embedding )
458+ n = len (embeddings )
459+ for i in range (n - 1 ):
460+ for j in range (i + 1 , n ):
461+ embedding1 = np .array (embeddings [i ])
462+ embedding2 = np .array (embeddings [j ])
463+ if context .debug :
464+ print (f"embedding1: { embedding1 [- 8 :]} \n " )
465+ print (f"embedding2: { embedding2 [- 8 :]} \n " )
466+ similarity = np .dot (embedding1 , embedding2 ) / (np .linalg .norm (embedding1 ) * np .linalg .norm (embedding2 ))
467+ msg = f"Similarity between { i } and { j } : { similarity :.10f} "
468+ if context .debug :
469+ print (f"{ msg } \n " )
470+ assert np .isclose (similarity , 1.0 , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ), msg
471+
436472@step (u'embeddings are generated' )
437473def step_assert_embeddings (context ):
438- if len (context .prompts ) == 0 :
439- assert_embeddings (context .embeddings )
440- else :
441- assert len (context .embeddings ) == len (context .prompts ), (f"unexpected response:\n "
442- f"context.prompts={ context .prompts } \n "
443- f"context.embeddings={ context .embeddings } " )
444- for embedding in context .embeddings :
445- context .prompts .pop ()
446- assert_embeddings (embedding )
474+ assert context .n_prompts == len (context .embeddings ), (f"unexpected response:\n "
475+ f"context.n_prompts={ context .n_prompts } \n "
476+ f"context.embeddings={ context .embeddings } " )
477+ for embedding in context .embeddings :
478+ assert_embeddings (embedding )
447479
448480
449481@step (u'an OAI compatible embeddings computation request for' )
450482@async_run_until_complete
451483async def step_oai_compute_embeddings (context ):
484+ context .n_prompts = 1
452485 context .embeddings = await request_oai_embeddings (context .text ,
453486 base_url = context .base_url ,
454487 user_api_key = context .user_api_key ,
@@ -462,6 +495,7 @@ async def step_oai_compute_embeddings_multiple_inputs(context):
462495 base_url = context .base_url ,
463496 user_api_key = context .user_api_key ,
464497 model = context .model )
498+ context .prompts .clear ()
465499
466500
467501@step (u'concurrent embedding requests' )
@@ -488,9 +522,9 @@ async def step_concurrent_oai_embedding_requests(context):
488522@async_run_until_complete ()
489523async def all_embeddings_are_generated (context ):
490524 n_embedding_requests = await gather_tasks_results (context )
491- assert n_embedding_requests > 0
525+ assert n_embedding_requests == context . n_prompts
492526 for i in range (n_embedding_requests ):
493- assert_embeddings (context .tasks_result .pop ())
527+ assert_embeddings (context .tasks_result .pop (). pop () )
494528
495529
496530@step (u'tokenizing' )
@@ -588,11 +622,11 @@ def step_supported_models(context, i_model, param, preposition, param_value):
588622
589623
590624async def concurrent_requests (context , f_completion , * args , ** kwargs ):
591- n_prompts = len (context .prompts )
625+ context . n_prompts = len (context .prompts )
592626 if context .debug :
593- print (f"starting { n_prompts } concurrent completion requests..." )
594- assert n_prompts > 0
595- for prompt_no in range (n_prompts ):
627+ print (f"starting { context . n_prompts } concurrent completion requests..." )
628+ assert context . n_prompts > 0
629+ for prompt_no in range (context . n_prompts ):
596630 shifted_args = [context .prompts .pop (), * args ]
597631 context .concurrent_tasks .append (asyncio .create_task (f_completion (* shifted_args , ** kwargs )))
598632 await asyncio .sleep (0.1 )
@@ -765,7 +799,7 @@ async def request_embedding(content, base_url=None):
765799 }) as response :
766800 assert response .status == 200
767801 response_json = await response .json ()
768- return response_json ['embedding' ]
802+ return [ response_json ['embedding' ] ]
769803
770804
771805async def request_oai_embeddings (input ,
@@ -775,6 +809,7 @@ async def request_oai_embeddings(input,
775809 user_api_key = user_api_key if user_api_key is not None else 'nope'
776810 if async_client :
777811 origin = 'llama.cpp'
812+ headers = []
778813 if user_api_key is not None :
779814 headers = {'Authorization' : f'Bearer { user_api_key } ' , 'Origin' : origin }
780815 async with aiohttp .ClientSession () as session :
@@ -790,7 +825,13 @@ async def request_oai_embeddings(input,
790825 response_json = await response .json ()
791826 assert response_json ['model' ] == model , f"invalid model received: { response_json ['model' ]} "
792827 assert response_json ['object' ] == 'list'
793- return response_json ['data' ]
828+ if isinstance (input , collections .abc .Sequence ):
829+ embeddings = []
830+ for an_oai_embeddings in response_json ['data' ]:
831+ embeddings .append (an_oai_embeddings ['embedding' ])
832+ else :
833+ embeddings = [response_json ['data' ]['embedding' ]]
834+ return embeddings
794835 else :
795836 openai .api_key = user_api_key
796837 openai .api_base = f'{ base_url } /v1'
@@ -804,7 +845,7 @@ async def request_oai_embeddings(input,
804845 for an_oai_embeddings in oai_embeddings .data :
805846 embeddings .append (an_oai_embeddings .embedding )
806847 else :
807- embeddings = oai_embeddings .data .embedding
848+ embeddings = [ oai_embeddings .data .embedding ]
808849 return embeddings
809850
810851
@@ -899,6 +940,8 @@ def assert_embeddings(embeddings):
899940 assert len (embeddings ) > 0
900941 embeddings_computed = False
901942 for emb in embeddings :
943+ if not isinstance (emb , float ):
944+ assert False , f"Bad embeddings: { embeddings } "
902945 if emb != 0 :
903946 embeddings_computed = True
904947 assert embeddings_computed , f"Embeddings: { embeddings } "
0 commit comments