diff --git a/benchmarks/run.py b/benchmarks/run.py index 3a87130dd..46694cad9 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -75,6 +75,11 @@ "examples.fp8_attention", "fp8_attention_tritonbench", ), + "layer_norm": ( + "tritonbench.operators.layer_norm.operator", + "examples.layer_norm", + "helion_layer_norm_wrapper", + ), } diff --git a/examples/layer_norm.py b/examples/layer_norm.py new file mode 100644 index 000000000..9a52fd32d --- /dev/null +++ b/examples/layer_norm.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import Any + +import torch + +import helion +from helion._testing import run_example +import helion.language as hl + + +# TODO(PaulZhang12): Support autotuning, setting reduction_loops currently errors +@helion.kernel( + static_shapes=True, + config=helion.Config( + block_sizes=[32], + reduction_loops=[None], + range_unroll_factors=[0], + range_warp_specializes=[], + range_num_stages=[0], + range_multi_buffers=[None], + range_flattens=[None], + num_warps=4, + num_stages=3, + indexing="pointer", + pid_type="flat", + ), +) +def layer_norm_fwd( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-5, +) -> torch.Tensor: + m, n = x.size() + assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}" + assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}" + out = torch.empty([m, n], dtype=torch.float16, device=x.device) + + for tile_m in hl.tile(m): + acc = x[tile_m, :].to( + torch.float32 + ) # TODO (PaulZhang12): Eliminate this cast, currently necessary + + var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0) + + normalized = (acc - mean) * torch.rsqrt(var + eps) + acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32)) + + out[tile_m, :] = acc + return out + + +def helion_layer_norm_wrapper( + x: torch.Tensor, + dims: list[int], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-5, +) -> Any: # noqa: ANN401 + assert len(dims) == 1, "Helion layer norm only supports 1D layer norm currently" + return layer_norm_fwd(x, weight, bias, eps) + + +def main() -> None: + batch_size = 32 + dim = 64 + device = "cuda" + + x = torch.randn([batch_size, dim], device=device, dtype=torch.float16) + weight = torch.randn([dim], device=device, dtype=torch.float16) + bias = torch.randn([dim], device=device, dtype=torch.float16) + eps = 1e-4 + + run_example( + helion_layer_norm_wrapper, + torch.nn.functional.layer_norm, + (x, [dim], weight, bias, eps), + kernel_name="helion", + baseline_name="torch", + rtol=1e-3, + atol=1e-3, + ) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index bad1eef1b..e900deb33 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -883,6 +883,56 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_ _launcher(_jagged_mean_kernel_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestExamples.test_layernorm) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_compat import libdevice +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _layer_norm_fwd_kernel(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + load = tl.load(x + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None) + v_0 = load.to(tl.float32) + var_mean_extra = tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1]) + v_1 = 64 + v_2 = var_mean_extra / v_1.to(tl.float32) + v_3 = v_0 - v_2 + v_4 = v_3 * v_3 + var_mean_extra_2 = tl.reshape(tl.sum(v_4, 1), [_BLOCK_SIZE_0, 1]) + v_5 = 64 + v_6 = var_mean_extra_2 / v_5.to(tl.float32) + v_7 = v_0 - v_2 + v_8 = v_6 + eps + v_9 = libdevice.rsqrt(v_8) + v_10 = v_7 * v_9 + load_1 = tl.load(weight + indices_1 * 1, None) + v_11 = load_1.to(tl.float32) + v_12 = v_11[None, :] + v_13 = v_10 * v_12 + load_2 = tl.load(bias + indices_1 * 1, None) + v_14 = load_2.to(tl.float32) + v_15 = v_14[None, :] + v_16 = v_13 + v_15 + v_17 = v_16.to(tl.float16) + tl.store(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), v_17, None) + +def layer_norm_fwd(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher): + m, n = x.size() + assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}' + assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}' + out = torch.empty([m, n], dtype=torch.float16, device=x.device) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_1 = 64 + _launcher(_layer_norm_fwd_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return out + --- assertExpectedJournal(TestExamples.test_matmul) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 95a415759..5a65cf0be 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -599,6 +599,21 @@ def test_fp8_attention(self): ) ) + def test_layernorm(self): + x = torch.randn([32, 64], device=DEVICE, dtype=torch.float16) + weight = torch.randn([64], device=DEVICE, dtype=torch.float16) + bias = torch.randn([64], device=DEVICE, dtype=torch.float16) + + self.assertExpectedJournal( + check_example( + "layer_norm", + (x, weight, bias), + torch.nn.functional.layer_norm(*(x, [64], weight, bias)), + fn_name="layer_norm_fwd", + block_sizes=[32], + ) + ) + if __name__ == "__main__": unittest.main()