Skip to content

Commit 66d32ab

Browse files
vasquArthurZucker
authored andcommitted
[OPT] Fix attention scaling (#38290)
* fix opt attention scaling * add comment to why we do this
1 parent f4fc422 commit 66d32ab

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/transformers/models/opt/modeling_opt.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,11 @@ def forward(
154154
"""Input shape: Batch x Time x Channel"""
155155
bsz, tgt_len, _ = hidden_states.size()
156156

157-
# get query proj
157+
# Scaling is susceptible to floating point arithmetics' inprecisions
158+
# which can lead to different results (this is dependent from model
159+
# to model, e.g. whisper is one such case). We therefore keep the
160+
# original order of scaling to follow the original implementation
161+
# and enforce no scaling (1.0) in the attention call below.
158162
query_states = self.q_proj(hidden_states) * self.scaling
159163
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
160164

@@ -187,7 +191,7 @@ def forward(
187191
value_states,
188192
attention_mask,
189193
dropout=0.0 if not self.training else self.dropout,
190-
scaling=self.scaling,
194+
scaling=1.0,
191195
**kwargs,
192196
)
193197

0 commit comments

Comments
 (0)