Skip to content

Commit 08de077

Browse files
authored
Prevent naming conflicts in expr_from_string placeholder replacement (#519)
1 parent d3d9a48 commit 08de077

16 files changed

+146
-71
lines changed

helion/_compiler/ast_extension.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import enum
5+
import re
56
import threading
67
import typing
78
from typing import TYPE_CHECKING
@@ -158,16 +159,68 @@ def create_arguments(args: list[ast.arg]) -> ast.arguments:
158159

159160

160161
def statement_from_string(template: str, **placeholders: ast.AST) -> ast.stmt:
161-
(statement,) = ast.parse(template).body
162+
"""
163+
Create an AST statement from a template string with placeholders.
164+
165+
Uses {placeholder} syntax to mark placeholders that should be replaced with AST nodes.
166+
This supports two common patterns:
167+
168+
1. Regular strings - placeholders use single braces:
169+
expr_from_string("tl.load({ptr} + {offset}, {mask})",
170+
ptr=ptr_ast, offset=offset_ast, mask=mask_ast)
171+
172+
2. f-strings - placeholders use double braces (which become single braces):
173+
name = "my_tensor"
174+
expr_from_string(f"tl.load({name} + {{offset}}, {{mask}})",
175+
offset=offset_ast, mask=mask_ast)
176+
# In the f-string, {name} is interpolated to "my_tensor",
177+
# while {{offset}} becomes {offset} for placeholder replacement
178+
"""
162179
location: SourceLocation = current_location()
163180

181+
# Find all placeholders and validate
182+
pattern = r"\{(\w+)\}(?!:)" # {word} not followed by colon (avoid dict keys)
183+
used = set(re.findall(pattern, template))
184+
if missing := used - placeholders.keys():
185+
raise KeyError(f"Missing placeholders: {sorted(missing)}")
186+
187+
# Replace placeholders with unique identifiers to avoid naming conflicts
188+
# For example, "{x}" in "x = {x}" must not conflict with the variable "x"
189+
mapping = {}
190+
191+
def make_unique(m: re.Match[str]) -> str:
192+
# Extract placeholder name from the regex match (e.g., "offset" from "{offset}")
193+
name = m.group(1)
194+
# Create a unique identifier that can't exist in user code
195+
# Using double underscores and "placeholder" to ensure uniqueness
196+
uid = f"__placeholder_{len(mapping)}__"
197+
# Store the mapping from unique ID to the actual AST node
198+
mapping[uid] = placeholders[name]
199+
return uid
200+
201+
# First pass: Replace all {placeholder} with __placeholder_N__ in the template
202+
# This prevents conflicts and allows ast.parse to create a valid AST
203+
modified_template = re.sub(pattern, make_unique, template)
204+
205+
# Parse the modified template into an AST
206+
(statement,) = ast.parse(modified_template).body
207+
208+
# Second pass: Recursively walk the AST and replace __placeholder_N__ identifiers
209+
# with the actual AST nodes provided by the user
164210
def _replace(node: _R) -> _R:
211+
# Handle lists by recursively transforming each element
165212
if isinstance(node, list):
166213
return [_replace(item) for item in node] # pyright: ignore[reportReturnType]
214+
215+
# Pass through non-AST nodes unchanged (e.g., strings, numbers)
167216
if not isinstance(node, ast.AST):
168217
return node
169-
if isinstance(node, ast.Name) and node.id in placeholders:
170-
return placeholders[node.id] # pyright: ignore[reportReturnType]
218+
219+
# Replace placeholder names with their corresponding AST nodes
220+
if isinstance(node, ast.Name) and node.id in mapping:
221+
return mapping[node.id] # pyright: ignore[reportReturnType]
222+
223+
# Recursively transform all child nodes and wrap in ExtendedAST subclass
171224
cls = get_wrapper_cls(type(node))
172225
return location.to_ast( # pyright: ignore[reportReturnType]
173226
cls(
@@ -176,6 +229,7 @@ def _replace(node: _R) -> _R:
176229
)
177230
)
178231

232+
# Apply the second pass transformation to replace all placeholders
179233
return _replace(statement)
180234

181235

helion/_compiler/device_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,8 +523,8 @@ def codegen_function_call(self) -> ast.AST:
523523
assert pid is not None
524524
# TODO(jansel): we should run CSE this statement
525525
call_statement = statement_from_string(
526-
f"_launcher({self.name}, __call_grid_expr, {', '.join(args)})",
527-
__call_grid_expr=pid.codegen_grid(),
526+
f"_launcher({self.name}, {{call_grid_expr}}, {', '.join(args)})",
527+
call_grid_expr=pid.codegen_grid(),
528528
)
529529
assert isinstance(call_statement, ExtendedAST)
530530
# Mark the kernel call we can find it in codegen_precompile_def

helion/_compiler/generate_ast.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Na
8080
assert isinstance(expr, ExtendedAST), expr
8181
with expr:
8282
varname = self.tmpvar(dce=dce, prefix=prefix)
83-
self.add_statement(statement_from_string(f"{varname} = expr", expr=expr))
83+
self.add_statement(
84+
statement_from_string(f"{varname} = {{expr}}", expr=expr)
85+
)
8486
return create(ast.Name, id=varname, ctx=ast.Load())
8587

8688
@contextlib.contextmanager

helion/_compiler/helper_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Na
4343
if isinstance(expr, ast.Name):
4444
return expr
4545
varname = self.tmpvar(dce=dce, prefix=prefix)
46-
self.add_statement(statement_from_string(f"{varname} = expr", expr=expr))
46+
self.add_statement(statement_from_string(f"{varname} = {{expr}}", expr=expr))
4747
return create(ast.Name, id=varname, ctx=ast.Load())
4848

4949

helion/_compiler/indexing_strategy.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def codegen_load(
8787
extra = ", other=0"
8888
name = state.device_function.tensor_arg(fake_tensor).name
8989
return expr_from_string(
90-
f"tl.load({name} + offset, mask{extra})",
90+
f"tl.load({name} + {{offset}}, {{mask}}{extra})",
9191
offset=indexing.index_expr,
9292
mask=indexing.mask_expr,
9393
)
@@ -103,7 +103,7 @@ def codegen_store(
103103
indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
104104
name = state.device_function.tensor_arg(fake_tensor).name
105105
return expr_from_string(
106-
f"tl.store({name} + offset, value, mask)",
106+
f"tl.store({name} + {{offset}}, {{value}}, {{mask}})",
107107
value=value,
108108
offset=indexing.index_expr,
109109
mask=indexing.mask_expr,
@@ -131,7 +131,7 @@ def codegen_load(
131131
return indexing.reshape_load(
132132
state,
133133
expr_from_string(
134-
f"tl.load(block_ptr, boundary_check={indexing.boundary_check(state)}, padding_option='zero')",
134+
f"tl.load({{block_ptr}}, boundary_check={indexing.boundary_check(state)}, padding_option='zero')",
135135
block_ptr=indexing.make_block_ptr(state),
136136
),
137137
)
@@ -153,7 +153,7 @@ def codegen_store(
153153
assert extra_mask is None
154154
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
155155
return expr_from_string(
156-
f"tl.store(block_ptr, value, boundary_check={indexing.boundary_check(state)})",
156+
f"tl.store({{block_ptr}}, {{value}}, boundary_check={indexing.boundary_check(state)})",
157157
block_ptr=indexing.make_block_ptr(state),
158158
value=indexing.reshape_store(state, value),
159159
)
@@ -268,7 +268,7 @@ def codegen_load(
268268
desc_arg = indexing.tensor_descriptor_arg(state)
269269
if desc_arg.permutation is not None:
270270
load_expr = expr_from_string(
271-
f"tl.permute(load_result, {desc_arg.inverse_permutation!r})",
271+
f"tl.permute({{load_result}}, {desc_arg.inverse_permutation!r})",
272272
load_result=load_expr,
273273
)
274274

@@ -296,12 +296,12 @@ def codegen_store(
296296
if desc_arg.permutation is not None:
297297
# Apply permutation to the value
298298
store_value = expr_from_string(
299-
f"tl.permute(store_val, {desc_arg.permutation!r})",
299+
f"tl.permute({{store_val}}, {desc_arg.permutation!r})",
300300
store_val=store_value,
301301
)
302302

303303
return expr_from_string(
304-
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, value)",
304+
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})",
305305
value=store_value,
306306
)
307307

@@ -372,7 +372,7 @@ def get_mask_expr(
372372
mask_exprs.append(dev_ptr_mask_expr)
373373

374374
if indexing.has_mask():
375-
mask_exprs.append(f"(tensor_mask){tensor_broadcast}")
375+
mask_exprs.append(f"({{tensor_mask}}){tensor_broadcast}")
376376
return expr_from_string(
377377
"&".join(mask_exprs), tensor_mask=indexing.mask_expr
378378
)
@@ -407,7 +407,7 @@ def codegen_load(
407407

408408
dtype = triton_type(tensor_like.dtype)
409409
return expr_from_string(
410-
f"tl.load((base.to(tl.pointer_type({dtype}))){stack_broadcast} + (offset){tensor_broadcast}, mask{extra})",
410+
f"tl.load(({{base}}.to(tl.pointer_type({dtype}))){stack_broadcast} + ({{offset}}){tensor_broadcast}, {{mask}}{extra})",
411411
base=dev_ptrs_ast,
412412
offset=indexing.index_expr,
413413
mask=mask_expr,
@@ -439,7 +439,7 @@ def codegen_store(
439439

440440
dtype = triton_type(tensor_like.dtype)
441441
return expr_from_string(
442-
f"tl.store(base.to(tl.pointer_type({dtype})){stack_broadcast} + (offset){tensor_broadcast}, value, mask)",
442+
f"tl.store({{base}}.to(tl.pointer_type({dtype})){stack_broadcast} + ({{offset}}){tensor_broadcast}, {{value}}, {{mask}})",
443443
base=dev_ptrs_ast,
444444
value=value,
445445
offset=indexing.index_expr,
@@ -616,7 +616,7 @@ def create(
616616

617617
kwargs = {}
618618
if extra_mask is not None:
619-
mask_values.setdefault("_extra_mask")
619+
mask_values.setdefault("{_extra_mask}")
620620
kwargs["_extra_mask"] = extra_mask
621621
return SubscriptIndexing(
622622
expr_from_string("+".join(index_expr)),
@@ -710,13 +710,13 @@ def reshape_load(self, state: CodegenState, node: ast.AST) -> ast.AST:
710710
if not self.need_reshape(node):
711711
return node
712712
shape = state.tile_strategy.shape_str(self.reshaped_size)
713-
return expr_from_string(f"tl.reshape(node, {shape})", node=node)
713+
return expr_from_string(f"tl.reshape({{node}}, {shape})", node=node)
714714

715715
def reshape_store(self, state: CodegenState, node: ast.AST) -> ast.AST:
716716
if not self.need_reshape(node):
717717
return node
718718
shape = state.tile_strategy.shape_str(self.block_shape)
719-
return expr_from_string(f"tl.reshape(node, {shape})", node=node)
719+
return expr_from_string(f"tl.reshape({{node}}, {shape})", node=node)
720720

721721
@staticmethod
722722
def is_supported(

helion/_compiler/inductor_lowering.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def visit(n: torch.fx.Node) -> None:
373373
# Broadcast to force ranks to match
374374
expand = ["None"] * (ndim - fake_val.ndim) + [":"] * fake_val.ndim
375375
ast_val = expr_from_string(
376-
"tensor[" + ", ".join(expand) + "]", tensor=ast_val
376+
"{tensor}[" + ", ".join(expand) + "]", tensor=ast_val
377377
)
378378
if (
379379
isinstance(ast_val, ast.Name)
@@ -796,7 +796,7 @@ def codegen_unsqueeze(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
796796
args = [":"] * ndim
797797
args.insert(dim, "None")
798798
return expr_from_string(
799-
f"tensor[{', '.join(args)}]",
799+
f"{{tensor}}[{', '.join(args)}]",
800800
tensor=tensor,
801801
)
802802

@@ -817,7 +817,7 @@ def codegen_view(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
817817
shape_str = ctx.cg.device_function.tile_strategy.shape_str(
818818
[*node.meta["val"].size()]
819819
)
820-
return expr_from_string(f"tl.reshape(tensor, {shape_str})", tensor=tensor)
820+
return expr_from_string(f"tl.reshape({{tensor}}, {shape_str})", tensor=tensor)
821821

822822

823823
@register_lowering(
@@ -831,7 +831,7 @@ def codegen_permute(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
831831
dims = [*dims] # pyright: ignore[reportGeneralTypeIssues,reportOptionalIterable]
832832
assert {*dims} == {*range(len(dims))}, dims
833833
return expr_from_string(
834-
f"tl.permute(tensor, {dims!r})",
834+
f"tl.permute({{tensor}}, {dims!r})",
835835
tensor=tensor,
836836
)
837837

@@ -851,10 +851,12 @@ def codegen_expand(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
851851
broadcasting = [":"] * len(shape)
852852
for i in range(len(shape) - node.args[0].meta["val"].ndim): # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
853853
broadcasting[i] = "None"
854-
tensor = expr_from_string(f"tensor[{', '.join(broadcasting)}]", tensor=tensor)
854+
tensor = expr_from_string(
855+
f"{{tensor}}[{', '.join(broadcasting)}]", tensor=tensor
856+
)
855857
shape_str = ctx.cg.device_function.tile_strategy.shape_str(shape)
856858
return expr_from_string(
857-
f"tl.broadcast_to(tensor, {shape_str})",
859+
f"tl.broadcast_to({{tensor}}, {shape_str})",
858860
tensor=tensor,
859861
)
860862

@@ -945,7 +947,7 @@ def reduce_3d_dot(
945947
f", input_precision={datatype!r}" if datatype is not None else ""
946948
)
947949
return expr_from_string(
948-
f"tl.dot(lhs, rhs, acc=acc{precision_arg})",
950+
f"tl.dot({{lhs}}, {{rhs}}, acc={{acc}}{precision_arg})",
949951
lhs=lhs,
950952
rhs=rhs,
951953
acc=acc, # pyright: ignore[reportArgumentType]
@@ -954,7 +956,9 @@ def reduce_3d_dot(
954956
precision_arg = (
955957
f", input_precision={datatype!r}" if datatype is not None else ""
956958
)
957-
return expr_from_string(f"tl.dot(lhs, rhs{precision_arg})", lhs=lhs, rhs=rhs)
959+
return expr_from_string(
960+
f"tl.dot({{lhs}}, {{rhs}}{precision_arg})", lhs=lhs, rhs=rhs
961+
)
958962

959963
# create reshape, dot, then reshape
960964
lhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
@@ -966,18 +970,18 @@ def reduce_3d_dot(
966970
out_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
967971
[*node.meta["val"].size()]
968972
)
969-
lhs_reshape = expr_from_string(f"tl.reshape(lhs, {lhs_shape_str})", lhs=lhs)
970-
rhs_reshape = expr_from_string(f"tl.reshape(rhs, {rhs_shape_str})", rhs=rhs)
973+
lhs_reshape = expr_from_string(f"tl.reshape({{lhs}}, {lhs_shape_str})", lhs=lhs)
974+
rhs_reshape = expr_from_string(f"tl.reshape({{rhs}}, {rhs_shape_str})", rhs=rhs)
971975
if with_acc:
972976
acc_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
973977
[*node.args[0].meta["val"].size()[1:]] # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
974978
)
975-
acc_reshape = expr_from_string(f"tl.reshape(rhs, {acc_shape_str})", rhs=acc) # pyright: ignore[reportArgumentType]
979+
acc_reshape = expr_from_string(f"tl.reshape({{rhs}}, {acc_shape_str})", rhs=acc) # pyright: ignore[reportArgumentType]
976980
precision_arg = (
977981
f", input_precision={datatype!r}" if datatype is not None else ""
978982
)
979983
comp = expr_from_string(
980-
f"tl.dot(lhs, rhs, acc=acc{precision_arg})",
984+
f"tl.dot({{lhs}}, {{rhs}}, acc={{acc}}{precision_arg})",
981985
lhs=lhs_reshape,
982986
rhs=rhs_reshape,
983987
acc=acc_reshape,
@@ -987,11 +991,11 @@ def reduce_3d_dot(
987991
f", input_precision={datatype!r}" if datatype is not None else ""
988992
)
989993
comp = expr_from_string(
990-
f"tl.dot(lhs, rhs{precision_arg})",
994+
f"tl.dot({{lhs}}, {{rhs}}{precision_arg})",
991995
lhs=lhs_reshape,
992996
rhs=rhs_reshape,
993997
)
994-
return expr_from_string(f"tl.reshape(lhs, {out_shape_str})", lhs=comp)
998+
return expr_from_string(f"tl.reshape({{lhs}}, {out_shape_str})", lhs=comp)
995999

9961000

9971001
@register_lowering(torch.ops.aten.bmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue]
@@ -1122,7 +1126,9 @@ def _create_named_result(self, node: Node, result: ast.expr) -> str:
11221126

11231127
# Regular variable assignment
11241128
name = self.cg.device_function.new_var(node.name)
1125-
self.cg.add_statement(statement_from_string(f"{name} = result", result=result))
1129+
self.cg.add_statement(
1130+
statement_from_string(f"{name} = {{result}}", result=result)
1131+
)
11261132
return name
11271133

11281134
def _collect_multi_outputs(
@@ -1160,7 +1166,7 @@ def _collect_multi_outputs(
11601166
if not isinstance(result, ast.Name):
11611167
var_name = self.cg.device_function.new_var(f"{node.name}_output{i}")
11621168
self.cg.add_statement(
1163-
statement_from_string(f"{var_name} = result", result=result)
1169+
statement_from_string(f"{var_name} = {{result}}", result=result)
11641170
)
11651171
result = create(ast.Name, id=var_name, ctx=ast.Load())
11661172
final_outputs.append(result)
@@ -1239,7 +1245,9 @@ def codegen_call_with_graph(
12391245
# Phi nodes will merge variable names from outside the loop, but the old value
12401246
# of those variables could have usages.
12411247
copy_name = cg.device_function.new_var(arg.id + "_copy")
1242-
cg.add_statement(statement_from_string(f"{copy_name} = arg", arg=arg))
1248+
cg.add_statement(
1249+
statement_from_string(f"{copy_name} = {{arg}}", arg=arg)
1250+
)
12431251
new_args.append(expr_from_string(copy_name))
12441252
else:
12451253
new_args.append(cg.lift(arg))
@@ -1296,11 +1304,11 @@ def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
12961304
)
12971305
assert isinstance(dtype, torch.dtype)
12981306
(length_arg,) = node.args # expecting a single argument for length
1299-
expr = "tl.arange(0, length)"
1307+
expr = "tl.arange(0, {length})"
13001308
if step != 1:
1301-
expr = f"step * {expr}"
1309+
expr = f"{{step}} * {expr}"
13021310
if start != 0:
1303-
expr = f"start + {expr}"
1311+
expr = f"{{start}} + {expr}"
13041312
if dtype != torch.int32:
13051313
expr = f"({expr}).to({triton_type(dtype)})"
13061314
return expr_from_string(

0 commit comments

Comments
 (0)