diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 931057e6c197..05c4d2616990 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -503,6 +503,7 @@ steps: - entrypoints/llm/test_collective_rpc.py commands: - pytest -v -s entrypoints/llm/test_collective_rpc.py + - VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py - torchrun --nproc-per-node=2 distributed/test_torchrun_example.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index a092a548a59c..1c6c28b4ed35 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -48,6 +48,12 @@ def test_consistent_across_ranks(obj): test_consistent_across_ranks( llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) +# make sure we can access the model parameters from the calling process +# of the `LLM` instance. +params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner. + model.parameters()) +test_consistent_across_ranks(len(params)) + # all ranks should have the same outputs for output in outputs: prompt = output.prompt diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index d035668098eb..8c2998e58892 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -5,6 +5,7 @@ import time import uuid from concurrent.futures import Future +from typing import List import pytest from transformers import AutoTokenizer @@ -211,8 +212,9 @@ def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest: class DummyExecutor(UniProcExecutor): - def initialize(self, kv_cache_config: KVCacheConfig) -> None: - super().initialize(kv_cache_config) + def initialize_from_config( + self, kv_cache_configs: List[KVCacheConfig]) -> None: + super().initialize_from_config(kv_cache_configs) # This executor actually can only run 1 batch at a time self.semaphore = threading.Semaphore(1) diff --git a/vllm/config.py b/vllm/config.py index d3139b5fd84e..6bcf34c3cff9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1407,6 +1407,11 @@ def __post_init__(self) -> None: self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT self.world_size_across_dp = self.world_size * self.data_parallel_size + if self.distributed_executor_backend == "external_launcher": + import os + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + logger.info("Disabling V1 multiprocessing for external launcher.") + ray_only_devices = ["tpu"] from vllm.platforms import current_platform if (current_platform.device_type in ray_only_devices diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 79ca45d55d96..b866413e3a62 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -541,7 +541,7 @@ def _compiled_ray_dag(self, enable_asyncio: bool): # and the TP group executes in SPMD fashion. if self.use_v1: outputs = [ - worker.execute_model. + worker.execute_model_ray. bind( # type: ignore[attr-defined] outputs[i]) for i, worker in enumerate(tp_group) ] diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 8ad466a5572e..1734c670bf10 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -112,10 +112,12 @@ def setup_device_if_necessary(self): torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - def execute_model( + def execute_model_ray( self, scheduler_output: "SchedulerOutput", ) -> "ModelRunnerOutput": + # this method is used to compile ray CG, + # and it needs a special logic of self.setup_device_if_necessary() self.setup_device_if_necessary() assert self.worker is not None, "Worker is not initialized" if isinstance(scheduler_output, tuple): diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 94db232240d5..e041215de660 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -93,9 +93,10 @@ def _init_executor(self) -> None: ("ExecutorWithExternalLauncher needs deterministic " "execution, so it" "does not support delay_factor in scheduling") - assert not envs.VLLM_USE_V1, \ - ("V1 architecture cannot guarantee deterministic execution, " - "so it is not supported in ExecutorWithExternalLauncher.") + if envs.VLLM_USE_V1: + assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \ + ("To get deterministic execution in V1, " + "please set VLLM_ENABLE_V1_MULTIPROCESSING=0") self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) # engines are launched in torchrun-compatible launchers diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 981d23237e2a..85c97293af8b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -110,7 +110,7 @@ def _initialize_kv_caches(self, num_cpu_blocks = 0 # Initialize kv cache and warmup the execution - self.model_executor.initialize(kv_cache_configs) + self.model_executor.initialize_from_config(kv_cache_configs) elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 04c7ee109e0b..33b1ddc0f6fe 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -4,10 +4,10 @@ from typing_extensions import TypeVar +import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase -from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -44,6 +44,7 @@ def __init__( use_cached_outputs: bool = False, multiprocess_mode: bool = False, ) -> None: + self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -83,6 +84,10 @@ def __init__( log_stats=False, # FIXME: implement ) + if not multiprocess_mode: + # for v0 compatibility + self.model_executor = self.engine_core.engine_core.model_executor # type: ignore + @classmethod def from_engine_args( cls, @@ -97,7 +102,7 @@ def from_engine_args( vllm_config = engine_args.create_engine_config(usage_context) executor_class = Executor.get_class(vllm_config) - if VLLM_ENABLE_V1_MULTIPROCESSING: + if envs.VLLM_ENABLE_V1_MULTIPROCESSING: logger.debug("Enabling multiprocessing for LLMEngine.") enable_multiprocessing = True diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 3663cbd08aec..11002ad0022d 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -3,6 +3,9 @@ from concurrent.futures import Future from typing import List, Type, Union +import torch +import torch.distributed as dist + from vllm.config import VllmConfig from vllm.executor.executor_base import ExecutorBase from vllm.executor.uniproc_executor import ( # noqa @@ -49,12 +52,14 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]: f"{distributed_executor_backend}") return executor_class - def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None: + def initialize_from_config(self, + kv_cache_configs: List[KVCacheConfig]) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. """ - self.collective_rpc("initialize_cache", args=(kv_cache_configs, )) + self.collective_rpc("initialize_from_config", + args=(kv_cache_configs, )) self.collective_rpc("compile_or_warm_up_model") def determine_available_memory(self) -> int: # in bytes @@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor): class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): - pass + + def determine_available_memory(self) -> int: # in bytes + # same as determine_num_available_blocks in v0, + # we need to get the min across all ranks. + memory = super().determine_available_memory() + from vllm.distributed.parallel_state import get_world_group + cpu_group = get_world_group().cpu_group + memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) + dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) + return memory_tensor.item() diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 14492f273ed3..d4582122fa6d 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -216,9 +216,10 @@ def __init__( "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, + "is_driver_worker": rank == 0, } wrapper.init_worker(all_kwargs) - self.worker = wrapper.worker + self.worker = wrapper pid = os.getpid() _add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid) @@ -239,7 +240,7 @@ def __init__( ready_socket.send_string(WorkerProc.READY_STR) ready_socket.send(payload) - wrapper.init_device() + self.worker.init_device() self.worker.load_model() @staticmethod diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ece0fa555342..d9a415aee528 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -2,7 +2,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional import torch import torch.distributed @@ -185,9 +185,8 @@ def determine_available_memory(self) -> int: def get_kv_cache_spec(self) -> KVCacheSpec: return self.model_runner.get_kv_cache_spec() - def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None: + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" - kv_cache_config = kv_cache_configs[self.rank] if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") @@ -225,7 +224,7 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) - return output if self.rank == 0 else None + return output if self.is_driver_worker else None def profile(self, is_start: bool = True): if self.profiler is None: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index f29edd34ede3..c236f263eddb 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -36,6 +36,7 @@ def __init__( distributed_init_method: str, is_driver_worker: bool = False, ): + self.is_driver_worker = is_driver_worker self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -151,7 +152,7 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) - return output if self.rank == 0 else None + return output if self.is_driver_worker else None def load_model(self) -> None: self.model_runner.load_model() @@ -170,9 +171,8 @@ def get_model(self) -> nn.Module: def get_kv_cache_spec(self) -> KVCacheSpec: return self.model_runner.get_kv_cache_spec() - def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None: + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" - kv_cache_config = kv_cache_configs[self.rank] self.model_runner.initialize_kv_cache(kv_cache_config) def check_health(self) -> None: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 44c26ed350a8..445c0d3285bf 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -567,6 +567,10 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: self.worker = worker_class(**kwargs) assert self.worker is not None + def initialize_from_config(self, kv_cache_configs: List[Any]) -> None: + kv_cache_config = kv_cache_configs[self.rpc_rank] + self.worker.initialize_from_config(kv_cache_config) # type: ignore + def init_device(self): with set_current_vllm_config(self.vllm_config): # To make vLLM config available during device initialization @@ -574,8 +578,11 @@ def init_device(self): def execute_method(self, method: Union[str, bytes], *args, **kwargs): try: - target = self if self.worker is None else self.worker - return run_method(target, method, args, kwargs) + # method resolution order: + # if a method is defined in this class, it will be called directly. + # otherwise, since we define `__getattr__` and redirect attribute + # query to `self.worker`, the method will be called on the worker. + return run_method(self, method, args, kwargs) except Exception as e: # if the driver worker also execute methods, # exceptions in the rest worker may cause deadlock in rpc like ray