@@ -96,7 +96,32 @@ def _(
9696 return _fake_outputs (output_like )
9797
9898
99- def _ensure_name (state : CodegenState , node : ast .AST ) -> str :
99+ def _ensure_name (
100+ state : CodegenState ,
101+ node : ast .AST ,
102+ original : object ,
103+ ) -> str :
104+ if (
105+ isinstance (node , ast .Call )
106+ and isinstance (node .func , ast .Name )
107+ and node .func .id == "_host_tensor"
108+ ):
109+ if not isinstance (original , torch .Tensor ):
110+ raise exc .InvalidAPIUsage (
111+ "inline_triton host tensor placeholders must be torch.Tensor instances"
112+ )
113+ return state .device_function .tensor_arg (original ).name
114+ if not isinstance (node , ast .AST ):
115+ return repr (node )
116+ if isinstance (node , ast .Constant ):
117+ return repr (node .value )
118+ if isinstance (original , torch .Tensor ):
119+ try :
120+ tensor_arg = state .device_function .tensor_arg (original )
121+ except KeyError :
122+ pass
123+ else :
124+ return tensor_arg .name
100125 lifted = state .codegen .lift (node )
101126 assert isinstance (lifted , ast .Name )
102127 return lifted .id
@@ -118,9 +143,13 @@ def _format_triton_source(
118143 "inline_triton expects a dict literal when args is a mapping"
119144 )
120145 assert args_obj .keys () == args_ast .keys ()
121- format_args : dict [str , str ] = {
122- key : _ensure_name (state , args_ast [key ]) for key in args_ast
123- }
146+ format_args : dict [str , str ] = {}
147+ for key in args_ast :
148+ format_args [key ] = _ensure_name (
149+ state ,
150+ args_ast [key ],
151+ args_obj [key ],
152+ )
124153 try :
125154 return source .format (** format_args )
126155 except (KeyError , IndexError , ValueError ) as exc_value :
@@ -138,7 +167,10 @@ def _format_triton_source(
138167 if isinstance (args_ast , (ast .List , ast .Tuple ))
139168 else list (args_ast )
140169 )
141- names = [_ensure_name (state , node ) for node in arg_nodes ]
170+ names = [
171+ _ensure_name (state , node , arg )
172+ for node , arg in zip (arg_nodes , args_obj , strict = False )
173+ ]
142174 try :
143175 expected_len = len (args_obj )
144176 except TypeError : # pragma: no cover - defensive
@@ -157,7 +189,10 @@ def _format_triton_source(
157189 raise exc .InvalidAPIUsage ("inline_triton args must be a tuple/list or a mapping" )
158190
159191
160- def _parse_triton_source (source : str ) -> tuple [list [ast .stmt ], ast .AST ]:
192+ def _parse_triton_source (
193+ source : str ,
194+ require_expression : bool ,
195+ ) -> tuple [list [ast .stmt ], ast .AST | None ]:
161196 try :
162197 module = ast .parse (source )
163198 except SyntaxError as exc_value :
@@ -166,16 +201,21 @@ def _parse_triton_source(source: str) -> tuple[list[ast.stmt], ast.AST]:
166201 ) from exc_value
167202
168203 if not module .body :
169- raise exc .InvalidAPIUsage ("triton_source must contain at least one expression " )
204+ raise exc .InvalidAPIUsage ("triton_source must contain code " )
170205
171206 * prefix , last = module .body
172- if not isinstance (last , ast .Expr ):
207+ converted_prefix = [cast ("ast.stmt" , convert (stmt )) for stmt in prefix ]
208+
209+ if isinstance (last , ast .Expr ):
210+ return converted_prefix , convert (last .value )
211+
212+ if require_expression :
173213 raise exc .InvalidAPIUsage (
174- "The last line of triton_source must be an expression"
214+ "The last line of triton_source must be an expression when output_like is provided "
175215 )
176216
177- converted_prefix = [ cast ("ast.stmt" , convert (stmt )) for stmt in prefix ]
178- return converted_prefix , convert ( last . value )
217+ converted_prefix . append ( cast ("ast.stmt" , convert (last )))
218+ return converted_prefix , None
179219
180220
181221def _normalize_output_ast (output_ast : object ) -> list [ast .AST ]:
@@ -315,12 +355,15 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]:
315355 state .ast_args [1 ],
316356 )
317357
318- statements , result_expr = _parse_triton_source (formatted )
358+ statements , result_expr = _parse_triton_source (
359+ formatted , require_expression = output_like is not None
360+ )
319361 for stmt in statements :
320362 state .add_statement (stmt )
321363
322364 if output_like is None :
323- state .add_statement (create (ast .Expr , value = result_expr ))
365+ if result_expr is not None :
366+ state .add_statement (create (ast .Expr , value = result_expr ))
324367 return create (ast .Constant , value = None )
325368
326369 result_name = state .device_function .new_var ("inline_triton_result" )
0 commit comments