Skip to content

Commit 0a118b7

Browse files
authored
Merge pull request #332 from dlashua/fix-persist
fix for state.persist
2 parents f4f32f4 + 88d3c9c commit 0a118b7

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

custom_components/pyscript/entity.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Entity Classes"""
2+
from homeassistant.helpers.restore_state import RestoreEntity
3+
from homeassistant.helpers.typing import StateType
4+
from homeassistant.const import STATE_UNKNOWN
5+
6+
7+
class PyscriptEntity(RestoreEntity):
8+
"""Generic Pyscript Entity"""
9+
10+
_attr_extra_state_attributes: dict
11+
_attr_state: StateType = STATE_UNKNOWN
12+
13+
def set_state(self, state):
14+
"""Set the state"""
15+
self._attr_state = state
16+
17+
def set_attributes(self, attributes):
18+
"""Set Attributes"""
19+
self._attr_extra_state_attributes = attributes

custom_components/pyscript/state.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from homeassistant.helpers.service import async_get_all_descriptions
99

1010
from .const import LOGGER_PATH
11+
from .entity import PyscriptEntity
1112
from .function import Function
1213

1314
_LOGGER = logging.getLogger(LOGGER_PATH + ".state")
@@ -56,7 +57,7 @@ class State:
5657
#
5758
# pyscript vars which have already been registered as persisted
5859
#
59-
persisted_vars = set()
60+
persisted_vars = {}
6061

6162
#
6263
# other parameters of all services that have "entity_id" as a parameter
@@ -198,6 +199,10 @@ def set(cls, var_name, value=None, new_attributes=None, **kwargs):
198199
#
199200
cls.notify_var_last[var_name] = StateVal(cls.hass.states.get(var_name))
200201

202+
if var_name in cls.persisted_vars:
203+
cls.persisted_vars[var_name].set_state(value)
204+
cls.persisted_vars[var_name].set_attributes(new_attributes)
205+
201206
@classmethod
202207
def setattr(cls, var_attr_name, value):
203208
"""Set a state variable's attribute in hass."""
@@ -213,8 +218,13 @@ async def register_persist(cls, var_name):
213218
"""Register pyscript state variable to be persisted with RestoreState."""
214219
if var_name.startswith("pyscript.") and var_name not in cls.persisted_vars:
215220
restore_data = await RestoreStateData.async_get_instance(cls.hass)
216-
restore_data.async_restore_entity_added(var_name)
217-
cls.persisted_vars.add(var_name)
221+
this_entity = PyscriptEntity()
222+
this_entity.entity_id = var_name
223+
cls.persisted_vars[var_name] = this_entity
224+
try:
225+
restore_data.async_restore_entity_added(this_entity)
226+
except TypeError:
227+
restore_data.async_restore_entity_added(var_name)
218228

219229
@classmethod
220230
async def persist(cls, var_name, default_value=None, default_attributes=None):

0 commit comments

Comments
 (0)