|
| 1 | +""" |
| 2 | +Jagged Layer Normalization Example |
| 3 | +================================= |
| 4 | +
|
| 5 | +This example demonstrates how to compute layer normalization on jagged tensors |
| 6 | +using Helion. The implementation closely follows the torch_jagged_layer_norm_torch_sum |
| 7 | +algorithm from tritonbench but is optimized for Helion's tiling approach. |
| 8 | +
|
| 9 | +A jagged tensor is a nested tensor where each sequence can have different lengths. |
| 10 | +Layer normalization is applied across the feature dimension (last dimension) for |
| 11 | +each individual sequence, computing mean and variance only over valid elements. |
| 12 | +""" |
| 13 | + |
| 14 | +# %% |
| 15 | +# Imports |
| 16 | +# ------- |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import itertools |
| 20 | +from typing import Callable |
| 21 | + |
| 22 | +import torch |
| 23 | + |
| 24 | +import helion |
| 25 | +from helion._testing import run_example |
| 26 | +import helion.language as hl |
| 27 | + |
| 28 | + |
| 29 | +# %% |
| 30 | +# Jagged Layer Norm Kernel |
| 31 | +# ---------------------- |
| 32 | +@helion.kernel(use_default_config=True) |
| 33 | +def jagged_layer_norm_kernel( |
| 34 | + x_values: torch.Tensor, # [total_L, M] - compressed values |
| 35 | + x_offsets: torch.Tensor, # [B+1] - sequence start offsets |
| 36 | + eps: float = 1e-6, |
| 37 | +) -> torch.Tensor: |
| 38 | + """ |
| 39 | + Compute layer normalization on jagged tensor using Helion. |
| 40 | +
|
| 41 | + This kernel implements layer normalization for jagged tensors by: |
| 42 | + 1. Computing mean and variance for each sequence individually |
| 43 | + 2. Normalizing values within each sequence |
| 44 | + 3. Applying optional affine transformation (weight/bias) |
| 45 | +
|
| 46 | + Args: |
| 47 | + x_values: Compressed values tensor of shape [total_L, M] |
| 48 | + x_offsets: Sequence boundary offsets of shape [B+1] |
| 49 | + eps: Small value for numerical stability |
| 50 | +
|
| 51 | + Returns: |
| 52 | + Normalized tensor of same shape as x_values [total_L, M] |
| 53 | + """ |
| 54 | + total_L, M = x_values.shape |
| 55 | + B = x_offsets.size(0) - 1 |
| 56 | + |
| 57 | + # Output tensor |
| 58 | + out = torch.empty_like(x_values) |
| 59 | + |
| 60 | + x_flat = x_values.view(-1) |
| 61 | + out_flat = out.view(-1) |
| 62 | + |
| 63 | + # Process sequences in tiles |
| 64 | + for tile_b in hl.tile(B): |
| 65 | + # Get sequence boundaries for this tile |
| 66 | + starts = x_offsets[tile_b] |
| 67 | + ends = x_offsets[tile_b.index + 1] |
| 68 | + seq_lengths = ends - starts |
| 69 | + max_seq_len = seq_lengths.amax() |
| 70 | + |
| 71 | + # Initialize accumulators for mean and variance computation |
| 72 | + mean_acc = hl.zeros([tile_b], dtype=x_values.dtype) |
| 73 | + var_acc = hl.zeros([tile_b], dtype=x_values.dtype) |
| 74 | + |
| 75 | + # First pass: compute mean |
| 76 | + for tile_m in hl.tile(M): |
| 77 | + row_sums = hl.zeros([tile_b, tile_m], dtype=x_values.dtype) |
| 78 | + for tile_k in hl.tile(0, max_seq_len): |
| 79 | + # Compute indices into x_values |
| 80 | + indices = starts[:, None] + tile_k.index[None, :] |
| 81 | + flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :] |
| 82 | + |
| 83 | + # Create mask for valid elements |
| 84 | + row_mask = tile_k.index[None, :] < seq_lengths[:, None] |
| 85 | + combined_mask = row_mask[:, :, None] |
| 86 | + |
| 87 | + # Load values with masking |
| 88 | + x_slice = hl.load( |
| 89 | + x_flat, |
| 90 | + [flat_indices], |
| 91 | + extra_mask=combined_mask, |
| 92 | + ) |
| 93 | + |
| 94 | + # Accumulate sum for mean (sum across sequence dimension) |
| 95 | + row_sums = row_sums + x_slice.sum(dim=1) |
| 96 | + mean_acc = mean_acc + row_sums.sum(dim=1) |
| 97 | + seq_lengths_float = seq_lengths.to(x_values.dtype) |
| 98 | + mean_acc = mean_acc / (seq_lengths_float * M) |
| 99 | + |
| 100 | + # Second pass: compute variance |
| 101 | + for tile_m in hl.tile(M): |
| 102 | + var_sums = hl.zeros([tile_b, tile_m], dtype=x_values.dtype) |
| 103 | + for tile_k in hl.tile(0, max_seq_len): |
| 104 | + # Compute indices into x_values |
| 105 | + indices = starts[:, None] + tile_k.index[None, :] |
| 106 | + flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :] |
| 107 | + |
| 108 | + # Create mask for valid elements |
| 109 | + row_mask = tile_k.index[None, :] < seq_lengths[:, None] |
| 110 | + combined_mask = row_mask[:, :, None] |
| 111 | + |
| 112 | + # Load values with masking |
| 113 | + x_slice = hl.load( |
| 114 | + x_flat, |
| 115 | + [flat_indices], |
| 116 | + extra_mask=combined_mask, |
| 117 | + ) |
| 118 | + |
| 119 | + # Compute centered values |
| 120 | + centered = torch.where( |
| 121 | + combined_mask, |
| 122 | + x_slice.to(torch.float32) - mean_acc[:, None, None], |
| 123 | + 0.0, |
| 124 | + ) |
| 125 | + |
| 126 | + # Accumulate squared differences for variance |
| 127 | + var_sums = var_sums + (centered * centered).sum(dim=1) |
| 128 | + var_acc = var_acc + var_sums.sum(dim=1) |
| 129 | + |
| 130 | + # Compute variance and reciprocal standard deviation |
| 131 | + variance = var_acc / (seq_lengths_float * M) |
| 132 | + rstd = torch.rsqrt(variance + eps) |
| 133 | + |
| 134 | + # Third pass: compute layernorm |
| 135 | + for tile_m in hl.tile(M): |
| 136 | + for tile_k in hl.tile(0, max_seq_len): |
| 137 | + # Compute indices into x_values |
| 138 | + indices = starts[:, None] + tile_k.index[None, :] |
| 139 | + flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :] |
| 140 | + |
| 141 | + # Create mask for valid elements |
| 142 | + row_mask = tile_k.index[None, :] < seq_lengths[:, None] |
| 143 | + combined_mask = row_mask[:, :, None] |
| 144 | + |
| 145 | + # Load values with masking |
| 146 | + x_slice = hl.load( |
| 147 | + x_flat, |
| 148 | + [flat_indices], |
| 149 | + extra_mask=combined_mask, |
| 150 | + ) |
| 151 | + |
| 152 | + # Normalize |
| 153 | + normalized = torch.where( |
| 154 | + combined_mask, |
| 155 | + (x_slice.to(torch.float32) - mean_acc[:, None, None]) |
| 156 | + * rstd[:, None, None], |
| 157 | + 0.0, |
| 158 | + ) |
| 159 | + |
| 160 | + # Store result |
| 161 | + hl.store( |
| 162 | + out_flat, |
| 163 | + [flat_indices], |
| 164 | + normalized.to(x_values.dtype), |
| 165 | + extra_mask=combined_mask, |
| 166 | + ) |
| 167 | + |
| 168 | + return out.reshape(total_L, M) |
| 169 | + |
| 170 | + |
| 171 | +# %% |
| 172 | +# Reference Implementation |
| 173 | +# ------------------------------ |
| 174 | +def reference_jagged_layer_norm_pytorch( |
| 175 | + x_values: torch.Tensor, |
| 176 | + x_offsets: torch.Tensor, |
| 177 | + eps: float = 1e-6, |
| 178 | +) -> torch.Tensor: |
| 179 | + """ |
| 180 | + Simple reference implementation using unbind approach for validation. |
| 181 | + """ |
| 182 | + |
| 183 | + return torch.cat( |
| 184 | + [ |
| 185 | + torch.nn.functional.layer_norm( |
| 186 | + x_values[x_offsets[i] : x_offsets[i + 1], :], |
| 187 | + x_values[x_offsets[i] : x_offsets[i + 1], :].shape, |
| 188 | + eps=eps, |
| 189 | + ) |
| 190 | + for i in range(x_offsets.shape[0] - 1) |
| 191 | + ], |
| 192 | + dim=0, |
| 193 | + ) |
| 194 | + |
| 195 | + |
| 196 | +# %% |
| 197 | +# Benchmark Wrapper |
| 198 | +# --------------- |
| 199 | +def jagged_layer_norm_tritonbench( |
| 200 | + tb_op: object, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float |
| 201 | +) -> Callable[[], torch.Tensor]: |
| 202 | + """ |
| 203 | + Wrapper for tritonbench that matches the expected interface. |
| 204 | +
|
| 205 | + Args: |
| 206 | + tb_op: TritonBench operator instance |
| 207 | + x: Nested tensor in jagged format with shape (B, *, M) |
| 208 | + B: Batch size |
| 209 | + M: Number of features |
| 210 | + seqlen: Maximum sequence length |
| 211 | + sparsity: Sparsity factor (not used) |
| 212 | +
|
| 213 | + Returns: |
| 214 | + Callable that returns normalized tensor values |
| 215 | + """ |
| 216 | + x_values = x._values |
| 217 | + x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue] |
| 218 | + |
| 219 | + return lambda: jagged_layer_norm_kernel(x_values, x_offsets, eps=1e-6) |
| 220 | + |
| 221 | + |
| 222 | +# %% |
| 223 | +# Helper function to create test data |
| 224 | +# --------------------------------- |
| 225 | +def create_test_jagged_tensor( |
| 226 | + B: int, |
| 227 | + M: int, |
| 228 | + max_seqlen: int, |
| 229 | + device: str = "cuda", |
| 230 | + dtype: torch.dtype = torch.float32, |
| 231 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 232 | + """Create test jagged tensor data.""" |
| 233 | + |
| 234 | + # Generate random sequence lengths |
| 235 | + seq_lengths = torch.randint(1, max_seqlen + 1, (B,), device=device) |
| 236 | + |
| 237 | + # Create offsets |
| 238 | + x_offsets = torch.cat( |
| 239 | + [ |
| 240 | + torch.zeros(1, dtype=torch.long, device=device), |
| 241 | + torch.cumsum(seq_lengths, dim=0), |
| 242 | + ] |
| 243 | + ) |
| 244 | + |
| 245 | + # Create values |
| 246 | + nnz = int(x_offsets[-1]) |
| 247 | + x_data = torch.randn(nnz, M, dtype=dtype, device=device) |
| 248 | + |
| 249 | + return x_data, x_offsets |
| 250 | + |
| 251 | + |
| 252 | +# %% |
| 253 | +# Main Function |
| 254 | +# ----------- |
| 255 | +def main() -> None: |
| 256 | + """ |
| 257 | + Main entry point for jagged layer norm example. |
| 258 | +
|
| 259 | + Creates test data and compares the Helion implementation against |
| 260 | + both PyTorch reference implementations. |
| 261 | + """ |
| 262 | + # B, M, max_seqlen = 3, 4, 3 |
| 263 | + B_list = [2**n for n in list(range(5, 16, 3))] |
| 264 | + M_list = [2**n for n in list(range(5, 10, 3))] |
| 265 | + max_seqlen_list = [128] |
| 266 | + eps = 1e-6 |
| 267 | + device = "cuda" |
| 268 | + |
| 269 | + for B, M, max_seqlen in itertools.product(B_list, M_list, max_seqlen_list): |
| 270 | + x_data, x_offsets = create_test_jagged_tensor( |
| 271 | + B, M, max_seqlen, device, dtype=torch.float32 |
| 272 | + ) |
| 273 | + run_example( |
| 274 | + lambda x, o, eps: jagged_layer_norm_kernel(x, o, eps), |
| 275 | + lambda x, o, eps: reference_jagged_layer_norm_pytorch(x, o, eps), |
| 276 | + (x_data, x_offsets, eps), |
| 277 | + ) |
| 278 | + |
| 279 | + |
| 280 | +# %% |
| 281 | +if __name__ == "__main__": |
| 282 | + main() |
0 commit comments