diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 230dd8383420..b383812ea274 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -527,6 +527,8 @@ steps: # - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py + # test sequence parallel + - pytest -v -s distributed/test_sequence_parallel.py - label: Plugin Tests (2 GPUs) # 40min working_dir: "/vllm-workspace/tests" diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index ac6d6aae3006..d97598b23fab 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict, get_pp_group, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) from ..utils import init_test_distributed_environment, multi_process_parallel @@ -47,6 +48,34 @@ def all_reduce_test_worker( torch.testing.assert_close(t, expected) +@ray.remote(num_gpus=1, max_calls=1) +def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int, + pp_size: int, rank: int, + distributed_init_port: str): + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU + monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + num_elements = 8 + all_tensors = [ + torch.arange(num_elements, dtype=torch.float32, device="cuda") * + (r + 1) for r in range(tp_size) + ] + + index = rank % tp_size + partition_size = num_elements // tp_size + all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0) + expected = all_reduce[index * partition_size:(index + 1) * partition_size] + t = all_tensors[index] + t = tensor_model_parallel_reduce_scatter(t, 0) + torch.testing.assert_close(t, expected) + + @ray.remote(num_gpus=1, max_calls=1) def all_gather_test_worker( monkeypatch: pytest.MonkeyPatch, @@ -211,6 +240,21 @@ def test_multi_process_tensor_parallel( multi_process_parallel(monkeypatch, tp_size, 1, test_target) +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("test_target", [ + all_reduce_test_worker, all_gather_test_worker, reduce_scatter_test_worker, + broadcast_tensor_dict_test_worker +]) +def test_multi_process_tesor_parallel_sequence_parallel( + tp_size: int, + test_target: Callable[..., Any], + monkeypatch: pytest.MonkeyPatch, +): + multi_process_parallel(monkeypatch, tp_size, 1, test_target) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("pp_size", [2]) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py new file mode 100644 index 000000000000..7a828b1b6525 --- /dev/null +++ b/tests/distributed/test_sequence_parallel.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +WARNING: This test runs in both single-node (4 GPUs) and multi-node + (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is + important to set the distributed backend to "mp" to avoid Ray scheduling + all workers in a node other than the head node, which can cause the test + to fail. +""" +import json +import os +from dataclasses import dataclass +from typing import Literal, NamedTuple, Optional + +import pytest + +from vllm.config import TaskOption +from vllm.logger import init_logger + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import compare_two_settings, create_new_process_for_each_test + +logger = init_logger("test_sequence_parallel") + +VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" + + +class ParallelSetup(NamedTuple): + tp_size: int + sp_enabled: bool + eager_mode: bool + chunked_prefill: bool + + +class SPTestOptions(NamedTuple): + multi_node_only: bool + load_format: Optional[str] = None + + +@dataclass +class SPTestSettings: + parallel_setups: list[ParallelSetup] + # NOTE: the length of distributed_backends and + # vllm_major_versions should be the same, and they + # are first zipped together to iterate over all + # test settings. + distributed_backends: list[str] + # vllm major version: "0" for V0, "1" for V1 + vllm_major_versions: list[str] + task: TaskOption + test_options: SPTestOptions + + def __post_init__(self): + if len(self.distributed_backends) != len(self.vllm_major_versions): + raise ValueError( + f"Length mismatch: distributed_backends " + f"({len(self.distributed_backends)}) != " + f"vllm_major_versions ({len(self.vllm_major_versions)})") + + @staticmethod + def detailed( + *, + tp_base: int = 2, + multi_node_only: bool = False, + task: TaskOption = "auto", + load_format: Optional[str] = None, + ): + return SPTestSettings( + parallel_setups=[ + # TODO support eager_mode = False + # ParallelSetup(tp_size=tp_base, + # sp_enabled=True, + # eager_mode=False, + # chunked_prefill=False), + # ParallelSetup(tp_size=tp_base, + # sp_enabled=True, + # eager_mode=False, + # chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=False), + ParallelSetup(tp_size=2 * tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=True) + ], + # only ray is supported for V1 + distributed_backends=["mp", "mp", "ray", "ray"], + vllm_major_versions=["0", "1", "0", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + @staticmethod + def fast( + *, + tp_base: int = 4, + task: TaskOption = "auto", + multi_node_only: bool = False, + load_format: Optional[str] = None, + ): + return SPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=False), + ], + distributed_backends=["mp", "mp"], + vllm_major_versions=["0", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + def iter_params(self, model_id: str): + opts = self.test_options + + for parallel_setup in self.parallel_setups: + for backend, vllm_major_version in zip(self.distributed_backends, + self.vllm_major_versions): + yield (model_id, parallel_setup, backend, vllm_major_version, + self.task, opts) + + +def _compare_sp( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: SPTestOptions, + num_gpus_available: int, + *, + method: Literal["generate", "encode"], + is_multimodal: bool, +): + ( + tp_size, + sp_enabled, + eager_mode, + chunked_prefill, + ) = parallel_setup + + multi_node_only, load_format = test_options + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + + trust_remote_code = model_info.trust_remote_code + tokenizer_mode = model_info.tokenizer_mode + hf_overrides = model_info.hf_overrides + + if load_format == "dummy": + # Avoid OOM + text_overrides = { + "num_hidden_layers": 4, + "hidden_size": 512, + "intermediate_size": 800, + "num_attention_heads": 4, + "num_key_value_heads": 1, + } + + if is_multimodal: + hf_overrides.update({"text_config": text_overrides}) + else: + hf_overrides.update(text_overrides) + else: + model_info.check_available_online(on_fail="skip") + + pp_size = 1 + if num_gpus_available < tp_size * pp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + if VLLM_MULTI_NODE and distributed_backend == "mp": + pytest.skip("Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend") + if multi_node_only and not VLLM_MULTI_NODE: + pytest.skip("Not in multi-node setting") + + common_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if task != "auto": + common_args.extend(["--task", task]) + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + if load_format: + common_args.extend(["--load-format", load_format]) + if hf_overrides: + common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + + sp_env = None + sp_args = [ + *common_args, + "--enable-sequence-parallel", + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + tp_env = { + "VLLM_USE_V1": vllm_major_version, + } + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + ] + + try: + compare_two_settings(model_id, + sp_args, + tp_args, + sp_env, + tp_env, + method=method) + except Exception: + testing_ray_compiled_graph = sp_env is not None + if testing_ray_compiled_graph and vllm_major_version == "0": + # Ray Compiled Graph tests are flaky for V0, + # so we don't want to fail the test + logger.exception("Ray Compiled Graph tests failed") + else: + raise + + +SP_TEXT_GENERATION_MODELS = { + # [Decoder-only] + "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(), +} + +SP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "meta-llama/Llama-3.2-1B-Instruct", +] + + +@pytest.mark.parametrize( + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "task", "test_options"), + [ + params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in SP_TEST_MODELS + ], +) +@create_new_process_for_each_test() +def test_tp_sp_generation( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: SPTestOptions, + num_gpus_available, +): + _compare_sp(model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + task, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False) diff --git a/vllm/config.py b/vllm/config.py index c510677d64ea..422a9543caad 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1358,6 +1358,8 @@ class ParallelConfig: tensor_parallel_size: int = 1 # Number of tensor parallel groups. data_parallel_size: int = 1 # Number of data parallel groups. data_parallel_rank: int = 0 # Rank of the data parallel group. + enable_sequence_parallel: bool = False # Enable sequence parallelism. + # IP of the data parallel master. data_parallel_master_ip: str = "127.0.0.1" data_parallel_master_port: int = 29500 # Port of the data parallel master. @@ -2150,6 +2152,8 @@ def create_draft_parallel_config( pipeline_parallel_size=target_parallel_config. pipeline_parallel_size, tensor_parallel_size=speculative_draft_tensor_parallel_size, + enable_sequence_parallel=target_parallel_config. + enable_sequence_parallel, distributed_executor_backend=target_parallel_config. distributed_executor_backend, max_parallel_loading_workers=target_parallel_config. diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 0228264f91f9..0552254ba423 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -13,6 +13,12 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: return get_tp_group().all_reduce(input_) +def tensor_model_parallel_reduce_scatter(input_: torch.Tensor, + dim: int) -> torch.Tensor: + """Reduce-scatter the input tensor across model parallel group.""" + return get_tp_group().reduce_scatter(input_, dim) + + def tensor_model_parallel_all_gather(input_: torch.Tensor, dim: int = -1) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index eb12f8834b41..5f0d631a6bfe 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -35,6 +35,40 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output_tensor = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + # Perform reduce-scatter operation + torch.distributed.reduce_scatter_tensor(output_tensor, + input_tensor, + group=self.device_group) + + # Reshape before returning + return output_tensor.movedim(0, dim).contiguous() + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: if dim < 0: # Convert negative dim to positive. diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 07c9ff506092..8bca278f3888 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -70,6 +70,31 @@ def all_reduce(self, input_): torch.distributed.all_reduce(out, group=self.device_group) return out + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + pynccl_comm.reduce_scatter(output, input_) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f897f1950e4c..4899cbb8b7e8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -118,6 +118,22 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: return torch.empty_like(tensor) +def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group.reduce_scatter(tensor, dim) + + +def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] // world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + if supports_custom_op(): direct_register_custom_op( op_name="all_reduce", @@ -126,6 +142,13 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: fake_impl=all_reduce_fake, ) + direct_register_custom_op( + op_name="reduce_scatter", + op_func=reduce_scatter, + mutates_args=[], + fake_impl=reduce_scatter_fake, + ) + class GroupCoordinator: """ @@ -312,6 +335,18 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: return self.device_communicator.all_reduce(input_) + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + return self.device_communicator.reduce_scatter(input_, dim) + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 02a9ec46939c..b7d0407f164c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -115,6 +115,7 @@ class EngineArgs: # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 + enable_sequence_parallel: bool = False enable_expert_parallel: bool = False max_parallel_loading_workers: Optional[int] = None block_size: Optional[int] = None @@ -435,6 +436,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=EngineArgs.tensor_parallel_size, help='Number of tensor parallel replicas.') + parser.add_argument('--enable-sequence-parallel', + '-sp', + action='store_true', + default=False, + help='If enable sequence parallel') parser.add_argument( '--enable-expert-parallel', action='store_true', @@ -1243,6 +1249,7 @@ def create_engine_config( parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, + enable_sequence_parallel=self.enable_sequence_parallel, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a0e2fa2918bd..819a99284065 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -74,6 +74,8 @@ class LLM: environments. tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. + enable_sequence_parallel: Enable sequence parallelism on top of tensor + parallelism. dtype: The data type for the model weights and activations. Currently, we support `float32`, `float16`, and `bfloat16`. If `auto`, we use the `torch_dtype` attribute specified in the model config file. @@ -164,6 +166,7 @@ def __init__( trust_remote_code: bool = False, allowed_local_media_path: str = "", tensor_parallel_size: int = 1, + enable_sequence_parallel: bool = False, dtype: str = "auto", quantization: Optional[str] = None, revision: Optional[str] = None, @@ -219,6 +222,7 @@ def __init__( trust_remote_code=trust_remote_code, allowed_local_media_path=allowed_local_media_path, tensor_parallel_size=tensor_parallel_size, + enable_sequence_parallel=enable_sequence_parallel, dtype=dtype, quantization=quantization, revision=revision, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index e195a03c5cac..5359550c6319 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -38,6 +38,7 @@ class ForwardContext: attn_metadata: "AttentionMetadata" # set dynamically for each forward pass # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass + enable_sequence_parallel: bool # If enable sequence_parallelism # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None @@ -53,11 +54,16 @@ def get_forward_context() -> ForwardContext: return _forward_context +def try_get_forward_context() -> Optional[ForwardContext]: + return _forward_context + + @contextmanager def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: int = 0): + num_tokens: int = 0, + enable_sequence_parallel: bool = False): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -90,6 +96,12 @@ def set_forward_context(attn_metadata: Any, cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) + # TODO: support pipeline parallel + if vllm_config.parallel_config.enable_sequence_parallel: + assert vllm_config.parallel_config.pipeline_parallel_size == 1, ( + "sequence parallel doesn't work correctly when " + "combined with pipeline parallel") + global _forward_context prev_context = _forward_context _forward_context = ForwardContext( @@ -97,6 +109,7 @@ def set_forward_context(attn_metadata: Any, static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, + enable_sequence_parallel=enable_sequence_parallel, dp_metadata=dp_metadata) try: yield diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1ae574072b8f..6c680a525b7d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,7 +13,9 @@ get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import try_get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -469,6 +471,11 @@ def forward( ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + input_ = tensor_model_parallel_all_gather(input_, 0) + # Matrix multiply. assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) @@ -1258,8 +1265,15 @@ def forward( output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + if self.reduce_results and self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + output = tensor_model_parallel_reduce_scatter(output_parallel, + dim=0) + else: + output = tensor_model_parallel_all_reduce(output_parallel) else: output = output_parallel diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 4a359725bad0..2a070d231330 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -10,6 +10,7 @@ import vllm.envs as envs from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) +from vllm.forward_context import try_get_forward_context from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -59,6 +60,11 @@ def forward( sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) if self.logits_as_input: logits = hidden_states else: @@ -105,6 +111,7 @@ def _get_logits( embedding_bias: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index f65dfc3cb329..b6a92e4db850 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -9,7 +9,9 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import try_get_forward_context from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) from vllm.model_executor.parameter import BasevLLMParameter @@ -204,7 +206,6 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() - # Keep the input dimensions. tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -418,7 +419,13 @@ def forward(self, input_): if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) # Reduce across all the model parallel GPUs. - output = tensor_model_parallel_all_reduce(output_parallel) + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + output = tensor_model_parallel_reduce_scatter(output_parallel, + dim=0) + else: + output = tensor_model_parallel_all_reduce(output_parallel) return output def extra_repr(self) -> str: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 657333c6d84c..bdf7eefd4046 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -13,6 +13,7 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig +from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY @@ -856,7 +857,6 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"): # depending on the input multimodal items. curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) - for output in curr_group_outputs: encoder_outputs.append(output) @@ -1026,9 +1026,22 @@ def execute_model( for k, v in self.intermediate_tensors.items() }) + # only do sequence parallelism when num of tokens + # is divisible by parallel size. + # sequence parallelism uses torch.distributed.reduce_scatter which only + # supports the case when size is divisible by parallel size + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and num_input_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context( + attn_metadata, + self.vllm_config, + enable_sequence_parallel=enable_sequence_parallel, + ): hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -1039,6 +1052,9 @@ def execute_model( # For mid-pipeline stages, return the hidden states. return hidden_states + if enable_sequence_parallel: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1218,6 +1234,7 @@ def _get_prompt_logprobs_dict( req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc_np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states, None) # Get the "target" tokens for each index. For prompt at index i, @@ -1295,15 +1312,26 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, - self.vllm_config, - num_tokens=num_tokens): + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and num_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + enable_sequence_parallel=enable_sequence_parallel, + ): hidden_states = model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if get_pp_group().is_last_rank and enable_sequence_parallel: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 473bd901b5b2..3cd7c994f809 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1556,7 +1556,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: capture_inputs) with set_forward_context(attn_metadata, self.vllm_config, - virtual_engine): + virtual_engine, batch_size): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( @@ -1736,9 +1736,24 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() + num_tokens = (model_input.input_tokens.shape[0] + if model_input.input_tokens is not None else None) + + # only do sequence parallelism when num of tokens + # is divisible by parallel size. + # sequence parallelism uses torch.distributed.reduce_scatter which only + # supports the case when size is divisible by parallel size + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and num_tokens is not None and num_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + if not bypass_model_exec: - with set_forward_context(model_input.attn_metadata, - self.vllm_config, virtual_engine): + with set_forward_context( + model_input.attn_metadata, + self.vllm_config, + virtual_engine, + enable_sequence_parallel=enable_sequence_parallel): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -1785,8 +1800,14 @@ def execute_model( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) + with set_forward_context( + model_input.attn_metadata, + self.vllm_config, + virtual_engine, + enable_sequence_parallel=enable_sequence_parallel, + ): + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) if not self.is_driver_worker: return []