From e6c79775ceb3360f8904df49a48c866db7bd5dc9 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 4 Jun 2025 21:33:47 +0800 Subject: [PATCH 1/4] [CI]Moe alltoall communication optimization for unquantized sence. Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/fused_moe.py | 365 +++++++++++++++++++++++------------ 1 file changed, 238 insertions(+), 127 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 4853b272824..2d0eb9d1f6a 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -15,11 +15,11 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py -from typing import Callable, Optional +from typing import Callable, List, Optional import torch -import torch.distributed as dist import torch_npu +import torch.distributed as dist from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_world_size, @@ -38,6 +38,70 @@ USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM +def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, + max_row_per_ep_rank: int, num_tokens: int, + top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + original_total_elements = num_tokens * top_k + device = topk_ids.device + original_dtype = topk_ids.dtype + + if original_total_elements == 0: + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + unpad_indices = torch.full((original_total_elements, ), + -1, + dtype=torch.long, + device=device) + return topk_ids_pad, unpad_indices + + experts_per_ep_rank_val = expert_num // ep_size + if experts_per_ep_rank_val == 0: + raise ValueError( + "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " + "Ensure expert_num >= ep_size.") + + assigned_ep_rank = (topk_ids.float() / + experts_per_ep_rank_val).to(original_dtype) + indices_arange = torch.arange(topk_ids.shape[0], device=device) + + is_new_segment = torch.cat((torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1])) + temp_start_markers = torch.full_like(indices_arange, + -1, + dtype=indices_arange.dtype) + temp_start_markers[is_new_segment] = indices_arange[is_new_segment] + start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] + token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token + is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank + cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) + indices_in_rec_cond_list_for_all = cumsum_kept - 1 + unpad_indices = torch.where( + is_kept_mask, indices_in_rec_cond_list_for_all, + torch.tensor(-1, device=device, dtype=torch.long)) + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + if topk_ids.shape[0] > 0: + all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx + temp_pad_buffer = torch.full((output_len + 1, ), + expert_num, + dtype=original_dtype, + device=device) + output_len_tensor = torch.tensor(output_len, + dtype=torch.long, + device=device) + scatter_indices = torch.where(is_kept_mask, all_destination_indices, + output_len_tensor) + temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) + topk_ids_pad = temp_pad_buffer[:output_len] + return topk_ids_pad, unpad_indices + + def fused_experts_with_mc2( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -147,138 +211,184 @@ def fused_experts_with_mc2( return hidden_states -# currently expert parallelism implemented with all2all -# is under-optimized. -def fused_experts_with_all2all( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, -): - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - num_tokens, _ = hidden_states.shape - num_experts = w1.shape[0] - device = hidden_states.device - - if expert_map is not None: - global_num_experts = len(expert_map) - local_num_experts = global_num_experts // ep_group.world_size - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - scatter_sizes = global_expert_tokens.view(ep_group.world_size, - -1).sum(-1) - - gather_sizes = torch.empty_like(scatter_sizes) - dist.all_to_all_single(gather_sizes, - scatter_sizes, - group=ep_group.device_group) - scatter_size_list = scatter_sizes.cpu().tolist() - gather_size_list = gather_sizes.cpu().tolist() - - expanded_expert_idx = expanded_expert_idx % local_num_experts - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - scatter_size_list, - gather_size_list) - local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, - scatter_size_list, - gather_size_list) - - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) +def apply_mlp(hidden_states_wrapper: List[torch.Tensor], + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts).to(torch.int64) + Args: + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) - hidden_states = hidden_states[sorted_idx] - else: - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) + Returns: + hidden_states: output hidden states after MLP. + """ - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) - expert_tokens = expert_tokens.to(torch.int64) + assert len(hidden_states_wrapper) == 1 + hidden_states = hidden_states_wrapper.pop() w1 = w1.transpose(1, 2) - gate_up_out_list = torch_npu.npu_grouped_matmul( + hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w1], split_item=2, - group_list_type=0, + group_list_type=group_list_type, group_type=0, - group_list=expert_tokens, + group_list=group_list, ) - # TODO: Remove this in the future. - hidden_states = torch.cat(gate_up_out_list, dim=0) + hidden_states = torch.cat(hidden_states, dim=0) hidden_states = torch_npu.npu_swiglu(hidden_states) w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul( + hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w2], split_item=2, - group_list_type=0, + group_list_type=group_list_type, group_type=0, - group_list=expert_tokens, + group_list=group_list, ) - hidden_states = torch.cat(down_out_list, dim=0) + hidden_states = torch.cat(hidden_states, dim=0) + return hidden_states - if expert_map is not None: - resorted_idx = torch.argsort(sorted_idx) - hidden_states = hidden_states[resorted_idx] - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - gather_size_list, - scatter_size_list) - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) +# currently expert parallelism implemented with all2all +# is under-optimized. +def fused_experts_with_all2all( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + max_model_len: int, + global_batch_size: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, +): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + device = hidden_states.device + + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_group.world_size + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, + device=device).view(top_k, + -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * + max_model_len // ep_group.world_size + + 1) * top_k * 2 + expert_idx_buffer_scatter, unpad_indices = process_topk_ids( + expanded_expert_idx, global_num_experts, ep_group.world_size, + max_row_per_ep_rank, num_tokens, top_k) + hidden_states_pad_idx = torch.zeros( + expert_idx_buffer_scatter.shape, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + non_pad_len = torch.sum( + (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[ + expert_idx_buffer_scatter != global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) + + hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] + expert_idx_buffer_gather = torch.empty_like( + expert_idx_buffer_scatter, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + hidden_states_buffer_gather = torch.empty_like( + hidden_states_buffer_scatter, + dtype=hidden_states_buffer_scatter.dtype, + device=hidden_states_buffer_scatter.device) + dist.all_to_all_single(expert_idx_buffer_gather, + expert_idx_buffer_scatter, + group=ep_group.device_group) + dist.all_to_all_single(hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group) + mask = expert_idx_buffer_gather != global_num_experts + local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( + global_num_experts // ep_group.world_size) + hidden_states = hidden_states_buffer_gather[mask] + idx_type = local_expert_idx.dtype + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) + sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx, local_num_experts).to(torch.int64) + hidden_states = hidden_states[sorted_idx] + group_list_type = 0 + + hidden_states_wrapper = [hidden_states] + del hidden_states + + hidden_states = apply_mlp(hidden_states_wrapper, + w1, + w2, + expert_tokens, + group_list_type=group_list_type) + + resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) + hidden_states = hidden_states[resorted_idx] + hidden_states_scatter = torch.zeros( + (mask.shape[0], hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states_scatter[mask] = hidden_states + hidden_states_gatter = torch.empty_like( + hidden_states_scatter, + dtype=hidden_states_scatter.dtype, + device=hidden_states_scatter.device) + dist.all_to_all_single(hidden_states_gatter, + hidden_states_scatter, + group=ep_group.device_group) + hidden_states_gatter = hidden_states_gatter[ + expert_idx_buffer_scatter != global_num_experts] + if hidden_states_gatter.shape[0] != row_idx_len: + hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states[unpad_indices != -1] = hidden_states_gatter else: # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) + hidden_states = hidden_states_gatter + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + if len(original_shape) == 3: final_hidden_states = final_hidden_states.view(original_shape) return final_hidden_states @@ -586,6 +696,7 @@ def __init__(self, moe: MoEConfig = None): self.ep_size = ep_group.world_size self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.local_batch_size = self.global_batch_size // self.ep_size + self.max_model_len = vllm_config.model_config.max_model_len self.enable_graph_mode = False additional_config = get_current_vllm_config().additional_config @@ -617,21 +728,22 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, router_logits: torch.Tensor, + top_k: int, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + use_grouped_topk: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = False, enable_force_load_balance: bool = False, **kwargs, - ): + ) -> torch.Tensor: + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern if global_num_experts == 256: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( @@ -687,18 +799,17 @@ def apply( top_k=top_k, expert_map=expert_map) else: - # The current implementation of deepseek moe splits hidden_states - # according to tp_size before they are feed into fused_moe module. - # Therefore, all2all is needed no matter how dp/tp is set so as to - # dispatch/combine tokens. - return fused_experts_with_all2all(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - ep_group=get_ep_group()) + return fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + max_model_len=self.max_model_len, + global_batch_size=self.global_batch_size, + expert_map=expert_map, + ep_group=get_ep_group()) class AscendFusedMoE(FusedMoE): From 3b1c3df90419b91c4def070b944c3649b815b286 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 4 Jun 2025 21:48:40 +0800 Subject: [PATCH 2/4] [CI]Moe alltoall communication optimization for unquantized sence. Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 2d0eb9d1f6a..da692b24df5 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -18,8 +18,8 @@ from typing import Callable, List, Optional import torch -import torch_npu import torch.distributed as dist +import torch_npu from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_world_size, From 7f86c069167fdb646b2b4b2a03767a8da7c4fc9d Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 4 Jun 2025 21:33:47 +0800 Subject: [PATCH 3/4] [CI]Moe alltoall communication optimization for unquantized sence. Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/fused_moe.py | 365 +++++++++++++++++++++++------------ 1 file changed, 238 insertions(+), 127 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6aff62fc627..c2abd1efcfe 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -15,11 +15,11 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py -from typing import Callable, Optional +from typing import Callable, List, Optional import torch -import torch.distributed as dist import torch_npu +import torch.distributed as dist from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_world_size, @@ -39,6 +39,70 @@ USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM +def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, + max_row_per_ep_rank: int, num_tokens: int, + top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + original_total_elements = num_tokens * top_k + device = topk_ids.device + original_dtype = topk_ids.dtype + + if original_total_elements == 0: + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + unpad_indices = torch.full((original_total_elements, ), + -1, + dtype=torch.long, + device=device) + return topk_ids_pad, unpad_indices + + experts_per_ep_rank_val = expert_num // ep_size + if experts_per_ep_rank_val == 0: + raise ValueError( + "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " + "Ensure expert_num >= ep_size.") + + assigned_ep_rank = (topk_ids.float() / + experts_per_ep_rank_val).to(original_dtype) + indices_arange = torch.arange(topk_ids.shape[0], device=device) + + is_new_segment = torch.cat((torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1])) + temp_start_markers = torch.full_like(indices_arange, + -1, + dtype=indices_arange.dtype) + temp_start_markers[is_new_segment] = indices_arange[is_new_segment] + start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] + token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token + is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank + cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) + indices_in_rec_cond_list_for_all = cumsum_kept - 1 + unpad_indices = torch.where( + is_kept_mask, indices_in_rec_cond_list_for_all, + torch.tensor(-1, device=device, dtype=torch.long)) + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + if topk_ids.shape[0] > 0: + all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx + temp_pad_buffer = torch.full((output_len + 1, ), + expert_num, + dtype=original_dtype, + device=device) + output_len_tensor = torch.tensor(output_len, + dtype=torch.long, + device=device) + scatter_indices = torch.where(is_kept_mask, all_destination_indices, + output_len_tensor) + temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) + topk_ids_pad = temp_pad_buffer[:output_len] + return topk_ids_pad, unpad_indices + + def fused_experts_with_mc2(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -147,138 +211,184 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, return hidden_states -# currently expert parallelism implemented with all2all -# is under-optimized. -def fused_experts_with_all2all( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, -): - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - num_tokens, _ = hidden_states.shape - num_experts = w1.shape[0] - device = hidden_states.device - - if expert_map is not None: - global_num_experts = len(expert_map) - local_num_experts = global_num_experts // ep_group.world_size - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - scatter_sizes = global_expert_tokens.view(ep_group.world_size, - -1).sum(-1) - - gather_sizes = torch.empty_like(scatter_sizes) - dist.all_to_all_single(gather_sizes, - scatter_sizes, - group=ep_group.device_group) - scatter_size_list = scatter_sizes.cpu().tolist() - gather_size_list = gather_sizes.cpu().tolist() - - expanded_expert_idx = expanded_expert_idx % local_num_experts - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - scatter_size_list, - gather_size_list) - local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, - scatter_size_list, - gather_size_list) - - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) +def apply_mlp(hidden_states_wrapper: List[torch.Tensor], + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts).to(torch.int64) + Args: + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) - hidden_states = hidden_states[sorted_idx] - else: - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) + Returns: + hidden_states: output hidden states after MLP. + """ - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) - expert_tokens = expert_tokens.to(torch.int64) + assert len(hidden_states_wrapper) == 1 + hidden_states = hidden_states_wrapper.pop() w1 = w1.transpose(1, 2) - gate_up_out_list = torch_npu.npu_grouped_matmul( + hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w1], split_item=2, - group_list_type=0, + group_list_type=group_list_type, group_type=0, - group_list=expert_tokens, + group_list=group_list, ) - # TODO: Remove this in the future. - hidden_states = torch.cat(gate_up_out_list, dim=0) + hidden_states = torch.cat(hidden_states, dim=0) hidden_states = torch_npu.npu_swiglu(hidden_states) w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul( + hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w2], split_item=2, - group_list_type=0, + group_list_type=group_list_type, group_type=0, - group_list=expert_tokens, + group_list=group_list, ) - hidden_states = torch.cat(down_out_list, dim=0) + hidden_states = torch.cat(hidden_states, dim=0) + return hidden_states - if expert_map is not None: - resorted_idx = torch.argsort(sorted_idx) - hidden_states = hidden_states[resorted_idx] - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - gather_size_list, - scatter_size_list) - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) +# currently expert parallelism implemented with all2all +# is under-optimized. +def fused_experts_with_all2all( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + max_model_len: int, + global_batch_size: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, +): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + device = hidden_states.device + + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_group.world_size + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, + device=device).view(top_k, + -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * + max_model_len // ep_group.world_size + + 1) * top_k * 2 + expert_idx_buffer_scatter, unpad_indices = process_topk_ids( + expanded_expert_idx, global_num_experts, ep_group.world_size, + max_row_per_ep_rank, num_tokens, top_k) + hidden_states_pad_idx = torch.zeros( + expert_idx_buffer_scatter.shape, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + non_pad_len = torch.sum( + (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[ + expert_idx_buffer_scatter != global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) + + hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] + expert_idx_buffer_gather = torch.empty_like( + expert_idx_buffer_scatter, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + hidden_states_buffer_gather = torch.empty_like( + hidden_states_buffer_scatter, + dtype=hidden_states_buffer_scatter.dtype, + device=hidden_states_buffer_scatter.device) + dist.all_to_all_single(expert_idx_buffer_gather, + expert_idx_buffer_scatter, + group=ep_group.device_group) + dist.all_to_all_single(hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group) + mask = expert_idx_buffer_gather != global_num_experts + local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( + global_num_experts // ep_group.world_size) + hidden_states = hidden_states_buffer_gather[mask] + idx_type = local_expert_idx.dtype + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) + sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx, local_num_experts).to(torch.int64) + hidden_states = hidden_states[sorted_idx] + group_list_type = 0 + + hidden_states_wrapper = [hidden_states] + del hidden_states + + hidden_states = apply_mlp(hidden_states_wrapper, + w1, + w2, + expert_tokens, + group_list_type=group_list_type) + + resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) + hidden_states = hidden_states[resorted_idx] + hidden_states_scatter = torch.zeros( + (mask.shape[0], hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states_scatter[mask] = hidden_states + hidden_states_gatter = torch.empty_like( + hidden_states_scatter, + dtype=hidden_states_scatter.dtype, + device=hidden_states_scatter.device) + dist.all_to_all_single(hidden_states_gatter, + hidden_states_scatter, + group=ep_group.device_group) + hidden_states_gatter = hidden_states_gatter[ + expert_idx_buffer_scatter != global_num_experts] + if hidden_states_gatter.shape[0] != row_idx_len: + hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states[unpad_indices != -1] = hidden_states_gatter else: # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) + hidden_states = hidden_states_gatter + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + if len(original_shape) == 3: final_hidden_states = final_hidden_states.view(original_shape) return final_hidden_states @@ -586,6 +696,7 @@ def __init__(self, moe: MoEConfig = None): self.ep_size = ep_group.world_size self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.local_batch_size = self.global_batch_size // self.ep_size + self.max_model_len = vllm_config.model_config.max_model_len ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled @@ -614,21 +725,22 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, router_logits: torch.Tensor, + top_k: int, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + use_grouped_topk: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = False, enable_force_load_balance: bool = False, **kwargs, - ): + ) -> torch.Tensor: + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern if global_num_experts == 256: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( @@ -685,18 +797,17 @@ def apply( top_k=top_k, expert_map=expert_map) else: - # The current implementation of deepseek moe splits hidden_states - # according to tp_size before they are feed into fused_moe module. - # Therefore, all2all is needed no matter how dp/tp is set so as to - # dispatch/combine tokens. - return fused_experts_with_all2all(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - ep_group=get_ep_group()) + return fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + max_model_len=self.max_model_len, + global_batch_size=self.global_batch_size, + expert_map=expert_map, + ep_group=get_ep_group()) class AscendFusedMoE(FusedMoE): From f30f99ddcb8856d06ff429dd15cd2031212ffd89 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 4 Jun 2025 21:48:40 +0800 Subject: [PATCH 4/4] [CI]Moe alltoall communication optimization for unquantized sence. Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index c2abd1efcfe..03463b17ff2 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -18,8 +18,8 @@ from typing import Callable, List, Optional import torch -import torch_npu import torch.distributed as dist +import torch_npu from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_world_size,