diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 2ac98976539e..b00519314d8b 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # usage: -# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \ -# python examples/offline_inference/data_parallel.py +# VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py # we need to have a launcher to create multiple data parallel # ranks. And each rank will create a vLLM instance to process its own prompts. import os @@ -55,7 +54,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): # Create an LLM. llm = LLM(model="ibm-research/PowerMoE-3b", tensor_parallel_size=GPUs_per_dp_rank, - enforce_eager=True) + enforce_eager=True, + enable_expert_parallel=True) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: diff --git a/vllm/config.py b/vllm/config.py index 3f1bff498129..7692f9ec80f5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -754,7 +754,7 @@ def verify_with_parallel_config( " must be divisible by tensor parallel size " f"({tensor_parallel_size}).") - if envs.VLLM_TEST_ENABLE_EP: + if parallel_config.enable_expert_parallel: self._verify_with_expert_parallelism() pipeline_parallel_size = parallel_config.pipeline_parallel_size @@ -1334,6 +1334,7 @@ class ParallelConfig: # IP of the data parallel master. data_parallel_master_ip: str = "127.0.0.1" data_parallel_master_port: int = 29500 # Port of the data parallel master. + enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers. # Maximum number of multiple batches # when load model sequentially. To avoid RAM OOM when using tensor diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 989eb4dbfd14..c18e2f391f81 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -114,6 +114,7 @@ class EngineArgs: # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 + enable_expert_parallel: bool = False max_parallel_loading_workers: Optional[int] = None block_size: Optional[int] = None enable_prefix_caching: Optional[bool] = None @@ -439,6 +440,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=EngineArgs.tensor_parallel_size, help='Number of tensor parallel replicas.') + parser.add_argument( + '--enable-expert-parallel', + action='store_true', + help='Use expert parallelism instead of tensor parallelism ' + 'for MoE layers.') parser.add_argument( '--max-parallel-loading-workers', type=int, @@ -1199,6 +1205,7 @@ def create_engine_config(self, parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, + enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, tokenizer_pool_config=TokenizerPoolConfig.create_config( diff --git a/vllm/envs.py b/vllm/envs.py index f6c038967b69..5d490efdadb2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -86,7 +86,6 @@ VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True VLLM_MLA_DISABLE_REQUANTIZATION: bool = False VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True - VLLM_TEST_ENABLE_EP: bool = False VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" @@ -578,12 +577,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) ), - # If set, vLLM will use the experimental expert parallel implementation on - # the FusedMoE layer, using tensor parallelism size as expert parallelism - # size. - "VLLM_TEST_ENABLE_EP": - lambda: bool(int(os.getenv("VLLM_TEST_ENABLE_EP", "0"))), - # Number of GPUs per worker in Ray, if it is set to be a fraction, # it allows ray to schedule multiple actors on a single GPU, # so that users can colocate other actors on the same GPUs as vLLM. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 33d2896f3fd2..d0209eb40e8c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -7,7 +7,6 @@ import torch from torch.nn.parameter import UninitializedParameter -import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -342,14 +341,6 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() - # For smuggling this layer into the fused moe custom op - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError("Duplicate layer name: {}".format(prefix)) - compilation_config.static_forward_context[prefix] = self - self.layer_name = prefix - self.use_direct_call = not envs.VLLM_TEST_ENABLE_EP - # Note: here we guard against accessing the TP and DP groups when # uninitialized (this happens when testing) self.tp_size = (tp_size if tp_size is not None else @@ -361,7 +352,21 @@ def __init__( if self.dp_size == 1 else get_dp_group().rank_in_group) self.global_num_experts = num_experts - if envs.VLLM_TEST_ENABLE_EP: + # Use expert parallelism instead of tensor parallelism? + vllm_config = get_current_vllm_config() + use_ep = (vllm_config.parallel_config.enable_expert_parallel + and self.tp_size > 1) + + # For smuggling this layer into the fused moe custom op + self.use_direct_call = self.dp_size == 1 + if self.use_direct_call: + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError("Duplicate layer name: {}".format(prefix)) + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + + if use_ep: # Set TP size to 1 to adjust for EP and adjust EP size and rank # for DP attention. self.ep_rank = tp_rank + self.tp_size * self.dp_rank