Skip to content

Commit 1fbcbfe

Browse files
committed
initial version
1 parent 1ac5365 commit 1fbcbfe

File tree

5 files changed

+856
-83
lines changed

5 files changed

+856
-83
lines changed

benchmarks/run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class RunResult:
126126
"layer_norm": (
127127
"tritonbench.operators.layer_norm.operator",
128128
"examples.layer_norm",
129-
"layer_norm_fwd",
129+
"layer_norm",
130130
),
131131
"jagged_softmax": (
132132
"tritonbench.operators.jagged_softmax.operator",
@@ -174,8 +174,8 @@ class RunResult:
174174
"liger_layer_norm-accuracy": "triton_accuracy",
175175
"torch_compile_layer_norm-speedup": "torch_compile_speedup",
176176
"torch_compile_layer_norm-accuracy": "torch_compile_accuracy",
177-
"helion_layer_norm_fwd-speedup": "helion_speedup",
178-
"helion_layer_norm_fwd-accuracy": "helion_accuracy",
177+
"helion_layer_norm-speedup": "helion_speedup",
178+
"helion_layer_norm-accuracy": "helion_accuracy",
179179
},
180180
"softmax": {
181181
"triton_softmax-speedup": "triton_speedup",

examples/layer_norm.py

Lines changed: 217 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
"""
2-
Helion Layer Normalization Forward Example
3-
==========================================
2+
Helion Layer Normalization Forward and Backward Example
3+
========================================================
44
This example demonstrates a Helion kernel implementation of 1D layer normalization
5-
using FP16 inputs and compares it against PyTorch's built-in layer_norm function.
5+
with both forward and backward passes using FP16 inputs and compares it against
6+
PyTorch's built-in layer_norm function.
67
"""
78

89
# %%
910
from __future__ import annotations
1011

12+
from typing import Any
13+
1114
import torch
1215

1316
import helion
@@ -17,47 +20,224 @@
1720

1821
# %%
1922
@helion.kernel
20-
def layer_norm_fwd(
23+
def layer_norm_fwd_kernel(
2124
x: torch.Tensor,
2225
normalized_shape: list[int],
2326
weight: torch.Tensor,
2427
bias: torch.Tensor | None = None,
2528
eps: float = 1e-5,
26-
) -> torch.Tensor:
29+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
2730
"""
2831
Performs 1D layer normalization on the input tensor using Helion.
2932
Args:
3033
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
3134
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
3235
weight (torch.Tensor): Learnable scale parameter of shape [dim].
33-
bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim].
36+
bias (torch.Tensor | None): Optional learnable bias parameter of shape [dim].
3437
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
3538
Returns:
36-
torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
39+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
40+
- The layer-normalized output tensor of shape [batch_size, dim], in FP16.
41+
- Mean tensor of shape [batch_size], in FP32.
42+
- Reciprocal standard deviation tensor of shape [batch_size], in FP32.
3743
"""
3844
m, n = x.size()
39-
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
45+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"
4046
if bias is not None:
41-
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
47+
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {n}"
4248
assert len(normalized_shape) == 1, (
4349
"Helion layer norm only supports 1D layer norm currently"
4450
)
4551
assert normalized_shape[0] == n, (
4652
f"normalized shape mismatch {normalized_shape[0]} != {n}"
4753
)
4854
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
55+
mean = torch.empty([m], dtype=torch.float32, device=x.device)
56+
rstd = torch.empty([m], dtype=torch.float32, device=x.device)
57+
4958
for tile_m in hl.tile(m):
5059
acc = x[tile_m, :].to(torch.float32)
51-
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
52-
normalized = (acc - mean) * torch.rsqrt(var + eps)
60+
# Compute mean
61+
mean_val = torch.sum(acc, dim=-1) / n
62+
# Compute variance
63+
centered = acc - mean_val[:, None]
64+
var_val = torch.sum(centered * centered, dim=-1) / n
65+
# Compute reciprocal standard deviation
66+
rstd_val = torch.rsqrt(var_val + eps)
67+
# Normalize
68+
normalized = centered * rstd_val[:, None]
69+
# Apply affine transformation
5370
if bias is not None:
5471
acc = normalized * (weight[:].to(torch.float32)) + (
5572
bias[:].to(torch.float32)
5673
)
5774
else:
5875
acc = normalized * (weight[:].to(torch.float32))
5976
out[tile_m, :] = acc.to(x.dtype)
60-
return out
77+
mean[tile_m] = mean_val
78+
rstd[tile_m] = rstd_val
79+
return out, mean, rstd
80+
81+
82+
# %%
83+
@helion.kernel
84+
def layer_norm_bwd_dwdb(
85+
grad_out: torch.Tensor,
86+
x: torch.Tensor,
87+
mean: torch.Tensor,
88+
rstd: torch.Tensor,
89+
weight: torch.Tensor,
90+
compute_bias_grad: hl.constexpr = True, # type: ignore[valid-type]
91+
) -> tuple[torch.Tensor, torch.Tensor | None]:
92+
"""
93+
Compute gradients for weight (dW) and optionally bias (dB) parameters.
94+
95+
This kernel performs reduction across the batch dimension (M) to accumulate
96+
gradients for each feature dimension's weight and bias parameters.
97+
98+
Args:
99+
grad_out: Gradient w.r.t layer norm output [M, N]
100+
x: Original input tensor [M, N]
101+
mean: Per-sample mean computed in forward pass [M]
102+
rstd: Per-sample reciprocal standard deviation from forward pass [M]
103+
weight: Weight parameter (used only for dtype/device info) [N]
104+
compute_bias_grad: Whether to compute bias gradient (default: True)
105+
106+
Returns:
107+
(grad_weight, grad_bias): Gradients for weight and bias (if computed), both shape [N]
108+
grad_bias is None if compute_bias_grad is False
109+
"""
110+
m, n = x.shape
111+
n = hl.specialize(n)
112+
113+
dw = torch.empty([n], dtype=weight.dtype, device=weight.device)
114+
if compute_bias_grad:
115+
db = torch.empty([n], dtype=weight.dtype, device=weight.device)
116+
else:
117+
db = None
118+
119+
# Reduce across rows (M) inside the kernel without atomics
120+
rdim = hl.register_reduction_dim(m)
121+
122+
for tile_n in hl.tile(n):
123+
rows = hl.arange(0, rdim)
124+
# Load slices for all rows in rdim and this tile of columns
125+
x_blk = x[rows, tile_n].to(torch.float32)
126+
dy_blk = grad_out[rows, tile_n].to(torch.float32)
127+
mean_vec = mean[rows]
128+
rstd_vec = rstd[rows]
129+
130+
x_hat_blk = (x_blk - mean_vec[:, None]) * rstd_vec[:, None]
131+
dw_tile = torch.sum(dy_blk * x_hat_blk, dim=0).to(weight.dtype)
132+
133+
dw[tile_n] = dw_tile
134+
if compute_bias_grad:
135+
db_tile = torch.sum(dy_blk, dim=0).to(weight.dtype)
136+
db[tile_n] = db_tile # type: ignore[index]
137+
138+
if compute_bias_grad:
139+
return dw, db
140+
return dw, None
141+
142+
143+
@helion.kernel
144+
def layer_norm_bwd_dx(
145+
grad_out: torch.Tensor,
146+
x: torch.Tensor,
147+
weight: torch.Tensor,
148+
mean: torch.Tensor,
149+
rstd: torch.Tensor,
150+
) -> torch.Tensor:
151+
"""
152+
Compute gradient for input tensor (dX).
153+
154+
This kernel computes per-sample gradients by performing reductions across
155+
the feature dimension (N) for each sample in the batch.
156+
157+
Args:
158+
grad_out: Gradient w.r.t layer norm output [M, N]
159+
x: Original input tensor [M, N]
160+
weight: Weight parameter [N]
161+
mean: Per-sample mean computed in forward pass [M]
162+
rstd: Per-sample reciprocal standard deviation from forward pass [M]
163+
164+
Returns:
165+
grad_x: Gradient w.r.t input tensor, shape [M, N]
166+
"""
167+
m, n = x.shape
168+
n = hl.specialize(n)
169+
170+
grad_x = torch.empty_like(x)
171+
172+
for tile_m in hl.tile(m):
173+
x_tile = x[tile_m, :].to(torch.float32)
174+
dy_tile = grad_out[tile_m, :].to(torch.float32)
175+
w = weight[:].to(torch.float32)
176+
mean_tile = mean[tile_m]
177+
rstd_tile = rstd[tile_m]
178+
179+
x_hat = (x_tile - mean_tile[:, None]) * rstd_tile[:, None]
180+
wdy = w * dy_tile
181+
c1 = torch.sum(x_hat * wdy, dim=-1) / n
182+
c2 = torch.sum(wdy, dim=-1) / n
183+
dx = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd_tile[:, None]
184+
grad_x[tile_m, :] = dx.to(x.dtype)
185+
186+
return grad_x
187+
188+
189+
# %%
190+
class LayerNormFunction(torch.autograd.Function):
191+
@staticmethod
192+
def forward(
193+
ctx: Any, # noqa: ANN401
194+
x: torch.Tensor,
195+
normalized_shape: list[int],
196+
weight: torch.Tensor,
197+
bias: torch.Tensor | None,
198+
eps: float,
199+
) -> torch.Tensor:
200+
"""Forward pass for layer normalization."""
201+
y, mean, rstd = layer_norm_fwd_kernel(x, normalized_shape, weight, bias, eps)
202+
ctx.save_for_backward(x, weight, bias, mean, rstd) # type: ignore[arg-type]
203+
ctx.normalized_shape = normalized_shape # type: ignore[attr-defined]
204+
return y
205+
206+
@staticmethod
207+
def backward( # type: ignore[override]
208+
ctx: Any, # noqa: ANN401
209+
grad_output: torch.Tensor,
210+
) -> tuple[
211+
torch.Tensor | None, None, torch.Tensor | None, torch.Tensor | None, None
212+
]:
213+
"""Backward pass for layer normalization split into two separate kernels for efficiency."""
214+
grad_out = grad_output # Use common name internally
215+
x, weight, bias, mean, rstd = ctx.saved_tensors # type: ignore[attr-defined]
216+
217+
# Check if bias gradient is needed
218+
compute_bias_grad = bias is not None
219+
220+
# First kernel: Compute gradients for weight and bias by reducing across batch dimension (M)
221+
grad_weight, grad_bias = layer_norm_bwd_dwdb(
222+
grad_out, x, mean, rstd, weight, compute_bias_grad
223+
)
224+
225+
# Second kernel: Compute gradient for input (dx) using per-sample reductions across feature dimension (N)
226+
grad_x = layer_norm_bwd_dx(grad_out, x, weight, mean, rstd)
227+
228+
return grad_x, None, grad_weight, grad_bias, None
229+
230+
231+
# %%
232+
def layer_norm(
233+
x: torch.Tensor,
234+
normalized_shape: list[int],
235+
weight: torch.Tensor,
236+
bias: torch.Tensor | None = None,
237+
eps: float = 1e-5,
238+
) -> torch.Tensor:
239+
"""Layer normalization with forward + backward support."""
240+
return LayerNormFunction.apply(x, normalized_shape, weight, bias, eps) # type: ignore[no-any-return]
61241

62242

63243
# %%
@@ -72,21 +252,43 @@ def main() -> None:
72252
batch_size = 32
73253
dim = 64
74254
device = "cuda"
255+
256+
# Test forward pass only
257+
print("\n=== Forward Pass Test ===")
75258
x = torch.randn([batch_size, dim], device=device, dtype=torch.float16)
76259
weight = torch.randn([dim], device=device, dtype=torch.float16)
77260
bias = torch.randn([dim], device=device, dtype=torch.float16)
78261
eps = 1e-4
79262
for b in [bias, None]:
80263
run_example(
81-
layer_norm_fwd,
264+
layer_norm,
82265
torch.nn.functional.layer_norm,
83266
(x, [dim], weight, b, eps),
84-
kernel_name="helion",
85-
baseline_name="torch",
86267
rtol=1e-3,
87268
atol=1e-3,
88269
)
89270

271+
# Test forward + backward pass
272+
print("\n\n=== Forward + Backward Pass Test ===")
273+
x_grad = torch.randn(
274+
[batch_size, dim], device=device, dtype=torch.float16, requires_grad=True
275+
)
276+
weight_grad = torch.randn(
277+
[dim], device=device, dtype=torch.float16, requires_grad=True
278+
)
279+
bias_grad = torch.randn(
280+
[dim], device=device, dtype=torch.float16, requires_grad=True
281+
)
282+
for b in [bias_grad, None]:
283+
run_example(
284+
layer_norm,
285+
torch.nn.functional.layer_norm,
286+
(x_grad, [dim], weight_grad, b, eps),
287+
rtol=1e-3,
288+
atol=1e-3,
289+
bwd=True,
290+
)
291+
90292

91293
# %%
92294
if __name__ == "__main__":

0 commit comments

Comments
 (0)