@@ -964,7 +964,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
964964 acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
965965 tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
966966
967- def matmul(x: torch. Tensor, y: torch. Tensor, *, _launcher=_default_launcher):
967+ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc , *, _launcher=_default_launcher):
968968 m, k = x.size()
969969 k2, n = y.size()
970970 assert k == k2, f'size mismatch {k} != {k2}'
@@ -1131,9 +1131,13 @@ def _matmul_split_k_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1
11311131 load = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), mask_3[None, :], other=0)
11321132 load_1 = tl.load(y + (indices_3[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0)
11331133 acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1134+ eq = offset_2 == 0
1135+ if eq:
1136+ acc_copy_1 = acc
1137+ acc = acc_copy_1
11341138 tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
11351139
1136- def matmul_split_k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1140+ def matmul_split_k(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
11371141 m, k = x.size()
11381142 k2, n = y.size()
11391143 assert k == k2, f'size mismatch {k} != {k2}'
@@ -1610,7 +1614,7 @@ from helion.runtime import default_launcher as _default_launcher
16101614import test.test_examples as _global_source0
16111615
16121616@triton.jit
1613- def _matmul_with_epilogue_kernel (x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1617+ def _matmul_kernel (x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
16141618 num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0)
16151619 num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1)
16161620 inner_2d_pid = tl.program_id(0)
@@ -1640,15 +1644,15 @@ def _matmul_with_epilogue_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: t
16401644 v_4 = v_3.to(tl.float16)
16411645 tl.store(out + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_4, None)
16421646
1643- def matmul_with_epilogue (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor], *, _launcher=_default_launcher):
1647+ def matmul (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc , *, _launcher=_default_launcher):
16441648 m, k = x.size()
16451649 k2, n = y.size()
16461650 assert k == k2, f'size mismatch {k} != {k2}'
16471651 out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
16481652 _BLOCK_SIZE_0 = 64
16491653 _BLOCK_SIZE_1 = 64
16501654 _BLOCK_SIZE_2 = 16
1651- _launcher(_matmul_with_epilogue_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1655+ _launcher(_matmul_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
16521656 return out
16531657
16541658--- assertExpectedJournal(TestExamples.test_template_via_closure1)
@@ -1663,7 +1667,7 @@ from helion.runtime import default_launcher as _default_launcher
16631667import test.test_examples as _global_source0
16641668
16651669@triton.jit
1666- def _matmul_with_epilogue_kernel (x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1670+ def _matmul_kernel (x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
16671671 num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0)
16681672 num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1)
16691673 inner_2d_pid = tl.program_id(0)
@@ -1690,15 +1694,15 @@ def _matmul_with_epilogue_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: t
16901694 v_4 = v_3.to(tl.float16)
16911695 tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_4, boundary_check=[0, 1])
16921696
1693- def matmul_with_epilogue (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor], *, _launcher=_default_launcher):
1697+ def matmul (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc , *, _launcher=_default_launcher):
16941698 m, k = x.size()
16951699 k2, n = y.size()
16961700 assert k == k2, f'size mismatch {k} != {k2}'
16971701 out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
16981702 _BLOCK_SIZE_0 = 64
16991703 _BLOCK_SIZE_1 = 64
17001704 _BLOCK_SIZE_2 = 16
1701- _launcher(_matmul_with_epilogue_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1705+ _launcher(_matmul_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
17021706 return out
17031707
17041708--- assertExpectedJournal(TestExamples.test_template_via_closure2)
@@ -1713,7 +1717,7 @@ from helion.runtime import default_launcher as _default_launcher
17131717import test.test_examples as _global_source0
17141718
17151719@triton.jit
1716- def _matmul_with_epilogue_kernel (x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1720+ def _matmul_kernel (x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
17171721 num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0)
17181722 num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1)
17191723 inner_2d_pid = tl.program_id(0)
@@ -1737,13 +1741,13 @@ def _matmul_with_epilogue_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_
17371741 v_2 = v_1.to(tl.float16)
17381742 tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_2, boundary_check=[0, 1])
17391743
1740- def matmul_with_epilogue (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor], *, _launcher=_default_launcher):
1744+ def matmul (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc , *, _launcher=_default_launcher):
17411745 m, k = x.size()
17421746 k2, n = y.size()
17431747 assert k == k2, f'size mismatch {k} != {k2}'
17441748 out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
17451749 _BLOCK_SIZE_0 = 64
17461750 _BLOCK_SIZE_1 = 64
17471751 _BLOCK_SIZE_2 = 16
1748- _launcher(_matmul_with_epilogue_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1752+ _launcher(_matmul_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
17491753 return out
0 commit comments