@@ -359,6 +359,15 @@ def forward(
359
359
hidden_states = self .norm (hidden_states )
360
360
return hidden_states
361
361
362
+ def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
363
+ # Params for weights, fp8 weight scales, fp8 activation scales
364
+ # (param_name, weight_name, expert_id, shard_id)
365
+ return FusedMoE .make_expert_params_mapping (
366
+ ckpt_gate_proj_name = "gate_proj" ,
367
+ ckpt_down_proj_name = "down_proj" ,
368
+ ckpt_up_proj_name = "up_proj" ,
369
+ num_experts = self .config .num_experts )
370
+
362
371
def load_weights (self , weights : Iterable [tuple [str ,
363
372
torch .Tensor ]]) -> set [str ]:
364
373
stacked_params_mapping = [
@@ -370,16 +379,9 @@ def load_weights(self, weights: Iterable[tuple[str,
370
379
("gate_up_proj" , "up_proj" , 1 ),
371
380
]
372
381
373
- # Params for weights, fp8 weight scales, fp8 activation scales
374
- # (param_name, weight_name, expert_id, shard_id)
375
- expert_params_mapping = FusedMoE .make_expert_params_mapping (
376
- ckpt_gate_proj_name = "gate_proj" ,
377
- ckpt_down_proj_name = "down_proj" ,
378
- ckpt_up_proj_name = "up_proj" ,
379
- num_experts = self .config .num_experts )
380
-
381
382
params_dict = dict (self .named_parameters ())
382
383
loaded_params : set [str ] = set ()
384
+ expert_params_mapping = self .get_expert_mapping ()
383
385
for name , loaded_weight in weights :
384
386
if "rotary_emb.inv_freq" in name :
385
387
continue
@@ -511,3 +513,6 @@ def sample(
511
513
def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
512
514
loader = AutoWeightsLoader (self )
513
515
return loader .load_weights (weights )
516
+
517
+ def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
518
+ return self .model .get_expert_mapping ()
0 commit comments