Skip to content

Commit 53728f3

Browse files
Merge pull request vllm-project#74 from raindaywhu/dev_whq_eplb
running time reduction forward_before and forward_end
2 parents 2e824cd + 1b78fb2 commit 53728f3

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,29 +45,38 @@ def __init__(self, model, **args):
4545
self.expert_map_per_layer[self.num_dense_layers + layer_idx] =\
4646
self.model.get_expert_map(self.num_dense_layers + layer_idx)
4747

48-
self.buffer_tensor_dict = dict()
4948
# TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
5049
num_buffer_tensor = torch.where(self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel()
51-
self.init_buffer_tensor_dict(num_buffer_tensor)
50+
self.buffer_tensor_list = [[] for _ in range(num_buffer_tensor)]
51+
self.init_buffer_tensor(num_buffer_tensor)
52+
53+
self.expert_param_per_layer = dict()
54+
self.init_expert_param_per_layer()
5255

5356
self.log2phy_map_per_layer = dict()
5457
for layer_idx in range(self.num_moe_layers):
5558
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] =\
5659
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
5760

58-
def init_buffer_tensor_dict(self, num_buffer_tensor):
61+
def init_buffer_tensor(self, num_buffer_tensor):
5962
for name in self.expert_weight_names:
6063
complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name
6164
expert_tensor = self.param_dict[complete_name].data[0:num_buffer_tensor]
62-
self.buffer_tensor_dict[name] = torch.empty_like(expert_tensor)
63-
64-
def get_buffer_tensor(self, buffer_tensor_id):
65-
return [self.buffer_tensor_dict[name][buffer_tensor_id] for name in self.expert_weight_names]
66-
67-
def get_expert_tensor(self, layer_id, global_expert_id_to_send):
68-
local_expert_id = self.expert_map_per_layer_cpu[layer_id][global_expert_id_to_send].item()
69-
return [self.param_dict["model.layers." + str(layer_id) + ".mlp.experts." + name].data[local_expert_id]
70-
for name in self.expert_weight_names]
65+
buffer_tensors = torch.empty_like(expert_tensor)
66+
for buffer_id in range(num_buffer_tensor):
67+
self.buffer_tensor_list[buffer_id].append(buffer_tensors[buffer_id])
68+
69+
def init_expert_param_per_layer(self):
70+
num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) +\
71+
".mlp.experts." + self.expert_weight_names[0]].data.shape[0]
72+
for moe_layer_id in range(self.num_moe_layers):
73+
layer_idx = self.num_dense_layers + moe_layer_id
74+
self.expert_param_per_layer[layer_idx] = list()
75+
for local_expert_id in range(num_local_expert):
76+
self.expert_param_per_layer[layer_idx].append(
77+
[self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name].data[local_expert_id]
78+
for name in self.expert_weight_names]
79+
)
7180

7281
def get_rank_expert_workload(
7382
self,
@@ -117,10 +126,11 @@ def do_update_expert_map(self, layer_id, updated_expert_map):
117126
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
118127

119128
def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id):
120-
for name in self.expert_weight_names:
121-
complete_name = "model.layers." + str(layer_id) + ".mlp.experts." + name
122-
expert_tensor = self.param_dict[complete_name].data[local_expert_to_replace]
123-
expert_tensor.copy_(self.buffer_tensor_dict[name][buffer_tensor_id])
129+
for expert_tensor, buffer_tensor in zip(
130+
self.expert_param_per_layer[layer_id][local_expert_to_replace],
131+
self.buffer_tensor_list[buffer_tensor_id]
132+
):
133+
expert_tensor.copy_(buffer_tensor)
124134

125135
def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
126136
if self.log2phy_map_per_layer[layer_id] is not None:

vllm_ascend/eplb/core/loader/device_transfer_loader.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,14 @@ def generate_expert_d2d_transfer_task(self, expert_send_info, expert_recv_info,
5656
self.comm_op_list = []
5757
for send_info in expert_send_info:
5858
dst_rank, global_expert_id_to_send = send_info
59-
src_tensors_this_expert = self.eplb_adaptor.get_expert_tensor(layer_id, global_expert_id_to_send)
60-
for src_tensor in src_tensors_this_expert:
59+
local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[layer_id][global_expert_id_to_send].item()
60+
for src_tensor in self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]:
6161
self.comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
6262

6363
buffer_tensor_id = 0
6464
for recv_info in expert_recv_info:
6565
recv_rank, global_expert_id_to_recv = recv_info
66-
buffer_tensors_this_expert = self.eplb_adaptor.get_buffer_tensor(buffer_tensor_id)
67-
for buffer_tensor in buffer_tensors_this_expert:
66+
for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[buffer_tensor_id]:
6867
self.comm_op_list.append(dist.P2POp(dist.irecv, buffer_tensor, recv_rank))
6968
local_expert_to_replace = self.updated_expert_map[global_expert_id_to_recv].item()
7069
self.recv_expert_list.append((local_expert_to_replace, buffer_tensor_id))

0 commit comments

Comments
 (0)