@@ -1314,6 +1314,44 @@ def _softmax_two_pass_make_precompiler(x: torch.Tensor):
13141314 from helion.runtime.precompile_shim import make_precompiler
13151315 return make_precompiler(_softmax_two_pass_kernel)(x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
13161316
1317+ --- assertExpectedJournal(TestExamples.test_sum)
1318+ from __future__ import annotations
1319+
1320+ import torch
1321+ import triton
1322+ import triton.language as tl
1323+
1324+ @triton.jit
1325+ def _sum_kernel_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, n, _REDUCTION_BLOCK_1: tl.constexpr):
1326+ pid_0 = tl.program_id(0)
1327+ offset_0 = pid_0
1328+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
1329+ sum_1_acc = tl.full([1, _REDUCTION_BLOCK_1], 0, tl.float32)
1330+ for roffset_1 in tl.range(0, n, step=_REDUCTION_BLOCK_1):
1331+ rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
1332+ mask_1 = rindex_1 < n
1333+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_1[None, :], other=0)
1334+ v_0 = sum_1_acc + load
1335+ sum_1_acc = v_0
1336+ sum_1 = tl.sum(sum_1_acc, 1)
1337+ tl.store(out + indices_0 * out_stride_0, sum_1, None)
1338+
1339+ def sum_kernel(x: torch.Tensor):
1340+ """Sum 2D tensor along the last dimension."""
1341+ m, n = x.shape
1342+ out = torch.empty([m], dtype=x.dtype, device=x.device)
1343+ _REDUCTION_BLOCK_1 = 32768
1344+ _sum_kernel_kernel[m,](x, out, out.stride(0), x.stride(0), x.stride(1), n, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
1345+ return out
1346+
1347+ def _sum_kernel_make_precompiler(x: torch.Tensor):
1348+ """Sum 2D tensor along the last dimension."""
1349+ m, n = x.shape
1350+ out = torch.empty([m], dtype=x.dtype, device=x.device)
1351+ _REDUCTION_BLOCK_1 = 32768
1352+ from helion.runtime.precompile_shim import make_precompiler
1353+ return make_precompiler(_sum_kernel_kernel)(x, out, out.stride(0), x.stride(0), x.stride(1), n, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
1354+
13171355--- assertExpectedJournal(TestExamples.test_template_via_closure0)
13181356from __future__ import annotations
13191357
@@ -1490,3 +1528,4 @@ def _matmul_with_epilogue_make_precompiler(x: Tensor, y: Tensor, epilogue: Calla
14901528 _BLOCK_SIZE_2 = 16
14911529 from helion.runtime.precompile_shim import make_precompiler
14921530 return make_precompiler(_matmul_with_epilogue_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1531+
0 commit comments