Skip to content

Commit eb24dc4

Browse files
authored
[v1] torchrun compatibility (#13642)
Signed-off-by: youkaichao <[email protected]>
1 parent 9bebc95 commit eb24dc4

File tree

14 files changed

+67
-24
lines changed

14 files changed

+67
-24
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ steps:
503503
- entrypoints/llm/test_collective_rpc.py
504504
commands:
505505
- pytest -v -s entrypoints/llm/test_collective_rpc.py
506+
- VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
506507
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
507508
- pytest -v -s ./compile/test_basic_correctness.py
508509
- pytest -v -s ./compile/test_wrapper.py

tests/distributed/test_torchrun_example.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def test_consistent_across_ranks(obj):
4848
test_consistent_across_ranks(
4949
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
5050

51+
# make sure we can access the model parameters from the calling process
52+
# of the `LLM` instance.
53+
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
54+
model.parameters())
55+
test_consistent_across_ranks(len(params))
56+
5157
# all ranks should have the same outputs
5258
for output in outputs:
5359
prompt = output.prompt

tests/v1/engine/test_engine_core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66
import uuid
77
from concurrent.futures import Future
8+
from typing import List
89

910
import pytest
1011
from transformers import AutoTokenizer
@@ -211,8 +212,9 @@ def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest:
211212

212213
class DummyExecutor(UniProcExecutor):
213214

214-
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
215-
super().initialize(kv_cache_config)
215+
def initialize_from_config(
216+
self, kv_cache_configs: List[KVCacheConfig]) -> None:
217+
super().initialize_from_config(kv_cache_configs)
216218

217219
# This executor actually can only run 1 batch at a time
218220
self.semaphore = threading.Semaphore(1)

vllm/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,11 @@ def __post_init__(self) -> None:
14071407
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
14081408
self.world_size_across_dp = self.world_size * self.data_parallel_size
14091409

1410+
if self.distributed_executor_backend == "external_launcher":
1411+
import os
1412+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
1413+
logger.info("Disabling V1 multiprocessing for external launcher.")
1414+
14101415
ray_only_devices = ["tpu"]
14111416
from vllm.platforms import current_platform
14121417
if (current_platform.device_type in ray_only_devices

vllm/executor/ray_distributed_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def _compiled_ray_dag(self, enable_asyncio: bool):
541541
# and the TP group executes in SPMD fashion.
542542
if self.use_v1:
543543
outputs = [
544-
worker.execute_model.
544+
worker.execute_model_ray.
545545
bind( # type: ignore[attr-defined]
546546
outputs[i]) for i, worker in enumerate(tp_group)
547547
]

vllm/executor/ray_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,12 @@ def setup_device_if_necessary(self):
112112
torch.cuda.set_device(self.worker.device)
113113
self.compiled_dag_cuda_device_set = True
114114

115-
def execute_model(
115+
def execute_model_ray(
116116
self,
117117
scheduler_output: "SchedulerOutput",
118118
) -> "ModelRunnerOutput":
119+
# this method is used to compile ray CG,
120+
# and it needs a special logic of self.setup_device_if_necessary()
119121
self.setup_device_if_necessary()
120122
assert self.worker is not None, "Worker is not initialized"
121123
if isinstance(scheduler_output, tuple):

vllm/executor/uniproc_executor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,10 @@ def _init_executor(self) -> None:
9393
("ExecutorWithExternalLauncher needs deterministic "
9494
"execution, so it"
9595
"does not support delay_factor in scheduling")
96-
assert not envs.VLLM_USE_V1, \
97-
("V1 architecture cannot guarantee deterministic execution, "
98-
"so it is not supported in ExecutorWithExternalLauncher.")
96+
if envs.VLLM_USE_V1:
97+
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
98+
("To get deterministic execution in V1, "
99+
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
99100
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
100101
rpc_rank=0)
101102
# engines are launched in torchrun-compatible launchers

vllm/v1/engine/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _initialize_kv_caches(self,
110110
num_cpu_blocks = 0
111111

112112
# Initialize kv cache and warmup the execution
113-
self.model_executor.initialize(kv_cache_configs)
113+
self.model_executor.initialize_from_config(kv_cache_configs)
114114

115115
elapsed = time.time() - start
116116
logger.info(("init engine (profile, create kv cache, "

vllm/v1/engine/llm_engine.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
from typing_extensions import TypeVar
66

7+
import vllm.envs as envs
78
from vllm.config import ParallelConfig, VllmConfig
89
from vllm.engine.arg_utils import EngineArgs
910
from vllm.engine.metrics_types import StatLoggerBase
10-
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
1111
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
1212
from vllm.logger import init_logger
1313
from vllm.lora.request import LoRARequest
@@ -44,6 +44,7 @@ def __init__(
4444
use_cached_outputs: bool = False,
4545
multiprocess_mode: bool = False,
4646
) -> None:
47+
self.vllm_config = vllm_config
4748
self.model_config = vllm_config.model_config
4849
self.cache_config = vllm_config.cache_config
4950

@@ -83,6 +84,10 @@ def __init__(
8384
log_stats=False, # FIXME: implement
8485
)
8586

87+
if not multiprocess_mode:
88+
# for v0 compatibility
89+
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
90+
8691
@classmethod
8792
def from_engine_args(
8893
cls,
@@ -97,7 +102,7 @@ def from_engine_args(
97102
vllm_config = engine_args.create_engine_config(usage_context)
98103
executor_class = Executor.get_class(vllm_config)
99104

100-
if VLLM_ENABLE_V1_MULTIPROCESSING:
105+
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
101106
logger.debug("Enabling multiprocessing for LLMEngine.")
102107
enable_multiprocessing = True
103108

vllm/v1/executor/abstract.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from concurrent.futures import Future
44
from typing import List, Type, Union
55

6+
import torch
7+
import torch.distributed as dist
8+
69
from vllm.config import VllmConfig
710
from vllm.executor.executor_base import ExecutorBase
811
from vllm.executor.uniproc_executor import ( # noqa
@@ -49,12 +52,14 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
4952
f"{distributed_executor_backend}")
5053
return executor_class
5154

52-
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
55+
def initialize_from_config(self,
56+
kv_cache_configs: List[KVCacheConfig]) -> None:
5357
"""
5458
Initialize the KV caches and begin the model execution loop of the
5559
underlying workers.
5660
"""
57-
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
61+
self.collective_rpc("initialize_from_config",
62+
args=(kv_cache_configs, ))
5863
self.collective_rpc("compile_or_warm_up_model")
5964

6065
def determine_available_memory(self) -> int: # in bytes
@@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
8994

9095

9196
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
92-
pass
97+
98+
def determine_available_memory(self) -> int: # in bytes
99+
# same as determine_num_available_blocks in v0,
100+
# we need to get the min across all ranks.
101+
memory = super().determine_available_memory()
102+
from vllm.distributed.parallel_state import get_world_group
103+
cpu_group = get_world_group().cpu_group
104+
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
105+
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
106+
return memory_tensor.item()

0 commit comments

Comments
 (0)