Skip to content

Commit 803bcca

Browse files
committed
added tests for #82 and fixed a couple of issues
1 parent 26f327f commit 803bcca

File tree

4 files changed

+185
-63
lines changed

4 files changed

+185
-63
lines changed

custom_components/pyscript/__init__.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,17 @@ def state_var_remove():
188188

189189
async def state_changed(event):
190190
var_name = event.data["entity_id"]
191-
# attr = event.data["new_state"].attributes
192-
if "new_state" not in event.data or event.data["new_state"] is None:
191+
if event.data.get("new_state", None):
192+
new_val = StateVar(event.data["new_state"])
193+
else:
193194
# state variable has been deleted
194195
new_val = None
195-
else:
196-
new_val = StateVar(event.data['new_state'])
197196

198-
if "old_state" not in event.data or event.data["old_state"] is None:
197+
if event.data.get("old_state", None):
198+
old_val = StateVar(event.data["old_state"])
199+
else:
199200
# no previous state
200201
old_val = None
201-
else:
202-
old_val = StateVar(event.data['old_state'])
203202

204203
new_vars = {var_name: new_val, f"{var_name}.old": old_val}
205204
func_args = {

custom_components/pyscript/state.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
STATE_VIRTUAL_ATTRS = {"last_changed", "last_updated"}
1616

17+
1718
class StateVar(str):
1819
"""Class for representing the value and attributes of a state variable."""
1920

@@ -127,14 +128,21 @@ async def update(cls, new_vars, func_args):
127128

128129
@classmethod
129130
def notify_var_get(cls, var_names, new_vars):
130-
"""Return the most recent value of a state variable change."""
131-
notify_vars = {}
131+
"""Add values of var_names to new_vars, or default to None."""
132+
notify_vars = new_vars.copy()
132133
for var_name in var_names if var_names is not None else []:
134+
if var_name in notify_vars:
135+
continue
136+
parts = var_name.split(".")
133137
if var_name in cls.notify_var_last:
134138
notify_vars[var_name] = cls.notify_var_last[var_name]
135-
elif var_name in new_vars:
136-
notify_vars[var_name] = new_vars[var_name]
137-
elif 1 <= var_name.count(".") <= 2 and not cls.exist(var_name):
139+
elif len(parts) == 3 and f"{parts[0]}.{parts[1]}" in cls.notify_var_last:
140+
notify_vars[var_name] = getattr(
141+
cls.notify_var_last[f"{parts[0]}.{parts[1]}"], parts[2], None
142+
)
143+
elif len(parts) == 4 and parts[2] == "old" and f"{parts[0]}.{parts[1]}.old" in notify_vars:
144+
notify_vars[var_name] = getattr(notify_vars[f"{parts[0]}.{parts[1]}.old"], parts[3], None)
145+
elif 1 <= var_name.count(".") <= 3 and not cls.exist(var_name):
138146
notify_vars[var_name] = None
139147
return notify_vars
140148

@@ -144,13 +152,14 @@ async def set(cls, var_name, value=None, new_attributes=None, **kwargs):
144152
if var_name.count(".") != 1:
145153
raise NameError(f"invalid name {var_name} (should be 'domain.entity')")
146154

147-
if isinstance(value, StateVar) and new_attributes is None:
148-
#
149-
# value is a StateVar, so extract the attributes and value
150-
#
151-
new_attributes = value.__dict__.copy()
152-
for discard in STATE_VIRTUAL_ATTRS:
153-
new_attributes.pop(discard, None)
155+
if isinstance(value, StateVar):
156+
if new_attributes is None:
157+
#
158+
# value is a StateVar, so extract the attributes and value
159+
#
160+
new_attributes = value.__dict__.copy()
161+
for discard in STATE_VIRTUAL_ATTRS:
162+
new_attributes.pop(discard, None)
154163
value = str(value)
155164

156165
state_value = None
@@ -176,10 +185,15 @@ async def set(cls, var_name, value=None, new_attributes=None, **kwargs):
176185
if kwargs:
177186
new_attributes = new_attributes.copy()
178187
new_attributes.update(kwargs)
179-
_LOGGER.debug("setting %s = %s, attr = %s", var_name, value, new_attributes)
180-
cls.notify_var_last[var_name] = str(value)
181188

189+
_LOGGER.debug("setting %s = %s, attr = %s", var_name, value, new_attributes)
182190
cls.hass.states.async_set(var_name, value, new_attributes, context=context)
191+
if var_name in cls.notify_var_last or var_name in cls.notify:
192+
#
193+
# immediately update a variable we are monitoring since it could take a while
194+
# for the state changed event to propagate
195+
#
196+
cls.notify_var_last[var_name] = StateVar(cls.hass.states.get(var_name))
183197

184198
@classmethod
185199
async def setattr(cls, var_attr_name, value):

custom_components/pyscript/trigger.py

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .eval import AstEval
1818
from .event import Event
1919
from .function import Function
20-
from .state import State, STATE_VIRTUAL_ATTRS
20+
from .state import STATE_VIRTUAL_ATTRS, State
2121

2222
_LOGGER = logging.getLogger(LOGGER_PATH + ".trigger")
2323

@@ -49,73 +49,58 @@ def parse_time_offset(offset_str):
4949

5050

5151
def ident_any_values_changed(func_args, ident):
52-
"""Check for changes to state or attributes on ident any vars"""
53-
value = func_args.get('value')
54-
old_value = func_args.get('old_value')
55-
var_name = func_args.get('var_name')
52+
"""Check for any changes to state or attributes on ident vars."""
53+
var_name = func_args.get("var_name", None)
5654

5755
if var_name is None:
5856
return False
57+
value = func_args["value"]
58+
old_value = func_args["old_value"]
5959

6060
for check_var in ident:
6161
if check_var == var_name and old_value != value:
6262
return True
6363

6464
if check_var.startswith(f"{var_name}."):
65-
var_pieces = check_var.split('.')
65+
var_pieces = check_var.split(".")
6666
if len(var_pieces) == 3 and f"{var_pieces[0]}.{var_pieces[1]}" == var_name:
6767
if var_pieces[2] == "*":
6868
# catch all has been requested, check all attributes for change
69-
all_attributes = set()
69+
all_attrs = set()
7070
if value is not None:
71-
all_attributes |= set(value.__dict__.keys())
71+
all_attrs |= set(value.__dict__.keys())
7272
if old_value is not None:
73-
all_attributes |= set(old_value.__dict__.keys())
74-
all_attributes -= STATE_VIRTUAL_ATTRS
75-
for attribute in all_attributes:
76-
attrib_val = getattr(value, attribute, None)
77-
attrib_old_val = getattr(old_value, attribute, None)
78-
if attrib_old_val != attrib_val:
73+
all_attrs |= set(old_value.__dict__.keys())
74+
for attr in all_attrs - STATE_VIRTUAL_ATTRS:
75+
if getattr(value, attr, None) != getattr(old_value, attr, None):
7976
return True
80-
else:
81-
attrib_val = getattr(value, var_pieces[2], None)
82-
attrib_old_val = getattr(old_value, var_pieces[2], None)
83-
if attrib_old_val != attrib_val:
84-
return True
85-
77+
elif getattr(value, var_pieces[2], None) != getattr(old_value, var_pieces[2], None):
78+
return True
79+
8680
return False
8781

82+
8883
def ident_values_changed(func_args, ident):
89-
"""Check for changes to state or attributes on ident vars"""
90-
value = func_args.get('value')
91-
old_value = func_args.get('old_value')
92-
var_name = func_args.get('var_name')
84+
"""Check for changes to state or attributes on ident vars."""
85+
var_name = func_args.get("var_name", None)
9386

9487
if var_name is None:
9588
return False
89+
value = func_args["value"]
90+
old_value = func_args["old_value"]
9691

9792
for check_var in ident:
98-
# if check_var in self.state_trig_ident_any:
99-
# _LOGGER.debug(
100-
# "%s ident change skipping %s because also ident_any",
101-
# self.name,
102-
# check_var,
103-
# )
104-
# continue
105-
var_pieces = check_var.split('.')
93+
var_pieces = check_var.split(".")
10694
if len(var_pieces) == 2 and check_var == var_name:
10795
if value != old_value:
10896
return True
10997
elif len(var_pieces) == 3 and f"{var_pieces[0]}.{var_pieces[1]}" == var_name:
110-
attrib_val = getattr(value, var_pieces[2], None)
111-
attrib_old_val = getattr(old_value, var_pieces[2], None)
112-
if attrib_old_val != attrib_val:
98+
if getattr(value, var_pieces[2], None) != getattr(old_value, var_pieces[2], None):
11399
return True
114100

115101
return False
116102

117103

118-
119104
class TrigTime:
120105
"""Class for trigger time functions."""
121106

@@ -222,7 +207,8 @@ async def wait_until(
222207
#
223208
# check straight away to see if the condition is met (to avoid race conditions)
224209
#
225-
state_trig_ok = await state_trig_eval.eval(State.notify_var_get(state_trig_ident, {}))
210+
new_vars = State.notify_var_get(state_trig_ident, {})
211+
state_trig_ok = await state_trig_eval.eval(new_vars)
226212
exc = state_trig_eval.get_exception_obj()
227213
if exc is not None:
228214
raise exc
@@ -324,8 +310,8 @@ async def wait_until(
324310

325311
if not ident_any_values_changed(func_args, state_trig_ident_any):
326312
# if var_name not in func_args we are state_check_now
327-
if "var_name" in func_args and not ident_values_changed(func_args, state_trig):
328-
state_trig_ok = False
313+
if "var_name" in func_args and not ident_values_changed(func_args, state_trig_ident):
314+
continue
329315

330316
if state_trig_eval:
331317
state_trig_ok = await state_trig_eval.eval(new_vars)
@@ -813,7 +799,9 @@ async def trigger_watch(self):
813799

814800
if not ident_any_values_changed(func_args, self.state_trig_ident_any):
815801
# if var_name not in func_args we are state_check_now
816-
if "var_name" in func_args and not ident_values_changed(func_args, self.state_trig_ident):
802+
if "var_name" in func_args and not ident_values_changed(
803+
func_args, self.state_trig_ident
804+
):
817805
continue
818806

819807
if self.state_trig_eval:
@@ -965,8 +953,9 @@ async def do_func_call(func, ast_ctx, task_unique, task_unique_func, hass_contex
965953
except asyncio.CancelledError:
966954
raise
967955

968-
except Exception:
956+
except Exception as exc:
969957
# _LOGGER.error(f"{self.name}: " + traceback.format_exc(-1))
958+
_LOGGER.error("%s: %s", self.name, exc)
970959
if self.state_trig_ident:
971960
State.notify_del(self.state_trig_ident, self.notify_q)
972961
if self.event_trigger is not None:

0 commit comments

Comments
 (0)