1
1
"""Handles state variable access and change notification."""
2
2
3
3
import logging
4
+ from typing import Dict , Optional
4
5
6
+ from homeassistant .core import valid_entity_id , split_entity_id
7
+ from homeassistant .helpers .template import AllStates
5
8
from .const import LOGGER_PATH
6
9
7
10
_LOGGER = logging .getLogger (LOGGER_PATH + ".state" )
@@ -28,6 +31,14 @@ def __init__(self, hass, handler_func):
28
31
#
29
32
self .notify_var_last = {}
30
33
34
+ #
35
+ # `states` object from template extensions, working exactly the
36
+ # same way as documented.
37
+ # Attribute "_is_coroutine" is needed for "asyncio.iscoroutinefunction()"
38
+ # fails in eval.
39
+ self ._states = AllStates (hass )
40
+ setattr (self ._states , "_is_coroutine" , None )
41
+
31
42
def notify_add (self , var_names , queue ):
32
43
"""Register to notify state variables changes to be sent to queue."""
33
44
@@ -91,6 +102,34 @@ def set(self, var_name, value, attributes=None, **kwargs):
91
102
_LOGGER .debug ("setting %s = %s, attr = %s" , var_name , value , attributes )
92
103
self .hass .states .async_set (var_name , value , attributes )
93
104
105
+ def set_new (
106
+ self ,
107
+ entity_id : str ,
108
+ new_state : str ,
109
+ attributes : Optional [Dict ] = None ,
110
+ ) -> None :
111
+ """Set the state of an entity, add entity if it does not exist.
112
+
113
+ Attributes is an optional dict to specify attributes of this state.
114
+ To remove existing attributes, set to empty dict.
115
+ Default is to preserve them."""
116
+ if not valid_entity_id (entity_id ):
117
+ _LOGGER .error (
118
+ "invalid entity_id %s (should be 'domain.entity')" , entity_id
119
+ )
120
+ return
121
+ if attributes == {}:
122
+ _LOGGER .debug ("setting %s = %s, attr = %s" , entity_id , new_state , None )
123
+ self .hass .states .async_set (entity_id , new_state )
124
+ else :
125
+ old_state = self .hass .states .get (entity_id )
126
+ old_attrs = getattr (old_state , "attributes" , {})
127
+ updated_attrs = dict (old_attrs )
128
+ if attributes :
129
+ updated_attrs .update (attributes )
130
+ _LOGGER .debug ("setting %s = %s, attr = %s" , entity_id , new_state , updated_attrs )
131
+ self .hass .states .async_set (entity_id , new_state , updated_attrs )
132
+
94
133
def exist (self , var_name ):
95
134
"""Check if a state variable value or attribute exists in hass."""
96
135
parts = var_name .split ("." )
@@ -101,15 +140,22 @@ def exist(self, var_name):
101
140
102
141
def get (self , var_name ):
103
142
"""Get a state variable value or attribute from hass."""
104
- parts = var_name . split ( "." )
105
- if len ( parts ) != 2 and len ( parts ) != 3 :
106
- return None
107
- value = self . hass . states . get ( f" { parts [ 0 ] } . { parts [ 1 ] } " )
108
- if not value :
143
+ entity_id , attr_name = var_name , None
144
+ num_period = var_name . count ( "." )
145
+ if num_period == 2 :
146
+ entity_id , attr_name = var_name . rsplit ( "." , maxsplit = 1 )
147
+ if num_period > 2 or valid_entity_id ( entity_id ) is False :
109
148
return None
110
- if len (parts ) == 2 :
111
- return value .state
112
- return value .attributes .get (parts [2 ])
149
+
150
+ state = self .hass .states .get (entity_id )
151
+ if state and attr_name :
152
+ return state .attributes .get (attr_name )
153
+ else :
154
+ return getattr (state , "state" , None )
155
+
156
+ def get_new (self , entity_id ):
157
+ """Retrieve state of entity_id or None if not found. """
158
+ return self .hass .states .get (entity_id )
113
159
114
160
def completions (self , root ):
115
161
"""Return possible completions of state variables."""
@@ -139,7 +185,8 @@ def completions(self, root):
139
185
def register_functions (self ):
140
186
"""Register state functions."""
141
187
functions = {
142
- "state.get" : self .get ,
143
- "state.set" : self .set ,
188
+ "state.get" : self .get_new ,
189
+ "state.set" : self .set_new ,
190
+ "states" : self ._states ,
144
191
}
145
192
self .handler .register (functions )
0 commit comments