|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""A basic correctness check for TPUs |
| 3 | +
|
| 4 | +Run `pytest tests/v1/tpu/test_basic.py`. |
| 5 | +""" |
| 6 | +import pytest |
| 7 | + |
| 8 | +from vllm.platforms import current_platform |
| 9 | + |
| 10 | +from ...conftest import VllmRunner |
| 11 | + |
| 12 | +MODELS = [ |
| 13 | + # "Qwen/Qwen2-7B-Instruct", |
| 14 | + "meta-llama/Llama-3.1-8B", |
| 15 | + # TODO: Add models here as necessary |
| 16 | +] |
| 17 | + |
| 18 | +TENSOR_PARALLEL_SIZES = [1] |
| 19 | + |
| 20 | +# TODO: Enable when CI/CD will have a multi-tpu instance |
| 21 | +# TENSOR_PARALLEL_SIZES = [1, 4] |
| 22 | + |
| 23 | + |
| 24 | +@pytest.mark.skipif(not current_platform.is_tpu(), |
| 25 | + reason="This is a basic test for TPU only") |
| 26 | +@pytest.mark.parametrize("model", MODELS) |
| 27 | +@pytest.mark.parametrize("max_tokens", [5]) |
| 28 | +@pytest.mark.parametrize("enforce_eager", [True]) |
| 29 | +@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) |
| 30 | +def test_models( |
| 31 | + monkeypatch, |
| 32 | + model: str, |
| 33 | + max_tokens: int, |
| 34 | + enforce_eager: bool, |
| 35 | + tensor_parallel_size: int, |
| 36 | +) -> None: |
| 37 | + prompt = "The next numbers of the sequence " + ", ".join( |
| 38 | + str(i) for i in range(1024)) + " are:" |
| 39 | + example_prompts = [prompt] |
| 40 | + |
| 41 | + with monkeypatch.context() as m: |
| 42 | + m.setenv("VLLM_USE_V1", "1") |
| 43 | + |
| 44 | + with VllmRunner( |
| 45 | + model, |
| 46 | + max_model_len=8192, |
| 47 | + enforce_eager=enforce_eager, |
| 48 | + gpu_memory_utilization=0.7, |
| 49 | + max_num_seqs=16, |
| 50 | + tensor_parallel_size=tensor_parallel_size) as vllm_model: |
| 51 | + vllm_outputs = vllm_model.generate_greedy(example_prompts, |
| 52 | + max_tokens) |
| 53 | + output = vllm_outputs[0][1] |
| 54 | + assert "1024" in output |
0 commit comments