diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py index 620355923b47..3ebc5a44d80c 100644 --- a/tests/entrypoints/llm/test_accuracy.py +++ b/tests/entrypoints/llm/test_accuracy.py @@ -42,6 +42,10 @@ def run_test(more_args=None): ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" +# TODO: [AlexM] Fix it with new CI/CD tests +TPU_TP_TEST_STR = "" #"tensor_parallel_size=4" + + @pytest.mark.skipif(not current_platform.is_cuda() and not current_platform.is_tpu(), reason="V1 is currently only supported on CUDA and TPU") @@ -56,6 +60,10 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch): # Limit compilation time for TPU V1 more_args = "max_num_seqs=64" + # Add TP test (if provided) + if TPU_TP_TEST_STR: + more_args += ",{}".format(TPU_TP_TEST_STR) + run_test(more_args) diff --git a/tests/v1/tpu/__init__.py b/tests/v1/tpu/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py new file mode 100644 index 000000000000..0309f545ea49 --- /dev/null +++ b/tests/v1/tpu/test_basic.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A basic correctness check for TPUs + +Run `pytest tests/v1/tpu/test_basic.py`. +""" +import pytest + +from vllm.platforms import current_platform + +from ...conftest import VllmRunner + +MODELS = [ + # "Qwen/Qwen2-7B-Instruct", + "meta-llama/Llama-3.1-8B", + # TODO: Add models here as necessary +] + +TENSOR_PARALLEL_SIZES = [1] + +# TODO: Enable when CI/CD will have a multi-tpu instance +# TENSOR_PARALLEL_SIZES = [1, 4] + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This is a basic test for TPU only") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) +def test_models( + monkeypatch, + model: str, + max_tokens: int, + enforce_eager: bool, + tensor_parallel_size: int, +) -> None: + prompt = "The next numbers of the sequence " + ", ".join( + str(i) for i in range(1024)) + " are:" + example_prompts = [prompt] + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + with VllmRunner( + model, + max_model_len=8192, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + max_num_seqs=16, + tensor_parallel_size=tensor_parallel_size) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + output = vllm_outputs[0][1] + assert "1024" in output diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 3b1735fdcf7a..18ff32155c5f 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -73,9 +73,14 @@ class RayDistributedExecutor(DistributedExecutorBase): def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None if envs.VLLM_USE_V1: - # v1 always uses the compiled DAG and SPMD worker. + # V1 uses SPMD worker and compiled DAG os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" + + # For TPU, avoid compiling NVIDIA's NCCL + if current_platform.is_tpu(): + os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0" + # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 5d8b48ac67b1..c1bf2fb316d9 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -11,6 +11,7 @@ from vllm.config import ParallelConfig from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import get_ip from vllm.worker.worker_base import WorkerWrapperBase @@ -106,10 +107,15 @@ def setup_device_if_necessary(self): # on a background thread, so we need to reset torch's current # device. # We can remove this API after it is fixed in compiled graph. - import torch assert self.worker is not None, "Worker is not initialized" if not self.compiled_dag_cuda_device_set: - torch.cuda.set_device(self.worker.device) + if current_platform.is_tpu(): + # Not needed + pass + else: + import torch + torch.cuda.set_device(self.worker.device) + self.compiled_dag_cuda_device_set = True def execute_model_ray( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f661412d9378..2d7ccd7f6b0c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -21,6 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType +from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, NUM_QUERIES_PER_BLOCK, @@ -543,6 +544,7 @@ def _gather_encoder_outputs( def execute_model( self, scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> ModelRunnerOutput: # Update cached state self._update_states(scheduler_output) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index d09f5dd84007..500dc5b20ce8 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -93,7 +93,8 @@ def init_device(self): # Set random seed. set_random_seed(self.model_config.seed) - xm.set_rng_state(self.model_config.seed, self.device) + if self.model_config.seed is not None: + xm.set_rng_state(self.model_config.seed, self.device) # Increase the cache size limit, which is the maximum number of # dynamo graphs that can be compiled.