diff --git a/examples/rms_norm.py b/examples/rms_norm.py new file mode 100644 index 000000000..554332826 --- /dev/null +++ b/examples/rms_norm.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import torch + +import helion +from helion._testing import run_example +import helion.language as hl + + +@helion.kernel(static_shapes=True) +def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + m, n = x.size() + assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}" + + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + + for tile_m in hl.tile(m): + x_tile = x[tile_m, :].to(torch.float32) + + # Compute RMS: sqrt(mean(x^2)) + x_squared = x_tile * x_tile + mean_x_squared = torch.mean(x_squared, dim=-1, keepdim=True) + rms = torch.rsqrt(mean_x_squared + eps) + + # Apply normalization and weight + normalized = x_tile * rms + out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype) + + return out + + +def rms_norm_tritonbench(H: int, inp: torch.Tensor) -> torch.Tensor: + """Wrapper for tritonbench that matches expected interface.""" + weight = torch.ones(H, device=inp.device, dtype=inp.dtype) + return rms_norm(inp, weight, eps=1e-6) + + +def rms_norm_pytorch( + x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5 +) -> torch.Tensor: + input_dtype = x.dtype + hidden_states = x.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) + return weight * hidden_states.to(input_dtype) + + +def check(m: int, n: int) -> None: + x = torch.randn([m, n], device="cuda", dtype=torch.float16) + weight = torch.randn([n], device="cuda", dtype=torch.float16) + run_example(rms_norm, rms_norm_pytorch, (x, weight, 1e-5)) + + +def main() -> None: + check(32, 64) + check(128, 256) + check(1024, 1024) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index db40d1893..f1a597b50 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -989,6 +989,54 @@ def _moe_matmul_ogs_make_precompiler(A: torch.Tensor, W: torch.Tensor, expert_to from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_moe_matmul_ogs_kernel)(expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A.stride(0), A.stride(1), C.stride(0), C.stride(1), W.stride(0), W.stride(1), W.stride(2), expert_token_counts.stride(0), expert_token_offsets.stride(0), sorted_to_orig_token_idx.stride(0), max_T_per_expert, N, K, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3) +--- assertExpectedJournal(TestExamples.test_rms_norm) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_compat import libdevice + +@triton.jit +def _rms_norm_kernel(x, weight, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + load = tl.load(x + (indices_0[:, None] * 256 + indices_1[None, :] * 1), None) + v_0 = load.to(tl.float32) + v_1 = v_0 * v_0 + mean_x_squared_extra = tl.reshape(tl.sum(v_1, 1), [_BLOCK_SIZE_0, 1]) + v_2 = 256 + v_3 = mean_x_squared_extra / v_2.to(tl.float32) + v_4 = v_3 + eps + v_5 = libdevice.rsqrt(v_4) + v_6 = v_0 * v_5 + load_1 = tl.load(weight + indices_1 * 1, None) + v_7 = load_1.to(tl.float32) + v_8 = v_7[None, :] + v_9 = v_6 * v_8 + v_10 = v_9.to(tl.float16) + tl.store(out + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_10, None) + +def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05): + m, n = x.size() + assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}' + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + _BLOCK_SIZE_0 = 16 + _RDIM_SIZE_1 = 256 + _rms_norm_kernel[triton.cdiv(128, _BLOCK_SIZE_0),](x, weight, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return out + +def _rms_norm_make_precompiler(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05): + m, n = x.size() + assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}' + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + _BLOCK_SIZE_0 = 16 + _RDIM_SIZE_1 = 256 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_rms_norm_kernel)(x, weight, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + --- assertExpectedJournal(TestExamples.test_softmax) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 999594d09..b32ac21a2 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -237,6 +237,26 @@ def test_softmax_two_pass_block_ptr(self): ) ) + def test_rms_norm(self): + args = ( + torch.randn([128, 256], device=DEVICE, dtype=torch.float16), + torch.randn([256], device=DEVICE, dtype=torch.float16), + 1e-5, + ) + # Import and use the reference implementation from rms_norm.py + mod = import_path(EXAMPLES_DIR / "rms_norm.py") + expected = mod.rms_norm_pytorch(*args) + + self.assertExpectedJournal( + check_example( + "rms_norm", + args, + expected, + block_sizes=[16], + indexing="pointer", + ) + ) + def test_embedding_pointers(self): args = ( torch.randint(0, 1024, [8, 128], device=DEVICE, dtype=torch.int32),