Skip to content

Commit 67fc426

Browse files
authored
[Misc] Print FusedMoE detail info (#13974)
1 parent 9804145 commit 67fc426

File tree

1 file changed

+20
-0
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+20
-0
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,3 +737,23 @@ def _load_fp8_scale(self, param: torch.nn.Parameter,
737737
# If we are in the row parallel case (down_proj)
738738
else:
739739
param_data[expert_id] = loaded_weight
740+
741+
def extra_repr(self) -> str:
742+
743+
s = (
744+
f"global_num_experts={self.global_num_experts}, "
745+
f"local_num_experts={self.local_num_experts}, "
746+
f"top_k={self.top_k}, "
747+
f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501
748+
f"tp_size={self.tp_size},\n"
749+
f"ep_size={self.ep_size}, "
750+
f"reduce_results={self.reduce_results}, "
751+
f"renormalize={self.renormalize}, "
752+
f"use_grouped_topk={self.use_grouped_topk}")
753+
754+
if self.use_grouped_topk:
755+
s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501
756+
757+
s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501
758+
759+
return s

0 commit comments

Comments
 (0)