Skip to content

Commit ee56353

Browse files
committed
[WIP] welford
stack-info: PR: #308, branch: yf225/stack/22
1 parent 5195d38 commit ee56353

File tree

2 files changed

+131
-0
lines changed

2 files changed

+131
-0
lines changed

benchmarks/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@
9494
"examples.grouped_gemm",
9595
"grouped_gemm_tritonbench",
9696
),
97+
"welford": (
98+
"tritonbench.operators.welford.operator",
99+
"examples.welford",
100+
"welford_tritonbench",
101+
),
97102
}
98103

99104

examples/welford.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import run_example
7+
import helion.language as hl
8+
9+
# TritonBench configuration
10+
TRITONBENCH_ARGS = {"primals_1": None, "primals_2": None, "primals_3": None}
11+
12+
13+
@helion.kernel()
14+
def welford_layer_norm(
15+
weight: torch.Tensor,
16+
bias: torch.Tensor,
17+
input: torch.Tensor,
18+
eps: float = 1e-5,
19+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
20+
"""
21+
Welford algorithm for computing layer norm.
22+
23+
Args:
24+
weight: Scale parameter (gamma) with shape [D]
25+
bias: Shift parameter (beta) with shape [D]
26+
input: Input tensor with shape [S, D]
27+
eps: Small value to avoid division by zero
28+
29+
Returns:
30+
Tuple of (output tensor, mean, inv_std)
31+
"""
32+
S, D = input.shape
33+
34+
# Output tensors
35+
out = torch.empty_like(input)
36+
mean_out = torch.empty((S, 1), dtype=torch.float32, device=input.device)
37+
inv_std_out = torch.empty((S, 1), dtype=torch.float32, device=input.device)
38+
39+
# Process rows in tiles
40+
for tile_s in hl.tile(S):
41+
# Compute mean using simple reduction first
42+
row_sum = hl.zeros([tile_s], dtype=torch.float32)
43+
44+
for tile_d in hl.tile(D):
45+
x = input[tile_s, tile_d].to(torch.float32)
46+
row_sum = row_sum + x.sum(dim=-1)
47+
48+
mean = row_sum / D
49+
50+
# Store mean
51+
mean_out[tile_s, 0] = mean
52+
53+
# Compute variance using the mean
54+
var_sum = hl.zeros([tile_s], dtype=torch.float32)
55+
56+
for tile_d in hl.tile(D):
57+
x = input[tile_s, tile_d].to(torch.float32)
58+
diff = x - mean[:, None]
59+
var_sum = var_sum + (diff * diff).sum(dim=-1)
60+
61+
variance = var_sum / D
62+
63+
# Compute inverse standard deviation
64+
inv_std = torch.rsqrt(variance + eps)
65+
66+
# Store inv_std
67+
inv_std_out[tile_s, 0] = inv_std
68+
69+
# Apply normalization
70+
for tile_d in hl.tile(D):
71+
x_orig = input[tile_s, tile_d]
72+
x_normalized = (x_orig - mean[:, None]) * inv_std[:, None]
73+
74+
# Apply scale and bias
75+
out[tile_s, tile_d] = x_normalized * weight[tile_d] + bias[tile_d]
76+
77+
return out, mean_out, inv_std_out
78+
79+
80+
def welford_tritonbench(primals_1, primals_2, primals_3):
81+
"""
82+
Wrapper for tritonbench that matches the expected interface.
83+
84+
Args:
85+
primals_1: weight (gamma) parameter
86+
primals_2: bias (beta) parameter
87+
primals_3: input tensor
88+
89+
Returns:
90+
Tuple of (output, input, mean, inv_std) to match tritonbench interface
91+
"""
92+
# Run the welford layer norm kernel
93+
output, mean, inv_std = welford_layer_norm(primals_1, primals_2, primals_3)
94+
95+
return (output, primals_3, mean, inv_std)
96+
97+
98+
def reference_layer_norm(weight, bias, input, eps=1e-5):
99+
"""PyTorch reference implementation for layer normalization."""
100+
return torch.nn.functional.layer_norm(input, (input.shape[-1],), weight, bias, eps)
101+
102+
103+
def check(S: int, D: int) -> None:
104+
# Create input tensors
105+
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
106+
bias = torch.randn(D, device="cuda", dtype=torch.bfloat16)
107+
input = torch.randn([S, D], device="cuda", dtype=torch.bfloat16)
108+
109+
# Run comparison - just compare the output tensor
110+
run_example(
111+
{"helion": lambda w, b, x: welford_layer_norm(w, b, x)[0]},
112+
lambda w, b, x: reference_layer_norm(w, b, x),
113+
(weight, bias, input),
114+
)
115+
116+
117+
def main() -> None:
118+
# Test with various shapes
119+
check(262144, 1024)
120+
check(262144, 2048)
121+
check(512, 768)
122+
check(1024, 1024)
123+
124+
125+
if __name__ == "__main__":
126+
main()

0 commit comments

Comments
 (0)