Skip to content

Commit 1169dfc

Browse files
wanghanqingLYTyangcheng (AJ)
authored andcommitted
Merge pull request vllm-project#74 from raindaywhu/dev_whq_eplb
running time reduction forward_before and forward_end
1 parent da88164 commit 1169dfc

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,28 @@ def init_expert_param_per_layer(self):
7878
for name in self.expert_weight_names]
7979
)
8080

81-
def get_rank_expert_workload(self, num_moe_layers):
82-
return self.model.get_all_moe_loads(num_moe_layers)
81+
def get_rank_expert_workload(
82+
self,
83+
num_moe_layers: int,
84+
) -> torch.Tensor:
85+
# 收集各层 topk_ids -> list of [B, K]
86+
all_topk_ids = [self.model.get_topk_ids(i) for i in range(num_moe_layers)]
87+
# stack & flatten -> ids2d: [L, B*K]
88+
stacked = torch.stack(all_topk_ids, dim=0) # [L, B, K]
89+
L, B, K = stacked.shape
90+
ids2d = stacked.view(L, B * K).to(torch.int64) # [L, N]
91+
92+
device = ids2d.device
93+
moe_load = torch.zeros((L, self.global_expert_num),
94+
dtype=torch.int64, device=device)
95+
96+
ones2d = torch.ones_like(ids2d, dtype=torch.int64)
97+
98+
assert moe_load.dim() == 2 and ids2d.dim() == 2 and ones2d.dim() == 2
99+
assert ids2d.shape == ones2d.shape
100+
101+
moe_load.scatter_add_(dim=1, index=ids2d, src=ones2d)
102+
return moe_load
83103

84104
def get_init_expert_map(self, num_moe_layers):
85105
expert_map = self.model.get_all_expert_map(num_moe_layers)

0 commit comments

Comments
 (0)