diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6d052d0f7f4a..0654fcfef0da 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -22,8 +22,18 @@ steps: working_dir: "/vllm-workspace/tests/distributed" num_gpus: 2 # only support 1 or 2 for now. -- label: Distributed Correctness Test - command: pytest -v -s --forked test_basic_distributed_correctness.py +- label: Distributed pynccl Test + command: pytest -v -s --forked test_pynccl.py + working_dir: "/vllm-workspace/tests/distributed" + num_gpus: 2 # only support 1 or 2 for now. + +- label: Distributed Correctness Test-facebook/opt-125m + command: TEST_DIST_MODEL=facebook/opt-125m pytest -v -s --forked test_basic_distributed_correctness.py + working_dir: "/vllm-workspace/tests/distributed" + num_gpus: 2 # only support 1 or 2 for now. + +- label: Distributed Correctness Test-meta-llama/Llama-2-7b-hf + command: TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s --forked test_basic_distributed_correctness.py working_dir: "/vllm-workspace/tests/distributed" num_gpus: 2 # only support 1 or 2 for now. diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 5211dc180798..2db687a287ef 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -49,7 +49,7 @@ jobs: matrix: os: ['ubuntu-20.04'] python-version: ['3.8', '3.9', '3.10', '3.11'] - pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt. + pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements.txt. cuda-version: ['11.8', '12.1'] steps: diff --git a/CMakeLists.txt b/CMakeLists.txt index 66842e6845ed..be3dc520e43f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,9 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") # Supported NVIDIA architectures. set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") +# used when building pytorch-related extensions +set(TORCH_CUDA_ARCH_LIST "7.0;7.5;8.0;8.6;8.9;9.0") + # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100") @@ -28,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100") # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.1.2") +set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1") set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1") set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1") diff --git a/Dockerfile b/Dockerfile index 1f254c76fe5a..d78ddd25ccf7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,6 +15,9 @@ RUN ldconfig /usr/local/cuda-12.1/compat/ WORKDIR /workspace +# used for downloading files +RUN apt install -y wget unzip + # install build and runtime dependencies COPY requirements.txt requirements.txt RUN --mount=type=cache,target=/root/.cache/pip \ diff --git a/MANIFEST.in b/MANIFEST.in index aa16da6500e6..677fa19721fc 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,3 +4,4 @@ include CMakeLists.txt recursive-include cmake * recursive-include csrc * +recursive-include vllm/lib * diff --git a/pyproject.toml b/pyproject.toml index b6d7649477dc..509c2a630b4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - "torch == 2.1.2", + "torch == 2.2.1", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-build.txt b/requirements-build.txt index a8efcde590bb..2bc07fb152aa 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -3,5 +3,5 @@ cmake>=3.21 ninja packaging setuptools>=49.4.0 -torch==2.1.2 +torch==2.2.1 wheel diff --git a/requirements.txt b/requirements.txt index e136defad494..57996f5cc231 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,9 @@ psutil ray >= 2.9 sentencepiece # Required for LLaMA tokenizer. numpy -torch == 2.1.2 +torch == 2.2.1 +xformers == 0.0.25 # Requires PyTorch 2.2.1. transformers >= 4.39.0 # Required for StarCoder2. -xformers == 0.0.23.post1 # Required for CUDA 12.1. fastapi uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. diff --git a/setup.py b/setup.py index 47cac5996f81..07d08be5fbb2 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,12 @@ from shutil import which import torch from torch.utils.cpp_extension import CUDA_HOME +import zipfile +import shutil +import logging +import tempfile + +logger = logging.getLogger(__name__) ROOT_DIR = os.path.dirname(__file__) @@ -188,6 +194,48 @@ def _install_punica() -> bool: return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))) +if _is_cuda(): + + # tricky part, nccl 2.19 has a bug that increased memory overhead + # of cudagraph. However, pytorch has binary dependencies on nccl 2.19, + # simply `pip install nvidia-nccl-cu12==2.18.3` will break pytorch, + # so we have to manually download nccl 2.18 and keep the library to + # a secrect place + + # Define the URL of the file and the directory to unzip to + file_url = ('https://files.pythonhosted.org/packages/44/6e/' + '3c9cd7007072f8a63dae7b5eddd1cc1525fd357377467ce3a4749b02d5ff' + '/nvidia_nccl_cu12-2.18.3-py3-none-manylinux1_x86_64.whl') + + logger.info('Installing NVIDIA NCCL library...') + + target_dir = os.path.dirname(os.path.abspath(__file__)) + "/vllm/lib/" + with tempfile.TemporaryDirectory() as temp_dir: + local_zip_path = ( + f"{temp_dir}/" + "nvidia_nccl_cu12-2.18.3-py3-none-manylinux1_x86_64.whl") + # make sure the target directory exists + os.makedirs(target_dir, exist_ok=True) + # Check if the file is already downloaded + if os.path.exists(target_dir + "nvidia"): + logger.info('library already exists.') + else: + # Download the file + logger.info('Downloading file...') + os.system(f"wget {file_url} -q -P {temp_dir}/") + # Unzip the file + logger.info('Unzipping file...') + with zipfile.ZipFile(local_zip_path, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + shutil.rmtree(f"{temp_dir}/nvidia_nccl_cu12-2.18.3.dist-info") + os.remove(local_zip_path) + # Move the unzipped files to the target directory + logger.info('Moving files...') + os.system(f"mv {temp_dir}/nvidia {target_dir}") + so_path = f"{target_dir}/nvidia/nccl/lib/libnccl.so.2" + os.rename(so_path, so_path.replace(".so.2", ".so.2.18.3")) + + def get_hipcc_rocm_version(): # Run the hipcc --version command result = subprocess.run(['hipcc', '--version'], @@ -330,7 +378,10 @@ def get_requirements() -> List[str]: ext_modules.append(CMakeExtension(name="vllm._C")) package_data = { - "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] + "vllm": [ + "py.typed", "model_executor/layers/fused_moe/configs/*.json", + "lib/nvidia/nccl/lib/libnccl.so.2.18.3" + ] } if os.environ.get("VLLM_USE_PRECOMPILED"): package_data["vllm"].append("*.so") @@ -362,6 +413,8 @@ def get_requirements() -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, - cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, + cmdclass={ + "build_ext": cmake_build_ext if not _is_neuron() else build_ext, + }, package_data=package_data, ) diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 82075356fccb..75d6a84adfc7 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -1,13 +1,23 @@ """Compare the outputs of HF and distributed vLLM when using greedy sampling. -Run `pytest tests/distributed/test_basic_distributed_correctness.py --forked`. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. +Run: + +```sh +TEST_DIST_MODEL=facebook/opt-125m pytest \ + test_basic_distributed_correctness.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + test_basic_distributed_correctness.py +``` """ +import os import pytest import torch MODELS = [ - "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", + os.environ["TEST_DIST_MODEL"], ] diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py new file mode 100644 index 000000000000..58376306c277 --- /dev/null +++ b/tests/distributed/test_pynccl.py @@ -0,0 +1,88 @@ +# this script is not run with `pytest`. +# It is run with `torchrun`. +import os +import multiprocessing +import pytest +import torch +from vllm.model_executor.parallel_utils.pynccl import ( + NCCLCommunicator, + ncclGetUniqueId, +) + + +def distributed_run(fn, world_size): + number_of_processes = world_size + processes = [] + for i in range(number_of_processes): + env = os.environ.copy() + env['RANK'] = str(i) + env['WORLD_SIZE'] = str(number_of_processes) + env['MASTER_ADDR'] = 'localhost' + env['MASTER_PORT'] = '12345' + p = multiprocessing.Process(target=fn, args=(env, )) + processes.append(p) + p.start() + + for p in processes: + p.join() + + +def update_env(fn): + + def wrapper(env): + import os + os.environ.update(env) + fn() + + return wrapper + + +@update_env +def worker_fn(): + comm = NCCLCommunicator() + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) + comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == comm.world_size + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl(): + distributed_run(worker_fn, 2) + + +@update_env +def worker_fn_with_cudagraph(): + with torch.no_grad(): + graph = torch.cuda.CUDAGraph() + comm = NCCLCommunicator() + # run something in the default stream to initialize torch engine + a = torch.ones((4, 4), device=f'cuda:{comm.rank}') + torch.cuda.synchronize() + with torch.cuda.graph(graph, stream=comm.stream): + comm.all_reduce(a) + comm.stream.synchronize() + assert a.mean().cpu().item() == comm.world_size**0 + graph.replay() + comm.stream.synchronize() + assert a.mean().cpu().item() == comm.world_size**2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_with_cudagraph(): + distributed_run(worker_fn_with_cudagraph, 2) + + +def test_ncclGetUniqueId(): + unique_id = ncclGetUniqueId() + # `list(unique_id.internal)` is something like this: + # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # as long as the function doesn't raise an exception, we're good + assert unique_id is not None diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 6f00fd001d95..28433d31f56a 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -4,12 +4,12 @@ import torch from torch.distributed import ProcessGroup -from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils import pynccl_utils from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, - is_cupy_nccl_enabled_for_all_reduce, + is_pynccl_enabled_for_all_reduce, ) from vllm.model_executor.parallel_utils.custom_all_reduce import ( custom_all_reduce) @@ -33,9 +33,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: out = custom_all_reduce(input_) if out is not None: return out - if is_cupy_nccl_enabled_for_all_reduce(): + if is_pynccl_enabled_for_all_reduce(): # TODO: support multiple parallel groups. - cupy_utils.all_reduce(input_) + pynccl_utils.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) diff --git a/vllm/model_executor/parallel_utils/cupy_utils.py b/vllm/model_executor/parallel_utils/cupy_utils.py deleted file mode 100644 index f8cffc01e3c3..000000000000 --- a/vllm/model_executor/parallel_utils/cupy_utils.py +++ /dev/null @@ -1,130 +0,0 @@ -"""CuPy utilities for all-reduce. - -We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing -CUDA graphs, because torch.distributed.all_reduce causes errors when capturing -CUDA graphs. - -NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8. -TODO: Remove this file when torch.distributed.all_reduce is fixed. -""" -import contextlib - -import torch -from torch.distributed import ReduceOp - -try: - import cupy - from cupy.cuda import nccl - from cupyx.distributed import NCCLBackend -except ImportError as e: - cupy = e - nccl = None - - class NCCLBackend: - ... - - -_OP_MAPPING = { - ReduceOp.SUM: "sum", - ReduceOp.PRODUCT: "prod", - ReduceOp.MIN: "min", - ReduceOp.MAX: "max", -} - - -class NCCLBackendWithBFloat16(NCCLBackend): - # This is enough to add bfloat16 support for most operations, - # but broadcast will fail (will require changes in compiled - # cupy code). - def _get_nccl_dtype_and_count(self, array, count=None): - nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count) - torch_dtype = getattr(array, "_torch_dtype", None) - if torch_dtype is torch.bfloat16: - nccl_dtype = nccl.NCCL_BFLOAT16 - return nccl_dtype, count - - def barrier(self) -> None: - raise RuntimeError( - "Currently, CuPy NCCL barrier is not supported since the TCP " - "store is immediately stopped after the initialization.") - - -_NCCL_BACKEND = None -_WORLD_SIZE = 0 - - -def is_initialized() -> bool: - """Returns whether the NCCL backend is initialized.""" - return _NCCL_BACKEND is not None - - -@contextlib.contextmanager -def set_cupy_stream(stream: torch.cuda.Stream): - """Set the cuda stream for communication""" - cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream, - stream.device_index) - with cupy_stream: - yield - - -def init_process_group(world_size: int, rank: int, host: str, - port: int) -> None: - """Initializes the CuPy NCCL backend. - - # TODO: handle NCCL timeouts. - """ - assert not is_initialized() - - if isinstance(cupy, Exception): - raise ImportError( - "NCCLBackend is not available. Please install cupy.") from cupy - - # TODO(woosuk): Create TP and PP process groups for CuPy. - global _NCCL_BACKEND - global _WORLD_SIZE - assert world_size > 0, f"{world_size=} should be a positive integer" - assert 0 <= rank < world_size, ( - f"{rank=} should be a integer between [0, {world_size})") - - cupy.cuda.runtime.setDevice(torch.cuda.current_device()) - _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port) - _WORLD_SIZE = world_size - - # Stop the TCP store to prevent the deadlock issues at termination time. - # FIXME(woosuk): This is hacky. Find a more robust solution. - if rank == 0 and hasattr(_NCCL_BACKEND, "_store"): - _NCCL_BACKEND._store.stop() - - -def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: - """All-reduces the input tensor across the process group.""" - assert input_.is_cuda, f"{input_} should be a cuda tensor" - # Hack to support bfloat16 - torch_dtype = input_.dtype - if torch_dtype is torch.bfloat16: - # We need to view as float16, otherwise - # cupy will fail. This will not change - # the underlying data. - input_ = input_.view(torch.float16) - cupy_input = cupy.asarray(input_) - cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access - _NCCL_BACKEND.all_reduce(in_array=cupy_input, - out_array=cupy_input, - op=_OP_MAPPING[op]) - - -def destroy_process_group() -> None: - """Destroys the NCCL backend.""" - global _NCCL_BACKEND - global _WORLD_SIZE - _NCCL_BACKEND = None - _WORLD_SIZE = 0 - - -def get_world_size() -> int: - """Returns the world size.""" - return _WORLD_SIZE - - -def get_nccl_backend(): - return _NCCL_BACKEND diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index c821936d06e4..63890d9cd5bd 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -7,7 +7,7 @@ import torch -from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils import pynccl_utils # Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None @@ -210,36 +210,36 @@ def destroy_model_parallel(): global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None - # Destroy the cupy states if any. - cupy_utils.destroy_process_group() + # Destroy the pynccl states if any. + pynccl_utils.destroy_process_group() -# Whether to use cupy for nccl all reduce. -# We use cupy for all reduce when using CUDA graph, because torch.distributed +# Whether to use pynccl for nccl all reduce. +# We use pynccl for all reduce when using CUDA graph, because torch.distributed # is not well supported by CUDA graph. -_ENABLE_CUPY_FOR_ALL_REDUCE = False +_ENABLE_PYNCCL_FOR_ALL_REDUCE = False @contextlib.contextmanager -def with_cupy_nccl_for_all_reduce(): - """use CuPy nccl instead of torch.distributed for all reduce""" +def with_pynccl_for_all_reduce(): + """use pynccl instead of torch.distributed for all reduce""" tp_size = get_tensor_model_parallel_world_size() if tp_size == 1: # No-op. # NOTE(woosuk): We don't initialize CuPy when tp_size is 1. yield else: - global _ENABLE_CUPY_FOR_ALL_REDUCE - old = _ENABLE_CUPY_FOR_ALL_REDUCE - _ENABLE_CUPY_FOR_ALL_REDUCE = True + global _ENABLE_PYNCCL_FOR_ALL_REDUCE + old = _ENABLE_PYNCCL_FOR_ALL_REDUCE + _ENABLE_PYNCCL_FOR_ALL_REDUCE = True stream = torch.cuda.current_stream() - with cupy_utils.set_cupy_stream(stream): + with pynccl_utils.set_pynccl_stream(stream): yield - _ENABLE_CUPY_FOR_ALL_REDUCE = old + _ENABLE_PYNCCL_FOR_ALL_REDUCE = old -def is_cupy_nccl_enabled_for_all_reduce(): - """check if CuPy nccl is enabled for all reduce""" - global _ENABLE_CUPY_FOR_ALL_REDUCE - return _ENABLE_CUPY_FOR_ALL_REDUCE +def is_pynccl_enabled_for_all_reduce(): + """check if pynccl is enabled for all reduce""" + global _ENABLE_PYNCCL_FOR_ALL_REDUCE + return _ENABLE_PYNCCL_FOR_ALL_REDUCE diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py new file mode 100644 index 000000000000..9f0aaf5f9321 --- /dev/null +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -0,0 +1,239 @@ +# ===================== pynccl.py ================================== +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/199366 +# ==================================================== + +# ===================== import region ===================== +import torch +import ctypes +import torch.distributed as dist +from torch.distributed import ReduceOp +import datetime +import os +import glob +import logging + +logger = logging.getLogger(__name__) + +so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") + +# manually load the nccl library +if so_file: + logger.info( + f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}") +else: + _path = os.path.dirname(os.path.abspath(__file__)) + so_file = glob.glob(f"{_path}/../../lib/nvidia/nccl/lib/libnccl.so.*")[0] + logger.info(f"Loading nccl from vLLM builtin file {so_file}") +nccl = ctypes.CDLL(so_file) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int + +# equivalent to c declaration: +# ncclResult_t ncclGetVersion(int *version); +_c_ncclGetVersion = nccl.ncclGetVersion +_c_ncclGetVersion.restype = ctypes.c_int +_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] + + +def ncclGetVersion() -> int: + version = ctypes.c_int() + result = _c_ncclGetVersion(ctypes.byref(version)) + assert result == 0 + # something like 21903 --> "2.19.3" + version_str = str(version.value) + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + +class NcclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +# equivalent to c declaration: +# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); +_c_ncclGetUniqueId = nccl.ncclGetUniqueId +_c_ncclGetUniqueId.restype = ctypes.c_int +_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)] + + +def ncclGetUniqueId() -> NcclUniqueId: + unique_id = NcclUniqueId() + result = _c_ncclGetUniqueId(ctypes.byref(unique_id)) + assert result == 0 + return unique_id + + +# equivalent to c declaration: +# ncclResult_t ncclCommInitRank( +# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); +# note that ncclComm_t is a pointer type, so the first argument +# is a pointer to a pointer +_c_ncclCommInitRank = nccl.ncclCommInitRank +_c_ncclCommInitRank.restype = ctypes.c_int +_c_ncclCommInitRank.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int +] + + +# enums +class ncclDataType_t(ctypes.c_int): + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t': + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +class ncclRedOp_t(ctypes.c_int): + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t': + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +# equivalent to c declaration: +# ncclResult_t ncclAllReduce( +# const void* sendbuff, void* recvbuff, size_t count, +# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, +# udaStream_t stream); +# note that cudaStream_t is a pointer type, so the last argument is a pointer +_c_ncclAllReduce = nccl.ncclAllReduce +_c_ncclAllReduce.restype = ctypes.c_int +_c_ncclAllReduce.argtypes = [ + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p +] + +# equivalent to c declaration: +# ncclResult_t ncclCommDestroy(ncclComm_t comm); +_c_ncclCommDestroy = nccl.ncclCommDestroy +_c_ncclCommDestroy.restype = ctypes.c_int +_c_ncclCommDestroy.argtypes = [ctypes.c_void_p] + + +class NCCLCommunicator: + + def __init__( + self, + backend=None, + init_method=None, + timeout=datetime.timedelta(seconds=10), + world_size: int = -1, + rank: int = -1, + store=None, + group_name: str = "", + pg_options=None, + ): + if not dist.is_initialized(): + backend = backend or "nccl" + assert backend == 'nccl', ( + "only use nccl backend for starting the NCCL communicator") + dist.init_process_group(backend=backend, + init_method=init_method, + timeout=timeout, + world_size=world_size, + rank=rank, + store=store, + group_name=group_name, + pg_options=pg_options) + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + torch.cuda.set_device(self.rank) + if self.rank == 0: + self.unique_id = ncclGetUniqueId() + else: + self.unique_id = NcclUniqueId() + tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda( + self.rank) + dist.broadcast(tensor, src=0) + byte_list = tensor.cpu().tolist() + self.unique_id = NcclUniqueId() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + self.comm = ctypes.c_void_p() + result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, + self.unique_id, self.rank) + assert result == 0 + self.stream = torch.cuda.Stream(device=f"cuda:{self.rank}") + + def all_reduce(self, + tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if stream is None: + stream = self.stream + result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), + ctypes.c_void_p(tensor.data_ptr()), + tensor.numel(), + ncclDataType_t.from_torch(tensor.dtype), + ncclRedOp_t.from_torch(op), self.comm, + ctypes.c_void_p(stream.cuda_stream)) + assert result == 0 + + def __del__(self): + dist.destroy_process_group() + _c_ncclCommDestroy(self.comm) + + +# ===================== pynccl.py ===================== diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/model_executor/parallel_utils/pynccl_utils.py new file mode 100644 index 000000000000..e498526b71bb --- /dev/null +++ b/vllm/model_executor/parallel_utils/pynccl_utils.py @@ -0,0 +1,67 @@ +import contextlib +import logging +import torch + +from typing import Optional +from torch.distributed import ReduceOp + +logger = logging.getLogger(__name__) + +try: + from vllm.model_executor.parallel_utils.pynccl import ( + NCCLCommunicator, + ncclGetVersion, + ) + logger.info(f"vLLM is using nccl=={ncclGetVersion()}") +except Exception as e: + # in non-NVIDIA environments, we can't import the nccl module + # e.g. when running on machines with AMD GPUs + logger.info(f"Failed to import NCCL library: {e}") + logger.info("It is expected if you are not running on NVIDIA GPUs.") + pass + +comm: Optional["NCCLCommunicator"] = None + + +def is_initialized() -> bool: + """Returns whether the NCCL backend is initialized.""" + return comm is not None + + +@contextlib.contextmanager +def set_pynccl_stream(stream: torch.cuda.Stream): + """Set the cuda stream for communication""" + try: + comm.stream = stream + yield + finally: + pass + + +def init_process_group(world_size: int, rank: int, host: str, + port: int) -> None: + assert not is_initialized() + global comm + comm = NCCLCommunicator(init_method=f"tcp://{host}:{port}", + world_size=world_size, + rank=rank) + + +def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: + """All-reduces the input tensor across the process group.""" + assert input_.is_cuda, f"{input_} should be a cuda tensor" + comm.all_reduce(input_, op) + + +def destroy_process_group() -> None: + global comm + comm = None + + +def get_world_size() -> int: + """Returns the world size.""" + return comm.world_size + + +def get_nccl_backend(): + return comm diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b8eeb51379f4..374f519afc81 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -11,11 +11,11 @@ from vllm.logger import init_logger from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils import pynccl_utils from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.parallel_state import ( - with_cupy_nccl_for_all_reduce) + with_pynccl_for_all_reduce) from vllm.model_executor.parallel_utils import custom_all_reduce from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata @@ -720,7 +720,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: """ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # deleted before the CUDA graphs. - self.cupy_nccl_backend = cupy_utils.get_nccl_backend() + self.cupy_nccl_backend = pynccl_utils.get_nccl_backend() assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " @@ -834,7 +834,7 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with _maybe_cupy_nccl(): + with _maybe_pynccl(): self.model( input_ids, positions, @@ -848,7 +848,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 - with _maybe_cupy_nccl(): + with _maybe_pynccl(): hidden_states = self.model( input_ids, positions, @@ -899,9 +899,10 @@ def __call__(self, *args, **kwargs): @contextlib.contextmanager -def _maybe_cupy_nccl(): - if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized(): - with with_cupy_nccl_for_all_reduce(): +def _maybe_pynccl(): + if pynccl_utils.is_initialized( + ) and not custom_all_reduce.is_initialized(): + with with_pynccl_for_all_reduce(): yield else: yield diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d8999dc17212..c979effae048 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig) from vllm.model_executor import set_random_seed -from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils import pynccl_utils from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar @@ -262,8 +262,8 @@ def init_distributed_environment( init_method=distributed_init_method, ) - if cupy_utils.is_initialized(): - cupy_world_size = cupy_utils.get_world_size() + if pynccl_utils.is_initialized(): + cupy_world_size = pynccl_utils.get_world_size() if cupy_world_size != parallel_config.world_size: raise RuntimeError( "cupy.distributed is already initialized but the cupy world " @@ -273,7 +273,7 @@ def init_distributed_environment( # NOTE(woosuk): We don't initialize CuPy process group when world size # is 1. # TODO(woosuk): Support multi-node connection. - cupy_utils.init_process_group( + pynccl_utils.init_process_group( world_size=parallel_config.world_size, rank=rank, host="localhost", @@ -282,8 +282,8 @@ def init_distributed_environment( # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) - if cupy_utils.is_initialized(): - cupy_utils.all_reduce(torch.zeros(1).cuda()) + if pynccl_utils.is_initialized(): + pynccl_utils.all_reduce(torch.zeros(1).cuda()) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size)