diff --git a/vllm/envs.py b/vllm/envs.py index 048d63bfec0f..7d1f46378f48 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -96,6 +96,7 @@ VLLM_DP_SIZE: int = 1 VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 + VLLM_USE_AITER_MOE: bool = False def get_default_cache_root(): @@ -630,6 +631,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # Whether to use S3 path for model loading in CI via RunAI Streamer "VLLM_CI_USE_S3": lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", + + # flag to control if vllm should use AITER MoE + "VLLM_USE_AITER_MOE": + lambda: (os.environ.get("VLLM_USE_AITER_MOE", "False").lower() in + ("true", "1")), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 00260313e72e..dff908fcf297 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -18,6 +18,7 @@ from vllm.utils import direct_register_custom_op logger = init_logger(__name__) +is_hip = current_platform.is_rocm() @triton.jit @@ -1164,8 +1165,52 @@ def fused_experts(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: + block_shape: Optional[List[int]] = None, + expert_mask: torch.Tensor = None +) -> torch.Tensor: + + if is_hip and envs.VLLM_USE_AITER_MOE: + from aiter.fused_moe_bf16_asm import moe_sorting_ck + + local_E = E = w1.shape[0] + if expert_mask is not None: + E = expert_mask.numel() + topk = topk_ids.shape[1] + model_dim = w1.shape[-1] + dtype = hidden_states.dtype + scale_blk_k = block_shape[1] + + ( + sorted_token_ids, + sorted_weight_buf, + sorted_expert_ids, + num_valid_ids, + out_asm, + ) = moe_sorting_ck( + topk_ids, topk_weights, E, model_dim, dtype, expert_mask=expert_mask + ) + + a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k) + aiter.fmoe_fp8_blockscale_g1u1( + out_asm, + a1, + w1, + w2, + sorted_token_ids, + sorted_weight_buf, + sorted_expert_ids, + num_valid_ids, + topk, + w1_scale.view(local_E, -1), + w2_scale.view(local_E, -1), + a1_scale.t().contiguous(), + block_shape[0], + block_shape[1], + None, + ) + return out_asm + # Not using AITER MoE if inplace: torch.ops.vllm.inplace_fused_experts( hidden_states, w1, w2, topk_weights, topk_ids, activation, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 052d4d54601f..086e958d273e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,6 +29,7 @@ else: fused_moe_pallas = None # type: ignore logger = init_logger(__name__) +is_hip = current_platform.is_rocm() class FusedMoeWeightScaleSupported(Enum): @@ -287,6 +288,8 @@ def __init__( scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + num_shared_experts: Optional[int] = 0, + routed_scaling_factor: Optional[float] = 1.0, ): super().__init__() @@ -364,6 +367,7 @@ def __init__( moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) + self.aiter_shuffled = False def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, @@ -669,6 +673,14 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + if is_hip and envs.VLLM_USE_AITER_MOE: + from aiter.ops.shuffle import shuffle_weight + + if not self.aiter_shuffled: + self.w13_weight.data = shuffle_weight(self.w13_weight, (16, 16)) + self.w2_weight.data = shuffle_weight(self.w2_weight, (16, 16)) + self.aiter_shuffled = True + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -698,7 +710,9 @@ def forward(self, hidden_states: torch.Tensor, def make_expert_params_mapping( cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, str]]: + num_experts: int, + num_shared_experts: Optional[int] = 0, + ) -> List[Tuple[str, str, int, str]]: return [ # (param_name, weight_name, expert_id, shard_id) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b5409c7fe1b7..d80f5dfa3f9c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -135,7 +135,10 @@ def __init__( topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias) + e_score_correction_bias=self.gate.e_score_correction_bias, + num_shared_experts=config.n_shared_experts, + routed_scaling_factor=self.routed_scaling_factor, + ) if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * @@ -715,7 +718,9 @@ def load_weights(self, weights: Iterable[Tuple[str, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + num_shared_experts=self.config.n_shared_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set()