Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
c480438
remove zero_like + scatter
3outeille Nov 27, 2025
c615c47
Merge branch 'main' into fix-moe-v5
3outeille Nov 27, 2025
073326f
fix mixtral moe
3outeille Nov 27, 2025
8ff6c18
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 27, 2025
f3457e2
fix other moe models as well
3outeille Nov 27, 2025
16737a4
fix ci
3outeille Nov 27, 2025
01da12d
Merge branch 'main' into fix-moe-v5
3outeille Nov 27, 2025
57541cd
fix modular mixtral
3outeille Nov 27, 2025
b7eb918
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 27, 2025
3992748
fix qwen2_moe + qwen3_next
3outeille Nov 28, 2025
15f41b9
fix device mismatch for qwen3_vl_moe to pass tests
3outeille Nov 28, 2025
35e8bf8
fix modular mixtral
3outeille Nov 28, 2025
e6f026f
fix other models
3outeille Nov 28, 2025
14b7ac0
rm slow tokenizers (#40936)
itazap Nov 27, 2025
ec3f555
[loading/saving] Reverse all loading operations when saving (#42396)
Cyrilvallez Nov 27, 2025
326eb75
Fix T5 tests: use generation_config for generation parameters (#42419)
Abdennacer-Badaoui Nov 28, 2025
50cc1e9
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
8bccd8c
linting
3outeille Nov 28, 2025
74e84d5
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 28, 2025
718cc64
more fix to pass the CI tests
3outeille Nov 28, 2025
19db8c9
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
1100864
fix lfm2 moe
3outeille Nov 28, 2025
7d024b9
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 28, 2025
e6f82dc
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
e982a15
fix docstring
3outeille Nov 28, 2025
98703cc
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 28, 2025
84bb660
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
3b14e7b
fix docstring
3outeille Nov 28, 2025
0ac90c8
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 28, 2025
5e4e7de
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
0399e13
fix qwen like model
3outeille Nov 28, 2025
af29eee
fix flex olmo
3outeille Nov 28, 2025
bf66927
revert lfm2 moe config
3outeille Nov 28, 2025
4d6e993
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
144ec86
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
ede2116
make fixup
3outeille Nov 28, 2025
3132b5f
fix docstring
3outeille Nov 28, 2025
2e04f12
fix conversion mapping
3outeille Nov 28, 2025
61d1b87
Merge branch 'main' into fix-moe-v5
3outeille Dec 1, 2025
cdb3eb1
fix inference of gpt-oss
3outeille Dec 1, 2025
5edd375
add some fixes to gpt-oss (but still not good)
3outeille Dec 1, 2025
8cc40f0
Merge branch 'main' into fix-moe-v5
3outeille Dec 1, 2025
a02f8bf
fix modular
ArthurZucker Dec 1, 2025
d213808
we need errors I think
ArthurZucker Dec 1, 2025
1317b4d
fix config issue
ArthurZucker Dec 1, 2025
51d4b52
this was fixed
ArthurZucker Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def _build_checkpoint_conversion_mapping():
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
mapping["minimax"] = mapping["mixtral"].copy()
mapping["flex_olmo"] = mapping["qwen2_moe"].copy()
mapping["olmoe"] = mapping["qwen2_moe"].copy()
Comment on lines +178 to +179
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice


return mapping

Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/deepseek_v2/modeling_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,22 +169,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/dots1/modeling_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,22 +327,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand Down
15 changes: 7 additions & 8 deletions src/transformers/models/flex_olmo/modeling_flex_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,22 +313,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand All @@ -351,8 +350,8 @@ def forward(self, hidden_states):
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
router_scores = router_top_value
return router_logits, router_scores, router_indices


class FlexOlmoSparseMoeBlock(nn.Module):
Expand All @@ -364,7 +363,7 @@ def __init__(self, config):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
top_k_weights, top_k_index = self.gate(hidden_states)
_, top_k_weights, top_k_index = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
batch_size, sequence_length, hidden_dim
)
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/glm4_moe/modeling_glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,22 +350,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/glm4v_moe/modeling_glm4v_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,22 +414,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand Down
31 changes: 19 additions & 12 deletions src/transformers/models/gpt_oss/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,11 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
"""
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
num_experts = routing_weights.shape[1]
if hidden_states.device.type == "cpu" or self.training:
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(
router_indices, num_classes=num_experts + 1
router_indices, num_classes=self.num_experts
) # masking is also a class
expert_mask = expert_mask.permute(2, 1, 0)
# we sum on the top_k and on the sequence length to get which experts
Expand All @@ -110,10 +109,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
# expert_idx only have 1 element, so we can use scale for fast indexing
expert_idx = expert_idx[0]
# skip masking index
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
Expand All @@ -122,21 +121,29 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
weighted_output = out * routing_weights[token_idx, expert_idx, None]
weighted_output = out * routing_weights[token_idx, top_k_pos, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
next_states = next_states.view(batch_size, -1, self.hidden_size)
else:
hidden_states = hidden_states.repeat(num_experts, 1)
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
num_tokens = hidden_states.shape[0]
hidden_states = hidden_states.repeat(self.num_experts, 1)
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
next_states = next_states + self.down_proj_bias[..., None, :]
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
next_states = next_states.view(self.num_experts, batch_size, -1, self.hidden_size)

full_routing_weights = torch.zeros(
num_tokens, self.num_experts, device=routing_weights.device, dtype=routing_weights.dtype
)
full_routing_weights.scatter_(1, router_indices, routing_weights)
full_routing_weights = full_routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1, 1)

next_states = next_states * full_routing_weights
next_states = next_states.sum(dim=0)
return next_states

Expand All @@ -155,8 +162,8 @@ def forward(self, hidden_states):
router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
router_scores = router_top_value
return router_logits, router_scores, router_indices


@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
Expand All @@ -167,7 +174,7 @@ def __init__(self, config):
self.experts = GptOssExperts(config)

def forward(self, hidden_states):
router_scores, router_indices = self.router(hidden_states)
_, router_scores, router_indices = self.router(hidden_states)
routed_out = self.experts(hidden_states, router_indices, router_scores)
return routed_out, router_scores

Expand Down
Loading