@@ -373,7 +373,7 @@ def visit(n: torch.fx.Node) -> None:
373373 # Broadcast to force ranks to match
374374 expand = ["None" ] * (ndim - fake_val .ndim ) + [":" ] * fake_val .ndim
375375 ast_val = expr_from_string (
376- "tensor[" + ", " .join (expand ) + "]" , tensor = ast_val
376+ "{ tensor} [" + ", " .join (expand ) + "]" , tensor = ast_val
377377 )
378378 if (
379379 isinstance (ast_val , ast .Name )
@@ -796,7 +796,7 @@ def codegen_unsqueeze(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
796796 args = [":" ] * ndim
797797 args .insert (dim , "None" )
798798 return expr_from_string (
799- f"tensor[{ ', ' .join (args )} ]" ,
799+ f"{{ tensor}} [{ ', ' .join (args )} ]" ,
800800 tensor = tensor ,
801801 )
802802
@@ -817,7 +817,7 @@ def codegen_view(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
817817 shape_str = ctx .cg .device_function .tile_strategy .shape_str (
818818 [* node .meta ["val" ].size ()]
819819 )
820- return expr_from_string (f"tl.reshape(tensor, { shape_str } )" , tensor = tensor )
820+ return expr_from_string (f"tl.reshape({{ tensor}} , { shape_str } )" , tensor = tensor )
821821
822822
823823@register_lowering (
@@ -831,7 +831,7 @@ def codegen_permute(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
831831 dims = [* dims ] # pyright: ignore[reportGeneralTypeIssues,reportOptionalIterable]
832832 assert {* dims } == {* range (len (dims ))}, dims
833833 return expr_from_string (
834- f"tl.permute(tensor, { dims !r} )" ,
834+ f"tl.permute({{ tensor}} , { dims !r} )" ,
835835 tensor = tensor ,
836836 )
837837
@@ -851,10 +851,12 @@ def codegen_expand(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
851851 broadcasting = [":" ] * len (shape )
852852 for i in range (len (shape ) - node .args [0 ].meta ["val" ].ndim ): # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
853853 broadcasting [i ] = "None"
854- tensor = expr_from_string (f"tensor[{ ', ' .join (broadcasting )} ]" , tensor = tensor )
854+ tensor = expr_from_string (
855+ f"{{tensor}}[{ ', ' .join (broadcasting )} ]" , tensor = tensor
856+ )
855857 shape_str = ctx .cg .device_function .tile_strategy .shape_str (shape )
856858 return expr_from_string (
857- f"tl.broadcast_to(tensor, { shape_str } )" ,
859+ f"tl.broadcast_to({{ tensor}} , { shape_str } )" ,
858860 tensor = tensor ,
859861 )
860862
@@ -945,7 +947,7 @@ def reduce_3d_dot(
945947 f", input_precision={ datatype !r} " if datatype is not None else ""
946948 )
947949 return expr_from_string (
948- f"tl.dot(lhs, rhs, acc=acc{ precision_arg } )" ,
950+ f"tl.dot({{ lhs}}, {{ rhs}} , acc={{ acc}} { precision_arg } )" ,
949951 lhs = lhs ,
950952 rhs = rhs ,
951953 acc = acc , # pyright: ignore[reportArgumentType]
@@ -954,7 +956,9 @@ def reduce_3d_dot(
954956 precision_arg = (
955957 f", input_precision={ datatype !r} " if datatype is not None else ""
956958 )
957- return expr_from_string (f"tl.dot(lhs, rhs{ precision_arg } )" , lhs = lhs , rhs = rhs )
959+ return expr_from_string (
960+ f"tl.dot({{lhs}}, {{rhs}}{ precision_arg } )" , lhs = lhs , rhs = rhs
961+ )
958962
959963 # create reshape, dot, then reshape
960964 lhs_shape_str = ctx .cg .device_function .tile_strategy .shape_str (
@@ -966,18 +970,18 @@ def reduce_3d_dot(
966970 out_shape_str = ctx .cg .device_function .tile_strategy .shape_str (
967971 [* node .meta ["val" ].size ()]
968972 )
969- lhs_reshape = expr_from_string (f"tl.reshape(lhs, { lhs_shape_str } )" , lhs = lhs )
970- rhs_reshape = expr_from_string (f"tl.reshape(rhs, { rhs_shape_str } )" , rhs = rhs )
973+ lhs_reshape = expr_from_string (f"tl.reshape({{ lhs}} , { lhs_shape_str } )" , lhs = lhs )
974+ rhs_reshape = expr_from_string (f"tl.reshape({{ rhs}} , { rhs_shape_str } )" , rhs = rhs )
971975 if with_acc :
972976 acc_shape_str = ctx .cg .device_function .tile_strategy .shape_str (
973977 [* node .args [0 ].meta ["val" ].size ()[1 :]] # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
974978 )
975- acc_reshape = expr_from_string (f"tl.reshape(rhs, { acc_shape_str } )" , rhs = acc ) # pyright: ignore[reportArgumentType]
979+ acc_reshape = expr_from_string (f"tl.reshape({{ rhs}} , { acc_shape_str } )" , rhs = acc ) # pyright: ignore[reportArgumentType]
976980 precision_arg = (
977981 f", input_precision={ datatype !r} " if datatype is not None else ""
978982 )
979983 comp = expr_from_string (
980- f"tl.dot(lhs, rhs, acc=acc{ precision_arg } )" ,
984+ f"tl.dot({{ lhs}}, {{ rhs}} , acc={{ acc}} { precision_arg } )" ,
981985 lhs = lhs_reshape ,
982986 rhs = rhs_reshape ,
983987 acc = acc_reshape ,
@@ -987,11 +991,11 @@ def reduce_3d_dot(
987991 f", input_precision={ datatype !r} " if datatype is not None else ""
988992 )
989993 comp = expr_from_string (
990- f"tl.dot(lhs, rhs{ precision_arg } )" ,
994+ f"tl.dot({{ lhs}}, {{ rhs}} { precision_arg } )" ,
991995 lhs = lhs_reshape ,
992996 rhs = rhs_reshape ,
993997 )
994- return expr_from_string (f"tl.reshape(lhs, { out_shape_str } )" , lhs = comp )
998+ return expr_from_string (f"tl.reshape({{ lhs}} , { out_shape_str } )" , lhs = comp )
995999
9961000
9971001@register_lowering (torch .ops .aten .bmm .default , apply_dot_requirements ) # pyright: ignore[reportAttributeAccessIssue]
@@ -1122,7 +1126,9 @@ def _create_named_result(self, node: Node, result: ast.expr) -> str:
11221126
11231127 # Regular variable assignment
11241128 name = self .cg .device_function .new_var (node .name )
1125- self .cg .add_statement (statement_from_string (f"{ name } = result" , result = result ))
1129+ self .cg .add_statement (
1130+ statement_from_string (f"{ name } = {{result}}" , result = result )
1131+ )
11261132 return name
11271133
11281134 def _collect_multi_outputs (
@@ -1160,7 +1166,7 @@ def _collect_multi_outputs(
11601166 if not isinstance (result , ast .Name ):
11611167 var_name = self .cg .device_function .new_var (f"{ node .name } _output{ i } " )
11621168 self .cg .add_statement (
1163- statement_from_string (f"{ var_name } = result" , result = result )
1169+ statement_from_string (f"{ var_name } = {{ result}} " , result = result )
11641170 )
11651171 result = create (ast .Name , id = var_name , ctx = ast .Load ())
11661172 final_outputs .append (result )
@@ -1239,7 +1245,9 @@ def codegen_call_with_graph(
12391245 # Phi nodes will merge variable names from outside the loop, but the old value
12401246 # of those variables could have usages.
12411247 copy_name = cg .device_function .new_var (arg .id + "_copy" )
1242- cg .add_statement (statement_from_string (f"{ copy_name } = arg" , arg = arg ))
1248+ cg .add_statement (
1249+ statement_from_string (f"{ copy_name } = {{arg}}" , arg = arg )
1250+ )
12431251 new_args .append (expr_from_string (copy_name ))
12441252 else :
12451253 new_args .append (cg .lift (arg ))
@@ -1296,11 +1304,11 @@ def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
12961304 )
12971305 assert isinstance (dtype , torch .dtype )
12981306 (length_arg ,) = node .args # expecting a single argument for length
1299- expr = "tl.arange(0, length)"
1307+ expr = "tl.arange(0, { length} )"
13001308 if step != 1 :
1301- expr = f"step * { expr } "
1309+ expr = f"{{ step}} * { expr } "
13021310 if start != 0 :
1303- expr = f"start + { expr } "
1311+ expr = f"{{ start}} + { expr } "
13041312 if dtype != torch .int32 :
13051313 expr = f"({ expr } ).to({ triton_type (dtype )} )"
13061314 return expr_from_string (
0 commit comments