From af875dca06948428e23282de8be44ceb77d6327e Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Thu, 11 Sep 2025 13:23:41 -0700 Subject: [PATCH] Support layernorm without bias (#585) Summary: The current layernorm only supports bias case, we thus add the case without bias. Differential Revision: D82171738 --- examples/layer_norm.py | 33 +++++++----- helion/_compiler/compile_environment.py | 2 + helion/runtime/kernel.py | 1 + test/test_examples.expected | 70 +++++++++++++++++++++++-- test/test_examples.py | 16 +++++- 5 files changed, 104 insertions(+), 18 deletions(-) diff --git a/examples/layer_norm.py b/examples/layer_norm.py index e7ef49448..c22468111 100644 --- a/examples/layer_norm.py +++ b/examples/layer_norm.py @@ -21,7 +21,7 @@ def layer_norm_fwd( x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, - bias: torch.Tensor, + bias: torch.Tensor | None = None, eps: float = 1e-5, ) -> torch.Tensor: """ @@ -30,14 +30,15 @@ def layer_norm_fwd( x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16. normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1). weight (torch.Tensor): Learnable scale parameter of shape [dim]. - bias (torch.Tensor): Learnable bias parameter of shape [dim]. + bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim]. eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5. Returns: torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16. """ m, n = x.size() assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}" - assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}" + if bias is not None: + assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}" assert len(normalized_shape) == 1, ( "Helion layer norm only supports 1D layer norm currently" ) @@ -49,7 +50,12 @@ def layer_norm_fwd( acc = x[tile_m, :].to(torch.float32) var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0) normalized = (acc - mean) * torch.rsqrt(var + eps) - acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32)) + if bias is not None: + acc = normalized * (weight[:].to(torch.float32)) + ( + bias[:].to(torch.float32) + ) + else: + acc = normalized * (weight[:].to(torch.float32)) out[tile_m, :] = acc.to(x.dtype) return out @@ -70,15 +76,16 @@ def main() -> None: weight = torch.randn([dim], device=device, dtype=torch.float16) bias = torch.randn([dim], device=device, dtype=torch.float16) eps = 1e-4 - run_example( - layer_norm_fwd, - torch.nn.functional.layer_norm, - (x, [dim], weight, bias, eps), - kernel_name="helion", - baseline_name="torch", - rtol=1e-3, - atol=1e-3, - ) + for b in [bias, None]: + run_example( + layer_norm_fwd, + torch.nn.functional.layer_norm, + (x, [dim], weight, b, eps), + kernel_name="helion", + baseline_name="torch", + rtol=1e-3, + atol=1e-3, + ) # %% diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 3e05ad16c..ea31747ee 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -204,6 +204,8 @@ def cached_create_unbacked_symint( return result def to_fake(self, obj: object, origin: Origin) -> object: + if obj is None: + return None if isinstance(obj, torch.Tensor): return self._to_fake_tensor(obj, origin.to_source()) if isinstance(obj, (bool, int, float)): diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 2a7b95887..de2abf090 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -757,6 +757,7 @@ def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable: types.BuiltinFunctionType: lambda fn, x: x, torch.fx.GraphModule: _graph_module_key, ConstExpr: lambda fn, x: x.value, # pyright: ignore[reportAttributeAccessIssue] + type(None): lambda fn, x: None, } diff --git a/test/test_examples.expected b/test/test_examples.expected index ba4451d3b..705879197 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1262,7 +1262,7 @@ def jagged_softmax_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _lau _launcher(_helion_jagged_softmax_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_flat, out, out.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3) return out.reshape(N, M) ---- assertExpectedJournal(TestExamples.test_layernorm) +--- assertExpectedJournal(TestExamples.test_layernorm_with_bias) from __future__ import annotations import torch @@ -1303,21 +1303,22 @@ def _helion_layer_norm_fwd(bias, x, weight, out, bias_size_0, bias_stride_0, out v_15 = tl.cast(v_14, tl.float16) tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_15, mask_0[:, None] & mask_1[None, :]) -def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher): +def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor | None=None, eps: float=1e-05, *, _launcher=_default_launcher): """ Performs 1D layer normalization on the input tensor using Helion. Args: x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16. normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1). weight (torch.Tensor): Learnable scale parameter of shape [dim]. - bias (torch.Tensor): Learnable bias parameter of shape [dim]. + bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim]. eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5. Returns: torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16. """ m, n = x.size() assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}' - assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}' + if bias is not None: + assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}' assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently' assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}' out = torch.empty([m, n], dtype=x.dtype, device=x.device) @@ -1326,6 +1327,67 @@ def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.T _launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), bias, x, weight, out, bias.size(0), bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestExamples.test_layernorm_without_bias) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_compat import libdevice +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_layer_norm_fwd(x, weight, out, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < n + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_0 = tl.cast(load, tl.float32) + var_mean_extra = tl.cast(tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1]), tl.float32) + v_1 = var_mean_extra / n.to(tl.float32) + _mask_to_1 = tl.where(tl.broadcast_to(mask_0[:, None], [_BLOCK_SIZE_0, 1]), v_1, tl.full([], 0, tl.float32)) + v_2 = v_0 - _mask_to_1 + v_3 = v_2 * v_2 + var_mean_extra_2 = tl.cast(tl.reshape(tl.sum(v_3, 1), [_BLOCK_SIZE_0, 1]), tl.float32) + v_4 = var_mean_extra_2 / n.to(tl.float32) + v_5 = v_0 - v_1 + v_6 = v_4 + eps + v_7 = libdevice.rsqrt(v_6) + v_8 = v_5 * v_7 + load_1 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0) + v_9 = tl.cast(load_1, tl.float32) + v_10 = v_9[None, :] + v_11 = v_8 * v_10 + v_12 = tl.cast(v_11, tl.float16) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_12, mask_0[:, None] & mask_1[None, :]) + +def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor | None=None, eps: float=1e-05, *, _launcher=_default_launcher): + """ + Performs 1D layer normalization on the input tensor using Helion. + Args: + x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16. + normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1). + weight (torch.Tensor): Learnable scale parameter of shape [dim]. + bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim]. + eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5. + Returns: + torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16. + """ + m, n = x.size() + assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}' + if bias is not None: + assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}' + assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently' + assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}' + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_1 = triton.next_power_of_2(n) + _launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return out + --- assertExpectedJournal(TestExamples.test_matmul) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index bc36c47aa..49ef491fd 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -611,7 +611,7 @@ def test_fp8_attention(self): ) ) - def test_layernorm(self): + def test_layernorm_with_bias(self): x = torch.randn([32, 64], device=DEVICE, dtype=torch.float16) weight = torch.randn([64], device=DEVICE, dtype=torch.float16) bias = torch.randn([64], device=DEVICE, dtype=torch.float16) @@ -627,6 +627,20 @@ def test_layernorm(self): ) ) + def test_layernorm_without_bias(self): + x = torch.randn([32, 64], device=DEVICE, dtype=torch.float16) + weight = torch.randn([64], device=DEVICE, dtype=torch.float16) + + args = (x, [64], weight) + self.assertExpectedJournal( + check_example( + "layer_norm", + args, + torch.nn.functional.layer_norm(*args), + fn_name="layer_norm_fwd", + ) + ) + @skipIfRefEager("ref eager mode hits CUDA indexing error with hl.store") def test_jagged_softmax(self): num_rows, max_cols = 128, 64