88# %%
99from __future__ import annotations
1010
11+ from typing import Optional
12+
1113import torch
1214
1315import helion
@@ -21,7 +23,7 @@ def layer_norm_fwd(
2123 x : torch .Tensor ,
2224 normalized_shape : list [int ],
2325 weight : torch .Tensor ,
24- bias : torch .Tensor ,
26+ bias : Optional [ torch .Tensor ] = None ,
2527 eps : float = 1e-5 ,
2628) -> torch .Tensor :
2729 """
@@ -30,14 +32,15 @@ def layer_norm_fwd(
3032 x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
3133 normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
3234 weight (torch.Tensor): Learnable scale parameter of shape [dim].
33- bias (torch.Tensor): Learnable bias parameter of shape [dim].
35+ bias (Optional[ torch.Tensor] ): Learnable bias parameter of shape [dim].
3436 eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
3537 Returns:
3638 torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
3739 """
3840 m , n = x .size ()
3941 assert weight .size (0 ) == n , f"weight size mismatch { weight .size (0 )} != { m } "
40- assert bias .size (0 ) == n , f"bias size mismatch { bias .size (0 )} != { m } "
42+ if bias is not None :
43+ assert bias .size (0 ) == n , f"bias size mismatch { bias .size (0 )} != { m } "
4144 assert len (normalized_shape ) == 1 , (
4245 "Helion layer norm only supports 1D layer norm currently"
4346 )
@@ -49,7 +52,10 @@ def layer_norm_fwd(
4952 acc = x [tile_m , :].to (torch .float32 )
5053 var , mean = torch .var_mean (acc , dim = - 1 , keepdim = True , correction = 0 )
5154 normalized = (acc - mean ) * torch .rsqrt (var + eps )
52- acc = normalized * (weight [:].to (torch .float32 )) + (bias [:].to (torch .float32 ))
55+ if bias is not None :
56+ acc = normalized * (weight [:].to (torch .float32 )) + (bias [:].to (torch .float32 ))
57+ else :
58+ acc = normalized * (weight [:].to (torch .float32 ))
5359 out [tile_m , :] = acc .to (x .dtype )
5460 return out
5561
@@ -70,15 +76,16 @@ def main() -> None:
7076 weight = torch .randn ([dim ], device = device , dtype = torch .float16 )
7177 bias = torch .randn ([dim ], device = device , dtype = torch .float16 )
7278 eps = 1e-4
73- run_example (
74- layer_norm_fwd ,
75- torch .nn .functional .layer_norm ,
76- (x , [dim ], weight , bias , eps ),
77- kernel_name = "helion" ,
78- baseline_name = "torch" ,
79- rtol = 1e-3 ,
80- atol = 1e-3 ,
81- )
79+ for b in [bias , None ]:
80+ run_example (
81+ layer_norm_fwd ,
82+ torch .nn .functional .layer_norm ,
83+ (x , [dim ], weight , b , eps ),
84+ kernel_name = "helion" ,
85+ baseline_name = "torch" ,
86+ rtol = 1e-3 ,
87+ atol = 1e-3 ,
88+ )
8289
8390
8491# %%
0 commit comments