Skip to content

Commit 9a021b7

Browse files
TimDettmersyounesbelkada
authored andcommitted
Bugfix: LLaMA layer norm incorrectly changes input type and consumers lots of memory (huggingface#23535)
* Fixed bug where LLaMA layer norm would change input type. * make fix-copies --------- Co-authored-by: younesbelkada <[email protected]>
1 parent 72922ce commit 9a021b7

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

src/transformers/models/llama/modeling_llama.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,11 @@ def __init__(self, hidden_size, eps=1e-6):
8181
self.variance_epsilon = eps
8282

8383
def forward(self, hidden_states):
84+
input_dtype = hidden_states.dtype
8485
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
8586
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
8687

87-
# convert into half-precision if necessary
88-
if self.weight.dtype in [torch.float16, torch.bfloat16]:
89-
hidden_states = hidden_states.to(self.weight.dtype)
90-
91-
return self.weight * hidden_states
88+
return (self.weight * hidden_states).to(input_dtype)
9289

9390

9491
class LlamaRotaryEmbedding(torch.nn.Module):

src/transformers/models/open_llama/modeling_open_llama.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,11 @@ def __init__(self, hidden_size, eps=1e-6):
9191
self.variance_epsilon = eps
9292

9393
def forward(self, hidden_states):
94+
input_dtype = hidden_states.dtype
9495
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
9596
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
9697

97-
# convert into half-precision if necessary
98-
if self.weight.dtype in [torch.float16, torch.bfloat16]:
99-
hidden_states = hidden_states.to(self.weight.dtype)
100-
101-
return self.weight * hidden_states
98+
return (self.weight * hidden_states).to(input_dtype)
10299

103100

104101
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama

0 commit comments

Comments
 (0)