Skip to content

Commit 5540fd8

Browse files
mengluy0125facebook-github-bot
authored andcommitted
Support layernorm without bias (#585)
Summary: The current layernorm only supports bias case, we thus add the case without bias. Differential Revision: D82171738
1 parent 247be92 commit 5540fd8

File tree

4 files changed

+27
-16
lines changed

4 files changed

+27
-16
lines changed

examples/layer_norm.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# %%
99
from __future__ import annotations
1010

11+
from typing import Optional
12+
1113
import torch
1214

1315
import helion
@@ -21,7 +23,7 @@ def layer_norm_fwd(
2123
x: torch.Tensor,
2224
normalized_shape: list[int],
2325
weight: torch.Tensor,
24-
bias: torch.Tensor,
26+
bias: Optional[torch.Tensor] = None,
2527
eps: float = 1e-5,
2628
) -> torch.Tensor:
2729
"""
@@ -30,14 +32,15 @@ def layer_norm_fwd(
3032
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
3133
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
3234
weight (torch.Tensor): Learnable scale parameter of shape [dim].
33-
bias (torch.Tensor): Learnable bias parameter of shape [dim].
35+
bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim].
3436
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
3537
Returns:
3638
torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
3739
"""
3840
m, n = x.size()
3941
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
40-
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
42+
if bias is not None:
43+
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
4144
assert len(normalized_shape) == 1, (
4245
"Helion layer norm only supports 1D layer norm currently"
4346
)
@@ -49,7 +52,10 @@ def layer_norm_fwd(
4952
acc = x[tile_m, :].to(torch.float32)
5053
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
5154
normalized = (acc - mean) * torch.rsqrt(var + eps)
52-
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
55+
if bias is not None:
56+
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
57+
else:
58+
acc = normalized * (weight[:].to(torch.float32))
5359
out[tile_m, :] = acc.to(x.dtype)
5460
return out
5561

@@ -70,15 +76,16 @@ def main() -> None:
7076
weight = torch.randn([dim], device=device, dtype=torch.float16)
7177
bias = torch.randn([dim], device=device, dtype=torch.float16)
7278
eps = 1e-4
73-
run_example(
74-
layer_norm_fwd,
75-
torch.nn.functional.layer_norm,
76-
(x, [dim], weight, bias, eps),
77-
kernel_name="helion",
78-
baseline_name="torch",
79-
rtol=1e-3,
80-
atol=1e-3,
81-
)
79+
for b in [bias, None]:
80+
run_example(
81+
layer_norm_fwd,
82+
torch.nn.functional.layer_norm,
83+
(x, [dim], weight, b, eps),
84+
kernel_name="helion",
85+
baseline_name="torch",
86+
rtol=1e-3,
87+
atol=1e-3,
88+
)
8289

8390

8491
# %%

helion/_compiler/compile_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def cached_create_unbacked_symint(
204204
return result
205205

206206
def to_fake(self, obj: object, origin: Origin) -> object:
207+
if obj is None:
208+
return None
207209
if isinstance(obj, torch.Tensor):
208210
return self._to_fake_tensor(obj, origin.to_source())
209211
if isinstance(obj, (bool, int, float)):

helion/runtime/kernel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,7 @@ def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable:
757757
types.BuiltinFunctionType: lambda fn, x: x,
758758
torch.fx.GraphModule: _graph_module_key,
759759
ConstExpr: lambda fn, x: x.value, # pyright: ignore[reportAttributeAccessIssue]
760+
type(None): lambda fn, x: None,
760761
}
761762

762763

test/test_examples.expected

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,21 +1303,22 @@ def _helion_layer_norm_fwd(bias, x, weight, out, bias_size_0, bias_stride_0, out
13031303
v_15 = tl.cast(v_14, tl.float16)
13041304
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_15, mask_0[:, None] & mask_1[None, :])
13051305

1306-
def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
1306+
def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: Optional[torch.Tensor] = None, eps: float=1e-05, *, _launcher=_default_launcher):
13071307
"""
13081308
Performs 1D layer normalization on the input tensor using Helion.
13091309
Args:
13101310
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
13111311
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
13121312
weight (torch.Tensor): Learnable scale parameter of shape [dim].
1313-
bias (torch.Tensor): Learnable bias parameter of shape [dim].
1313+
bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim].
13141314
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
13151315
Returns:
13161316
torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
13171317
"""
13181318
m, n = x.size()
13191319
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
1320-
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
1320+
if bias is not None:
1321+
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
13211322
assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
13221323
assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}'
13231324
out = torch.empty([m, n], dtype=x.dtype, device=x.device)

0 commit comments

Comments
 (0)