Skip to content

Commit c19532d

Browse files
mengluy0125facebook-github-bot
authored andcommitted
Support layernorm without bias (#585)
Summary: The current layernorm only supports bias case, we thus add the case without bias. Differential Revision: D82171738
1 parent ead8a63 commit c19532d

File tree

5 files changed

+98
-16
lines changed

5 files changed

+98
-16
lines changed

examples/layer_norm.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def layer_norm_fwd(
2121
x: torch.Tensor,
2222
normalized_shape: list[int],
2323
weight: torch.Tensor,
24-
bias: torch.Tensor,
24+
bias: torch.Tensor | None = None,
2525
eps: float = 1e-5,
2626
) -> torch.Tensor:
2727
"""
@@ -30,14 +30,15 @@ def layer_norm_fwd(
3030
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
3131
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
3232
weight (torch.Tensor): Learnable scale parameter of shape [dim].
33-
bias (torch.Tensor): Learnable bias parameter of shape [dim].
33+
bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim].
3434
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
3535
Returns:
3636
torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
3737
"""
3838
m, n = x.size()
3939
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
40-
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
40+
if bias is not None:
41+
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
4142
assert len(normalized_shape) == 1, (
4243
"Helion layer norm only supports 1D layer norm currently"
4344
)
@@ -49,7 +50,12 @@ def layer_norm_fwd(
4950
acc = x[tile_m, :].to(torch.float32)
5051
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
5152
normalized = (acc - mean) * torch.rsqrt(var + eps)
52-
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
53+
if bias is not None:
54+
acc = normalized * (weight[:].to(torch.float32)) + (
55+
bias[:].to(torch.float32)
56+
)
57+
else:
58+
acc = normalized * (weight[:].to(torch.float32))
5359
out[tile_m, :] = acc.to(x.dtype)
5460
return out
5561

@@ -70,15 +76,16 @@ def main() -> None:
7076
weight = torch.randn([dim], device=device, dtype=torch.float16)
7177
bias = torch.randn([dim], device=device, dtype=torch.float16)
7278
eps = 1e-4
73-
run_example(
74-
layer_norm_fwd,
75-
torch.nn.functional.layer_norm,
76-
(x, [dim], weight, bias, eps),
77-
kernel_name="helion",
78-
baseline_name="torch",
79-
rtol=1e-3,
80-
atol=1e-3,
81-
)
79+
for b in [bias, None]:
80+
run_example(
81+
layer_norm_fwd,
82+
torch.nn.functional.layer_norm,
83+
(x, [dim], weight, b, eps),
84+
kernel_name="helion",
85+
baseline_name="torch",
86+
rtol=1e-3,
87+
atol=1e-3,
88+
)
8289

8390

8491
# %%

helion/_compiler/compile_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def cached_create_unbacked_symint(
204204
return result
205205

206206
def to_fake(self, obj: object, origin: Origin) -> object:
207+
if obj is None:
208+
return None
207209
if isinstance(obj, torch.Tensor):
208210
return self._to_fake_tensor(obj, origin.to_source())
209211
if isinstance(obj, (bool, int, float)):

helion/runtime/kernel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,7 @@ def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable:
757757
types.BuiltinFunctionType: lambda fn, x: x,
758758
torch.fx.GraphModule: _graph_module_key,
759759
ConstExpr: lambda fn, x: x.value, # pyright: ignore[reportAttributeAccessIssue]
760+
type(None): lambda fn, x: None,
760761
}
761762

762763

test/test_examples.expected

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,21 +1303,22 @@ def _helion_layer_norm_fwd(bias, x, weight, out, bias_size_0, bias_stride_0, out
13031303
v_15 = tl.cast(v_14, tl.float16)
13041304
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_15, mask_0[:, None] & mask_1[None, :])
13051305

1306-
def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
1306+
def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor | None=None, eps: float=1e-05, *, _launcher=_default_launcher):
13071307
"""
13081308
Performs 1D layer normalization on the input tensor using Helion.
13091309
Args:
13101310
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
13111311
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
13121312
weight (torch.Tensor): Learnable scale parameter of shape [dim].
1313-
bias (torch.Tensor): Learnable bias parameter of shape [dim].
1313+
bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim].
13141314
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
13151315
Returns:
13161316
torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
13171317
"""
13181318
m, n = x.size()
13191319
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
1320-
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
1320+
if bias is not None:
1321+
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
13211322
assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
13221323
assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}'
13231324
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
@@ -1326,6 +1327,67 @@ def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.T
13261327
_launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), bias, x, weight, out, bias.size(0), bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
13271328
return out
13281329

1330+
--- assertExpectedJournal(TestExamples.test_layernorm)
1331+
from __future__ import annotations
1332+
1333+
import torch
1334+
import triton
1335+
import triton.language as tl
1336+
from torch._inductor.runtime.triton_compat import libdevice
1337+
from helion.runtime import default_launcher as _default_launcher
1338+
1339+
@triton.jit
1340+
def _helion_layer_norm_fwd(x, weight, out, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
1341+
pid_0 = tl.program_id(0)
1342+
offset_0 = pid_0 * _BLOCK_SIZE_0
1343+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1344+
mask_0 = indices_0 < m
1345+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
1346+
mask_1 = indices_1 < n
1347+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1348+
v_0 = tl.cast(load, tl.float32)
1349+
var_mean_extra = tl.cast(tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
1350+
v_1 = var_mean_extra / n.to(tl.float32)
1351+
_mask_to_1 = tl.where(tl.broadcast_to(mask_0[:, None], [_BLOCK_SIZE_0, 1]), v_1, tl.full([], 0, tl.float32))
1352+
v_2 = v_0 - _mask_to_1
1353+
v_3 = v_2 * v_2
1354+
var_mean_extra_2 = tl.cast(tl.reshape(tl.sum(v_3, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
1355+
v_4 = var_mean_extra_2 / n.to(tl.float32)
1356+
v_5 = v_0 - v_1
1357+
v_6 = v_4 + eps
1358+
v_7 = libdevice.rsqrt(v_6)
1359+
v_8 = v_5 * v_7
1360+
load_1 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0)
1361+
v_9 = tl.cast(load_1, tl.float32)
1362+
v_10 = v_9[None, :]
1363+
v_11 = v_8 * v_10
1364+
v_12 = tl.cast(v_11, tl.float16)
1365+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_12, mask_0[:, None] & mask_1[None, :])
1366+
1367+
def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor | None=None, eps: float=1e-05, *, _launcher=_default_launcher):
1368+
"""
1369+
Performs 1D layer normalization on the input tensor using Helion.
1370+
Args:
1371+
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
1372+
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
1373+
weight (torch.Tensor): Learnable scale parameter of shape [dim].
1374+
bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim].
1375+
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
1376+
Returns:
1377+
torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
1378+
"""
1379+
m, n = x.size()
1380+
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
1381+
if bias is not None:
1382+
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
1383+
assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
1384+
assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}'
1385+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
1386+
_BLOCK_SIZE_0 = 32
1387+
_RDIM_SIZE_1 = triton.next_power_of_2(n)
1388+
_launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
1389+
return out
1390+
13291391
--- assertExpectedJournal(TestExamples.test_matmul)
13301392
from __future__ import annotations
13311393

test/test_examples.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,16 @@ def test_layernorm(self):
627627
)
628628
)
629629

630+
args = (x, [64], weight)
631+
self.assertExpectedJournal(
632+
check_example(
633+
"layer_norm",
634+
args,
635+
torch.nn.functional.layer_norm(*args),
636+
fn_name="layer_norm_fwd",
637+
)
638+
)
639+
630640
@skipIfRefEager("ref eager mode hits CUDA indexing error with hl.store")
631641
def test_jagged_softmax(self):
632642
num_rows, max_cols = 128, 64

0 commit comments

Comments
 (0)