Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions examples/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

import torch

import helion
from helion._testing import run_example
import helion.language as hl


@helion.kernel(static_shapes=True)
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
m, n = x.size()
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"

out = torch.empty([m, n], dtype=x.dtype, device=x.device)

for tile_m in hl.tile(m):
x_tile = x[tile_m, :].to(torch.float32)

# Compute RMS: sqrt(mean(x^2))
x_squared = x_tile * x_tile
mean_x_squared = torch.mean(x_squared, dim=-1, keepdim=True)
rms = torch.rsqrt(mean_x_squared + eps)

# Apply normalization and weight
normalized = x_tile * rms
out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype)

return out


def rms_norm_tritonbench(H: int, inp: torch.Tensor) -> torch.Tensor:
"""Wrapper for tritonbench that matches expected interface."""
weight = torch.ones(H, device=inp.device, dtype=inp.dtype)
return rms_norm(inp, weight, eps=1e-6)


def rms_norm_pytorch(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5
) -> torch.Tensor:
input_dtype = x.dtype
hidden_states = x.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
return weight * hidden_states.to(input_dtype)


def check(m: int, n: int) -> None:
x = torch.randn([m, n], device="cuda", dtype=torch.float16)
weight = torch.randn([n], device="cuda", dtype=torch.float16)
run_example(rms_norm, rms_norm_pytorch, (x, weight, 1e-5))


def main() -> None:
check(32, 64)
check(128, 256)
check(1024, 1024)


if __name__ == "__main__":
main()
48 changes: 48 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,54 @@ def _moe_matmul_ogs_make_precompiler(A: torch.Tensor, W: torch.Tensor, expert_to
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_moe_matmul_ogs_kernel)(expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A.stride(0), A.stride(1), C.stride(0), C.stride(1), W.stride(0), W.stride(1), W.stride(2), expert_token_counts.stride(0), expert_token_offsets.stride(0), sorted_to_orig_token_idx.stride(0), max_T_per_expert, N, K, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestExamples.test_rms_norm)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice

@triton.jit
def _rms_norm_kernel(x, weight, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
load = tl.load(x + (indices_0[:, None] * 256 + indices_1[None, :] * 1), None)
v_0 = load.to(tl.float32)
v_1 = v_0 * v_0
mean_x_squared_extra = tl.reshape(tl.sum(v_1, 1), [_BLOCK_SIZE_0, 1])
v_2 = 256
v_3 = mean_x_squared_extra / v_2.to(tl.float32)
v_4 = v_3 + eps
v_5 = libdevice.rsqrt(v_4)
v_6 = v_0 * v_5
load_1 = tl.load(weight + indices_1 * 1, None)
v_7 = load_1.to(tl.float32)
v_8 = v_7[None, :]
v_9 = v_6 * v_8
v_10 = v_9.to(tl.float16)
tl.store(out + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_10, None)

def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05):
m, n = x.size()
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}'
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
_BLOCK_SIZE_0 = 16
_RDIM_SIZE_1 = 256
_rms_norm_kernel[triton.cdiv(128, _BLOCK_SIZE_0),](x, weight, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return out

def _rms_norm_make_precompiler(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05):
m, n = x.size()
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}'
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
_BLOCK_SIZE_0 = 16
_RDIM_SIZE_1 = 256
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_rms_norm_kernel)(x, weight, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestExamples.test_softmax)
from __future__ import annotations

Expand Down
20 changes: 20 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,26 @@ def test_softmax_two_pass_block_ptr(self):
)
)

def test_rms_norm(self):
args = (
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),
torch.randn([256], device=DEVICE, dtype=torch.float16),
1e-5,
)
# Import and use the reference implementation from rms_norm.py
mod = import_path(EXAMPLES_DIR / "rms_norm.py")
expected = mod.rms_norm_pytorch(*args)

self.assertExpectedJournal(
check_example(
"rms_norm",
args,
expected,
block_sizes=[16],
indexing="pointer",
)
)

def test_embedding_pointers(self):
args = (
torch.randint(0, 1024, [8, 128], device=DEVICE, dtype=torch.int32),
Expand Down
Loading