Skip to content

Commit 6316a9e

Browse files
3outeilleitazapArthurZuckerCyrilvallezAbdennacer-Badaoui
authored
Fix MoE for V5 (#42456)
* remove zero_like + scatter * fix mixtral moe * fix other moe models as well * fix ci * fix modular mixtral * fix qwen2_moe + qwen3_next * fix device mismatch for qwen3_vl_moe to pass tests * fix modular mixtral * fix other models * rm slow tokenizers (#40936) * fixes missed * gemma test fix * refactor * rm legacy from llama * added renaming * add _model * update legacy * update legacy * fix docstring * always load blank, then set _tokenizer if we have it * new toks * update all berttokenizer based models * apply feedback - delete bert duplicates * more models --> fast only * more convert_slow models * fix common test refs * updating fast only tokenizers * openai and pegasus * enable sentencepiecebackend * more models * code gen * t5 * code gen tests * speecht5 * mbart * mbart50 * more models * more models * layouglmv2 * update tests * update tests * update tests * pretrainedtokenizer * whisper * whisper * layoutxlm and storing backends * refactor sentencepiecebackend and additional_special_tokens * renaming tokenization_utils --> tokenization_python * udpate tests * bert test * blenderbot * clip * codegen * code_llama * cohere * deberata, deberat v2, funnel * gpt2 * batch update tests * pegasus qwen2 roberta * more models * layout tests * some renaming * fix references to utils_fast * fix refs * fix refs * fix refs * fix refs * fix refs * fix refs * fix refs * fix some tests * regression * fix refs * fix refs * missed the most crucial file in my last commit * fix refs * fix refs * fix refs * batch encode fix * fix some tests * BC for batch_decode bc too many refs * more tests * fix more tests * fix for processors * fixing more models * deleted mbart50 by accident * seamless m4t * albert fix * whisper * layout3 * attempt to fix cached tokenizers on CI * trying another fix on CI * again try to work around CI * bertweet * tapas * mbart50 * luke * mluke * markuplm * markuplm * fix some more auto tests * some random model failures * mistralcommontestser * more fixes * ref fix * siglip * marian * plbart * update utils toks * seamless m4t * roc bert * udpate byt5 test * xlm * esm * roformer * code llama * biogpt * m2m100 * dpr and flaubert * xlm and speech to text * tok backend pass object * tokenizer object pass * wav2vec2 * wav2vec2 * cpmant * update utils tokenizers * cpmant * bartpho * test apply chat template assistant mask * apply chat template video * apply chat template assistant mask * test torch * update from slow in base and fix donut processor errors * auto to point to tokenizers backend, fix kosmos2 * some non model fixes for old slow models that no longer have their own tokenizer file as they are the same as bert * missed file from last commit * idefics2 * fixup * fixup * pretrained tokenizer fast test update * stash * bad merged * cherry pick more stuff that did not merge well * fix gptsw3 * nit warn for now * update error raising * just ran fixup * bring back bert legacy * fix * nit * fix 56 errors on blenderbotsmall? * 18 for blenderbotsmall * tok auto * missed clip * fix tests * something missed * token healing * tok common tests update - nonmodel * try to fix non-model test in test_tokenization_utils * fix hub tests * try to fix hub tests * custom vocab related fixed * bert jap * BERT JAP * rename bert legacy to bert legacy * Wav2vec2 * fix in tok python to update total vocab size - fixes speech t5 * blender bot small * forgot test file * test failures * marian * gpt2 tiktoken * big bird / marian * udop * forgot couple changes * test_serve fix * missing import * a couple processors fixes * style partly * fix to fetch tests ci * Revert branch back to commit f5bc69e state * revert branch to styling * update mistral after merge * fixes for non model tests * some processor test fixes * more processor test fixes * more processor fixes * hub tests * python tok utils * fix hub test * make style for now * remove problemattic fic copies * python utils/check_copies.py --fix_and_overwrite * more styling * fixup * silence docstirng * fix import? * fix imports * add the local test as well * throw spm error * llamas * fix a couple tests * broke ci * broke ci * broke ci * broke ci * add logs to debug gemma on ci * gemma and llama * gemma * revert las commit * gemma debug * gemma debug * gemma * safely import spiece backend * tok tests * check none * setup and qual * ruff * del dev files * tok auto * fill docstrings * update auto * blenderbot small nit * add migration guide * move mixtral patch to `TokenizersBackend`, move `TokenizerExtractor` * rename MistralCommonTokenizer to MistralCommonB ackend * nit * fix failures * fixup * remoove one old test * mark the slow one as slow * very small fixes * update auto mapping for missing ones * fixup lorsd * fixup doc and stuff * should be the final fixe * processing update * update * FIX or brute AI fix the llava test * style * slow? * fix is offline mode? * fix mt5 * One tok utils (#42462) * consolidate python and utils tokenization files, they are copies * ruff and ref * Format * fix cohere * ? * up * am I dumbb? * grumble --------- Co-authored-by: Arthur <[email protected]> * [loading/saving] Reverse all loading operations when saving (#42396) * first shot * default to reversing * oupso * oupsi 2 * oupsi 3 * fix renamed kwargs * fix timm_wrapper * remove fix_state_dict methods * can do it all the time, with __init__ as well * doc * oupsi * fix * create helper * fix annotation annoying isue * small fix * small fixes * alright commit all that already * oupsi * the fix * update quantizers * this works * the hardcoded regex got me hard.... * style * the final one * cleanup a bit * better * style * oupsi readded it * do it inside the ops instead - no need for full names anymore * reverse quantizers and simplify signatures * small thingy * add no_grad decorator * utils to rename keys * oupssii again * add test * simplify nicely * Fix T5 tests: use generation_config for generation parameters (#42419) * pass the generation parameters to generate() * fix use_task_specific_params to separate model.config and model.generation_config params * fix style * some fixes * remove redundant check * update expectation for llama_7b_bf16 on rocm * Update tests/models/llama/test_modeling_llama.py Co-authored-by: Rémi Ouazan <[email protected]> --------- Co-authored-by: Rémi Ouazan <[email protected]> * linting * more fix to pass the CI tests * fix lfm2 moe * fix docstring * fix docstring * fix qwen like model * fix flex olmo * revert lfm2 moe config * make fixup * fix docstring * fix conversion mapping * fix inference of gpt-oss * add some fixes to gpt-oss (but still not good) * fix modular * we need errors I think * fix config issue * this was fixed --------- Co-authored-by: Ita Zaporozhets <[email protected]> Co-authored-by: Arthur <[email protected]> Co-authored-by: Cyril Vallez <[email protected]> Co-authored-by: BADAOUI Abdennacer <[email protected]> Co-authored-by: Rémi Ouazan <[email protected]>
1 parent eb399a9 commit 6316a9e

30 files changed

+177
-168
lines changed

src/transformers/conversion_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def _build_checkpoint_conversion_mapping():
175175
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
176176
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
177177
mapping["minimax"] = mapping["mixtral"].copy()
178+
mapping["flex_olmo"] = mapping["qwen2_moe"].copy()
179+
mapping["olmoe"] = mapping["qwen2_moe"].copy()
178180

179181
return mapping
180182

src/transformers/models/deepseek_v2/modeling_deepseek_v2.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,22 +61,21 @@ def forward(
6161
top_k_weights: torch.Tensor,
6262
) -> torch.Tensor:
6363
final_hidden_states = torch.zeros_like(hidden_states)
64-
num_experts = top_k_weights.shape[1]
6564
with torch.no_grad():
66-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
65+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
6766
expert_mask = expert_mask.permute(2, 1, 0)
6867
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
6968

7069
for expert_idx in expert_hit:
7170
expert_idx = expert_idx[0]
72-
if expert_idx == num_experts:
71+
if expert_idx == self.num_experts:
7372
continue
74-
_, token_idx = torch.where(expert_mask[expert_idx])
73+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
7574
current_state = hidden_states[token_idx]
7675
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
7776
current_hidden_states = self.act_fn(gate) * up
7877
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
79-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
78+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
8079
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
8180

8281
return final_hidden_states

src/transformers/models/deepseek_v3/modeling_deepseek_v3.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,22 +169,21 @@ def forward(
169169
top_k_weights: torch.Tensor,
170170
) -> torch.Tensor:
171171
final_hidden_states = torch.zeros_like(hidden_states)
172-
num_experts = top_k_weights.shape[1]
173172
with torch.no_grad():
174-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
173+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
175174
expert_mask = expert_mask.permute(2, 1, 0)
176175
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
177176

178177
for expert_idx in expert_hit:
179178
expert_idx = expert_idx[0]
180-
if expert_idx == num_experts:
179+
if expert_idx == self.num_experts:
181180
continue
182-
_, token_idx = torch.where(expert_mask[expert_idx])
181+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
183182
current_state = hidden_states[token_idx]
184183
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
185184
current_hidden_states = self.act_fn(gate) * up
186185
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
187-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
186+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
188187
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
189188

190189
return final_hidden_states

src/transformers/models/dots1/modeling_dots1.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,22 +327,21 @@ def forward(
327327
top_k_weights: torch.Tensor,
328328
) -> torch.Tensor:
329329
final_hidden_states = torch.zeros_like(hidden_states)
330-
num_experts = top_k_weights.shape[1]
331330
with torch.no_grad():
332-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
331+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
333332
expert_mask = expert_mask.permute(2, 1, 0)
334333
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
335334

336335
for expert_idx in expert_hit:
337336
expert_idx = expert_idx[0]
338-
if expert_idx == num_experts:
337+
if expert_idx == self.num_experts:
339338
continue
340-
_, token_idx = torch.where(expert_mask[expert_idx])
339+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
341340
current_state = hidden_states[token_idx]
342341
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
343342
current_hidden_states = self.act_fn(gate) * up
344343
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
345-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
344+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
346345
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
347346

348347
return final_hidden_states

src/transformers/models/flex_olmo/modeling_flex_olmo.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,22 +313,21 @@ def forward(
313313
top_k_weights: torch.Tensor,
314314
) -> torch.Tensor:
315315
final_hidden_states = torch.zeros_like(hidden_states)
316-
num_experts = top_k_weights.shape[1]
317316
with torch.no_grad():
318-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
317+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
319318
expert_mask = expert_mask.permute(2, 1, 0)
320319
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
321320

322321
for expert_idx in expert_hit:
323322
expert_idx = expert_idx[0]
324-
if expert_idx == num_experts:
323+
if expert_idx == self.num_experts:
325324
continue
326-
_, token_idx = torch.where(expert_mask[expert_idx])
325+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
327326
current_state = hidden_states[token_idx]
328327
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
329328
current_hidden_states = self.act_fn(gate) * up
330329
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
331-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
330+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
332331
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
333332

334333
return final_hidden_states
@@ -351,8 +350,8 @@ def forward(self, hidden_states):
351350
if self.norm_topk_prob:
352351
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
353352
router_top_value = router_top_value.to(router_logits.dtype)
354-
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
355-
return router_scores, router_indices
353+
router_scores = router_top_value
354+
return router_logits, router_scores, router_indices
356355

357356

358357
class FlexOlmoSparseMoeBlock(nn.Module):
@@ -364,7 +363,7 @@ def __init__(self, config):
364363
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
365364
batch_size, sequence_length, hidden_dim = hidden_states.shape
366365
hidden_states = hidden_states.view(-1, hidden_dim)
367-
top_k_weights, top_k_index = self.gate(hidden_states)
366+
_, top_k_weights, top_k_index = self.gate(hidden_states)
368367
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
369368
batch_size, sequence_length, hidden_dim
370369
)

src/transformers/models/glm4_moe/modeling_glm4_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,22 +350,21 @@ def forward(
350350
top_k_weights: torch.Tensor,
351351
) -> torch.Tensor:
352352
final_hidden_states = torch.zeros_like(hidden_states)
353-
num_experts = top_k_weights.shape[1]
354353
with torch.no_grad():
355-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
354+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
356355
expert_mask = expert_mask.permute(2, 1, 0)
357356
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
358357

359358
for expert_idx in expert_hit:
360359
expert_idx = expert_idx[0]
361-
if expert_idx == num_experts:
360+
if expert_idx == self.num_experts:
362361
continue
363-
_, token_idx = torch.where(expert_mask[expert_idx])
362+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
364363
current_state = hidden_states[token_idx]
365364
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
366365
current_hidden_states = self.act_fn(gate) * up
367366
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
368-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
367+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
369368
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
370369

371370
return final_hidden_states

src/transformers/models/glm4v_moe/modeling_glm4v_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -414,22 +414,21 @@ def forward(
414414
top_k_weights: torch.Tensor,
415415
) -> torch.Tensor:
416416
final_hidden_states = torch.zeros_like(hidden_states)
417-
num_experts = top_k_weights.shape[1]
418417
with torch.no_grad():
419-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
418+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
420419
expert_mask = expert_mask.permute(2, 1, 0)
421420
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
422421

423422
for expert_idx in expert_hit:
424423
expert_idx = expert_idx[0]
425-
if expert_idx == num_experts:
424+
if expert_idx == self.num_experts:
426425
continue
427-
_, token_idx = torch.where(expert_mask[expert_idx])
426+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
428427
current_state = hidden_states[token_idx]
429428
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
430429
current_hidden_states = self.act_fn(gate) * up
431430
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
432-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
431+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
433432
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
434433

435434
return final_hidden_states

src/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,11 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
9595
"""
9696
batch_size = hidden_states.shape[0]
9797
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
98-
num_experts = routing_weights.shape[1]
9998
if hidden_states.device.type == "cpu" or self.training:
10099
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
101100
with torch.no_grad():
102101
expert_mask = torch.nn.functional.one_hot(
103-
router_indices, num_classes=num_experts + 1
102+
router_indices, num_classes=self.num_experts
104103
) # masking is also a class
105104
expert_mask = expert_mask.permute(2, 1, 0)
106105
# we sum on the top_k and on the sequence length to get which experts
@@ -110,10 +109,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
110109
# expert_idx only have 1 element, so we can use scale for fast indexing
111110
expert_idx = expert_idx[0]
112111
# skip masking index
113-
if expert_idx == num_experts:
112+
if expert_idx == self.num_experts:
114113
continue
115114
with torch.no_grad():
116-
_, token_idx = torch.where(expert_mask[expert_idx])
115+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
117116
current_state = hidden_states[token_idx]
118117
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
119118
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
@@ -122,21 +121,29 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
122121
glu = gate * torch.sigmoid(gate * self.alpha)
123122
gated_output = (up + 1) * glu
124123
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
125-
weighted_output = out * routing_weights[token_idx, expert_idx, None]
124+
weighted_output = out * routing_weights[token_idx, top_k_pos, None]
126125
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
127126
next_states = next_states.view(batch_size, -1, self.hidden_size)
128127
else:
129-
hidden_states = hidden_states.repeat(num_experts, 1)
130-
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
128+
num_tokens = hidden_states.shape[0]
129+
hidden_states = hidden_states.repeat(self.num_experts, 1)
130+
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
131131
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
132132
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
133133
gate = gate.clamp(min=None, max=self.limit)
134134
up = up.clamp(min=-self.limit, max=self.limit)
135135
glu = gate * torch.sigmoid(gate * self.alpha)
136136
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
137137
next_states = next_states + self.down_proj_bias[..., None, :]
138-
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
139-
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
138+
next_states = next_states.view(self.num_experts, batch_size, -1, self.hidden_size)
139+
140+
full_routing_weights = torch.zeros(
141+
num_tokens, self.num_experts, device=routing_weights.device, dtype=routing_weights.dtype
142+
)
143+
full_routing_weights.scatter_(1, router_indices, routing_weights)
144+
full_routing_weights = full_routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1, 1)
145+
146+
next_states = next_states * full_routing_weights
140147
next_states = next_states.sum(dim=0)
141148
return next_states
142149

@@ -155,8 +162,8 @@ def forward(self, hidden_states):
155162
router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
156163
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
157164
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
158-
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
159-
return router_scores, router_indices
165+
router_scores = router_top_value
166+
return router_logits, router_scores, router_indices
160167

161168

162169
@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
@@ -167,7 +174,7 @@ def __init__(self, config):
167174
self.experts = GptOssExperts(config)
168175

169176
def forward(self, hidden_states):
170-
router_scores, router_indices = self.router(hidden_states)
177+
_, router_scores, router_indices = self.router(hidden_states)
171178
routed_out = self.experts(hidden_states, router_indices, router_scores)
172179
return routed_out, router_scores
173180

0 commit comments

Comments
 (0)