Skip to content

Commit 4db264a

Browse files
authored
Relax requirements for inline_triton output_like=None (#1087)
1 parent c7848db commit 4db264a

File tree

2 files changed

+76
-13
lines changed

2 files changed

+76
-13
lines changed

helion/language/inline_triton_ops.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,32 @@ def _(
9696
return _fake_outputs(output_like)
9797

9898

99-
def _ensure_name(state: CodegenState, node: ast.AST) -> str:
99+
def _ensure_name(
100+
state: CodegenState,
101+
node: ast.AST,
102+
original: object,
103+
) -> str:
104+
if (
105+
isinstance(node, ast.Call)
106+
and isinstance(node.func, ast.Name)
107+
and node.func.id == "_host_tensor"
108+
):
109+
if not isinstance(original, torch.Tensor):
110+
raise exc.InvalidAPIUsage(
111+
"inline_triton host tensor placeholders must be torch.Tensor instances"
112+
)
113+
return state.device_function.tensor_arg(original).name
114+
if not isinstance(node, ast.AST):
115+
return repr(node)
116+
if isinstance(node, ast.Constant):
117+
return repr(node.value)
118+
if isinstance(original, torch.Tensor):
119+
try:
120+
tensor_arg = state.device_function.tensor_arg(original)
121+
except KeyError:
122+
pass
123+
else:
124+
return tensor_arg.name
100125
lifted = state.codegen.lift(node)
101126
assert isinstance(lifted, ast.Name)
102127
return lifted.id
@@ -118,9 +143,13 @@ def _format_triton_source(
118143
"inline_triton expects a dict literal when args is a mapping"
119144
)
120145
assert args_obj.keys() == args_ast.keys()
121-
format_args: dict[str, str] = {
122-
key: _ensure_name(state, args_ast[key]) for key in args_ast
123-
}
146+
format_args: dict[str, str] = {}
147+
for key in args_ast:
148+
format_args[key] = _ensure_name(
149+
state,
150+
args_ast[key],
151+
args_obj[key],
152+
)
124153
try:
125154
return source.format(**format_args)
126155
except (KeyError, IndexError, ValueError) as exc_value:
@@ -138,7 +167,10 @@ def _format_triton_source(
138167
if isinstance(args_ast, (ast.List, ast.Tuple))
139168
else list(args_ast)
140169
)
141-
names = [_ensure_name(state, node) for node in arg_nodes]
170+
names = [
171+
_ensure_name(state, node, arg)
172+
for node, arg in zip(arg_nodes, args_obj, strict=False)
173+
]
142174
try:
143175
expected_len = len(args_obj)
144176
except TypeError: # pragma: no cover - defensive
@@ -157,7 +189,10 @@ def _format_triton_source(
157189
raise exc.InvalidAPIUsage("inline_triton args must be a tuple/list or a mapping")
158190

159191

160-
def _parse_triton_source(source: str) -> tuple[list[ast.stmt], ast.AST]:
192+
def _parse_triton_source(
193+
source: str,
194+
require_expression: bool,
195+
) -> tuple[list[ast.stmt], ast.AST | None]:
161196
try:
162197
module = ast.parse(source)
163198
except SyntaxError as exc_value:
@@ -166,16 +201,21 @@ def _parse_triton_source(source: str) -> tuple[list[ast.stmt], ast.AST]:
166201
) from exc_value
167202

168203
if not module.body:
169-
raise exc.InvalidAPIUsage("triton_source must contain at least one expression")
204+
raise exc.InvalidAPIUsage("triton_source must contain code")
170205

171206
*prefix, last = module.body
172-
if not isinstance(last, ast.Expr):
207+
converted_prefix = [cast("ast.stmt", convert(stmt)) for stmt in prefix]
208+
209+
if isinstance(last, ast.Expr):
210+
return converted_prefix, convert(last.value)
211+
212+
if require_expression:
173213
raise exc.InvalidAPIUsage(
174-
"The last line of triton_source must be an expression"
214+
"The last line of triton_source must be an expression when output_like is provided"
175215
)
176216

177-
converted_prefix = [cast("ast.stmt", convert(stmt)) for stmt in prefix]
178-
return converted_prefix, convert(last.value)
217+
converted_prefix.append(cast("ast.stmt", convert(last)))
218+
return converted_prefix, None
179219

180220

181221
def _normalize_output_ast(output_ast: object) -> list[ast.AST]:
@@ -315,12 +355,15 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]:
315355
state.ast_args[1],
316356
)
317357

318-
statements, result_expr = _parse_triton_source(formatted)
358+
statements, result_expr = _parse_triton_source(
359+
formatted, require_expression=output_like is not None
360+
)
319361
for stmt in statements:
320362
state.add_statement(stmt)
321363

322364
if output_like is None:
323-
state.add_statement(create(ast.Expr, value=result_expr))
365+
if result_expr is not None:
366+
state.add_statement(create(ast.Expr, value=result_expr))
324367
return create(ast.Constant, value=None)
325368

326369
result_name = state.device_function.new_var("inline_triton_result")

test/test_inline_triton.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,23 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
169169
bound = kernel.bind((x,))
170170
code = bound.to_triton_code(bound.config_spec.default_config())
171171
self.assertIn("tl.store(", code)
172+
173+
def test_inline_triton_none_output_allows_terminal_statement(self) -> None:
174+
@helion.kernel(autotune_effort="none")
175+
def kernel(grad_x_lock: torch.Tensor) -> torch.Tensor:
176+
for _ in hl.tile(grad_x_lock.shape):
177+
hl.inline_triton(
178+
"""
179+
while tl.atomic_cas({0} + {1}, 0, 1) == 1:
180+
pass
181+
""",
182+
args=(grad_x_lock, 0),
183+
output_like=None,
184+
)
185+
return grad_x_lock
186+
187+
grad_x_lock = torch.ones(4, device=DEVICE, dtype=torch.int32)
188+
bound = kernel.bind((grad_x_lock,))
189+
code = bound.to_triton_code(bound.config_spec.default_config())
190+
self.assertIn("while tl.atomic_cas", code)
191+
self.assertNotIn("_host_tensor", code)

0 commit comments

Comments
 (0)