53
53
54
54
ALL_DECORATORS = TRIG_DECORATORS .union ({"service" })
55
55
56
+
56
57
def ast_eval_exec_factory (ast_ctx , mode ):
57
58
"""Generate a function that executes eval() or exec() with given ast_ctx."""
58
59
@@ -291,11 +292,20 @@ async def trigger_init(self):
291
292
"event_trigger" ,
292
293
"mqtt_trigger" ,
293
294
}
295
+ arg_check = {
296
+ "event_trigger" : {"arg_cnt" : {1 , 2 }},
297
+ "mqtt_trigger" : {"arg_cnt" : {1 , 2 }},
298
+ "state_active" : {"arg_cnt" : {1 }},
299
+ "state_trigger" : {"arg_cnt" : {"*" }, "type" : {list , set }},
300
+ "task_unique" : {"arg_cnt" : {1 }},
301
+ "time_active" : {"arg_cnt" : {0 , "*" }},
302
+ "time_trigger" : {"arg_cnt" : {0 , "*" }},
303
+ }
294
304
295
305
decorator_used = set ()
296
306
for dec in self .decorators :
297
307
dec_name , dec_args , dec_kwargs = dec [0 ], dec [1 ], dec [2 ]
298
- if dec_name in decorator_used :
308
+ if dec_name in decorator_used and "*" not in arg_check . get ( dec_name , { "arg_cnt" : {}})[ "arg_cnt" ] :
299
309
self .logger .error (
300
310
"%s defined in %s: decorator %s repeated; ignoring decorator" ,
301
311
self .name ,
@@ -313,7 +323,10 @@ async def trigger_init(self):
313
323
if dec_args is not None :
314
324
trig_args [dec_name ]["args" ] += dec_args
315
325
if dec_kwargs is not None :
316
- trig_args [dec_name ]["kwargs" ] = dec_kwargs
326
+ if "kwargs" in trig_args [dec_name ]:
327
+ trig_args [dec_name ]["kwargs" ].update (dec_kwargs )
328
+ else :
329
+ trig_args [dec_name ]["kwargs" ] = dec_kwargs
317
330
elif dec_name == "service" :
318
331
if dec_args is not None :
319
332
self .logger .error (
@@ -394,15 +407,6 @@ async def do_service_call(func, ast_ctx, data):
394
407
# check that we have the right number of arguments, and that they are
395
408
# strings
396
409
#
397
- arg_check = {
398
- "event_trigger" : {"arg_cnt" : {1 , 2 }},
399
- "mqtt_trigger" : {"arg_cnt" : {1 , 2 }},
400
- "state_active" : {"arg_cnt" : {1 }},
401
- "state_trigger" : {"arg_cnt" : {"*" }, "type" : {list , set }},
402
- "task_unique" : {"arg_cnt" : {1 }},
403
- "time_active" : {"arg_cnt" : {0 , "*" }},
404
- "time_trigger" : {"arg_cnt" : {0 , "*" }},
405
- }
406
410
for dec_name , arg_info in arg_check .items ():
407
411
arg_cnt = arg_info ["arg_cnt" ]
408
412
if dec_name not in trig_args :
@@ -518,36 +522,27 @@ def trigger_stop(self):
518
522
519
523
async def eval_decorators (self , ast_ctx ):
520
524
"""Evaluate the function decorators arguments."""
521
- self .decorators = []
522
525
code_str , code_list = ast_ctx .code_str , ast_ctx .code_list
523
526
ast_ctx .code_str , ast_ctx .code_list = self .code_str , self .code_list
524
527
525
- dec_funcs = []
528
+ dec_other = []
529
+ dec_trig = []
526
530
for dec in self .func_def .decorator_list :
527
- if isinstance (dec , ast .Call ) and isinstance (dec .func , ast .Name ) and dec .func .id in ALL_DECORATORS :
531
+ if (
532
+ isinstance (dec , ast .Call )
533
+ and isinstance (dec .func , ast .Name )
534
+ and dec .func .id in ALL_DECORATORS
535
+ ):
528
536
args = [await ast_ctx .aeval (arg ) for arg in dec .args ]
529
537
kwargs = {keyw .arg : await ast_ctx .aeval (keyw .value ) for keyw in dec .keywords }
530
- if len (kwargs ) == 0 :
531
- kwargs = None
532
- self .decorators .append ([dec .func .id , args , kwargs ])
538
+ dec_trig .append ([dec .func .id , args , kwargs if len (kwargs ) > 0 else None ])
533
539
elif isinstance (dec , ast .Name ) and dec .id in ALL_DECORATORS :
534
- self . decorators .append ([dec .id , None , None ])
540
+ dec_trig .append ([dec .id , None , None ])
535
541
else :
536
- dec_funcs .append (await ast_ctx .aeval (dec ))
537
-
538
- def make_dec_call (func ):
539
- async def dec_call (* args_tuple , ** kwargs ):
540
- args = list (args_tuple )
541
- if len (args ) > 0 and isinstance (args [0 ], AstEval ):
542
- args .pop (0 )
543
- return await func (ast_ctx , * args , ** kwargs )
544
-
545
- return dec_call
546
-
547
- for func in reversed (dec_funcs ):
548
- self .call = await ast_ctx .call_func (func , None , make_dec_call (self .call ))
542
+ dec_other .append (await ast_ctx .aeval (dec ))
549
543
550
544
ast_ctx .code_str , ast_ctx .code_list = code_str , code_list
545
+ return dec_trig , reversed (dec_other )
551
546
552
547
async def resolve_nonlocals (self , ast_ctx ):
553
548
"""Tag local variables and resolve nonlocals."""
@@ -729,11 +724,18 @@ class EvalFuncVar:
729
724
def __init__ (self , func ):
730
725
"""Initialize instance with given EvalFunc function."""
731
726
self .func = func
727
+ self .ast_ctx = None
732
728
733
729
def get_func (self ):
734
730
"""Return the EvalFunc function."""
735
731
return self .func
736
732
733
+ def remove_func (self ):
734
+ """Remove and return the EvalFunc function."""
735
+ func = self .func
736
+ self .func = None
737
+ return func
738
+
737
739
async def call (self , ast_ctx , * args , ** kwargs ):
738
740
"""Call the EvalFunc function."""
739
741
return await self .func .call (ast_ctx , * args , ** kwargs )
@@ -742,9 +744,22 @@ def get_name(self):
742
744
"""Return the function name."""
743
745
return self .func .get_name ()
744
746
747
+ def set_ast_ctx (self , ast_ctx ):
748
+ """Set the ast context."""
749
+ self .ast_ctx = ast_ctx
750
+
751
+ def get_ast_ctx (self ):
752
+ """Return the ast context."""
753
+ return self .ast_ctx
754
+
745
755
def __del__ (self ):
746
756
"""On deletion, stop any triggers for this function."""
747
- self .func .trigger_stop ()
757
+ if self .func :
758
+ self .func .trigger_stop ()
759
+
760
+ async def __call__ (self , * args , ** kwargs ):
761
+ """Call the EvalFunc function using our saved ast ctx."""
762
+ return await self .func .call (self .ast_ctx , * args , ** kwargs )
748
763
749
764
750
765
class EvalFuncVarClassInst (EvalFuncVar ):
@@ -760,35 +775,6 @@ async def call(self, ast_ctx, *args, **kwargs):
760
775
return await self .func .call (ast_ctx , self .class_inst , * args , ** kwargs )
761
776
762
777
763
- class EvalFuncVarAstCtx :
764
- """Class for a callable pyscript function with ast context."""
765
-
766
- def __init__ (self , ast_ctx , eval_func_var ):
767
- """Initialize instance with given EvalFunc function."""
768
- self .eval_func_var = eval_func_var
769
- self .ast_ctx = ast_ctx
770
-
771
- async def call (self , ast_ctx , * args , ** kwargs ):
772
- """Call the EvalFunc function."""
773
- return await self .eval_func_var .call (ast_ctx , * args , ** kwargs )
774
-
775
- def get_name (self ):
776
- """Return the function name."""
777
- return self .eval_func_var .get_name ()
778
-
779
- def get_ast_ctx (self ):
780
- """Return the ast context."""
781
- return self .ast_ctx
782
-
783
- def get_eval_func_var (self ):
784
- """Return the eval_func_var."""
785
- return self .eval_func_var
786
-
787
- async def __call__ (self , * args , ** kwargs ):
788
- """Call the EvalFunc function using our saved ast ctx."""
789
- return await self .eval_func_var .call (self .ast_ctx , * args , ** kwargs )
790
-
791
-
792
778
class AstEval :
793
779
"""Python interpreter AST object evaluator."""
794
780
@@ -833,7 +819,7 @@ async def aeval(self, arg, undefined_check=True):
833
819
if undefined_check and isinstance (val , EvalName ):
834
820
raise NameError (f"name '{ val .name } ' is not defined" )
835
821
if isinstance (val , EvalFuncVar ):
836
- return EvalFuncVarAstCtx (self , val )
822
+ val . set_ast_ctx (self )
837
823
return val
838
824
except Exception as err :
839
825
if not self .exception_obj :
@@ -1012,11 +998,22 @@ async def ast_functiondef(self, arg):
1012
998
1013
999
func = EvalFunc (arg , self .code_list , self .code_str , self .global_ctx )
1014
1000
await func .eval_defaults (self )
1015
- await func .eval_decorators (self )
1016
1001
await func .resolve_nonlocals (self )
1017
- await func .trigger_init ()
1018
1002
name = func .get_name ()
1019
- func_var = EvalFuncVar (func )
1003
+ dec_trig , dec_other = await func .eval_decorators (self )
1004
+ for dec_func in dec_other :
1005
+ func = await self .call_func (dec_func , None , func )
1006
+ if isinstance (func , EvalFuncVar ):
1007
+ func = func .remove_func ()
1008
+ dec_trig += func .decorators
1009
+ if isinstance (func , EvalFunc ):
1010
+ func .decorators = dec_trig
1011
+ func .trigger_stop ()
1012
+ await func .trigger_init ()
1013
+ func_var = EvalFuncVar (func )
1014
+ else :
1015
+ func_var = func
1016
+
1020
1017
if name in self .sym_table and isinstance (self .sym_table [name ], EvalLocalVar ):
1021
1018
self .sym_table [name ].set (func_var )
1022
1019
else :
@@ -1742,15 +1739,15 @@ async def call_func(self, func, func_name, *args, **kwargs):
1742
1739
"""Call a function with the given arguments."""
1743
1740
if func_name is None :
1744
1741
try :
1745
- if isinstance (func , (EvalFunc , EvalFuncVar , EvalFuncVarAstCtx )):
1742
+ if isinstance (func , (EvalFunc , EvalFuncVar )):
1746
1743
func_name = func .get_name ()
1747
1744
else :
1748
1745
func_name = func .__name__
1749
1746
except Exception :
1750
1747
func_name = "<function>"
1751
1748
arg_str = ", " .join (['"' + elt + '"' if isinstance (elt , str ) else str (elt ) for elt in args ])
1752
1749
_LOGGER .debug ("%s: calling %s(%s, %s)" , self .name , func_name , arg_str , kwargs )
1753
- if isinstance (func , (EvalFunc , EvalFuncVar , EvalFuncVarAstCtx )):
1750
+ if isinstance (func , (EvalFunc , EvalFuncVar )):
1754
1751
return await func .call (self , * args , ** kwargs )
1755
1752
if inspect .isclass (func ) and hasattr (func , "__init__evalfunc_wrap__" ):
1756
1753
inst = func ()
0 commit comments