@@ -78,8 +78,28 @@ def init_expert_param_per_layer(self):
78
78
for name in self .expert_weight_names ]
79
79
)
80
80
81
- def get_rank_expert_workload (self , num_moe_layers ):
82
- return self .model .get_all_moe_loads (num_moe_layers )
81
+ def get_rank_expert_workload (
82
+ self ,
83
+ num_moe_layers : int ,
84
+ ) -> torch .Tensor :
85
+ # 收集各层 topk_ids -> list of [B, K]
86
+ all_topk_ids = [self .model .get_topk_ids (i ) for i in range (num_moe_layers )]
87
+ # stack & flatten -> ids2d: [L, B*K]
88
+ stacked = torch .stack (all_topk_ids , dim = 0 ) # [L, B, K]
89
+ L , B , K = stacked .shape
90
+ ids2d = stacked .view (L , B * K ).to (torch .int64 ) # [L, N]
91
+
92
+ device = ids2d .device
93
+ moe_load = torch .zeros ((L , self .global_expert_num ),
94
+ dtype = torch .int64 , device = device )
95
+
96
+ ones2d = torch .ones_like (ids2d , dtype = torch .int64 )
97
+
98
+ assert moe_load .dim () == 2 and ids2d .dim () == 2 and ones2d .dim () == 2
99
+ assert ids2d .shape == ones2d .shape
100
+
101
+ moe_load .scatter_add_ (dim = 1 , index = ids2d , src = ones2d )
102
+ return moe_load
83
103
84
104
def get_init_expert_map (self , num_moe_layers ):
85
105
expert_map = self .model .get_all_expert_map (num_moe_layers )
0 commit comments