Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ steps:
- tests/v1
commands:
# split the test to avoid interference
- VLLM_ENABLE_V1_MULTIPROCESSING=0 VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- VLLM_USE_V1=1 pytest -v -s v1/core
- VLLM_USE_V1=1 pytest -v -s v1/engine
- VLLM_USE_V1=1 pytest -v -s v1/sample
Expand Down
3 changes: 3 additions & 0 deletions tests/distributed/test_torchrun_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def test_consistent_across_ranks(obj):
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to test if we can directly access the model wih llm.llm_engine.model_executor.driver_worker.worker.model_runner.model . it is used in https://github.com/volcengine/verl/blob/0a1b16f800c25ac80504038fd8b8be4282d6c606/verl/workers/sharding_manager/fsdp_vllm.py#L84

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth a comment?

model.parameters())
test_consistent_across_ranks(len(params))

# all ranks should have the same outputs
for output in outputs:
Expand Down
6 changes: 4 additions & 2 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import uuid
from concurrent.futures import Future
from typing import List

import pytest
from transformers import AutoTokenizer
Expand Down Expand Up @@ -211,8 +212,9 @@ def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest:

class DummyExecutor(UniProcExecutor):

def initialize(self, kv_cache_config: KVCacheConfig) -> None:
super().initialize(kv_cache_config)
def initialize_from_config(
self, kv_cache_configs: List[KVCacheConfig]) -> None:
super().initialize_from_config(kv_cache_configs)

# This executor actually can only run 1 batch at a time
self.semaphore = threading.Semaphore(1)
Expand Down
7 changes: 4 additions & 3 deletions vllm/executor/uniproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ def _init_executor(self) -> None:
("ExecutorWithExternalLauncher needs deterministic "
"execution, so it"
"does not support delay_factor in scheduling")
assert not envs.VLLM_USE_V1, \
("V1 architecture cannot guarantee deterministic execution, "
"so it is not supported in ExecutorWithExternalLauncher.")
if envs.VLLM_USE_V1:
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
("To get deterministic execution in V1, "
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0)
# engines are launched in torchrun-compatible launchers
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _initialize_kv_caches(self,
num_cpu_blocks = 0

# Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_configs)
self.model_executor.initialize_from_config(kv_cache_configs)

elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config

Expand Down Expand Up @@ -76,6 +77,10 @@ def __init__(
log_stats=False, # FIXME: implement
)

if not multiprocess_mode:
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore

@classmethod
def from_engine_args(
cls,
Expand Down
20 changes: 17 additions & 3 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from concurrent.futures import Future
from typing import List, Type, Union

import torch
import torch.distributed as dist

from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import ( # noqa
Expand Down Expand Up @@ -49,12 +52,14 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
f"{distributed_executor_backend}")
return executor_class

def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
def initialize_from_config(self,
kv_cache_configs: List[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
self.collective_rpc("initialize_from_config",
args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")

def determine_available_memory(self) -> int: # in bytes
Expand Down Expand Up @@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor):


class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
pass

def determine_available_memory(self) -> int: # in bytes
# same as determine_num_available_blocks below,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return memory_tensor.item()
7 changes: 3 additions & 4 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional

import torch
import torch.distributed
Expand Down Expand Up @@ -185,9 +185,8 @@ def determine_available_memory(self) -> int:
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()

def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
Expand Down Expand Up @@ -225,7 +224,7 @@ def execute_model(
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None
return output if self.is_driver_worker else None

def profile(self, is_start: bool = True):
if self.profiler is None:
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,8 @@ def get_model(self) -> nn.Module:
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()

def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
self.model_runner.initialize_kv_cache(kv_cache_config)

def check_health(self) -> None:
Expand Down
4 changes: 4 additions & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,10 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
self.worker = worker_class(**kwargs)
assert self.worker is not None

def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ruisearch42 @comaniac FYI, if a method needs to send different argument to different ranks, the indexing should use self.rpc_rank , and it should happen in this WorkerWrapperBase

self.worker.initialize_from_config(kv_cache_config) # type: ignore

def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try:
target = self if self.worker is None else self.worker
Expand Down