Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ pydantic >= 2.0 # Required for OpenAI server.
aioprometheus[starlette]
pynvml == 11.5.0
triton >= 2.1.0
cupy-cuda12x == 12.3.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
is_driver_worker=True,
)

self._run_workers("init_model")
self._run_workers("init_model", cupy_port=get_open_port())
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/parallel_utils/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union

from torch.distributed import ProcessGroup

import torch
from torch.distributed import ProcessGroup

from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
is_cupy_nccl_enabled_for_all_reduce,
)
from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce

Expand All @@ -31,8 +32,12 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
out = custom_all_reduce(input_)
if out is not None:
return out
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
if is_cupy_nccl_enabled_for_all_reduce():
# TODO: support multiple parallel groups.
cupy_utils.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_


Expand Down
130 changes: 130 additions & 0 deletions vllm/model_executor/parallel_utils/cupy_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""CuPy utilities for all-reduce.

We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
CUDA graphs.

NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
TODO: Remove this file when torch.distributed.all_reduce is fixed.
"""
import contextlib

import torch
from torch.distributed import ReduceOp

try:
import cupy
from cupy.cuda import nccl
from cupyx.distributed import NCCLBackend
except ImportError as e:
cupy = e
nccl = None

class NCCLBackend:
...


_OP_MAPPING = {
ReduceOp.SUM: "sum",
ReduceOp.PRODUCT: "prod",
ReduceOp.MIN: "min",
ReduceOp.MAX: "max",
}


class NCCLBackendWithBFloat16(NCCLBackend):
# This is enough to add bfloat16 support for most operations,
# but broadcast will fail (will require changes in compiled
# cupy code).
def _get_nccl_dtype_and_count(self, array, count=None):
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
torch_dtype = getattr(array, "_torch_dtype", None)
if torch_dtype is torch.bfloat16:
nccl_dtype = nccl.NCCL_BFLOAT16
return nccl_dtype, count

def barrier(self) -> None:
raise RuntimeError(
"Currently, CuPy NCCL barrier is not supported since the TCP "
"store is immediately stopped after the initialization.")


_NCCL_BACKEND = None
_WORLD_SIZE = 0


def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return _NCCL_BACKEND is not None


@contextlib.contextmanager
def set_cupy_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
stream.device_index)
with cupy_stream:
yield


def init_process_group(world_size: int, rank: int, host: str,
port: int) -> None:
"""Initializes the CuPy NCCL backend.

# TODO: handle NCCL timeouts.
"""
assert not is_initialized()

if isinstance(cupy, Exception):
raise ImportError(
"NCCLBackend is not available. Please install cupy.") from cupy

# TODO(woosuk): Create TP and PP process groups for CuPy.
global _NCCL_BACKEND
global _WORLD_SIZE
assert world_size > 0, f"{world_size=} should be a positive integer"
assert 0 <= rank < world_size, (
f"{rank=} should be a integer between [0, {world_size})")

cupy.cuda.runtime.setDevice(torch.cuda.current_device())
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
_WORLD_SIZE = world_size

# Stop the TCP store to prevent the deadlock issues at termination time.
# FIXME(woosuk): This is hacky. Find a more robust solution.
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
_NCCL_BACKEND._store.stop()


def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
# Hack to support bfloat16
torch_dtype = input_.dtype
if torch_dtype is torch.bfloat16:
# We need to view as float16, otherwise
# cupy will fail. This will not change
# the underlying data.
input_ = input_.view(torch.float16)
cupy_input = cupy.asarray(input_)
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
out_array=cupy_input,
op=_OP_MAPPING[op])


def destroy_process_group() -> None:
"""Destroys the NCCL backend."""
global _NCCL_BACKEND
global _WORLD_SIZE
_NCCL_BACKEND = None
_WORLD_SIZE = 0


def get_world_size() -> int:
"""Returns the world size."""
return _WORLD_SIZE


def get_nccl_backend():
return _NCCL_BACKEND
37 changes: 37 additions & 0 deletions vllm/model_executor/parallel_utils/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib

import torch

from vllm.model_executor.parallel_utils import cupy_utils

# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to.
Expand Down Expand Up @@ -206,3 +209,37 @@ def destroy_model_parallel():
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None

# Destroy the cupy states if any.
cupy_utils.destroy_process_group()


# Whether to use cupy for nccl all reduce.
# We use cupy for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph.
_ENABLE_CUPY_FOR_ALL_REDUCE = False


@contextlib.contextmanager
def with_cupy_nccl_for_all_reduce():
"""use CuPy nccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
# No-op.
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
yield
else:
global _ENABLE_CUPY_FOR_ALL_REDUCE
old = _ENABLE_CUPY_FOR_ALL_REDUCE
_ENABLE_CUPY_FOR_ALL_REDUCE = True

stream = torch.cuda.current_stream()
with cupy_utils.set_cupy_stream(stream):
yield
_ENABLE_CUPY_FOR_ALL_REDUCE = old


def is_cupy_nccl_enabled_for_all_reduce():
"""check if CuPy nccl is enabled for all reduce"""
global _ENABLE_CUPY_FOR_ALL_REDUCE
return _ENABLE_CUPY_FOR_ALL_REDUCE
7 changes: 5 additions & 2 deletions vllm/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ def init_test_distributed_environment(
tensor_parallel_size,
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(parallel_config, rank,
distributed_init_method)
init_distributed_environment(
parallel_config,
rank,
cupy_port=None,
distributed_init_method=distributed_init_method)


def multi_process_tensor_parallel(
Expand Down
52 changes: 39 additions & 13 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import torch
import torch.nn as nn

from vllm.config import DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.cupy_utils import get_nccl_backend
from vllm.model_executor.parallel_utils.parallel_state import (
with_cupy_nccl_for_all_reduce)
from vllm.model_executor.parallel_utils import custom_all_reduce
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
Expand Down Expand Up @@ -644,6 +648,10 @@ def list_loras(self) -> Set[int]:

@torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None:
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs.
self.cupy_nccl_backend = get_nccl_backend()

assert not self.model_config.enforce_eager
logger.info("Capturing the model for CUDA graphs. This may lead to "
"unexpected consequences if the model is not static. To "
Expand Down Expand Up @@ -674,6 +682,12 @@ def capture_model(self, kv_caches: List[KVCache]) -> None:

# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or CuPy NCCL if it is disabled or not supported.
with custom_all_reduce.capture():
for batch_size in reversed(batch_size_capture_list):
# Create dummy input_metadata.
Expand Down Expand Up @@ -713,6 +727,14 @@ def capture_model(self, kv_caches: List[KVCache]) -> None:
# This usually takes < 10 seconds.
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")

def __del__(self) -> None:
# Delete the CUDA graphs before deleting the CuPy NCCL communicator.
# NOTE(woosuk): This is necessary because otherwise deadlocks can
# happen.
# FIXME(woosuk): This is a bit hacky. Find a more robust solution.
self.graph_runners.clear()
self.cupy_nccl_backend = None


class CUDAGraphRunner:

Expand All @@ -734,25 +756,29 @@ def capture(
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()

# Capture the graph.
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool):
hidden_states = self.model(
with with_cupy_nccl_for_all_reduce():
self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()

# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with with_cupy_nccl_for_all_reduce():
hidden_states = self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()

# Save the input and output buffers.
self.input_buffers = {
"input_ids": input_ids,
Expand Down
Loading