diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index fcba253d159f..3a505baa8c26 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -19,8 +19,9 @@ from ..utils import multi_gpu_test MODELS = [ - "google/gemma-2-2b-it", - "meta-llama/Llama-3.2-1B", + "facebook/opt-125m", + # "google/gemma-2-2b-it", + # "meta-llama/Llama-3.2-1B", ] TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") @@ -37,10 +38,11 @@ def test_vllm_gc_ed(): @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) +# @pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN_VLLM_V1"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("enforce_eager", [True]) def test_models( hf_runner, model: str, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 75a77be750ac..990db4cf18d7 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -158,8 +158,6 @@ def step(self) -> List[RequestOutput]: return request_outputs - # TODO(rob): Can we get rid of these? - def get_model_config(self): pass diff --git a/vllm/v1/executor/distributed_gpu_executor.py b/vllm/v1/executor/distributed_gpu_executor.py new file mode 100644 index 000000000000..4d01de2ef4ae --- /dev/null +++ b/vllm/v1/executor/distributed_gpu_executor.py @@ -0,0 +1,89 @@ +import asyncio +from abc import abstractmethod +from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union + +from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + + +class DistributedGPUExecutor(GPUExecutor): + """Abstract superclass of multi-GPU executor implementations.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks", ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + return num_gpu_blocks, 0 + + def initialize_cache(self, num_gpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# GPU blocks: %d", num_gpu_blocks) + self._run_workers("initialize_cache", num_gpu_blocks) + self._run_workers("compile_or_warm_up_model") + + @abstractmethod + def execute_model( + self, + scheduler_output: SchedulerOutput, + ) -> ModelRunnerOutput: + raise NotImplementedError + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self._run_workers("save_sharded_state", + path=path, + pattern=pattern, + max_size=max_size) + + @abstractmethod + def _run_workers( + self, + method: str, + *args, + async_run_tensor_parallel_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + """ + raise NotImplementedError + + @abstractmethod + def check_health(self) -> None: + # SANG-TODO + raise NotImplementedError + diff --git a/vllm/v1/executor/gpu_executor.py b/vllm/v1/executor/gpu_executor.py index f71fa16b16e2..eef68d1c0e76 100644 --- a/vllm/v1/executor/gpu_executor.py +++ b/vllm/v1/executor/gpu_executor.py @@ -1,11 +1,12 @@ import os -from typing import Optional, Tuple +from typing import Optional, Tuple, Callable, Type, Dict, Any from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_worker import Worker +from vllm.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -75,3 +76,25 @@ def check_health(self) -> None: # GPUExecutor will always be healthy as long as # it's running. return + + def _get_worker_module_and_class( + self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: + worker_module_name = "vllm.v1.worker.gpu_worker" + worker_class_name = "Worker" + return worker_module_name, worker_class_name + + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Dict[str, Any]: + """Return worker init args for a given rank.""" + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + ) diff --git a/vllm/v1/executor/ray_gpu_executor.py b/vllm/v1/executor/ray_gpu_executor.py new file mode 100644 index 000000000000..db83bec4ea08 --- /dev/null +++ b/vllm/v1/executor/ray_gpu_executor.py @@ -0,0 +1,305 @@ +import os +from collections import defaultdict +from itertools import islice, repeat +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import vllm.envs as envs +from vllm.v1.executor.distributed_gpu_executor import DistributedGPUExecutor +from vllm.v1.executor.ray_utils import RayWorkerWrapper, ray +from vllm.logger import init_logger +from vllm.utils import (_run_task_with_lock, get_distributed_init_method, + get_ip, get_open_port, get_vllm_instance_id) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.core.scheduler import SchedulerOutput + +if ray is not None: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + + +class RayGPUExecutor(DistributedGPUExecutor): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_executor() + + def _init_executor(self) -> None: + self.forward_dag: Optional["ray.dag.CompiledDAG"] = None + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + # Create the parallel GPU workers. + self._init_workers_ray(placement_group) + + def shutdown(self) -> None: + if hasattr(self, "forward_dag") and self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) + self.forward_dag = None + + def _configure_ray_workers_use_nsight(self, + ray_remote_kwargs) -> Dict[str, Any]: + # If nsight profiling is enabled, we need to set the profiling + # configuration for the ray workers as runtime env. + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) + runtime_env.update({ + "nsight": { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } + }) + + return ray_remote_kwargs + + def _get_worker_wrapper_args(self) -> Dict[str, Any]: + worker_module_name, worker_class_name = self._get_worker_module_and_class() + + return dict( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + + # child class could overwrite this to return actual env vars. + def _get_env_vars_to_be_updated(self): + return self._env_vars_for_all_workers + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + if (self.parallel_config.tensor_parallel_size == 1 + and self.parallel_config.pipeline_parallel_size == 1): + # For single GPU case, we use a ray worker with constrained memory. + # TODO-SANG Q: Is it necessary? + num_gpus = self.cache_config.gpu_memory_utilization + else: + # Otherwise, the ray workers are allocated with a full GPU. + num_gpus = 1 + + # A list of workers to run a model. + self.workers: List[RayWorkerWrapper] = [] + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + + # Create the workers. + driver_ip = get_ip() + worker_wrapper_kwargs = self._get_worker_wrapper_args() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("GPU", 0): + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + + worker = ray.remote( + num_cpus=0, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) + self.workers.append(worker) + + logger.debug("workers: %s", self.workers) + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = ray.get(worker.get_node_ip.remote()) + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids") + + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] + node_gpus[node_id].extend(gpu_ids) + + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + all_ips = set(worker_ips) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP` or " + "`HOST_IP` environment variable, make sure it is unique for" + " each node.") + + VLLM_INSTANCE_ID = get_vllm_instance_id() + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])), + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + **({ + "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND + } if envs.VLLM_ATTENTION_BACKEND is not None else {}) + }, ) for (node_id, _) in worker_node_and_gpu_ids] + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + + self._run_workers("update_environment_variables", + all_args=self._get_env_vars_to_be_updated()) + + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + self._run_workers("initialize") + self._run_workers("load_model") + + def execute_model( + self, + scheduler_output: SchedulerOutput, + ) -> ModelRunnerOutput: + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag() + # All workers are supposed to produce the same output. Only + # get the first output. + output = ray.get(self.forward_dag.execute(scheduler_output))[0] + return output + + def _run_workers( + self, + method: str, + *args, + all_args: Optional[List[Tuple[Any, ...]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. Can be used in the following + ways: + + Args: + - async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + - args/kwargs: All workers share the same args/kwargs + - all_args/all_kwargs: args/kwargs for each worker are specified + individually + """ + count = len(self.workers) + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, 0, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, 0, None) + + # Start the ray workers first. + ray_workers = self.workers + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(ray_workers, all_worker_args, all_worker_kwargs) + ] + + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return ray_worker_outputs + + def _check_ray_compiled_graph_installation(self): + import pkg_resources + from packaging import version + + import importlib.util + adag_spec = importlib.util.find_spec( + "ray.experimental.compiled_dag_ref") + if adag_spec is None: + raise ValueError("Ray accelerated DAG is not installed. " + "Run `pip install ray[adag]` to install it.") + + cupy_spec = importlib.util.find_spec("cupy") + if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: + raise ValueError( + "cupy is not installed but required since " + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set." + "Run `pip install ray[adag]` and check cupy installation.") + + def _compiled_ray_dag(self): + assert self.parallel_config.use_ray + self._check_ray_compiled_graph_installation() + from ray.dag import InputNode, MultiOutputNode + from ray.experimental.channel.torch_tensor_type import TorchTensorType + + with InputNode() as input_batches: + outputs = [ + worker.execute_model.bind(input_batches) + for worker in self.workers + ] + forward_dag = MultiOutputNode(outputs) + + return forward_dag.experimental_compile() + + def __del__(self): + self.shutdown() diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py new file mode 100644 index 000000000000..521ca8b2eef3 --- /dev/null +++ b/vllm/v1/executor/ray_utils.py @@ -0,0 +1,369 @@ +import importlib +import os +import time +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union, Callable + +import msgspec + +from vllm.config import ParallelConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import get_ip, update_environment_variables +from vllm.v1.outputs import ModelRunnerOutput + +logger = init_logger(__name__) +PG_WAIT_TIMEOUT = 60 + + +class WorkerWrapperBase: + """ + The whole point of this class is to lazily initialize the worker. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + + Otherwise, the worker class will be obtained by dynamically importing it + using worker_module_name and worker_class_name. + """ + + def __init__( + self, + worker_module_name: str, + worker_class_name: str, + trust_remote_code: bool = False) -> None: + self.worker_module_name = worker_module_name + self.worker_class_name = worker_class_name + self.worker = None + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + @staticmethod + def update_environment_variables(envs: Dict[str, str]) -> None: + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, *args, **kwargs): + """ + Here we inject some common logic before initializing the worker. + Arguments are passed to the worker class constructor. + """ + # TODO(sang): Enable it + # enable_trace_function_call_for_thread() + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + + # TODO(sang): Enable it + # from vllm.plugins import load_general_plugins + # load_general_plugins() + + mod = importlib.import_module(self.worker_module_name) + worker_class = getattr(mod, self.worker_class_name) + + self.worker = worker_class(*args, **kwargs) + assert self.worker is not None + + def execute_method(self, method, *args, **kwargs): + try: + target = self if self.worker is None else self.worker + executor = getattr(target, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + +try: + import ray + from ray.util import placement_group_table + from ray.util.placement_group import PlacementGroup + try: + from ray._private.state import available_resources_per_node + except ImportError: + # Ray 2.9.x doesn't expose `available_resources_per_node` + from ray._private.state import state as _state + available_resources_per_node = _state._available_resources_per_node + + class RayWorkerWrapper(WorkerWrapperBase): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Since the compiled DAG runs a main execution + # in a different thread that calls cuda.set_device. + # The flag indicates is set_device is called on + # that thread. It will be removed soon. + self.compiled_dag_cuda_device_set = False + + def get_node_ip(self) -> str: + return get_ip() + + def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: + node_id = ray.get_runtime_context().get_node_id() + gpu_ids = ray.get_gpu_ids() + return node_id, gpu_ids + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + # TODO(swang): This is needed right now because Ray CG executes + # on a background thread, so we need to reset torch's current + # device. + import torch + if not self.compiled_dag_cuda_device_set: + torch.cuda.set_device(self.worker.device) + self.compiled_dag_cuda_device_set = True + + output = self.worker.model_runner.execute_model(scheduler_output) + return output + + ray_import_err = None + +except ImportError as e: + ray = None # type: ignore + ray_import_err = e + RayWorkerWrapper = None # type: ignore + + +def ray_is_available() -> bool: + """Returns True if Ray is available.""" + return ray is not None + + +def assert_ray_available(): + """Raise an exception if Ray is not available.""" + if ray is None: + raise ValueError("Failed to import Ray, please install Ray with " + "`pip install ray`.") from ray_import_err + + +def _verify_bundles(placement_group: "PlacementGroup", + parallel_config: ParallelConfig, device_str: str): + """Verify a given placement group has bundles located in the right place. + + There are 2 rules. + - Warn if all tensor parallel workers cannot fit in a single node. + - Fail if driver node is not included in a placement group. + """ + assert ray.is_initialized(), ( + "Ray is not initialized although distributed-executor-backend is ray.") + pg_data = placement_group_table(placement_group) + # bundle_idx -> node_id + bundle_to_node_ids = pg_data["bundles_to_node_id"] + # bundle_idx -> bundle (e.g., {"GPU": 1}) + bundles = pg_data["bundles"] + # node_id -> List of bundle (e.g., {"GPU": 1}) + node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list) + + for bundle_idx, node_id in bundle_to_node_ids.items(): + node_id_to_bundle[node_id].append(bundles[bundle_idx]) + driver_node_id = ray.get_runtime_context().get_node_id() + + if driver_node_id not in node_id_to_bundle: + raise RuntimeError( + f"driver node id {driver_node_id} is not included in a placement " + f"group {placement_group.id}. Node id -> bundles " + f"{node_id_to_bundle}. " + "You don't have enough GPUs available in a current node. Check " + "`ray status` to see if you have available GPUs in a node " + f"{driver_node_id} before starting an vLLM engine.") + + for node_id, bundles in node_id_to_bundle.items(): + if len(bundles) < parallel_config.tensor_parallel_size: + logger.warning( + "tensor_parallel_size=%d " + "is bigger than a reserved number of %ss (%d " + "%ss) in a node %s. Tensor parallel workers can be " + "spread out to 2+ nodes which can degrade the performance " + "unless you have fast interconnect across nodes, like " + "Infiniband. To resolve this issue, make sure you have more " + "than %d GPUs available at each node.", + parallel_config.tensor_parallel_size, device_str, len(bundles), + device_str, node_id, parallel_config.tensor_parallel_size) + + +def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): + """Wait until a placement group is ready. + + It prints the informative log messages if the placement group is + not created within time. + + """ + # Wait until PG is ready - this will block until all + # requested resources are available, and will timeout + # if they cannot be provisioned. + placement_group_specs = current_placement_group.bundle_specs + + s = time.time() + pg_ready_ref = current_placement_group.ready() + wait_interval = 10 + while time.time() - s < PG_WAIT_TIMEOUT: + ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval) + if len(ready) > 0: + break + + # Exponential backoff for warning print. + wait_interval *= 2 + logger.info( + "Waiting for creating a placement group of specs for " + "%d seconds. specs=%s. Check " + "`ray status` to see if you have enough resources.", + int(time.time() - s), placement_group_specs) + + try: + ray.get(pg_ready_ref, timeout=0) + except ray.exceptions.GetTimeoutError: + raise ValueError( + "Cannot provide a placement group of " + f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See " + "`ray status` to make sure the cluster has enough resources." + ) from None + + +def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): + ray.util.remove_placement_group(current_placement_group) + s = time.time() + wait_interval = 10 + while time.time() - s < PG_WAIT_TIMEOUT: + pg = ray.util.get_current_placement_group() + if pg is None: + break + + # Exponential backoff for warning print. + wait_interval *= 2 + logger.info( + "Waiting for removing a placement group of specs for " + "%d seconds.", int(time.time() - s)) + time.sleep(wait_interval) + + +def initialize_ray_cluster( + parallel_config: ParallelConfig, + ray_address: Optional[str] = None, +): + """Initialize the distributed cluster with Ray. + + it will connect to the Ray cluster and create a placement group + for the workers, which includes the specification of the resources + for each distributed worker. + + Args: + parallel_config: The configurations for parallel execution. + ray_address: The address of the Ray cluster. If None, uses + the default Ray cluster address. + """ + assert_ray_available() + + # Connect to a ray cluster. + if current_platform.is_rocm() or current_platform.is_xpu(): + # Try to connect existing ray instance and create a new one if not found + try: + ray.init("auto") + except ConnectionError: + logger.warning( + "No existing RAY instance detected. " + "A new instance will be launched with current node resources.") + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) + else: + ray.init(address=ray_address, ignore_reinit_error=True) + + if parallel_config.placement_group: + # Placement group is already set. + return + + device_str = "GPU" if not current_platform.is_tpu() else "TPU" + # Create placement group for worker processes + current_placement_group = ray.util.get_current_placement_group() + if current_placement_group: + # We are in a placement group + bundles = current_placement_group.bundle_specs + # Verify that we can use the placement group. + device_bundles = 0 + for bundle in bundles: + bundle_devices = bundle.get(device_str, 0) + if bundle_devices > 1: + raise ValueError( + "Placement group bundle cannot have more than 1 " + f"{device_str}.") + if bundle_devices: + device_bundles += 1 + if parallel_config.world_size > device_bundles: + raise ValueError( + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group." + f"Required number of devices: {parallel_config.world_size}. " + f"Total number of devices: {device_bundles}.") + else: + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + if parallel_config.world_size > num_devices_in_cluster: + raise ValueError( + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group.") + # Create a new placement group + placement_group_specs: List[Dict[str, float]] = ([{ + device_str: 1.0 + } for _ in range(parallel_config.world_size)]) + + # vLLM engine is also a worker to execute model with an accelerator, + # so it requires to have the device in a current node. Check if + # the current node has at least one device. + current_ip = get_ip() + current_node_id = ray.get_runtime_context().get_node_id() + current_node_resource = available_resources_per_node()[current_node_id] + if current_node_resource.get(device_str, 0) < 1: + raise ValueError( + f"Current node has no {device_str} available. " + f"{current_node_resource=}. vLLM engine cannot start without " + f"{device_str}. Make sure you have at least 1 {device_str} " + f"available in a node {current_node_id=} {current_ip=}.") + # This way, at least bundle is required to be created in a current + # node. + placement_group_specs[0][f"node:{current_ip}"] = 0.001 + + # By default, Ray packs resources as much as possible. + current_placement_group = ray.util.placement_group( + placement_group_specs, strategy="PACK") + _wait_until_pg_ready(current_placement_group) + + assert current_placement_group is not None + _verify_bundles(current_placement_group, parallel_config, device_str) + # Set the placement group in the parallel config + parallel_config.placement_group = current_placement_group + + +def get_num_tpu_nodes() -> int: + from ray._private.accelerators import TPUAcceleratorManager + cluster_resources = ray.cluster_resources() + total_tpus = int(cluster_resources["TPU"]) + tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() + assert total_tpus % tpus_per_node == 0 + return total_tpus // tpus_per_node + + +def get_num_nodes_in_placement_group() -> int: + pg_table = ray.util.placement_group_table() + current_pg = ray.util.get_current_placement_group() + num_nodes = 0 + + if current_pg: + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg["bundles_to_node_id"].items(): + nodes_in_pg.add(node) + num_nodes = len(nodes_in_pg) + + return num_nodes