diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 108f606e2fb8..3b1735fdcf7a 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import json import os from collections import defaultdict from dataclasses import dataclass @@ -48,6 +49,24 @@ class RayWorkerMetaData: class RayDistributedExecutor(DistributedExecutorBase): + """Ray-based distributed executor""" + + # These env vars are worker-specific, therefore are NOT copied + # from the driver to the workers + WORKER_SPECIFIC_ENV_VARS = { + "VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES" + } + + config_home = envs.VLLM_CONFIG_ROOT + # This file contains a list of env vars that should not be copied + # from the driver to the Ray workers. + non_carry_over_env_vars_file = os.path.join( + config_home, "ray_non_carry_over_env_vars.json") + if os.path.exists(non_carry_over_env_vars_file): + with open(non_carry_over_env_vars_file) as f: + non_carry_over_env_vars = set(json.load(f)) + else: + non_carry_over_env_vars = set() uses_ray: bool = True @@ -311,9 +330,9 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): # Environment variables to copy from driver to workers env_vars_to_copy = [ - "VLLM_ATTENTION_BACKEND", "TPU_CHIPS_PER_HOST_BOUNDS", - "TPU_HOST_BOUNDS", "VLLM_USE_V1", "VLLM_TRACE_FUNCTION", - "VLLM_TORCH_PROFILER_DIR", "VLLM_TEST_ENABLE_EP" + v for v in envs.environment_variables + if v not in self.WORKER_SPECIFIC_ENV_VARS + and v not in self.non_carry_over_env_vars ] # Copy existing env vars to each worker's args @@ -323,9 +342,14 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): if name in os.environ: args[name] = os.environ[name] + logger.info("non_carry_over_env_vars from config: %s", + self.non_carry_over_env_vars) logger.info( "Copying the following environment variables to workers: %s", [v for v in env_vars_to_copy if v in os.environ]) + logger.info( + "If certain env vars should NOT be copied to workers, add them to " + "%s file", self.non_carry_over_env_vars_file) self._env_vars_for_all_workers = ( all_args_to_update_environment_variables)