Skip to content

Commit 9551fdb

Browse files
committed
updates to decorators; see #43, #122
1 parent 23386e3 commit 9551fdb

File tree

4 files changed

+337
-74
lines changed

4 files changed

+337
-74
lines changed

custom_components/pyscript/eval.py

+64-67
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
ALL_DECORATORS = TRIG_DECORATORS.union({"service"})
5555

56+
5657
def ast_eval_exec_factory(ast_ctx, mode):
5758
"""Generate a function that executes eval() or exec() with given ast_ctx."""
5859

@@ -291,11 +292,20 @@ async def trigger_init(self):
291292
"event_trigger",
292293
"mqtt_trigger",
293294
}
295+
arg_check = {
296+
"event_trigger": {"arg_cnt": {1, 2}},
297+
"mqtt_trigger": {"arg_cnt": {1, 2}},
298+
"state_active": {"arg_cnt": {1}},
299+
"state_trigger": {"arg_cnt": {"*"}, "type": {list, set}},
300+
"task_unique": {"arg_cnt": {1}},
301+
"time_active": {"arg_cnt": {0, "*"}},
302+
"time_trigger": {"arg_cnt": {0, "*"}},
303+
}
294304

295305
decorator_used = set()
296306
for dec in self.decorators:
297307
dec_name, dec_args, dec_kwargs = dec[0], dec[1], dec[2]
298-
if dec_name in decorator_used:
308+
if dec_name in decorator_used and "*" not in arg_check.get(dec_name, {"arg_cnt": {}})["arg_cnt"]:
299309
self.logger.error(
300310
"%s defined in %s: decorator %s repeated; ignoring decorator",
301311
self.name,
@@ -313,7 +323,10 @@ async def trigger_init(self):
313323
if dec_args is not None:
314324
trig_args[dec_name]["args"] += dec_args
315325
if dec_kwargs is not None:
316-
trig_args[dec_name]["kwargs"] = dec_kwargs
326+
if "kwargs" in trig_args[dec_name]:
327+
trig_args[dec_name]["kwargs"].update(dec_kwargs)
328+
else:
329+
trig_args[dec_name]["kwargs"] = dec_kwargs
317330
elif dec_name == "service":
318331
if dec_args is not None:
319332
self.logger.error(
@@ -394,15 +407,6 @@ async def do_service_call(func, ast_ctx, data):
394407
# check that we have the right number of arguments, and that they are
395408
# strings
396409
#
397-
arg_check = {
398-
"event_trigger": {"arg_cnt": {1, 2}},
399-
"mqtt_trigger": {"arg_cnt": {1, 2}},
400-
"state_active": {"arg_cnt": {1}},
401-
"state_trigger": {"arg_cnt": {"*"}, "type": {list, set}},
402-
"task_unique": {"arg_cnt": {1}},
403-
"time_active": {"arg_cnt": {0, "*"}},
404-
"time_trigger": {"arg_cnt": {0, "*"}},
405-
}
406410
for dec_name, arg_info in arg_check.items():
407411
arg_cnt = arg_info["arg_cnt"]
408412
if dec_name not in trig_args:
@@ -518,36 +522,27 @@ def trigger_stop(self):
518522

519523
async def eval_decorators(self, ast_ctx):
520524
"""Evaluate the function decorators arguments."""
521-
self.decorators = []
522525
code_str, code_list = ast_ctx.code_str, ast_ctx.code_list
523526
ast_ctx.code_str, ast_ctx.code_list = self.code_str, self.code_list
524527

525-
dec_funcs = []
528+
dec_other = []
529+
dec_trig = []
526530
for dec in self.func_def.decorator_list:
527-
if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Name) and dec.func.id in ALL_DECORATORS:
531+
if (
532+
isinstance(dec, ast.Call)
533+
and isinstance(dec.func, ast.Name)
534+
and dec.func.id in ALL_DECORATORS
535+
):
528536
args = [await ast_ctx.aeval(arg) for arg in dec.args]
529537
kwargs = {keyw.arg: await ast_ctx.aeval(keyw.value) for keyw in dec.keywords}
530-
if len(kwargs) == 0:
531-
kwargs = None
532-
self.decorators.append([dec.func.id, args, kwargs])
538+
dec_trig.append([dec.func.id, args, kwargs if len(kwargs) > 0 else None])
533539
elif isinstance(dec, ast.Name) and dec.id in ALL_DECORATORS:
534-
self.decorators.append([dec.id, None, None])
540+
dec_trig.append([dec.id, None, None])
535541
else:
536-
dec_funcs.append(await ast_ctx.aeval(dec))
537-
538-
def make_dec_call(func):
539-
async def dec_call(*args_tuple, **kwargs):
540-
args = list(args_tuple)
541-
if len(args) > 0 and isinstance(args[0], AstEval):
542-
args.pop(0)
543-
return await func(ast_ctx, *args, **kwargs)
544-
545-
return dec_call
546-
547-
for func in reversed(dec_funcs):
548-
self.call = await ast_ctx.call_func(func, None, make_dec_call(self.call))
542+
dec_other.append(await ast_ctx.aeval(dec))
549543

550544
ast_ctx.code_str, ast_ctx.code_list = code_str, code_list
545+
return dec_trig, reversed(dec_other)
551546

552547
async def resolve_nonlocals(self, ast_ctx):
553548
"""Tag local variables and resolve nonlocals."""
@@ -729,11 +724,18 @@ class EvalFuncVar:
729724
def __init__(self, func):
730725
"""Initialize instance with given EvalFunc function."""
731726
self.func = func
727+
self.ast_ctx = None
732728

733729
def get_func(self):
734730
"""Return the EvalFunc function."""
735731
return self.func
736732

733+
def remove_func(self):
734+
"""Remove and return the EvalFunc function."""
735+
func = self.func
736+
self.func = None
737+
return func
738+
737739
async def call(self, ast_ctx, *args, **kwargs):
738740
"""Call the EvalFunc function."""
739741
return await self.func.call(ast_ctx, *args, **kwargs)
@@ -742,9 +744,22 @@ def get_name(self):
742744
"""Return the function name."""
743745
return self.func.get_name()
744746

747+
def set_ast_ctx(self, ast_ctx):
748+
"""Set the ast context."""
749+
self.ast_ctx = ast_ctx
750+
751+
def get_ast_ctx(self):
752+
"""Return the ast context."""
753+
return self.ast_ctx
754+
745755
def __del__(self):
746756
"""On deletion, stop any triggers for this function."""
747-
self.func.trigger_stop()
757+
if self.func:
758+
self.func.trigger_stop()
759+
760+
async def __call__(self, *args, **kwargs):
761+
"""Call the EvalFunc function using our saved ast ctx."""
762+
return await self.func.call(self.ast_ctx, *args, **kwargs)
748763

749764

750765
class EvalFuncVarClassInst(EvalFuncVar):
@@ -760,35 +775,6 @@ async def call(self, ast_ctx, *args, **kwargs):
760775
return await self.func.call(ast_ctx, self.class_inst, *args, **kwargs)
761776

762777

763-
class EvalFuncVarAstCtx:
764-
"""Class for a callable pyscript function with ast context."""
765-
766-
def __init__(self, ast_ctx, eval_func_var):
767-
"""Initialize instance with given EvalFunc function."""
768-
self.eval_func_var = eval_func_var
769-
self.ast_ctx = ast_ctx
770-
771-
async def call(self, ast_ctx, *args, **kwargs):
772-
"""Call the EvalFunc function."""
773-
return await self.eval_func_var.call(ast_ctx, *args, **kwargs)
774-
775-
def get_name(self):
776-
"""Return the function name."""
777-
return self.eval_func_var.get_name()
778-
779-
def get_ast_ctx(self):
780-
"""Return the ast context."""
781-
return self.ast_ctx
782-
783-
def get_eval_func_var(self):
784-
"""Return the eval_func_var."""
785-
return self.eval_func_var
786-
787-
async def __call__(self, *args, **kwargs):
788-
"""Call the EvalFunc function using our saved ast ctx."""
789-
return await self.eval_func_var.call(self.ast_ctx, *args, **kwargs)
790-
791-
792778
class AstEval:
793779
"""Python interpreter AST object evaluator."""
794780

@@ -833,7 +819,7 @@ async def aeval(self, arg, undefined_check=True):
833819
if undefined_check and isinstance(val, EvalName):
834820
raise NameError(f"name '{val.name}' is not defined")
835821
if isinstance(val, EvalFuncVar):
836-
return EvalFuncVarAstCtx(self, val)
822+
val.set_ast_ctx(self)
837823
return val
838824
except Exception as err:
839825
if not self.exception_obj:
@@ -1012,11 +998,22 @@ async def ast_functiondef(self, arg):
1012998

1013999
func = EvalFunc(arg, self.code_list, self.code_str, self.global_ctx)
10141000
await func.eval_defaults(self)
1015-
await func.eval_decorators(self)
10161001
await func.resolve_nonlocals(self)
1017-
await func.trigger_init()
10181002
name = func.get_name()
1019-
func_var = EvalFuncVar(func)
1003+
dec_trig, dec_other = await func.eval_decorators(self)
1004+
for dec_func in dec_other:
1005+
func = await self.call_func(dec_func, None, func)
1006+
if isinstance(func, EvalFuncVar):
1007+
func = func.remove_func()
1008+
dec_trig += func.decorators
1009+
if isinstance(func, EvalFunc):
1010+
func.decorators = dec_trig
1011+
func.trigger_stop()
1012+
await func.trigger_init()
1013+
func_var = EvalFuncVar(func)
1014+
else:
1015+
func_var = func
1016+
10201017
if name in self.sym_table and isinstance(self.sym_table[name], EvalLocalVar):
10211018
self.sym_table[name].set(func_var)
10221019
else:
@@ -1742,15 +1739,15 @@ async def call_func(self, func, func_name, *args, **kwargs):
17421739
"""Call a function with the given arguments."""
17431740
if func_name is None:
17441741
try:
1745-
if isinstance(func, (EvalFunc, EvalFuncVar, EvalFuncVarAstCtx)):
1742+
if isinstance(func, (EvalFunc, EvalFuncVar)):
17461743
func_name = func.get_name()
17471744
else:
17481745
func_name = func.__name__
17491746
except Exception:
17501747
func_name = "<function>"
17511748
arg_str = ", ".join(['"' + elt + '"' if isinstance(elt, str) else str(elt) for elt in args])
17521749
_LOGGER.debug("%s: calling %s(%s, %s)", self.name, func_name, arg_str, kwargs)
1753-
if isinstance(func, (EvalFunc, EvalFuncVar, EvalFuncVarAstCtx)):
1750+
if isinstance(func, (EvalFunc, EvalFuncVar)):
17541751
return await func.call(self, *args, **kwargs)
17551752
if inspect.isclass(func) and hasattr(func, "__init__evalfunc_wrap__"):
17561753
inst = func()

custom_components/pyscript/trigger.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import homeassistant.helpers.sun as sun
1515

1616
from .const import LOGGER_PATH
17-
from .eval import AstEval, EvalFunc, EvalFuncVar, EvalFuncVarAstCtx
17+
from .eval import AstEval, EvalFunc, EvalFuncVar
1818
from .event import Event
1919
from .function import Function
2020
from .mqtt import Mqtt
@@ -147,7 +147,7 @@ async def func_call(func, func_name, new_ast_ctx, *args, **kwargs):
147147
return ret
148148

149149
try:
150-
if isinstance(func, (EvalFunc, EvalFuncVar, EvalFuncVarAstCtx)):
150+
if isinstance(func, (EvalFunc, EvalFuncVar)):
151151
func_name = func.get_name()
152152
else:
153153
func_name = func.__name__
@@ -183,15 +183,12 @@ async def user_task_wait(aws):
183183
async def user_task_add_done_callback(task, callback, *args, **kwargs):
184184
"""Implement task.add_done_callback()."""
185185
ast_ctx = None
186-
if type(callback) is EvalFuncVarAstCtx:
186+
if type(callback) is EvalFuncVar:
187187
ast_ctx = callback.get_ast_ctx()
188-
callback = callback.get_eval_func_var()
189188
Function.task_add_done_callback(task, ast_ctx, callback, *args, **kwargs)
190189

191190
async def user_task_remove_done_callback(task, callback):
192191
"""Implement task.remove_done_callback()."""
193-
if type(callback) is EvalFuncVarAstCtx:
194-
callback = callback.get_eval_func_var()
195192
Function.task_remove_done_callback(task, callback)
196193

197194
funcs = {

0 commit comments

Comments
 (0)