Skip to content

Commit 56b3cdf

Browse files
committed
Implement get_expert_mapping
Signed-off-by: Shane A <[email protected]>
1 parent d1adc72 commit 56b3cdf

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

vllm/model_executor/models/flex_olmo.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,15 @@ def forward(
359359
hidden_states = self.norm(hidden_states)
360360
return hidden_states
361361

362+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
363+
# Params for weights, fp8 weight scales, fp8 activation scales
364+
# (param_name, weight_name, expert_id, shard_id)
365+
return FusedMoE.make_expert_params_mapping(
366+
ckpt_gate_proj_name="gate_proj",
367+
ckpt_down_proj_name="down_proj",
368+
ckpt_up_proj_name="up_proj",
369+
num_experts=self.config.num_experts)
370+
362371
def load_weights(self, weights: Iterable[tuple[str,
363372
torch.Tensor]]) -> set[str]:
364373
stacked_params_mapping = [
@@ -370,16 +379,9 @@ def load_weights(self, weights: Iterable[tuple[str,
370379
("gate_up_proj", "up_proj", 1),
371380
]
372381

373-
# Params for weights, fp8 weight scales, fp8 activation scales
374-
# (param_name, weight_name, expert_id, shard_id)
375-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
376-
ckpt_gate_proj_name="gate_proj",
377-
ckpt_down_proj_name="down_proj",
378-
ckpt_up_proj_name="up_proj",
379-
num_experts=self.config.num_experts)
380-
381382
params_dict = dict(self.named_parameters())
382383
loaded_params: set[str] = set()
384+
expert_params_mapping = self.get_expert_mapping()
383385
for name, loaded_weight in weights:
384386
if "rotary_emb.inv_freq" in name:
385387
continue
@@ -511,3 +513,6 @@ def sample(
511513
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
512514
loader = AutoWeightsLoader(self)
513515
return loader.load_weights(weights)
516+
517+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
518+
return self.model.get_expert_mapping()

0 commit comments

Comments
 (0)