Skip to content

Commit 1067577

Browse files
jiqing-fengSunMarc
andauthored
fix gpt-oss out shape (#40535)
* fix out shape Signed-off-by: jiqing-feng <[email protected]> * reset gpt-oss modeling Signed-off-by: jiqing-feng <[email protected]> * fix copies Signed-off-by: jiqing-feng <[email protected]> * fix tests Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: Marc Sun <[email protected]>
1 parent 7efb4c8 commit 1067577

File tree

3 files changed

+2
-5
lines changed

3 files changed

+2
-5
lines changed

src/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
116116
glu = gate * torch.sigmoid(gate * self.alpha)
117117
gated_output = (up + 1) * glu
118118
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
119-
weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
119+
weighted_output = out * routing_weights[token_idx, expert_idx, None]
120120
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
121121
next_states = next_states.view(batch_size, -1, self.hidden_size)
122122
else:

src/transformers/models/gpt_oss/modular_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
115115
glu = gate * torch.sigmoid(gate * self.alpha)
116116
gated_output = (up + 1) * glu
117117
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
118-
weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
118+
weighted_output = out * routing_weights[token_idx, expert_idx, None]
119119
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
120120
next_states = next_states.view(batch_size, -1, self.hidden_size)
121121
else:

tests/models/gpt_oss/test_modeling_gpt_oss.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,6 @@ def test_flex_attention_with_grads(self):
128128
def test_generate_compile_model_forward_fullgraph(self):
129129
return super().test_generate_compile_model_forward_fullgraph()
130130

131-
def test_batching_equivalence(self, **kwargs):
132-
return super().test_batching_equivalence(atol=5e-4, rtol=1e-3)
133-
134131

135132
RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/gpt_oss/integration_tests.json"
136133

0 commit comments

Comments
 (0)