@@ -313,7 +313,7 @@ def getattr(self):
313313class EvalFunc :
314314 """Class for a callable pyscript function."""
315315
316- def __init__ (self , func_def , code_list , code_str , global_ctx ):
316+ def __init__ (self , func_def , code_list , code_str , global_ctx , async_func = False ):
317317 """Initialize a function calling context."""
318318 self .func_def = func_def
319319 self .name = func_def .name
@@ -338,6 +338,7 @@ def __init__(self, func_def, code_list, code_str, global_ctx):
338338 self .trigger = []
339339 self .trigger_service = set ()
340340 self .has_closure = False
341+ self .async_func = async_func
341342
342343 def get_name (self ):
343344 """Return the function name."""
@@ -930,14 +931,18 @@ async def ast_not_implemented(self, arg, *args):
930931 name = "ast_" + arg .__class__ .__name__ .lower ()
931932 raise NotImplementedError (f"{ self .name } : not implemented ast " + name )
932933
933- async def aeval (self , arg , undefined_check = True ):
934+ async def aeval (self , arg , undefined_check = True , do_await = True ):
934935 """Vector to specific function based on ast class type."""
935936 name = "ast_" + arg .__class__ .__name__ .lower ()
936937 try :
937938 if hasattr (arg , "lineno" ):
938939 self .lineno = arg .lineno
939940 self .col_offset = arg .col_offset
940- val = await getattr (self , name , self .ast_not_implemented )(arg )
941+ val = (
942+ await getattr (self , name , self .ast_not_implemented )(arg )
943+ if do_await
944+ else getattr (self , name , self .ast_not_implemented )(arg )
945+ )
941946 if undefined_check and isinstance (val , EvalName ):
942947 raise NameError (f"name '{ val .name } ' is not defined" )
943948 return val
@@ -1102,7 +1107,7 @@ async def ast_classdef(self, arg):
11021107 del sym_table ["__init__" ]
11031108 sym_table_assign [arg .name ].set (type (arg .name , tuple (bases ), sym_table ))
11041109
1105- async def ast_functiondef (self , arg ):
1110+ async def ast_functiondef (self , arg , async_func = False ):
11061111 """Evaluate function definition."""
11071112 other_dec = []
11081113 dec_name = None
@@ -1158,7 +1163,7 @@ async def executor_wrap(*args, **kwargs):
11581163 self .sym_table [arg .name ].set (func )
11591164 return
11601165
1161- func = EvalFunc (arg , self .code_list , self .code_str , self .global_ctx )
1166+ func = EvalFunc (arg , self .code_list , self .code_str , self .global_ctx , async_func )
11621167 await func .eval_defaults (self )
11631168 await func .resolve_nonlocals (self )
11641169 name = func .get_name ()
@@ -1215,7 +1220,7 @@ async def ast_lambda(self, arg):
12151220
12161221 async def ast_asyncfunctiondef (self , arg ):
12171222 """Evaluate async function definition."""
1218- return await self .ast_functiondef (arg )
1223+ return await self .ast_functiondef (arg , async_func = True )
12191224
12201225 async def ast_try (self , arg ):
12211226 """Execute try...except statement."""
@@ -2020,7 +2025,10 @@ async def ast_formattedvalue(self, arg):
20202025
20212026 async def ast_await (self , arg ):
20222027 """Evaluate await expr."""
2023- return await self .aeval (arg .value )
2028+ coro = await self .aeval (arg .value , do_await = False )
2029+ if coro and asyncio .iscoroutine (coro ):
2030+ return await coro
2031+ return coro
20242032
20252033 async def get_target_names (self , lhs ):
20262034 """Recursively find all the target names mentioned in the AST tree."""
0 commit comments