|  | 
|  | 1 | +import openai | 
|  | 2 | +import pytest | 
|  | 3 | +import ray | 
|  | 4 | + | 
|  | 5 | +from ..utils import VLLM_PATH, RemoteOpenAIServer | 
|  | 6 | + | 
|  | 7 | +EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" | 
|  | 8 | + | 
|  | 9 | +pytestmark = pytest.mark.openai | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +@pytest.fixture(scope="module") | 
|  | 13 | +def ray_ctx(): | 
|  | 14 | +    ray.init(runtime_env={"working_dir": VLLM_PATH}) | 
|  | 15 | +    yield | 
|  | 16 | +    ray.shutdown() | 
|  | 17 | + | 
|  | 18 | + | 
|  | 19 | +@pytest.fixture(scope="module") | 
|  | 20 | +def embedding_server(ray_ctx): | 
|  | 21 | +    return RemoteOpenAIServer([ | 
|  | 22 | +        "--model", | 
|  | 23 | +        EMBEDDING_MODEL_NAME, | 
|  | 24 | +        # use half precision for speed and memory savings in CI environment | 
|  | 25 | +        "--dtype", | 
|  | 26 | +        "bfloat16", | 
|  | 27 | +        "--enforce-eager", | 
|  | 28 | +        "--max-model-len", | 
|  | 29 | +        "8192", | 
|  | 30 | +        "--enforce-eager", | 
|  | 31 | +    ]) | 
|  | 32 | + | 
|  | 33 | + | 
|  | 34 | +@pytest.mark.asyncio | 
|  | 35 | +@pytest.fixture(scope="module") | 
|  | 36 | +def embedding_client(embedding_server): | 
|  | 37 | +    return embedding_server.get_async_client() | 
|  | 38 | + | 
|  | 39 | + | 
|  | 40 | +@pytest.mark.asyncio | 
|  | 41 | +@pytest.mark.parametrize( | 
|  | 42 | +    "model_name", | 
|  | 43 | +    [EMBEDDING_MODEL_NAME], | 
|  | 44 | +) | 
|  | 45 | +async def test_single_embedding(embedding_client: openai.AsyncOpenAI, | 
|  | 46 | +                                model_name: str): | 
|  | 47 | +    input_texts = [ | 
|  | 48 | +        "The chef prepared a delicious meal.", | 
|  | 49 | +    ] | 
|  | 50 | + | 
|  | 51 | +    # test single embedding | 
|  | 52 | +    embeddings = await embedding_client.embeddings.create( | 
|  | 53 | +        model=model_name, | 
|  | 54 | +        input=input_texts, | 
|  | 55 | +        encoding_format="float", | 
|  | 56 | +    ) | 
|  | 57 | +    assert embeddings.id is not None | 
|  | 58 | +    assert len(embeddings.data) == 1 | 
|  | 59 | +    assert len(embeddings.data[0].embedding) == 4096 | 
|  | 60 | +    assert embeddings.usage.completion_tokens == 0 | 
|  | 61 | +    assert embeddings.usage.prompt_tokens == 9 | 
|  | 62 | +    assert embeddings.usage.total_tokens == 9 | 
|  | 63 | + | 
|  | 64 | +    # test using token IDs | 
|  | 65 | +    input_tokens = [1, 1, 1, 1, 1] | 
|  | 66 | +    embeddings = await embedding_client.embeddings.create( | 
|  | 67 | +        model=model_name, | 
|  | 68 | +        input=input_tokens, | 
|  | 69 | +        encoding_format="float", | 
|  | 70 | +    ) | 
|  | 71 | +    assert embeddings.id is not None | 
|  | 72 | +    assert len(embeddings.data) == 1 | 
|  | 73 | +    assert len(embeddings.data[0].embedding) == 4096 | 
|  | 74 | +    assert embeddings.usage.completion_tokens == 0 | 
|  | 75 | +    assert embeddings.usage.prompt_tokens == 5 | 
|  | 76 | +    assert embeddings.usage.total_tokens == 5 | 
|  | 77 | + | 
|  | 78 | + | 
|  | 79 | +@pytest.mark.asyncio | 
|  | 80 | +@pytest.mark.parametrize( | 
|  | 81 | +    "model_name", | 
|  | 82 | +    [EMBEDDING_MODEL_NAME], | 
|  | 83 | +) | 
|  | 84 | +async def test_batch_embedding(embedding_client: openai.AsyncOpenAI, | 
|  | 85 | +                               model_name: str): | 
|  | 86 | +    # test List[str] | 
|  | 87 | +    input_texts = [ | 
|  | 88 | +        "The cat sat on the mat.", "A feline was resting on a rug.", | 
|  | 89 | +        "Stars twinkle brightly in the night sky." | 
|  | 90 | +    ] | 
|  | 91 | +    embeddings = await embedding_client.embeddings.create( | 
|  | 92 | +        model=model_name, | 
|  | 93 | +        input=input_texts, | 
|  | 94 | +        encoding_format="float", | 
|  | 95 | +    ) | 
|  | 96 | +    assert embeddings.id is not None | 
|  | 97 | +    assert len(embeddings.data) == 3 | 
|  | 98 | +    assert len(embeddings.data[0].embedding) == 4096 | 
|  | 99 | + | 
|  | 100 | +    # test List[List[int]] | 
|  | 101 | +    input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], | 
|  | 102 | +                    [25, 32, 64, 77]] | 
|  | 103 | +    embeddings = await embedding_client.embeddings.create( | 
|  | 104 | +        model=model_name, | 
|  | 105 | +        input=input_tokens, | 
|  | 106 | +        encoding_format="float", | 
|  | 107 | +    ) | 
|  | 108 | +    assert embeddings.id is not None | 
|  | 109 | +    assert len(embeddings.data) == 4 | 
|  | 110 | +    assert len(embeddings.data[0].embedding) == 4096 | 
|  | 111 | +    assert embeddings.usage.completion_tokens == 0 | 
|  | 112 | +    assert embeddings.usage.prompt_tokens == 17 | 
|  | 113 | +    assert embeddings.usage.total_tokens == 17 | 
0 commit comments