Skip to content

Commit c15309a

Browse files
authored
[Model] Apply SharedFusedMoE to glm4_moe. (#24849)
Signed-off-by: whx-sjtu <[email protected]>
1 parent 4a9375f commit c15309a

File tree

1 file changed

+55
-30
lines changed

1 file changed

+55
-30
lines changed

vllm/model_executor/models/glm4_moe.py

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4747
from vllm.model_executor.layers.quantization import QuantizationConfig
4848
from vllm.model_executor.layers.rotary_embedding import get_rope
49+
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
4950
from vllm.model_executor.layers.vocab_parallel_embedding import (
5051
ParallelLMHead, VocabParallelEmbedding)
5152
from vllm.model_executor.model_loader.weight_utils import (
@@ -146,25 +147,6 @@ def __init__(
146147
self.physical_expert_end = (self.physical_expert_start +
147148
self.n_local_physical_experts)
148149

149-
self.experts = FusedMoE(
150-
num_experts=config.n_routed_experts,
151-
top_k=config.num_experts_per_tok,
152-
hidden_size=config.hidden_size,
153-
intermediate_size=config.moe_intermediate_size,
154-
reduce_results=False,
155-
renormalize=config.norm_topk_prob,
156-
quant_config=quant_config,
157-
use_grouped_topk=True,
158-
num_expert_group=config.n_group,
159-
topk_group=config.topk_group,
160-
prefix=f"{prefix}.experts",
161-
scoring_func="sigmoid",
162-
# we do scaling outside, set factor to 1.0 to avoid double mul
163-
routed_scaling_factor=1.0,
164-
e_score_correction_bias=self.gate.e_score_correction_bias,
165-
enable_eplb=self.enable_eplb,
166-
num_redundant_experts=self.n_redundant_experts)
167-
168150
if config.n_shared_experts is not None:
169151
intermediate_size = (config.moe_intermediate_size *
170152
config.n_shared_experts)
@@ -173,25 +155,68 @@ def __init__(
173155
intermediate_size=intermediate_size,
174156
hidden_act=config.hidden_act,
175157
quant_config=quant_config,
176-
reduce_results=self.experts.must_reduce_shared_expert_outputs(
177-
),
158+
reduce_results=False,
178159
prefix=f"{prefix}.shared_experts",
179160
)
161+
self.experts = SharedFusedMoE(
162+
shared_experts=self.shared_experts,
163+
num_experts=config.n_routed_experts,
164+
top_k=config.num_experts_per_tok,
165+
hidden_size=config.hidden_size,
166+
intermediate_size=config.moe_intermediate_size,
167+
reduce_results=False,
168+
renormalize=config.norm_topk_prob,
169+
quant_config=quant_config,
170+
use_grouped_topk=True,
171+
num_expert_group=config.n_group,
172+
topk_group=config.topk_group,
173+
prefix=f"{prefix}.experts",
174+
scoring_func="sigmoid",
175+
# we do scaling outside, set factor to 1.0 to avoid double mul
176+
routed_scaling_factor=1.0,
177+
e_score_correction_bias=self.gate.e_score_correction_bias,
178+
enable_eplb=self.enable_eplb,
179+
num_redundant_experts=self.n_redundant_experts,
180+
)
181+
else:
182+
self.experts = FusedMoE(
183+
num_experts=config.n_routed_experts,
184+
top_k=config.num_experts_per_tok,
185+
hidden_size=config.hidden_size,
186+
intermediate_size=config.moe_intermediate_size,
187+
reduce_results=False,
188+
renormalize=config.norm_topk_prob,
189+
quant_config=quant_config,
190+
use_grouped_topk=True,
191+
num_expert_group=config.n_group,
192+
topk_group=config.topk_group,
193+
prefix=f"{prefix}.experts",
194+
scoring_func="sigmoid",
195+
# we do scaling outside, set factor to 1.0 to avoid double mul
196+
routed_scaling_factor=1.0,
197+
e_score_correction_bias=self.gate.e_score_correction_bias,
198+
enable_eplb=self.enable_eplb,
199+
num_redundant_experts=self.n_redundant_experts)
180200

181201
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
182202
num_tokens, hidden_dim = hidden_states.shape
183203
hidden_states = hidden_states.view(-1, hidden_dim)
184204

185-
if self.n_shared_experts is not None:
186-
shared_output = self.shared_experts(hidden_states)
187-
else:
188-
shared_output = None
205+
# router_logits: (num_tokens, n_experts)
189206
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
190-
final_hidden_states = self.experts(
191-
hidden_states=hidden_states,
192-
router_logits=router_logits) * self.routed_scaling_factor
193-
if shared_output is not None:
194-
final_hidden_states = final_hidden_states + shared_output
207+
208+
fused_moe_out = self.experts(hidden_states=hidden_states,
209+
router_logits=router_logits)
210+
211+
if self.shared_experts is not None:
212+
shared_output, final_hidden_states = fused_moe_out
213+
assert shared_output is not None
214+
final_hidden_states = \
215+
final_hidden_states * self.routed_scaling_factor\
216+
+ shared_output
217+
else:
218+
final_hidden_states = fused_moe_out * self.routed_scaling_factor
219+
195220
if self.tp_size > 1:
196221
final_hidden_states = (
197222
self.experts.maybe_all_reduce_tensor_model_parallel(

0 commit comments

Comments
 (0)