Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

import torch

import helion
from helion._testing import run_example
import helion.language as hl


@helion.kernel()
def sum_kernel(x: torch.Tensor) -> torch.Tensor:
"""Sum 2D tensor along the last dimension."""
m, n = x.shape
out = torch.empty([m], dtype=x.dtype, device=x.device)

for tile_m in hl.tile(m):
out[tile_m] = x[tile_m, :].sum(-1)

return out


def check(m: int, n: int) -> None:
x = torch.randn([m, n], device="cuda", dtype=torch.float32)
kernels = {"helion": sum_kernel}
run_example(kernels, lambda x: x.sum(-1), (x,))


def main() -> None:
check(512, 256)
check(1024, 1024)


if __name__ == "__main__":
main()
39 changes: 39 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,44 @@ def _softmax_two_pass_make_precompiler(x: torch.Tensor):
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_softmax_two_pass_kernel)(x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestExamples.test_sum)
from __future__ import annotations

import torch
import triton
import triton.language as tl

@triton.jit
def _sum_kernel_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, n, _REDUCTION_BLOCK_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
indices_0 = offset_0 + tl.zeros([1], tl.int32)
sum_1_acc = tl.full([1, _REDUCTION_BLOCK_1], 0, tl.float32)
for roffset_1 in tl.range(0, n, step=_REDUCTION_BLOCK_1):
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
mask_1 = rindex_1 < n
load = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_1[None, :], other=0)
v_0 = sum_1_acc + load
sum_1_acc = v_0
sum_1 = tl.sum(sum_1_acc, 1)
tl.store(out + indices_0 * out_stride_0, sum_1, None)

def sum_kernel(x: torch.Tensor):
"""Sum 2D tensor along the last dimension."""
m, n = x.shape
out = torch.empty([m], dtype=x.dtype, device=x.device)
_REDUCTION_BLOCK_1 = 32768
_sum_kernel_kernel[m,](x, out, out.stride(0), x.stride(0), x.stride(1), n, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
return out

def _sum_kernel_make_precompiler(x: torch.Tensor):
"""Sum 2D tensor along the last dimension."""
m, n = x.shape
out = torch.empty([m], dtype=x.dtype, device=x.device)
_REDUCTION_BLOCK_1 = 32768
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_sum_kernel_kernel)(x, out, out.stride(0), x.stride(0), x.stride(1), n, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestExamples.test_template_via_closure0)
from __future__ import annotations

Expand Down Expand Up @@ -1490,3 +1528,4 @@ def _matmul_with_epilogue_make_precompiler(x: Tensor, y: Tensor, epilogue: Calla
_BLOCK_SIZE_2 = 16
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_matmul_with_epilogue_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)

13 changes: 13 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,19 @@ def test_matmul_split_k(self):
)
)

def test_sum(self):
args = (torch.randn([512, 512], device=DEVICE, dtype=torch.float32),)
self.assertExpectedJournal(
check_example(
"sum",
args,
torch.sum(args[0], dim=-1),
fn_name="sum_kernel",
block_sizes=[1],
reduction_loops=[32768],
)
)


if __name__ == "__main__":
unittest.main()
Loading