@@ -855,6 +855,62 @@ def addToBoth(a, b, c, *, _launcher=_default_launcher):
855855 _launcher(_addToBoth_kernel, (triton.cdiv(a_n, _BLOCK_SIZE_0) * triton.cdiv(a_m, _BLOCK_SIZE_1) + triton.cdiv(b_n, _BLOCK_SIZE_2) * triton.cdiv(b_m, _BLOCK_SIZE_3) + triton.cdiv(c_n, _BLOCK_SIZE_4) * triton.cdiv(c_m, _BLOCK_SIZE_5),), x0, x1, x2, x0.stride(0), x0.stride(1), x1.stride(0), x1.stride(1), x2.stride(0), x2.stride(1), a_n, a_m, c0, b_n, b_m, c1, c_n, c_m, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=3)
856856 return (x0, x1, x2)
857857
858+ --- assertExpectedJournal(TestLoops.test_nested_loop_accumulator)
859+ from __future__ import annotations
860+
861+ import torch
862+ import triton
863+ import triton.language as tl
864+ from helion.runtime import default_launcher as _default_launcher
865+
866+ @triton.jit
867+ def _nested_loop_accumulator_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, x_stride_0, x_stride_1, x_stride_2, N, M, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
868+ pid_0 = tl.program_id(0)
869+ offset_0 = pid_0
870+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
871+ acc = tl.full([1], 0.0, tl.float32)
872+ for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1):
873+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
874+ mask_1 = indices_1 < N
875+ acc_copy = acc
876+ acc = acc_copy
877+ for offset_2 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_2):
878+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
879+ mask_2 = indices_2 < M
880+ acc_copy_0_copy = acc
881+ acc_copy_0_copy_0 = acc_copy_0_copy
882+ vals = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_1[None, :, None] * x_stride_1 + indices_2[None, None, :] * x_stride_2), mask_1[None, :, None] & mask_2[None, None, :], other=0)
883+ sum_1 = tl.sum(vals, 2)
884+ sum_2 = tl.sum(sum_1, 1)
885+ acc = acc_copy_0_copy_0 + sum_2
886+ mul = M * N
887+ v_1 = mul.to(tl.float32)
888+ v_2 = acc / v_1
889+ for offset_3 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_3):
890+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
891+ mask_3 = indices_3 < N
892+ v_2_copy = v_2
893+ v_2_copy_0 = v_2_copy
894+ for offset_4 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_4):
895+ indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
896+ mask_4 = indices_4 < M
897+ v_2_copy_0_copy = v_2_copy_0
898+ v_2_copy_0_copy_0 = v_2_copy_0_copy
899+ vals_1 = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_3[None, :, None] * x_stride_1 + indices_4[None, None, :] * x_stride_2), mask_3[None, :, None] & mask_4[None, None, :], other=0)
900+ subscript = v_2_copy_0_copy_0[:, None, None]
901+ v_3 = vals_1 - subscript
902+ tl.store(out + (indices_0[:, None, None] * out_stride_0 + indices_3[None, :, None] * out_stride_1 + indices_4[None, None, :] * out_stride_2), v_3, mask_3[None, :, None] & mask_4[None, None, :])
903+
904+ def nested_loop_accumulator(x: torch.Tensor, *, _launcher=_default_launcher):
905+ B, N, M = x.size()
906+ out = torch.zeros_like(x)
907+ _BLOCK_SIZE_1 = 2
908+ _BLOCK_SIZE_2 = 4
909+ _BLOCK_SIZE_3 = 2
910+ _BLOCK_SIZE_4 = 4
911+ _launcher(_nested_loop_accumulator_kernel, (B,), x, out, out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), N, M, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3)
912+ return out
913+
858914--- assertExpectedJournal(TestLoops.test_pointwise_device_loop)
859915from __future__ import annotations
860916
@@ -977,3 +1033,70 @@ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
9771033 _BLOCK_SIZE_2 = 64
9781034 _launcher(_matmul_kernel, (triton.cdiv(256, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
9791035 return out
1036+
1037+ --- assertExpectedJournal(TestLoops.test_three_pass_kernel)
1038+ from __future__ import annotations
1039+
1040+ import torch
1041+ import triton
1042+ import triton.language as tl
1043+ from torch._inductor.runtime.triton_compat import libdevice
1044+ from helion.runtime import default_launcher as _default_launcher
1045+
1046+ @triton.jit
1047+ def _three_pass_kernel_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, B, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1048+ pid_0 = tl.program_id(0)
1049+ offset_0 = pid_0 * _BLOCK_SIZE_0
1050+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1051+ mask_0 = indices_0 < B
1052+ sum_val = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
1053+ for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1):
1054+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1055+ mask_1 = indices_1 < M
1056+ sum_val_copy = sum_val
1057+ sum_val_copy_0 = sum_val_copy
1058+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1059+ sum_1 = tl.sum(load, 1)
1060+ sum_val = sum_val_copy_0 + sum_1
1061+ sum_sq = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
1062+ for offset_2 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_2):
1063+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1064+ mask_2 = indices_2 < M
1065+ sum_sq_copy = sum_sq
1066+ sum_sq_copy_0 = sum_sq_copy
1067+ vals = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1068+ v_1 = vals * vals
1069+ sum_2 = tl.sum(v_1, 1)
1070+ sum_sq = sum_sq_copy_0 + sum_2
1071+ v_3 = M.to(tl.float32)
1072+ v_4 = sum_val / v_3
1073+ v_5 = M.to(tl.float32)
1074+ v_6 = sum_sq / v_5
1075+ v_7 = v_4 * v_4
1076+ v_8 = v_6 - v_7
1077+ v_9 = 1e-06
1078+ v_10 = v_8 + v_9
1079+ v_11 = libdevice.sqrt(v_10)
1080+ for offset_3 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_3):
1081+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1082+ mask_3 = indices_3 < M
1083+ v_4_copy = v_4
1084+ v_11_copy = v_11
1085+ v_4_copy_0 = v_4_copy
1086+ v_11_copy_0 = v_11_copy
1087+ vals_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_0[:, None] & mask_3[None, :], other=0)
1088+ subscript = v_4_copy_0[:, None]
1089+ v_12 = vals_1 - subscript
1090+ subscript_1 = v_11_copy_0[:, None]
1091+ v_13 = v_12 / subscript_1
1092+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_3[None, :] * out_stride_1), v_13, mask_0[:, None] & mask_3[None, :])
1093+
1094+ def three_pass_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
1095+ B, M = x.size()
1096+ out = torch.zeros_like(x)
1097+ _BLOCK_SIZE_0 = 2
1098+ _BLOCK_SIZE_1 = 8
1099+ _BLOCK_SIZE_2 = 8
1100+ _BLOCK_SIZE_3 = 8
1101+ _launcher(_three_pass_kernel_kernel, (triton.cdiv(B, _BLOCK_SIZE_0),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), B, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1102+ return out
0 commit comments