-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Closed
Description
System Info
transformers==4.37.2
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
@ArthurZucker @younesbelkada
Hi~ I found a bug in the LlamaRMSNorm(nn.Module) (lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
On the last line, if the input_dtype is bfloat16, the return tensor will still be float32 because the self.weight has been initialized as float32. Thus the last line should be modified to:
return (self.weight * hidden_states).to(input_dtype)
Expected behavior
see above and looking forward to your reply~ Thank you
Metadata
Metadata
Assignees
Labels
No labels