11"""
2- Helion Layer Normalization Forward Example
3- ==========================================
2+ Helion Layer Normalization Forward and Backward Example
3+ ========================================================
44This example demonstrates a Helion kernel implementation of 1D layer normalization
5- using FP16 inputs and compares it against PyTorch's built-in layer_norm function.
5+ with both forward and backward passes using FP16 inputs and compares it against
6+ PyTorch's built-in layer_norm function.
67"""
78
89# %%
910from __future__ import annotations
1011
12+ from typing import Any
13+
1114import torch
1215
1316import helion
1720
1821# %%
1922@helion .kernel
20- def layer_norm_fwd (
23+ def layer_norm_fwd_kernel (
2124 x : torch .Tensor ,
2225 normalized_shape : list [int ],
2326 weight : torch .Tensor ,
2427 bias : torch .Tensor | None = None ,
2528 eps : float = 1e-5 ,
26- ) -> torch .Tensor :
29+ ) -> tuple [ torch .Tensor , torch . Tensor , torch . Tensor ] :
2730 """
2831 Performs 1D layer normalization on the input tensor using Helion.
2932 Args:
3033 x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
3134 normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
3235 weight (torch.Tensor): Learnable scale parameter of shape [dim].
33- bias (Optional[ torch.Tensor] ): Learnable bias parameter of shape [dim].
36+ bias (torch.Tensor | None ): Optional learnable bias parameter of shape [dim].
3437 eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
3538 Returns:
36- torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
39+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
40+ - The layer-normalized output tensor of shape [batch_size, dim], in FP16.
41+ - Mean tensor of shape [batch_size], in FP32.
42+ - Reciprocal standard deviation tensor of shape [batch_size], in FP32.
3743 """
3844 m , n = x .size ()
39- assert weight .size (0 ) == n , f"weight size mismatch { weight .size (0 )} != { m } "
45+ assert weight .size (0 ) == n , f"weight size mismatch { weight .size (0 )} != { n } "
4046 if bias is not None :
41- assert bias .size (0 ) == n , f"bias size mismatch { bias .size (0 )} != { m } "
47+ assert bias .size (0 ) == n , f"bias size mismatch { bias .size (0 )} != { n } "
4248 assert len (normalized_shape ) == 1 , (
4349 "Helion layer norm only supports 1D layer norm currently"
4450 )
4551 assert normalized_shape [0 ] == n , (
4652 f"normalized shape mismatch { normalized_shape [0 ]} != { n } "
4753 )
4854 out = torch .empty ([m , n ], dtype = x .dtype , device = x .device )
55+ mean = torch .empty ([m ], dtype = torch .float32 , device = x .device )
56+ rstd = torch .empty ([m ], dtype = torch .float32 , device = x .device )
57+
4958 for tile_m in hl .tile (m ):
5059 acc = x [tile_m , :].to (torch .float32 )
51- var , mean = torch .var_mean (acc , dim = - 1 , keepdim = True , correction = 0 )
52- normalized = (acc - mean ) * torch .rsqrt (var + eps )
60+ # Compute mean
61+ mean_val = torch .sum (acc , dim = - 1 ) / n
62+ # Compute variance
63+ centered = acc - mean_val [:, None ]
64+ var_val = torch .sum (centered * centered , dim = - 1 ) / n
65+ # Compute reciprocal standard deviation
66+ rstd_val = torch .rsqrt (var_val + eps )
67+ # Normalize
68+ normalized = centered * rstd_val [:, None ]
69+ # Apply affine transformation
5370 if bias is not None :
5471 acc = normalized * (weight [:].to (torch .float32 )) + (
5572 bias [:].to (torch .float32 )
5673 )
5774 else :
5875 acc = normalized * (weight [:].to (torch .float32 ))
5976 out [tile_m , :] = acc .to (x .dtype )
60- return out
77+ mean [tile_m ] = mean_val
78+ rstd [tile_m ] = rstd_val
79+ return out , mean , rstd
80+
81+
82+ # %%
83+ @helion .kernel
84+ def layer_norm_bwd_dwdb (
85+ grad_out : torch .Tensor ,
86+ x : torch .Tensor ,
87+ mean : torch .Tensor ,
88+ rstd : torch .Tensor ,
89+ weight : torch .Tensor ,
90+ compute_bias_grad : hl .constexpr = True , # type: ignore[valid-type]
91+ ) -> tuple [torch .Tensor , torch .Tensor | None ]:
92+ """
93+ Compute gradients for weight (dW) and optionally bias (dB) parameters.
94+
95+ This kernel performs reduction across the batch dimension (M) to accumulate
96+ gradients for each feature dimension's weight and bias parameters.
97+
98+ Args:
99+ grad_out: Gradient w.r.t layer norm output [M, N]
100+ x: Original input tensor [M, N]
101+ mean: Per-sample mean computed in forward pass [M]
102+ rstd: Per-sample reciprocal standard deviation from forward pass [M]
103+ weight: Weight parameter (used only for dtype/device info) [N]
104+ compute_bias_grad: Whether to compute bias gradient (default: True)
105+
106+ Returns:
107+ (grad_weight, grad_bias): Gradients for weight and bias (if computed), both shape [N]
108+ grad_bias is None if compute_bias_grad is False
109+ """
110+ m , n = x .shape
111+ n = hl .specialize (n )
112+
113+ dw = torch .empty ([n ], dtype = weight .dtype , device = weight .device )
114+ if compute_bias_grad :
115+ db = torch .empty ([n ], dtype = weight .dtype , device = weight .device )
116+ else :
117+ db = None
118+
119+ # Reduce across rows (M) inside the kernel without atomics
120+ rdim = hl .register_reduction_dim (m )
121+
122+ for tile_n in hl .tile (n ):
123+ rows = hl .arange (0 , rdim )
124+ # Load slices for all rows in rdim and this tile of columns
125+ x_blk = x [rows , tile_n ].to (torch .float32 )
126+ dy_blk = grad_out [rows , tile_n ].to (torch .float32 )
127+ mean_vec = mean [rows ]
128+ rstd_vec = rstd [rows ]
129+
130+ x_hat_blk = (x_blk - mean_vec [:, None ]) * rstd_vec [:, None ]
131+ dw_tile = torch .sum (dy_blk * x_hat_blk , dim = 0 ).to (weight .dtype )
132+
133+ dw [tile_n ] = dw_tile
134+ if compute_bias_grad :
135+ db_tile = torch .sum (dy_blk , dim = 0 ).to (weight .dtype )
136+ db [tile_n ] = db_tile # type: ignore[index]
137+
138+ if compute_bias_grad :
139+ return dw , db
140+ return dw , None
141+
142+
143+ @helion .kernel
144+ def layer_norm_bwd_dx (
145+ grad_out : torch .Tensor ,
146+ x : torch .Tensor ,
147+ weight : torch .Tensor ,
148+ mean : torch .Tensor ,
149+ rstd : torch .Tensor ,
150+ ) -> torch .Tensor :
151+ """
152+ Compute gradient for input tensor (dX).
153+
154+ This kernel computes per-sample gradients by performing reductions across
155+ the feature dimension (N) for each sample in the batch.
156+
157+ Args:
158+ grad_out: Gradient w.r.t layer norm output [M, N]
159+ x: Original input tensor [M, N]
160+ weight: Weight parameter [N]
161+ mean: Per-sample mean computed in forward pass [M]
162+ rstd: Per-sample reciprocal standard deviation from forward pass [M]
163+
164+ Returns:
165+ grad_x: Gradient w.r.t input tensor, shape [M, N]
166+ """
167+ m , n = x .shape
168+ n = hl .specialize (n )
169+
170+ grad_x = torch .empty_like (x )
171+
172+ for tile_m in hl .tile (m ):
173+ x_tile = x [tile_m , :].to (torch .float32 )
174+ dy_tile = grad_out [tile_m , :].to (torch .float32 )
175+ w = weight [:].to (torch .float32 )
176+ mean_tile = mean [tile_m ]
177+ rstd_tile = rstd [tile_m ]
178+
179+ x_hat = (x_tile - mean_tile [:, None ]) * rstd_tile [:, None ]
180+ wdy = w * dy_tile
181+ c1 = torch .sum (x_hat * wdy , dim = - 1 ) / n
182+ c2 = torch .sum (wdy , dim = - 1 ) / n
183+ dx = (wdy - (x_hat * c1 [:, None ] + c2 [:, None ])) * rstd_tile [:, None ]
184+ grad_x [tile_m , :] = dx .to (x .dtype )
185+
186+ return grad_x
187+
188+
189+ # %%
190+ class LayerNormFunction (torch .autograd .Function ):
191+ @staticmethod
192+ def forward (
193+ ctx : Any , # noqa: ANN401
194+ x : torch .Tensor ,
195+ normalized_shape : list [int ],
196+ weight : torch .Tensor ,
197+ bias : torch .Tensor | None ,
198+ eps : float ,
199+ ) -> torch .Tensor :
200+ """Forward pass for layer normalization."""
201+ y , mean , rstd = layer_norm_fwd_kernel (x , normalized_shape , weight , bias , eps )
202+ ctx .save_for_backward (x , weight , bias , mean , rstd ) # type: ignore[arg-type]
203+ ctx .normalized_shape = normalized_shape # type: ignore[attr-defined]
204+ return y
205+
206+ @staticmethod
207+ def backward ( # type: ignore[override]
208+ ctx : Any , # noqa: ANN401
209+ grad_output : torch .Tensor ,
210+ ) -> tuple [
211+ torch .Tensor | None , None , torch .Tensor | None , torch .Tensor | None , None
212+ ]:
213+ """Backward pass for layer normalization split into two separate kernels for efficiency."""
214+ grad_out = grad_output # Use common name internally
215+ x , weight , bias , mean , rstd = ctx .saved_tensors # type: ignore[attr-defined]
216+
217+ # Check if bias gradient is needed
218+ compute_bias_grad = bias is not None
219+
220+ # First kernel: Compute gradients for weight and bias by reducing across batch dimension (M)
221+ grad_weight , grad_bias = layer_norm_bwd_dwdb (
222+ grad_out , x , mean , rstd , weight , compute_bias_grad
223+ )
224+
225+ # Second kernel: Compute gradient for input (dx) using per-sample reductions across feature dimension (N)
226+ grad_x = layer_norm_bwd_dx (grad_out , x , weight , mean , rstd )
227+
228+ return grad_x , None , grad_weight , grad_bias , None
229+
230+
231+ # %%
232+ def layer_norm (
233+ x : torch .Tensor ,
234+ normalized_shape : list [int ],
235+ weight : torch .Tensor ,
236+ bias : torch .Tensor | None = None ,
237+ eps : float = 1e-5 ,
238+ ) -> torch .Tensor :
239+ """Layer normalization with forward + backward support."""
240+ return LayerNormFunction .apply (x , normalized_shape , weight , bias , eps ) # type: ignore[no-any-return]
61241
62242
63243# %%
@@ -72,21 +252,43 @@ def main() -> None:
72252 batch_size = 32
73253 dim = 64
74254 device = "cuda"
255+
256+ # Test forward pass only
257+ print ("\n === Forward Pass Test ===" )
75258 x = torch .randn ([batch_size , dim ], device = device , dtype = torch .float16 )
76259 weight = torch .randn ([dim ], device = device , dtype = torch .float16 )
77260 bias = torch .randn ([dim ], device = device , dtype = torch .float16 )
78261 eps = 1e-4
79262 for b in [bias , None ]:
80263 run_example (
81- layer_norm_fwd ,
264+ layer_norm ,
82265 torch .nn .functional .layer_norm ,
83266 (x , [dim ], weight , b , eps ),
84- kernel_name = "helion" ,
85- baseline_name = "torch" ,
86267 rtol = 1e-3 ,
87268 atol = 1e-3 ,
88269 )
89270
271+ # Test forward + backward pass
272+ print ("\n \n === Forward + Backward Pass Test ===" )
273+ x_grad = torch .randn (
274+ [batch_size , dim ], device = device , dtype = torch .float16 , requires_grad = True
275+ )
276+ weight_grad = torch .randn (
277+ [dim ], device = device , dtype = torch .float16 , requires_grad = True
278+ )
279+ bias_grad = torch .randn (
280+ [dim ], device = device , dtype = torch .float16 , requires_grad = True
281+ )
282+ for b in [bias_grad , None ]:
283+ run_example (
284+ layer_norm ,
285+ torch .nn .functional .layer_norm ,
286+ (x_grad , [dim ], weight_grad , b , eps ),
287+ rtol = 1e-3 ,
288+ atol = 1e-3 ,
289+ bwd = True ,
290+ )
291+
90292
91293# %%
92294if __name__ == "__main__" :
0 commit comments