Skip to content

Commit 34e1ec2

Browse files
karthickaiyf225
authored andcommitted
[Benchmark] Welford kernel and example
stack-info: PR: #614, branch: karthickai/stack/1
1 parent 9857041 commit 34e1ec2

File tree

4 files changed

+255
-0
lines changed

4 files changed

+255
-0
lines changed

benchmarks/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ class RunResult:
156156
("examples.matmul_split_k", "matmul_split_k_tritonbench"),
157157
],
158158
),
159+
"welford": (
160+
"tritonbench.operators.welford.operator",
161+
"examples.welford",
162+
"welford",
163+
),
159164
}
160165

161166

@@ -240,6 +245,14 @@ class RunResult:
240245
"helion_jsd_tritonbench-speedup": "helion_speedup",
241246
"helion_jsd_tritonbench-accuracy": "helion_accuracy",
242247
},
248+
"welford": {
249+
"test_welford-speedup": "triton_speedup",
250+
"test_welford-accuracy": "triton_accuracy",
251+
"torch_compile_layer_norm-speedup": "torch_compile_speedup",
252+
"torch_compile_layer_norm-accuracy": "torch_compile_accuracy",
253+
"helion_welford-speedup": "helion_speedup",
254+
"helion_welford-accuracy": "helion_accuracy",
255+
},
243256
}
244257

245258

examples/welford.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
Welford Example
3+
================
4+
5+
This example demonstrates how to implement a welford layernorm using Helion.
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
import torch
14+
15+
import helion
16+
from helion._testing import run_example
17+
import helion.language as hl
18+
19+
20+
# %%
21+
# Welford Kernel Implementations
22+
# -------------------
23+
@helion.kernel()
24+
def welford(
25+
weight: torch.Tensor, bias: torch.Tensor, x: torch.Tensor, eps: float = 1e-05
26+
) -> torch.Tensor:
27+
"""
28+
Applies LayerNorm using Welford's algorithm for mean/variance.
29+
Args:
30+
weight: weight tensor of shape [N]
31+
bias: bias tensor of shape [N]
32+
x: input tensor of shape [M, N]
33+
Returns:
34+
Output tensor of shape [M, N]
35+
"""
36+
m, n = x.size()
37+
38+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
39+
40+
for tile_m in hl.tile(m):
41+
acc_cnt = torch.zeros_like(x[tile_m, 0], dtype=torch.float32)
42+
acc_mean = torch.zeros_like(acc_cnt)
43+
acc_m2 = torch.zeros_like(acc_cnt)
44+
45+
for tile_n in hl.tile(n):
46+
chunk = x[tile_m, tile_n]
47+
Tn = chunk.size(-1)
48+
sum_x = torch.sum(chunk, dim=-1)
49+
sum_x2 = torch.sum(chunk * chunk, dim=-1)
50+
mean_c = sum_x / Tn
51+
m2_c = sum_x2 - (sum_x * sum_x) / Tn
52+
53+
delta = mean_c - acc_mean
54+
new_cnt = acc_cnt + Tn
55+
new_mean = acc_mean + delta * (Tn / new_cnt)
56+
new_m2 = acc_m2 + m2_c + delta * delta * (acc_cnt * Tn / new_cnt)
57+
58+
acc_cnt, acc_mean, acc_m2 = new_cnt, new_mean, new_m2
59+
60+
rstd_tile = torch.rsqrt(acc_m2 / acc_cnt + eps)
61+
mean_col = acc_mean[:, None]
62+
rstd_col = rstd_tile[:, None]
63+
64+
for tile_n in hl.tile(n):
65+
xi_chuck = x[tile_m, tile_n]
66+
w_chuck = weight[tile_n][None, :]
67+
b_chuck = bias[tile_n][None, :]
68+
69+
y = (xi_chuck - mean_col) * rstd_col
70+
y = y * w_chuck + b_chuck
71+
72+
out[tile_m, tile_n] = y.to(x.dtype)
73+
return out
74+
75+
76+
# %%
77+
# Baseline Function
78+
# -------------------
79+
def eager_layer_norm(
80+
weight: torch.Tensor, bias: torch.Tensor, x: torch.Tensor, eps: float = 1e-05
81+
) -> torch.Tensor:
82+
return torch.nn.functional.layer_norm(
83+
x, normalized_shape=(x.shape[-1],), weight=weight, bias=bias, eps=eps
84+
)
85+
86+
87+
# %%
88+
# Verification Function
89+
# -------------------
90+
def check(s: int, d: int) -> None:
91+
"""
92+
Verify the welford kernel implementation against PyTorch's native layer_norm function.
93+
94+
Args:
95+
s: First dimension of the test tensor
96+
d: Second dimension of the test tensor
97+
"""
98+
99+
weight = torch.rand((d,), device="cuda:0", dtype=torch.float32)
100+
bias = torch.rand((d,), device="cuda:0", dtype=torch.float32)
101+
x = torch.rand((s, d), device="cuda:0", dtype=torch.float32)
102+
103+
kernels = {"helion": welford}
104+
run_example(kernels, eager_layer_norm, (weight, bias, x))
105+
106+
107+
# %%
108+
# Main Function
109+
# -----------
110+
def main() -> None:
111+
"""
112+
Main entry point that runs the welford kernel verification with different tensor sizes.
113+
114+
Tests with two configurations:
115+
- 262144x1536
116+
- 262144x2048
117+
"""
118+
check(262144, 1536)
119+
check(262144, 2048)
120+
121+
122+
if __name__ == "__main__":
123+
main()

test/test_examples.expected

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3040,3 +3040,102 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]]
30403040
_BLOCK_SIZE_2 = 16
30413041
_launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
30423042
return out
3043+
3044+
--- assertExpectedJournal(TestExamples.test_welford)
3045+
from __future__ import annotations
3046+
3047+
import torch
3048+
import triton
3049+
import triton.language as tl
3050+
from torch._inductor.runtime.triton_compat import libdevice
3051+
from helion.runtime import default_launcher as _default_launcher
3052+
3053+
@triton.jit
3054+
def _helion_welford(x, weight, bias, out, bias_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
3055+
pid_0 = tl.program_id(0)
3056+
offset_0 = pid_0 * _BLOCK_SIZE_0
3057+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
3058+
mask_0 = indices_0 < m
3059+
acc_cnt = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
3060+
acc_mean = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
3061+
acc_m2 = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
3062+
for offset_1 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1):
3063+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
3064+
mask_1 = indices_1 < n
3065+
acc_mean_copy = acc_mean
3066+
acc_cnt_copy = acc_cnt
3067+
acc_m2_copy = acc_m2
3068+
acc_mean_copy_0 = acc_mean_copy
3069+
acc_cnt_copy_0 = acc_cnt_copy
3070+
acc_m2_copy_0 = acc_m2_copy
3071+
chunk = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
3072+
sum_x = tl.cast(tl.sum(chunk, 1), tl.float32)
3073+
v_0 = chunk * chunk
3074+
sum_x2 = tl.cast(tl.sum(v_0, 1), tl.float32)
3075+
_BLOCK_SIZE_1_ = _BLOCK_SIZE_1
3076+
v_1 = tl.cast(_BLOCK_SIZE_1_, tl.float32)
3077+
v_2 = sum_x / v_1
3078+
v_3 = sum_x * sum_x
3079+
_BLOCK_SIZE_1__1 = _BLOCK_SIZE_1
3080+
v_4 = tl.cast(_BLOCK_SIZE_1__1, tl.float32)
3081+
v_5 = v_3 / v_4
3082+
v_6 = sum_x2 - v_5
3083+
v_7 = v_2 - acc_mean_copy_0
3084+
_BLOCK_SIZE_1__2 = _BLOCK_SIZE_1
3085+
v_8 = tl.cast(_BLOCK_SIZE_1__2, tl.float32)
3086+
acc_cnt = acc_cnt_copy_0 + v_8
3087+
v_10 = tl.full([], 1, tl.int32)
3088+
v_11 = v_10 / acc_cnt
3089+
_BLOCK_SIZE_1__3 = _BLOCK_SIZE_1
3090+
v_12 = tl.cast(_BLOCK_SIZE_1__3, tl.float32)
3091+
v_13 = v_11 * v_12
3092+
v_14 = v_7 * v_13
3093+
acc_mean = acc_mean_copy_0 + v_14
3094+
v_16 = acc_m2_copy_0 + v_6
3095+
v_17 = v_7 * v_7
3096+
_BLOCK_SIZE_1__4 = _BLOCK_SIZE_1
3097+
v_18 = tl.cast(_BLOCK_SIZE_1__4, tl.float32)
3098+
v_19 = acc_cnt_copy_0 * v_18
3099+
v_20 = v_19 / acc_cnt
3100+
v_21 = v_17 * v_20
3101+
acc_m2 = v_16 + v_21
3102+
v_23 = acc_m2 / acc_cnt
3103+
v_24 = v_23 + eps
3104+
v_25 = libdevice.rsqrt(v_24)
3105+
mean_col = acc_mean[:, None]
3106+
rstd_col = v_25[:, None]
3107+
for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_2):
3108+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
3109+
mask_2 = indices_2 < n
3110+
mean_col_copy = mean_col
3111+
rstd_col_copy = rstd_col
3112+
mean_col_copy_0 = mean_col_copy
3113+
rstd_col_copy_0 = rstd_col_copy
3114+
xi_chuck = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
3115+
load_1 = tl.load(weight + indices_2 * weight_stride_0, mask_2, other=0)
3116+
w_chuck = load_1[None, :]
3117+
load_2 = tl.load(bias + indices_2 * bias_stride_0, mask_2, other=0)
3118+
b_chuck = load_2[None, :]
3119+
v_26 = xi_chuck - mean_col_copy_0
3120+
v_27 = v_26 * rstd_col_copy_0
3121+
v_28 = v_27 * w_chuck
3122+
v_29 = v_28 + b_chuck
3123+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_29, mask_0[:, None] & mask_2[None, :])
3124+
3125+
def welford(weight: torch.Tensor, bias: torch.Tensor, x: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
3126+
"""
3127+
Applies LayerNorm using Welford's algorithm for mean/variance.
3128+
Args:
3129+
weight: weight tensor of shape [N]
3130+
bias: bias tensor of shape [N]
3131+
x: input tensor of shape [M, N]
3132+
Returns:
3133+
Output tensor of shape [M, N]
3134+
"""
3135+
m, n = x.size()
3136+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
3137+
_BLOCK_SIZE_0 = 16
3138+
_BLOCK_SIZE_1 = 16
3139+
_BLOCK_SIZE_2 = 16
3140+
_launcher(_helion_welford, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, bias, out, bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
3141+
return out

test/test_examples.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,26 @@ def test_cross_entropy(self):
288288
)
289289
)
290290

291+
def test_welford(self):
292+
s, d = 128, 1024
293+
weight = torch.rand((d,), device=DEVICE, dtype=torch.float32)
294+
bias = torch.rand((d,), device=DEVICE, dtype=torch.float32)
295+
x = torch.rand((s, d), device=DEVICE, dtype=torch.float32)
296+
297+
self.assertExpectedJournal(
298+
check_example(
299+
"welford",
300+
(weight, bias, x),
301+
torch.nn.functional.layer_norm(
302+
x,
303+
normalized_shape=(x.shape[-1],),
304+
weight=weight,
305+
bias=bias,
306+
eps=1e-05,
307+
),
308+
)
309+
)
310+
291311
def test_rms_norm_fwd(self):
292312
args = (
293313
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)