Skip to content

Conversation

@mstojkovicTT
Copy link

@mstojkovicTT mstojkovicTT commented Nov 27, 2025

What does this PR do?

Fixes #42398

This PR replaces custom RMSNorm/T5-style norm implementations (e.g. in Llama) that manually compute variance and scaling with the built-in torch.nn.functional.rms_norm. For example, code like:

input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

is simplified to:

return F.rms_norm(hidden_states, hidden_states.shape[-1:], self.weight, self.variance_epsilon)

This keeps the behavior and epsilon handling the same while reducing the number of ops, this should improve performance for users without requiring any additional changes on their side.

To verify the performance and the numerical stability, i have wrote the following test

import timeit
import torch
import torch.nn as nn

# Original implementation
class LlamaRMSNormHF(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


# New implementation using torch.nn.functional.rms_norm
class LlamaRMSNormTorch(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        return nn.functional.rms_norm(
            hidden_states,
            hidden_states.shape[-1:],
            self.weight,
            self.variance_epsilon,
        )

def bench(module, x, iters=1000):
    # Warmup
    for _ in range(10):
        module(x)

    if x.is_cuda:
        torch.cuda.synchronize()
    start = timeit.default_timer()
    for _ in range(iters):
        module(x)
    if x.is_cuda:
        torch.cuda.synchronize()
    end = timeit.default_timer()

    return (end - start) / iters


def test_llama_rms_norm_equivalence():
    torch.manual_seed(0)

    hidden_size = 64
    batch_size = 2
    seq_len = 3

    dtype = torch.bfloat16
    device = "cpu" 

    x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)

    hf_module = LlamaRMSNormHF(hidden_size, eps=1e-6).to(device)
    new_module = LlamaRMSNormTorch(hidden_size, eps=1e-6).to(device)

    # make sure they have the same weights
    with torch.no_grad():
        new_module.weight.copy_(hf_module.weight)

    y_hf = hf_module(x)
    y_new = new_module(x)
    y_new = y_new.to(y_hf.dtype) # torch.allclose needs same dtype

    # Check numerically close
    print(torch.allclose(y_hf, y_new, atol=1e-5, rtol=1e-5))

    # speed benchmark
    t_hf = bench(hf_module, x)
    t_new = bench(new_module, x)

    print(f"HF   RMSNorm: {t_hf * 1e6:.2f} µs / call")
    print(f"F.rms_norm  : {t_new * 1e6:.2f} µs / call")

test_llama_rms_norm_equivalence()

The results show the following:

  • for cpu device:
True
HF   RMSNorm: 86.75 µs / call
F.rms_norm  : 47.25 µs / call
  • for cuda device:
True
HF   RMSNorm: 112.27 µs / call
F.rms_norm  : 83.05 µs / call

note: I have encountered that when I try dtypes that are lower then float32, old implementation will keep it at float32, but my new one will have the dtype of the input tensor. Thats why i have to cast to y_hf.dtype (trying float64 for example will make both implementation output float64). This can be changed, depending on what we want to accomplish.

Who can review?

@Rocketknight1

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: aimv2, apertus, arcee, aria, bamba, bitnet, blt, chameleon, clvp, csm, cwm, deepseek_v2, deepseek_v3, dia, diffllama, doge

@Rocketknight1
Copy link
Member

Hey @mstojkovicTT, thanks for the PR! We definitely want the functions to be a drop-in replacement, so they should return exactly the same dtype as the old functions did.

Also, in your tests you're initializing self.weight = nn.Parameter(torch.ones(hidden_size)) which means that the final scaling step is just multiplying by 1 and having no effect. Can you try with randomly initialized weights with a mean of 1 but a little bit of variance instead, so we can see if everything is equivalent with the original?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LlamaRMSNorm and equivalent module implementations using torch ops

2 participants