Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ steps:
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
# test sequence parallel
- pytest -v -s distributed/test_sequence_parallel.py

- label: Plugin Tests (2 GPUs) # 40min
working_dir: "/vllm-workspace/tests"
Expand Down
46 changes: 45 additions & 1 deletion tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)

from ..utils import init_test_distributed_environment, multi_process_parallel

Expand Down Expand Up @@ -47,6 +48,34 @@ def all_reduce_test_worker(
torch.testing.assert_close(t, expected)


@ray.remote(num_gpus=1, max_calls=1)
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int,
pp_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)

num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
(r + 1) for r in range(tp_size)
]

index = rank % tp_size
partition_size = num_elements // tp_size
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
t = all_tensors[index]
t = tensor_model_parallel_reduce_scatter(t, 0)
torch.testing.assert_close(t, expected)


@ray.remote(num_gpus=1, max_calls=1)
def all_gather_test_worker(
monkeypatch: pytest.MonkeyPatch,
Expand Down Expand Up @@ -211,6 +240,21 @@ def test_multi_process_tensor_parallel(
multi_process_parallel(monkeypatch, tp_size, 1, test_target)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("test_target", [
all_reduce_test_worker, all_gather_test_worker, reduce_scatter_test_worker,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tesor_parallel_sequence_parallel(
tp_size: int,
test_target: Callable[..., Any],
monkeypatch: pytest.MonkeyPatch,
):
multi_process_parallel(monkeypatch, tp_size, 1, test_target)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
Expand Down
282 changes: 282 additions & 0 deletions tests/distributed/test_sequence_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# SPDX-License-Identifier: Apache-2.0
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
important to set the distributed backend to "mp" to avoid Ray scheduling
all workers in a node other than the head node, which can cause the test
to fail.
"""
import json
import os
from dataclasses import dataclass
from typing import Literal, NamedTuple, Optional

import pytest

from vllm.config import TaskOption
from vllm.logger import init_logger

from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test

logger = init_logger("test_sequence_parallel")

VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"


class ParallelSetup(NamedTuple):
tp_size: int
sp_enabled: bool
eager_mode: bool
chunked_prefill: bool


class SPTestOptions(NamedTuple):
multi_node_only: bool
load_format: Optional[str] = None


@dataclass
class SPTestSettings:
parallel_setups: list[ParallelSetup]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends: list[str]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions: list[str]
task: TaskOption
test_options: SPTestOptions

def __post_init__(self):
if len(self.distributed_backends) != len(self.vllm_major_versions):
raise ValueError(
f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})")

@staticmethod
def detailed(
*,
tp_base: int = 2,
multi_node_only: bool = False,
task: TaskOption = "auto",
load_format: Optional[str] = None,
):
return SPTestSettings(
parallel_setups=[
# TODO support eager_mode = False
# ParallelSetup(tp_size=tp_base,
# sp_enabled=True,
# eager_mode=False,
# chunked_prefill=False),
# ParallelSetup(tp_size=tp_base,
# sp_enabled=True,
# eager_mode=False,
# chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=2 * tp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=True)
],
# only ray is supported for V1
distributed_backends=["mp", "mp", "ray", "ray"],
vllm_major_versions=["0", "1", "0", "1"],
task=task,
test_options=SPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
)

@staticmethod
def fast(
*,
tp_base: int = 4,
task: TaskOption = "auto",
multi_node_only: bool = False,
load_format: Optional[str] = None,
):
return SPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
sp_enabled=True,
eager_mode=True,
chunked_prefill=False),
],
distributed_backends=["mp", "mp"],
vllm_major_versions=["0", "1"],
task=task,
test_options=SPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
)

def iter_params(self, model_id: str):
opts = self.test_options

for parallel_setup in self.parallel_setups:
for backend, vllm_major_version in zip(self.distributed_backends,
self.vllm_major_versions):
yield (model_id, parallel_setup, backend, vllm_major_version,
self.task, opts)


def _compare_sp(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: SPTestOptions,
num_gpus_available: int,
*,
method: Literal["generate", "encode"],
is_multimodal: bool,
):
(
tp_size,
sp_enabled,
eager_mode,
chunked_prefill,
) = parallel_setup

multi_node_only, load_format = test_options

model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")

trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides

if load_format == "dummy":
# Avoid OOM
text_overrides = {
"num_hidden_layers": 4,
"hidden_size": 512,
"intermediate_size": 800,
"num_attention_heads": 4,
"num_key_value_heads": 1,
}

if is_multimodal:
hf_overrides.update({"text_config": text_overrides})
else:
hf_overrides.update(text_overrides)
else:
model_info.check_available_online(on_fail="skip")

pp_size = 1
if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting")

common_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"2048",
"--max-num-seqs",
"8",
]
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
if eager_mode:
common_args.append("--enforce-eager")
if task != "auto":
common_args.extend(["--task", task])
if trust_remote_code:
common_args.append("--trust-remote-code")
if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])

sp_env = None
sp_args = [
*common_args,
"--enable-sequence-parallel",
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
distributed_backend,
]

tp_env = {
"VLLM_USE_V1": vllm_major_version,
}
tp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
"mp",
]

try:
compare_two_settings(model_id,
sp_args,
tp_args,
sp_env,
tp_env,
method=method)
except Exception:
testing_ray_compiled_graph = sp_env is not None
if testing_ray_compiled_graph and vllm_major_version == "0":
# Ray Compiled Graph tests are flaky for V0,
# so we don't want to fail the test
logger.exception("Ray Compiled Graph tests failed")
else:
raise


SP_TEXT_GENERATION_MODELS = {
# [Decoder-only]
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
}

SP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"meta-llama/Llama-3.2-1B-Instruct",
]


@pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
"task", "test_options"),
[
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_id)
if model_id in SP_TEST_MODELS
],
)
@create_new_process_for_each_test()
def test_tp_sp_generation(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: SPTestOptions,
num_gpus_available,
):
_compare_sp(model_id,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
method="generate",
is_multimodal=False)
4 changes: 4 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,8 @@ class ParallelConfig:
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
data_parallel_size: int = 1 # Number of data parallel groups.
data_parallel_rank: int = 0 # Rank of the data parallel group.
enable_sequence_parallel: bool = False # Enable sequence parallelism.

# IP of the data parallel master.
data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master.
Expand Down Expand Up @@ -2150,6 +2152,8 @@ def create_draft_parallel_config(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
enable_sequence_parallel=target_parallel_config.
enable_sequence_parallel,
distributed_executor_backend=target_parallel_config.
distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.
Expand Down
6 changes: 6 additions & 0 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
return get_tp_group().all_reduce(input_)


def tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
dim: int) -> torch.Tensor:
"""Reduce-scatter the input tensor across model parallel group."""
return get_tp_group().reduce_scatter(input_, dim)


def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
Expand Down
Loading