17
17
from .eval import AstEval
18
18
from .event import Event
19
19
from .function import Function
20
- from .state import State
20
+ from .state import State , STATE_VIRTUAL_ATTRS
21
21
22
22
_LOGGER = logging .getLogger (LOGGER_PATH + ".trigger" )
23
23
24
24
25
- STATE_RE = re .compile (r"[a-zA-Z]\w*\.[a-zA-Z]\w*$" )
25
+ STATE_RE = re .compile (r"[a-zA-Z]\w*\.[a-zA-Z]\w*(\.(([a-zA-Z]\w*)|\*))? $" )
26
26
27
27
28
28
def dt_now ():
@@ -48,6 +48,74 @@ def parse_time_offset(offset_str):
48
48
return value * scale
49
49
50
50
51
+ def ident_any_values_changed (func_args , ident ):
52
+ """Check for changes to state or attributes on ident any vars"""
53
+ value = func_args .get ('value' )
54
+ old_value = func_args .get ('old_value' )
55
+ var_name = func_args .get ('var_name' )
56
+
57
+ if var_name is None :
58
+ return False
59
+
60
+ for check_var in ident :
61
+ if check_var == var_name and old_value != value :
62
+ return True
63
+
64
+ if check_var .startswith (f"{ var_name } ." ):
65
+ var_pieces = check_var .split ('.' )
66
+ if len (var_pieces ) == 3 and f"{ var_pieces [0 ]} .{ var_pieces [1 ]} " == var_name :
67
+ if var_pieces [2 ] == "*" :
68
+ # catch all has been requested, check all attributes for change
69
+ all_attributes = set ()
70
+ if value is not None :
71
+ all_attributes |= set (value .__dict__ .keys ())
72
+ if old_value is not None :
73
+ all_attributes |= set (old_value .__dict__ .keys ())
74
+ all_attributes -= STATE_VIRTUAL_ATTRS
75
+ for attribute in all_attributes :
76
+ attrib_val = getattr (value , attribute , None )
77
+ attrib_old_val = getattr (old_value , attribute , None )
78
+ if attrib_old_val != attrib_val :
79
+ return True
80
+ else :
81
+ attrib_val = getattr (value , var_pieces [2 ], None )
82
+ attrib_old_val = getattr (old_value , var_pieces [2 ], None )
83
+ if attrib_old_val != attrib_val :
84
+ return True
85
+
86
+ return False
87
+
88
+ def ident_values_changed (func_args , ident ):
89
+ """Check for changes to state or attributes on ident vars"""
90
+ value = func_args .get ('value' )
91
+ old_value = func_args .get ('old_value' )
92
+ var_name = func_args .get ('var_name' )
93
+
94
+ if var_name is None :
95
+ return False
96
+
97
+ for check_var in ident :
98
+ # if check_var in self.state_trig_ident_any:
99
+ # _LOGGER.debug(
100
+ # "%s ident change skipping %s because also ident_any",
101
+ # self.name,
102
+ # check_var,
103
+ # )
104
+ # continue
105
+ var_pieces = check_var .split ('.' )
106
+ if len (var_pieces ) == 2 and check_var == var_name :
107
+ if value != old_value :
108
+ return True
109
+ elif len (var_pieces ) == 3 and f"{ var_pieces [0 ]} .{ var_pieces [1 ]} " == var_name :
110
+ attrib_val = getattr (value , var_pieces [2 ], None )
111
+ attrib_old_val = getattr (old_value , var_pieces [2 ], None )
112
+ if attrib_old_val != attrib_val :
113
+ return True
114
+
115
+ return False
116
+
117
+
118
+
51
119
class TrigTime :
52
120
"""Class for trigger time functions."""
53
121
@@ -252,14 +320,19 @@ async def wait_until(
252
320
else :
253
321
new_vars , func_args = None , {}
254
322
255
- state_trig_ok = False
256
- if func_args .get ("var_name" , "" ) in state_trig_ident_any :
257
- state_trig_ok = True
258
- elif state_trig_eval :
259
- state_trig_ok = await state_trig_eval .eval (new_vars )
260
- exc = state_trig_eval .get_exception_obj ()
261
- if exc is not None :
262
- break
323
+ state_trig_ok = True
324
+
325
+ if not ident_any_values_changed (func_args , state_trig_ident_any ):
326
+ # if var_name not in func_args we are state_check_now
327
+ if "var_name" in func_args and not ident_values_changed (func_args , state_trig ):
328
+ state_trig_ok = False
329
+
330
+ if state_trig_eval :
331
+ state_trig_ok = await state_trig_eval .eval (new_vars )
332
+ exc = state_trig_eval .get_exception_obj ()
333
+ if exc is not None :
334
+ break
335
+
263
336
if state_hold is not None :
264
337
if state_trig_ok :
265
338
if not state_trig_waiting :
@@ -738,7 +811,11 @@ async def trigger_watch(self):
738
811
elif notify_type == "state" :
739
812
new_vars , func_args = notify_info
740
813
741
- if "var_name" not in func_args or func_args ["var_name" ] not in self .state_trig_ident_any :
814
+ if not ident_any_values_changed (func_args , self .state_trig_ident_any ):
815
+ # if var_name not in func_args we are state_check_now
816
+ if "var_name" in func_args and not ident_values_changed (func_args , self .state_trig_ident ):
817
+ continue
818
+
742
819
if self .state_trig_eval :
743
820
trig_ok = await self .state_trig_eval .eval (new_vars )
744
821
exc = self .state_trig_eval .get_exception_long ()
@@ -747,6 +824,7 @@ async def trigger_watch(self):
747
824
trig_ok = False
748
825
else :
749
826
trig_ok = False
827
+
750
828
if self .state_hold_dur is not None :
751
829
if trig_ok :
752
830
if not state_trig_waiting :
0 commit comments