@@ -289,42 +289,47 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
289289 block_info .from_config (self .device_function .config )
290290 )
291291 )
292- elif isinstance (type_info , SequenceType ):
292+ elif isinstance (type_info , SequenceType ) and all (
293+ isinstance (x , TileIndexType ) for x in type_info .unpack ()
294+ ):
293295 values = type_info .unpack ()
294- if all (isinstance (x , TileIndexType ) for x in values ):
295- block_infos = [env .block_sizes [x .block_id ] for x in values ] # pyright: ignore[reportAttributeAccessIssue]
296- return expr_from_string (
297- self .host_function .literal_expr (
298- [
299- x .from_config (self .device_function .config )
300- for x in block_infos
301- ]
302- )
296+ block_infos = [env .block_sizes [x .block_id ] for x in values ] # pyright: ignore[reportAttributeAccessIssue]
297+ return expr_from_string (
298+ self .host_function .literal_expr (
299+ [x .from_config (self .device_function .config ) for x in block_infos ]
303300 )
301+ )
304302 elif (
305303 isinstance (fn_type_info := func_node ._type_info , CallableType )
306304 and is_api_func (api := fn_type_info .value )
307305 and api ._codegen is not None
308306 ):
307+ ast_args = []
308+ ast_kwargs = {}
309309 proxy_args = []
310310 proxy_kwargs = {}
311311 for arg in node .args :
312312 assert not isinstance (arg , ast .Starred )
313313 assert isinstance (arg , ExtendedAST )
314314 assert arg ._type_info is not None
315+ ast_args .append (arg )
315316 proxy_args .append (arg ._type_info .proxy ())
316317 for kwarg in node .keywords :
317318 assert kwarg .arg is not None
318319 assert isinstance (kwarg .value , ExtendedAST )
319320 assert kwarg .value ._type_info is not None
321+ ast_kwargs [kwarg .arg ] = kwarg .value
320322 proxy_kwargs [kwarg .arg ] = kwarg .value ._type_info .proxy ()
323+ ast_params = api ._signature .bind (* ast_args , ** ast_kwargs )
321324 proxy_params = api ._signature .bind (* proxy_args , ** proxy_kwargs )
325+ ast_params .apply_defaults ()
322326 proxy_params .apply_defaults ()
323327 return api ._codegen ( # pyright: ignore[reportReturnType]
324328 CodegenState (
325329 self ,
326330 None ,
327331 proxy_args = [* proxy_params .arguments .values ()],
332+ ast_args = [* ast_params .arguments .values ()],
328333 )
329334 )
330335 return self .generic_visit (node )
0 commit comments