Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ..exc import InductorLoweringError
from ..language._decorators import APIFunc
from ..language._decorators import is_api_func
from .ast_extension import ExtendedAST
from .ast_extension import create
from .ast_extension import expr_from_string
from .ast_extension import statement_from_string
Expand Down Expand Up @@ -350,8 +351,22 @@ def visit(n: torch.fx.Node) -> None:
ast_val = expr_from_string(
"tensor[" + ", ".join(expand) + "]", tensor=ast_val
)
input_asts.append(ast_val)
if (
isinstance(ast_val, ast.Name)
and ast_val.id in device_function._constexpr_args
):
# introduce a copy so triton doesn't complain about `id.to(...)` calls
assert isinstance(ast_val, ExtendedAST)
with ast_val:
copy_var = device_function.new_var(f"{ast_val.id}_", dce=True)
ctx.cg.add_statement(
statement_from_string(f"{copy_var} = {ast_val.id}")
)
input_asts.append(expr_from_string(f"{copy_var}"))
else:
input_asts.append(ast_val)

device_function: DeviceFunction = ctx.cg.device_function
ndim: int = max([x.ndim for x in self.input_fake_tensors(node)] or (0,))
input_asts: list[ast.AST] = []
map_arg((node.args, node.kwargs), visit)
Expand Down
48 changes: 48 additions & 0 deletions test/test_misc.expected
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,54 @@ def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_kernel_kernel)(a0, o0, o1, a0.size(0), a0.stride(0), o0.stride(0), o1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestMisc.test_tile_block_size_constexpr_fix)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice

@triton.jit
def _test_tile_block_size_usage_kernel(out, x_size_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
_BLOCK_SIZE_0_ = _BLOCK_SIZE_0
v_0 = _BLOCK_SIZE_0_.to(tl.int32)
v_1 = indices_0 % v_0
v_2 = tl.full([], 0, tl.int32)
v_3 = v_1 != v_2
v_4 = libdevice.signbit(v_1) != 0 if v_1.dtype is tl.float32 else v_1 < 0
v_5 = libdevice.signbit(v_0) != 0 if v_0.dtype is tl.float32 else v_0 < 0
v_6 = v_4 != v_5
v_7 = v_3 & v_6
v_8 = v_1 + v_0
v_9 = tl.where(v_7, v_8, v_1)
sub = -1 + _BLOCK_SIZE_0
v_10 = sub.to(tl.int32)
v_11 = v_9 == v_10
v_12 = tl.full([], 0, tl.int64)
v_13 = tl.full([], 1, tl.int64)
v_14 = v_13[None]
v_15 = v_12[None]
v_16 = tl.where(v_11, v_14, v_15)
v_17 = v_16.to(tl.int32)
tl.store(out + indices_0 * out_stride_0, v_17, mask_0)

def test_tile_block_size_usage(x: torch.Tensor):
out = torch.zeros_like(x, dtype=torch.int32)
_BLOCK_SIZE_0 = 32
_test_tile_block_size_usage_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](out, x.size(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

def _test_tile_block_size_usage_make_precompiler(x: torch.Tensor):
out = torch.zeros_like(x, dtype=torch.int32)
_BLOCK_SIZE_0 = 32
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_test_tile_block_size_usage_kernel)(out, x.size(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestMisc.test_torch_alloc)
from __future__ import annotations

Expand Down
20 changes: 20 additions & 0 deletions test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,26 @@ def test_tile_id(x: torch.Tensor) -> torch.Tensor:
result = test_tile_id.bind((x,)).compile_config(config)(x)
self.assertEqual(result.sum().item(), 16)

def test_tile_block_size_constexpr_fix(self):
"""Test that tile.block_size can be used in expressions without compilation errors."""

@helion.kernel(use_default_config=True)
def test_tile_block_size_usage(x: torch.Tensor) -> torch.Tensor:
out = torch.zeros_like(x, dtype=torch.int32)
for tile in hl.tile(x.shape[0]):
# This should not cause a compilation error when tile.block_size is used
# in expressions that generate .to() calls
block_size_temp = tile.block_size
mask = tile.index % block_size_temp == block_size_temp - 1
out[tile] = torch.where(mask, 1, 0)
return out

x = torch.randn(32, device=DEVICE)
code, result = code_and_output(test_tile_block_size_usage, (x,))
self.assertExpectedJournal(code)
# The result should have 1s at positions that are last in their tile
self.assertTrue(result.sum().item() > 0)


if __name__ == "__main__":
unittest.main()
Loading