Skip to content

Commit 3405224

Browse files
committed
[Benchmark] jagged_layer_norm kernel and test
stack-info: PR: #704, branch: Sibylau/stack/6
1 parent 19a7442 commit 3405224

File tree

4 files changed

+526
-0
lines changed

4 files changed

+526
-0
lines changed

benchmarks/run.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ class RunResult:
224224
"num_inputs": 10, # int4_gemm takes long time on Benchmark CI, so use fewer inputs instead.
225225
},
226226
),
227+
"jagged_layer_norm": (
228+
"tritonbench.operators.jagged_layer_norm.operator",
229+
"examples.jagged_layer_norm",
230+
"jagged_layer_norm_tritonbench",
231+
),
227232
}
228233

229234

@@ -348,6 +353,12 @@ class RunResult:
348353
"helion_grouped_gemm_jagged_persistent_tritonbench-speedup": "helion_speedup",
349354
"helion_grouped_gemm_jagged_persistent_tritonbench-accuracy": "helion_accuracy",
350355
},
356+
"jagged_layer_norm": {
357+
"torch_compile_grouped_gemm-speedup": "torch_compile_speedup",
358+
"torch_compile_grouped_gemm-accuracy": "torch_compile_accuracy",
359+
"helion_jagged_layer_norm_tritonbench-speedup": "helion_speedup",
360+
"helion_jagged_layer_norm_tritonbench-accuracy": "helion_accuracy",
361+
},
351362
}
352363

353364

examples/jagged_layer_norm.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
"""
2+
Jagged Layer Normalization Example
3+
=================================
4+
5+
This example demonstrates how to compute layer normalization on jagged tensors
6+
using Helion. The implementation closely follows the torch_jagged_layer_norm_torch_sum
7+
algorithm from tritonbench but is optimized for Helion's tiling approach.
8+
9+
A jagged tensor is a nested tensor where each sequence can have different lengths.
10+
Layer normalization is applied across the feature dimension (last dimension) for
11+
each individual sequence, computing mean and variance only over valid elements.
12+
"""
13+
14+
# %%
15+
# Imports
16+
# -------
17+
from __future__ import annotations
18+
19+
import itertools
20+
from typing import Callable
21+
22+
import torch
23+
24+
import helion
25+
from helion._testing import run_example
26+
import helion.language as hl
27+
28+
29+
# %%
30+
# Jagged Layer Norm Kernel
31+
# ----------------------
32+
@helion.kernel(use_default_config=True)
33+
def jagged_layer_norm_kernel(
34+
x_values: torch.Tensor, # [total_L, M] - compressed values
35+
x_offsets: torch.Tensor, # [B+1] - sequence start offsets
36+
eps: float = 1e-6,
37+
) -> torch.Tensor:
38+
"""
39+
Compute layer normalization on jagged tensor using Helion.
40+
41+
This kernel implements layer normalization for jagged tensors by:
42+
1. Computing mean and variance for each sequence individually
43+
2. Normalizing values within each sequence
44+
3. Applying optional affine transformation (weight/bias)
45+
46+
Args:
47+
x_values: Compressed values tensor of shape [total_L, M]
48+
x_offsets: Sequence boundary offsets of shape [B+1]
49+
eps: Small value for numerical stability
50+
51+
Returns:
52+
Normalized tensor of same shape as x_values [total_L, M]
53+
"""
54+
total_L, M = x_values.shape
55+
B = x_offsets.size(0) - 1
56+
57+
# Output tensor
58+
out = torch.empty_like(x_values)
59+
60+
x_flat = x_values.view(-1)
61+
out_flat = out.view(-1)
62+
63+
# Process sequences in tiles
64+
for tile_b in hl.tile(B):
65+
# Get sequence boundaries for this tile
66+
starts = x_offsets[tile_b]
67+
ends = x_offsets[tile_b.index + 1]
68+
seq_lengths = ends - starts
69+
max_seq_len = seq_lengths.amax()
70+
71+
# Initialize accumulators for mean and variance computation
72+
mean_acc = hl.zeros([tile_b], dtype=x_values.dtype)
73+
var_acc = hl.zeros([tile_b], dtype=x_values.dtype)
74+
75+
# First pass: compute mean
76+
for tile_m in hl.tile(M):
77+
row_sums = hl.zeros([tile_b, tile_m], dtype=x_values.dtype)
78+
for tile_k in hl.tile(0, max_seq_len):
79+
# Compute indices into x_values
80+
indices = starts[:, None] + tile_k.index[None, :]
81+
flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :]
82+
83+
# Create mask for valid elements
84+
row_mask = tile_k.index[None, :] < seq_lengths[:, None]
85+
combined_mask = row_mask[:, :, None]
86+
87+
# Load values with masking
88+
x_slice = hl.load(
89+
x_flat,
90+
[flat_indices],
91+
extra_mask=combined_mask,
92+
)
93+
94+
# Accumulate sum for mean (sum across sequence dimension)
95+
row_sums = row_sums + x_slice.sum(dim=1)
96+
mean_acc = mean_acc + row_sums.sum(dim=1)
97+
seq_lengths_float = seq_lengths.to(x_values.dtype)
98+
mean_acc = mean_acc / (seq_lengths_float * M)
99+
100+
# Second pass: compute variance
101+
for tile_m in hl.tile(M):
102+
var_sums = hl.zeros([tile_b, tile_m], dtype=x_values.dtype)
103+
for tile_k in hl.tile(0, max_seq_len):
104+
# Compute indices into x_values
105+
indices = starts[:, None] + tile_k.index[None, :]
106+
flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :]
107+
108+
# Create mask for valid elements
109+
row_mask = tile_k.index[None, :] < seq_lengths[:, None]
110+
combined_mask = row_mask[:, :, None]
111+
112+
# Load values with masking
113+
x_slice = hl.load(
114+
x_flat,
115+
[flat_indices],
116+
extra_mask=combined_mask,
117+
)
118+
119+
# Compute centered values
120+
centered = torch.where(
121+
combined_mask,
122+
x_slice.to(torch.float32) - mean_acc[:, None, None],
123+
0.0,
124+
)
125+
126+
# Accumulate squared differences for variance
127+
var_sums = var_sums + (centered * centered).sum(dim=1)
128+
var_acc = var_acc + var_sums.sum(dim=1)
129+
130+
# Compute variance and reciprocal standard deviation
131+
variance = var_acc / (seq_lengths_float * M)
132+
rstd = torch.rsqrt(variance + eps)
133+
134+
# Third pass: compute layernorm
135+
for tile_m in hl.tile(M):
136+
for tile_k in hl.tile(0, max_seq_len):
137+
# Compute indices into x_values
138+
indices = starts[:, None] + tile_k.index[None, :]
139+
flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :]
140+
141+
# Create mask for valid elements
142+
row_mask = tile_k.index[None, :] < seq_lengths[:, None]
143+
combined_mask = row_mask[:, :, None]
144+
145+
# Load values with masking
146+
x_slice = hl.load(
147+
x_flat,
148+
[flat_indices],
149+
extra_mask=combined_mask,
150+
)
151+
152+
# Normalize
153+
normalized = torch.where(
154+
combined_mask,
155+
(x_slice.to(torch.float32) - mean_acc[:, None, None])
156+
* rstd[:, None, None],
157+
0.0,
158+
)
159+
160+
# Store result
161+
hl.store(
162+
out_flat,
163+
[flat_indices],
164+
normalized.to(x_values.dtype),
165+
extra_mask=combined_mask,
166+
)
167+
168+
return out.reshape(total_L, M)
169+
170+
171+
# %%
172+
# Reference Implementation
173+
# ------------------------------
174+
def reference_jagged_layer_norm_pytorch(
175+
x_values: torch.Tensor,
176+
x_offsets: torch.Tensor,
177+
eps: float = 1e-6,
178+
) -> torch.Tensor:
179+
"""
180+
Simple reference implementation using unbind approach for validation.
181+
"""
182+
183+
return torch.cat(
184+
[
185+
torch.nn.functional.layer_norm(
186+
x_values[x_offsets[i] : x_offsets[i + 1], :],
187+
x_values[x_offsets[i] : x_offsets[i + 1], :].shape,
188+
eps=eps,
189+
)
190+
for i in range(x_offsets.shape[0] - 1)
191+
],
192+
dim=0,
193+
)
194+
195+
196+
# %%
197+
# Benchmark Wrapper
198+
# ---------------
199+
def jagged_layer_norm_tritonbench(
200+
tb_op: object, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
201+
) -> Callable[[], torch.Tensor]:
202+
"""
203+
Wrapper for tritonbench that matches the expected interface.
204+
205+
Args:
206+
tb_op: TritonBench operator instance
207+
x: Nested tensor in jagged format with shape (B, *, M)
208+
B: Batch size
209+
M: Number of features
210+
seqlen: Maximum sequence length
211+
sparsity: Sparsity factor (not used)
212+
213+
Returns:
214+
Callable that returns normalized tensor values
215+
"""
216+
x_values = x._values
217+
x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue]
218+
219+
return lambda: jagged_layer_norm_kernel(x_values, x_offsets, eps=1e-6)
220+
221+
222+
# %%
223+
# Helper function to create test data
224+
# ---------------------------------
225+
def create_test_jagged_tensor(
226+
B: int,
227+
M: int,
228+
max_seqlen: int,
229+
device: str = "cuda",
230+
dtype: torch.dtype = torch.float32,
231+
) -> tuple[torch.Tensor, torch.Tensor]:
232+
"""Create test jagged tensor data."""
233+
234+
# Generate random sequence lengths
235+
seq_lengths = torch.randint(1, max_seqlen + 1, (B,), device=device)
236+
237+
# Create offsets
238+
x_offsets = torch.cat(
239+
[
240+
torch.zeros(1, dtype=torch.long, device=device),
241+
torch.cumsum(seq_lengths, dim=0),
242+
]
243+
)
244+
245+
# Create values
246+
nnz = int(x_offsets[-1])
247+
x_data = torch.randn(nnz, M, dtype=dtype, device=device)
248+
249+
return x_data, x_offsets
250+
251+
252+
# %%
253+
# Main Function
254+
# -----------
255+
def main() -> None:
256+
"""
257+
Main entry point for jagged layer norm example.
258+
259+
Creates test data and compares the Helion implementation against
260+
both PyTorch reference implementations.
261+
"""
262+
# B, M, max_seqlen = 3, 4, 3
263+
B_list = [2**n for n in list(range(5, 16, 3))]
264+
M_list = [2**n for n in list(range(5, 10, 3))]
265+
max_seqlen_list = [128]
266+
eps = 1e-6
267+
device = "cuda"
268+
269+
for B, M, max_seqlen in itertools.product(B_list, M_list, max_seqlen_list):
270+
x_data, x_offsets = create_test_jagged_tensor(
271+
B, M, max_seqlen, device, dtype=torch.float32
272+
)
273+
run_example(
274+
lambda x, o, eps: jagged_layer_norm_kernel(x, o, eps),
275+
lambda x, o, eps: reference_jagged_layer_norm_pytorch(x, o, eps),
276+
(x_data, x_offsets, eps),
277+
)
278+
279+
280+
# %%
281+
if __name__ == "__main__":
282+
main()

0 commit comments

Comments
 (0)