@@ -3040,3 +3040,102 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]]
30403040 _BLOCK_SIZE_2 = 16
30413041 _launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
30423042 return out
3043+
3044+ --- assertExpectedJournal(TestExamples.test_welford)
3045+ from __future__ import annotations
3046+
3047+ import torch
3048+ import triton
3049+ import triton.language as tl
3050+ from torch._inductor.runtime.triton_compat import libdevice
3051+ from helion.runtime import default_launcher as _default_launcher
3052+
3053+ @triton.jit
3054+ def _helion_welford(x, weight, bias, out, bias_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
3055+ pid_0 = tl.program_id(0)
3056+ offset_0 = pid_0 * _BLOCK_SIZE_0
3057+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
3058+ mask_0 = indices_0 < m
3059+ acc_cnt = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
3060+ acc_mean = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
3061+ acc_m2 = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
3062+ for offset_1 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1):
3063+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
3064+ mask_1 = indices_1 < n
3065+ acc_mean_copy = acc_mean
3066+ acc_cnt_copy = acc_cnt
3067+ acc_m2_copy = acc_m2
3068+ acc_mean_copy_0 = acc_mean_copy
3069+ acc_cnt_copy_0 = acc_cnt_copy
3070+ acc_m2_copy_0 = acc_m2_copy
3071+ chunk = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
3072+ sum_x = tl.cast(tl.sum(chunk, 1), tl.float32)
3073+ v_0 = chunk * chunk
3074+ sum_x2 = tl.cast(tl.sum(v_0, 1), tl.float32)
3075+ _BLOCK_SIZE_1_ = _BLOCK_SIZE_1
3076+ v_1 = tl.cast(_BLOCK_SIZE_1_, tl.float32)
3077+ v_2 = sum_x / v_1
3078+ v_3 = sum_x * sum_x
3079+ _BLOCK_SIZE_1__1 = _BLOCK_SIZE_1
3080+ v_4 = tl.cast(_BLOCK_SIZE_1__1, tl.float32)
3081+ v_5 = v_3 / v_4
3082+ v_6 = sum_x2 - v_5
3083+ v_7 = v_2 - acc_mean_copy_0
3084+ _BLOCK_SIZE_1__2 = _BLOCK_SIZE_1
3085+ v_8 = tl.cast(_BLOCK_SIZE_1__2, tl.float32)
3086+ acc_cnt = acc_cnt_copy_0 + v_8
3087+ v_10 = tl.full([], 1, tl.int32)
3088+ v_11 = v_10 / acc_cnt
3089+ _BLOCK_SIZE_1__3 = _BLOCK_SIZE_1
3090+ v_12 = tl.cast(_BLOCK_SIZE_1__3, tl.float32)
3091+ v_13 = v_11 * v_12
3092+ v_14 = v_7 * v_13
3093+ acc_mean = acc_mean_copy_0 + v_14
3094+ v_16 = acc_m2_copy_0 + v_6
3095+ v_17 = v_7 * v_7
3096+ _BLOCK_SIZE_1__4 = _BLOCK_SIZE_1
3097+ v_18 = tl.cast(_BLOCK_SIZE_1__4, tl.float32)
3098+ v_19 = acc_cnt_copy_0 * v_18
3099+ v_20 = v_19 / acc_cnt
3100+ v_21 = v_17 * v_20
3101+ acc_m2 = v_16 + v_21
3102+ v_23 = acc_m2 / acc_cnt
3103+ v_24 = v_23 + eps
3104+ v_25 = libdevice.rsqrt(v_24)
3105+ mean_col = acc_mean[:, None]
3106+ rstd_col = v_25[:, None]
3107+ for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_2):
3108+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
3109+ mask_2 = indices_2 < n
3110+ mean_col_copy = mean_col
3111+ rstd_col_copy = rstd_col
3112+ mean_col_copy_0 = mean_col_copy
3113+ rstd_col_copy_0 = rstd_col_copy
3114+ xi_chuck = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
3115+ load_1 = tl.load(weight + indices_2 * weight_stride_0, mask_2, other=0)
3116+ w_chuck = load_1[None, :]
3117+ load_2 = tl.load(bias + indices_2 * bias_stride_0, mask_2, other=0)
3118+ b_chuck = load_2[None, :]
3119+ v_26 = xi_chuck - mean_col_copy_0
3120+ v_27 = v_26 * rstd_col_copy_0
3121+ v_28 = v_27 * w_chuck
3122+ v_29 = v_28 + b_chuck
3123+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_29, mask_0[:, None] & mask_2[None, :])
3124+
3125+ def welford(weight: torch.Tensor, bias: torch.Tensor, x: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
3126+ """
3127+ Applies LayerNorm using Welford's algorithm for mean/variance.
3128+ Args:
3129+ weight: weight tensor of shape [N]
3130+ bias: bias tensor of shape [N]
3131+ x: input tensor of shape [M, N]
3132+ Returns:
3133+ Output tensor of shape [M, N]
3134+ """
3135+ m, n = x.size()
3136+ out = torch.empty([m, n], dtype=x.dtype, device=x.device)
3137+ _BLOCK_SIZE_0 = 16
3138+ _BLOCK_SIZE_1 = 16
3139+ _BLOCK_SIZE_2 = 16
3140+ _launcher(_helion_welford, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, bias, out, bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
3141+ return out
0 commit comments