@@ -45,29 +45,38 @@ def __init__(self, model, **args):
4545            self .expert_map_per_layer [self .num_dense_layers  +  layer_idx ] = \
4646                self .model .get_expert_map (self .num_dense_layers  +  layer_idx )
4747
48-         self .buffer_tensor_dict  =  dict ()
4948        # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved 
5049        num_buffer_tensor  =  torch .where (self .expert_map_per_layer [self .num_dense_layers ] !=  - 1 )[0 ].numel ()
51-         self .init_buffer_tensor_dict (num_buffer_tensor )
50+         self .buffer_tensor_list  =  [[] for  _  in  range (num_buffer_tensor )]
51+         self .init_buffer_tensor (num_buffer_tensor )
52+ 
53+         self .expert_param_per_layer  =  dict ()
54+         self .init_expert_param_per_layer ()
5255
5356        self .log2phy_map_per_layer  =  dict ()
5457        for  layer_idx  in  range (self .num_moe_layers ):
5558            self .log2phy_map_per_layer [self .num_dense_layers  +  layer_idx ] = \
5659                self .model .get_log2phy_map (self .num_dense_layers  +  layer_idx )
5760
58-     def  init_buffer_tensor_dict (self , num_buffer_tensor ):
61+     def  init_buffer_tensor (self , num_buffer_tensor ):
5962        for  name  in  self .expert_weight_names :
6063            complete_name  =  "model.layers."  +  str (self .num_dense_layers ) +  ".mlp.experts."  +  name 
6164            expert_tensor  =  self .param_dict [complete_name ].data [0 :num_buffer_tensor ]
62-             self .buffer_tensor_dict [name ] =  torch .empty_like (expert_tensor )
63- 
64-     def  get_buffer_tensor (self , buffer_tensor_id ):
65-         return  [self .buffer_tensor_dict [name ][buffer_tensor_id ] for  name  in  self .expert_weight_names ]
66- 
67-     def  get_expert_tensor (self , layer_id , global_expert_id_to_send ):
68-         local_expert_id  =  self .expert_map_per_layer_cpu [layer_id ][global_expert_id_to_send ].item ()
69-         return  [self .param_dict ["model.layers."  +  str (layer_id ) +  ".mlp.experts."  +  name ].data [local_expert_id ]
70-             for  name  in  self .expert_weight_names ]
65+             buffer_tensors  =  torch .empty_like (expert_tensor )
66+             for  buffer_id  in  range (num_buffer_tensor ):
67+                 self .buffer_tensor_list [buffer_id ].append (buffer_tensors [buffer_id ])
68+ 
69+     def  init_expert_param_per_layer (self ):
70+         num_local_expert  =  self .param_dict ["model.layers."  +  str (self .num_dense_layers ) + \
71+             ".mlp.experts."  +  self .expert_weight_names [0 ]].data .shape [0 ]
72+         for  moe_layer_id  in  range (self .num_moe_layers ):
73+             layer_idx  =  self .num_dense_layers  +  moe_layer_id 
74+             self .expert_param_per_layer [layer_idx ] =  list ()
75+             for  local_expert_id  in  range (num_local_expert ):
76+                 self .expert_param_per_layer [layer_idx ].append (
77+                     [self .param_dict ["model.layers."  +  str (layer_idx ) +  ".mlp.experts."  +  name ].data [local_expert_id ]
78+                         for  name  in  self .expert_weight_names ]
79+                 )
7180
7281    def  get_rank_expert_workload (
7382        self ,
@@ -117,10 +126,11 @@ def do_update_expert_map(self, layer_id, updated_expert_map):
117126        self .expert_map_per_layer_cpu [layer_id ].copy_ (updated_expert_map )
118127
119128    def  do_update_expert_weight (self , layer_id , local_expert_to_replace , buffer_tensor_id ):
120-         for  name  in  self .expert_weight_names :
121-             complete_name  =  "model.layers."  +  str (layer_id ) +  ".mlp.experts."  +  name 
122-             expert_tensor  =  self .param_dict [complete_name ].data [local_expert_to_replace ]
123-             expert_tensor .copy_ (self .buffer_tensor_dict [name ][buffer_tensor_id ])
129+         for  expert_tensor , buffer_tensor  in  zip (
130+             self .expert_param_per_layer [layer_id ][local_expert_to_replace ],
131+             self .buffer_tensor_list [buffer_tensor_id ]
132+         ):
133+             expert_tensor .copy_ (buffer_tensor )
124134
125135    def  do_update_log2phy_map (self , layer_id , updated_log2phy_map ):
126136        if  self .log2phy_map_per_layer [layer_id ] is  not   None :
0 commit comments