@@ -1303,21 +1303,22 @@ def _helion_layer_norm_fwd(bias, x, weight, out, bias_size_0, bias_stride_0, out
13031303 v_15 = tl.cast(v_14, tl.float16)
13041304 tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_15, mask_0[:, None] & mask_1[None, :])
13051305
1306- def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
1306+ def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor | None=None , eps: float=1e-05, *, _launcher=_default_launcher):
13071307 """
13081308 Performs 1D layer normalization on the input tensor using Helion.
13091309 Args:
13101310 x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
13111311 normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
13121312 weight (torch.Tensor): Learnable scale parameter of shape [dim].
1313- bias (torch.Tensor): Learnable bias parameter of shape [dim].
1313+ bias (Optional[ torch.Tensor] ): Learnable bias parameter of shape [dim].
13141314 eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
13151315 Returns:
13161316 torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
13171317 """
13181318 m, n = x.size()
13191319 assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
1320- assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
1320+ if bias is not None:
1321+ assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
13211322 assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
13221323 assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}'
13231324 out = torch.empty([m, n], dtype=x.dtype, device=x.device)
@@ -1326,6 +1327,67 @@ def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.T
13261327 _launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), bias, x, weight, out, bias.size(0), bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
13271328 return out
13281329
1330+ --- assertExpectedJournal(TestExamples.test_layernorm)
1331+ from __future__ import annotations
1332+
1333+ import torch
1334+ import triton
1335+ import triton.language as tl
1336+ from torch._inductor.runtime.triton_compat import libdevice
1337+ from helion.runtime import default_launcher as _default_launcher
1338+
1339+ @triton.jit
1340+ def _helion_layer_norm_fwd(x, weight, out, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
1341+ pid_0 = tl.program_id(0)
1342+ offset_0 = pid_0 * _BLOCK_SIZE_0
1343+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1344+ mask_0 = indices_0 < m
1345+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
1346+ mask_1 = indices_1 < n
1347+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1348+ v_0 = tl.cast(load, tl.float32)
1349+ var_mean_extra = tl.cast(tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
1350+ v_1 = var_mean_extra / n.to(tl.float32)
1351+ _mask_to_1 = tl.where(tl.broadcast_to(mask_0[:, None], [_BLOCK_SIZE_0, 1]), v_1, tl.full([], 0, tl.float32))
1352+ v_2 = v_0 - _mask_to_1
1353+ v_3 = v_2 * v_2
1354+ var_mean_extra_2 = tl.cast(tl.reshape(tl.sum(v_3, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
1355+ v_4 = var_mean_extra_2 / n.to(tl.float32)
1356+ v_5 = v_0 - v_1
1357+ v_6 = v_4 + eps
1358+ v_7 = libdevice.rsqrt(v_6)
1359+ v_8 = v_5 * v_7
1360+ load_1 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0)
1361+ v_9 = tl.cast(load_1, tl.float32)
1362+ v_10 = v_9[None, :]
1363+ v_11 = v_8 * v_10
1364+ v_12 = tl.cast(v_11, tl.float16)
1365+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_12, mask_0[:, None] & mask_1[None, :])
1366+
1367+ def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor | None=None, eps: float=1e-05, *, _launcher=_default_launcher):
1368+ """
1369+ Performs 1D layer normalization on the input tensor using Helion.
1370+ Args:
1371+ x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
1372+ normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
1373+ weight (torch.Tensor): Learnable scale parameter of shape [dim].
1374+ bias (Optional[torch.Tensor]): Learnable bias parameter of shape [dim].
1375+ eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
1376+ Returns:
1377+ torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16.
1378+ """
1379+ m, n = x.size()
1380+ assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
1381+ if bias is not None:
1382+ assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
1383+ assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
1384+ assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}'
1385+ out = torch.empty([m, n], dtype=x.dtype, device=x.device)
1386+ _BLOCK_SIZE_0 = 32
1387+ _RDIM_SIZE_1 = triton.next_power_of_2(n)
1388+ _launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
1389+ return out
1390+
13291391--- assertExpectedJournal(TestExamples.test_matmul)
13301392from __future__ import annotations
13311393
0 commit comments