Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions examples/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def layer_norm_fwd(
x: torch.Tensor,
normalized_shape: list[int],
weight: torch.Tensor,
bias: torch.Tensor,
bias: torch.Tensor | None = None,
eps: float = 1e-5,
) -> torch.Tensor:
"""
Expand All @@ -30,14 +30,15 @@ def layer_norm_fwd(
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
weight (torch.Tensor): Learnable scale parameter of shape [dim].
bias (torch.Tensor): Learnable bias parameter of shape [dim].
bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim].
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
Returns:
torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
"""
m, n = x.size()
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
if bias is not None:
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
assert len(normalized_shape) == 1, (
"Helion layer norm only supports 1D layer norm currently"
)
Expand All @@ -49,7 +50,12 @@ def layer_norm_fwd(
acc = x[tile_m, :].to(torch.float32)
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
normalized = (acc - mean) * torch.rsqrt(var + eps)
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
if bias is not None:
acc = normalized * (weight[:].to(torch.float32)) + (
bias[:].to(torch.float32)
)
else:
acc = normalized * (weight[:].to(torch.float32))
out[tile_m, :] = acc.to(x.dtype)
return out

Expand All @@ -70,15 +76,16 @@ def main() -> None:
weight = torch.randn([dim], device=device, dtype=torch.float16)
bias = torch.randn([dim], device=device, dtype=torch.float16)
eps = 1e-4
run_example(
layer_norm_fwd,
torch.nn.functional.layer_norm,
(x, [dim], weight, bias, eps),
kernel_name="helion",
baseline_name="torch",
rtol=1e-3,
atol=1e-3,
)
for b in [bias, None]:
run_example(
layer_norm_fwd,
torch.nn.functional.layer_norm,
(x, [dim], weight, b, eps),
kernel_name="helion",
baseline_name="torch",
rtol=1e-3,
atol=1e-3,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mengluy0125 thanks for the PR! would you like to add the no-bias test in test_examples.py too?

def test_layernorm(self):
x = torch.randn([32, 64], device=DEVICE, dtype=torch.float16)
weight = torch.randn([64], device=DEVICE, dtype=torch.float16)
bias = torch.randn([64], device=DEVICE, dtype=torch.float16)
args = (x, [64], weight, bias)
self.assertExpectedJournal(
check_example(
"layer_norm",
args,
torch.nn.functional.layer_norm(*args),
fn_name="layer_norm_fwd",
)
)



# %%
Expand Down
2 changes: 2 additions & 0 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def cached_create_unbacked_symint(
return result

def to_fake(self, obj: object, origin: Origin) -> object:
if obj is None:
return None
if isinstance(obj, torch.Tensor):
return self._to_fake_tensor(obj, origin.to_source())
if isinstance(obj, (bool, int, float)):
Expand Down
1 change: 1 addition & 0 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,7 @@ def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable:
types.BuiltinFunctionType: lambda fn, x: x,
torch.fx.GraphModule: _graph_module_key,
ConstExpr: lambda fn, x: x.value, # pyright: ignore[reportAttributeAccessIssue]
type(None): lambda fn, x: None,
}


Expand Down
70 changes: 66 additions & 4 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ def jagged_softmax_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _lau
_launcher(_helion_jagged_softmax_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_flat, out, out.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
return out.reshape(N, M)

--- assertExpectedJournal(TestExamples.test_layernorm)
--- assertExpectedJournal(TestExamples.test_layernorm_with_bias)
from __future__ import annotations

import torch
Expand Down Expand Up @@ -1303,21 +1303,22 @@ def _helion_layer_norm_fwd(bias, x, weight, out, bias_size_0, bias_stride_0, out
v_15 = tl.cast(v_14, tl.float16)
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_15, mask_0[:, None] & mask_1[None, :])

def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor | None=None, eps: float=1e-05, *, _launcher=_default_launcher):
"""
Performs 1D layer normalization on the input tensor using Helion.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
weight (torch.Tensor): Learnable scale parameter of shape [dim].
bias (torch.Tensor): Learnable bias parameter of shape [dim].
bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim].
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
Returns:
torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
"""
m, n = x.size()
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
if bias is not None:
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}'
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
Expand All @@ -1326,6 +1327,67 @@ def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.T
_launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), bias, x, weight, out, bias.size(0), bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_layernorm_without_bias)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_layer_norm_fwd(x, weight, out, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < m
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
mask_1 = indices_1 < n
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = tl.cast(load, tl.float32)
var_mean_extra = tl.cast(tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
v_1 = var_mean_extra / n.to(tl.float32)
_mask_to_1 = tl.where(tl.broadcast_to(mask_0[:, None], [_BLOCK_SIZE_0, 1]), v_1, tl.full([], 0, tl.float32))
v_2 = v_0 - _mask_to_1
v_3 = v_2 * v_2
var_mean_extra_2 = tl.cast(tl.reshape(tl.sum(v_3, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
v_4 = var_mean_extra_2 / n.to(tl.float32)
v_5 = v_0 - v_1
v_6 = v_4 + eps
v_7 = libdevice.rsqrt(v_6)
v_8 = v_5 * v_7
load_1 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0)
v_9 = tl.cast(load_1, tl.float32)
v_10 = v_9[None, :]
v_11 = v_8 * v_10
v_12 = tl.cast(v_11, tl.float16)
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_12, mask_0[:, None] & mask_1[None, :])

def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor | None=None, eps: float=1e-05, *, _launcher=_default_launcher):
"""
Performs 1D layer normalization on the input tensor using Helion.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
weight (torch.Tensor): Learnable scale parameter of shape [dim].
bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim].
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
Returns:
torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
"""
m, n = x.size()
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
if bias is not None:
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}'
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_1 = triton.next_power_of_2(n)
_launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_matmul)
from __future__ import annotations

Expand Down
16 changes: 15 additions & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def test_fp8_attention(self):
)
)

def test_layernorm(self):
def test_layernorm_with_bias(self):
x = torch.randn([32, 64], device=DEVICE, dtype=torch.float16)
weight = torch.randn([64], device=DEVICE, dtype=torch.float16)
bias = torch.randn([64], device=DEVICE, dtype=torch.float16)
Expand All @@ -627,6 +627,20 @@ def test_layernorm(self):
)
)

def test_layernorm_without_bias(self):
x = torch.randn([32, 64], device=DEVICE, dtype=torch.float16)
weight = torch.randn([64], device=DEVICE, dtype=torch.float16)

args = (x, [64], weight)
self.assertExpectedJournal(
check_example(
"layer_norm",
args,
torch.nn.functional.layer_norm(*args),
fn_name="layer_norm_fwd",
)
)

@skipIfRefEager("ref eager mode hits CUDA indexing error with hl.store")
def test_jagged_softmax(self):
num_rows, max_cols = 128, 64
Expand Down
Loading