Skip to content

LlamaRMSNorm() Dtype Casting Error #30236

@Ritz111

Description

@Ritz111

System Info

transformers==4.37.2

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

@ArthurZucker @younesbelkada
Hi~ I found a bug in the LlamaRMSNorm(nn.Module) (lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py)

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

On the last line, if the input_dtype is bfloat16, the return tensor will still be float32 because the self.weight has been initialized as float32. Thus the last line should be modified to:

return (self.weight * hidden_states).to(input_dtype)

Expected behavior

see above and looking forward to your reply~ Thank you

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions