4646from vllm .model_executor .layers .logits_processor import LogitsProcessor
4747from vllm .model_executor .layers .quantization import QuantizationConfig
4848from vllm .model_executor .layers .rotary_embedding import get_rope
49+ from vllm .model_executor .layers .shared_fused_moe import SharedFusedMoE
4950from vllm .model_executor .layers .vocab_parallel_embedding import (
5051 ParallelLMHead , VocabParallelEmbedding )
5152from vllm .model_executor .model_loader .weight_utils import (
@@ -146,25 +147,6 @@ def __init__(
146147 self .physical_expert_end = (self .physical_expert_start +
147148 self .n_local_physical_experts )
148149
149- self .experts = FusedMoE (
150- num_experts = config .n_routed_experts ,
151- top_k = config .num_experts_per_tok ,
152- hidden_size = config .hidden_size ,
153- intermediate_size = config .moe_intermediate_size ,
154- reduce_results = False ,
155- renormalize = config .norm_topk_prob ,
156- quant_config = quant_config ,
157- use_grouped_topk = True ,
158- num_expert_group = config .n_group ,
159- topk_group = config .topk_group ,
160- prefix = f"{ prefix } .experts" ,
161- scoring_func = "sigmoid" ,
162- # we do scaling outside, set factor to 1.0 to avoid double mul
163- routed_scaling_factor = 1.0 ,
164- e_score_correction_bias = self .gate .e_score_correction_bias ,
165- enable_eplb = self .enable_eplb ,
166- num_redundant_experts = self .n_redundant_experts )
167-
168150 if config .n_shared_experts is not None :
169151 intermediate_size = (config .moe_intermediate_size *
170152 config .n_shared_experts )
@@ -173,25 +155,68 @@ def __init__(
173155 intermediate_size = intermediate_size ,
174156 hidden_act = config .hidden_act ,
175157 quant_config = quant_config ,
176- reduce_results = self .experts .must_reduce_shared_expert_outputs (
177- ),
158+ reduce_results = False ,
178159 prefix = f"{ prefix } .shared_experts" ,
179160 )
161+ self .experts = SharedFusedMoE (
162+ shared_experts = self .shared_experts ,
163+ num_experts = config .n_routed_experts ,
164+ top_k = config .num_experts_per_tok ,
165+ hidden_size = config .hidden_size ,
166+ intermediate_size = config .moe_intermediate_size ,
167+ reduce_results = False ,
168+ renormalize = config .norm_topk_prob ,
169+ quant_config = quant_config ,
170+ use_grouped_topk = True ,
171+ num_expert_group = config .n_group ,
172+ topk_group = config .topk_group ,
173+ prefix = f"{ prefix } .experts" ,
174+ scoring_func = "sigmoid" ,
175+ # we do scaling outside, set factor to 1.0 to avoid double mul
176+ routed_scaling_factor = 1.0 ,
177+ e_score_correction_bias = self .gate .e_score_correction_bias ,
178+ enable_eplb = self .enable_eplb ,
179+ num_redundant_experts = self .n_redundant_experts ,
180+ )
181+ else :
182+ self .experts = FusedMoE (
183+ num_experts = config .n_routed_experts ,
184+ top_k = config .num_experts_per_tok ,
185+ hidden_size = config .hidden_size ,
186+ intermediate_size = config .moe_intermediate_size ,
187+ reduce_results = False ,
188+ renormalize = config .norm_topk_prob ,
189+ quant_config = quant_config ,
190+ use_grouped_topk = True ,
191+ num_expert_group = config .n_group ,
192+ topk_group = config .topk_group ,
193+ prefix = f"{ prefix } .experts" ,
194+ scoring_func = "sigmoid" ,
195+ # we do scaling outside, set factor to 1.0 to avoid double mul
196+ routed_scaling_factor = 1.0 ,
197+ e_score_correction_bias = self .gate .e_score_correction_bias ,
198+ enable_eplb = self .enable_eplb ,
199+ num_redundant_experts = self .n_redundant_experts )
180200
181201 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
182202 num_tokens , hidden_dim = hidden_states .shape
183203 hidden_states = hidden_states .view (- 1 , hidden_dim )
184204
185- if self .n_shared_experts is not None :
186- shared_output = self .shared_experts (hidden_states )
187- else :
188- shared_output = None
205+ # router_logits: (num_tokens, n_experts)
189206 router_logits = self .gate (hidden_states .to (dtype = torch .float32 ))
190- final_hidden_states = self .experts (
191- hidden_states = hidden_states ,
192- router_logits = router_logits ) * self .routed_scaling_factor
193- if shared_output is not None :
194- final_hidden_states = final_hidden_states + shared_output
207+
208+ fused_moe_out = self .experts (hidden_states = hidden_states ,
209+ router_logits = router_logits )
210+
211+ if self .shared_experts is not None :
212+ shared_output , final_hidden_states = fused_moe_out
213+ assert shared_output is not None
214+ final_hidden_states = \
215+ final_hidden_states * self .routed_scaling_factor \
216+ + shared_output
217+ else :
218+ final_hidden_states = fused_moe_out * self .routed_scaling_factor
219+
195220 if self .tp_size > 1 :
196221 final_hidden_states = (
197222 self .experts .maybe_all_reduce_tensor_model_parallel (
0 commit comments