diff --git a/custom_components/pyscript/__init__.py b/custom_components/pyscript/__init__.py index 32c4667..df68a65 100644 --- a/custom_components/pyscript/__init__.py +++ b/custom_components/pyscript/__init__.py @@ -23,10 +23,12 @@ from .const import ( CONF_ALLOW_ALL_IMPORTS, CONF_HASS_IS_GLOBAL, + CONFIG_ENTRY, DOMAIN, FOLDER, LOGGER_PATH, SERVICE_JUPYTER_KERNEL_START, + UNSUB_LISTENERS, ) from .eval import AstEval from .event import Event @@ -87,7 +89,8 @@ async def async_setup_entry(hass, config_entry): await hass.async_add_executor_job(os.makedirs, pyscript_folder) hass.data.setdefault(DOMAIN, {}) - hass.data[DOMAIN] = config_entry + hass.data[DOMAIN][CONFIG_ENTRY] = config_entry + hass.data[DOMAIN][UNSUB_LISTENERS] = [] State.set_pyscript_config(config_entry.data) @@ -112,15 +115,7 @@ async def reload_scripts_handler(call): State.set_pyscript_config(config_entry.data) - ctx_delete = {} - for global_ctx_name, global_ctx in GlobalContextMgr.items(): - idx = global_ctx_name.find(".") - if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps", "modules"}: - continue - global_ctx.stop() - ctx_delete[global_ctx_name] = global_ctx - for global_ctx_name, global_ctx in ctx_delete.items(): - await GlobalContextMgr.delete(global_ctx_name) + await unload_scripts() await load_scripts(hass, config_entry.data) @@ -183,7 +178,7 @@ async def state_changed(event): async def start_triggers(event): _LOGGER.debug("adding state changed listener and starting triggers") - hass.bus.async_listen(EVENT_STATE_CHANGED, state_changed) + hass.data[DOMAIN][UNSUB_LISTENERS].append(hass.bus.async_listen(EVENT_STATE_CHANGED, state_changed)) for global_ctx_name, global_ctx in GlobalContextMgr.items(): idx = global_ctx_name.find(".") if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps"}: @@ -202,18 +197,44 @@ async def stop_triggers(event): # tell reaper task to exit (after other tasks are cancelled) await Function.reaper_stop() - hass.bus.async_listen(EVENT_HOMEASSISTANT_STARTED, start_triggers) - hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, stop_triggers) + # Store callbacks to event listeners so we can unsubscribe on unload + hass.data[DOMAIN][UNSUB_LISTENERS].append( + hass.bus.async_listen(EVENT_HOMEASSISTANT_STARTED, start_triggers) + ) + hass.data[DOMAIN][UNSUB_LISTENERS].append(hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, stop_triggers)) return True async def async_unload_entry(hass, config_entry): """Unload a config entry.""" + # Unload scripts + await unload_scripts() + + # tell reaper task to exit (after other tasks are cancelled) + await Function.reaper_stop() + + # Unsubscribe from listeners + for unsub_listener in hass.data[DOMAIN][UNSUB_LISTENERS]: + unsub_listener() + hass.data.pop(DOMAIN) return True +async def unload_scripts(): + """Unload all scripts from GlobalContextMgr.""" + ctx_delete = {} + for global_ctx_name, global_ctx in GlobalContextMgr.items(): + idx = global_ctx_name.find(".") + if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps", "modules"}: + continue + global_ctx.stop() + ctx_delete[global_ctx_name] = global_ctx + for global_ctx_name, global_ctx in ctx_delete.items(): + await GlobalContextMgr.delete(global_ctx_name) + + @bind_hass async def load_scripts(hass, data): """Load all python scripts in FOLDER.""" diff --git a/custom_components/pyscript/const.py b/custom_components/pyscript/const.py index 78f0660..d453189 100644 --- a/custom_components/pyscript/const.py +++ b/custom_components/pyscript/const.py @@ -2,6 +2,9 @@ DOMAIN = "pyscript" +CONFIG_ENTRY = "config_entry" +UNSUB_LISTENERS = "unsub_listeners" + FOLDER = "pyscript" CONF_ALLOW_ALL_IMPORTS = "allow_all_imports" diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index 4906174..afacb4a 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -16,7 +16,14 @@ from homeassistant.const import SERVICE_RELOAD from homeassistant.helpers.service import async_set_service_schema -from .const import ALLOWED_IMPORTS, CONF_ALLOW_ALL_IMPORTS, DOMAIN, LOGGER_PATH, SERVICE_JUPYTER_KERNEL_START +from .const import ( + ALLOWED_IMPORTS, + CONF_ALLOW_ALL_IMPORTS, + CONFIG_ENTRY, + DOMAIN, + LOGGER_PATH, + SERVICE_JUPYTER_KERNEL_START, +) from .function import Function from .state import State @@ -745,7 +752,7 @@ def __init__(self, name, global_ctx, logger_name=None): self.logger_handlers = set() self.logger = None self.set_logger_name(logger_name if logger_name is not None else self.name) - self.config_entry = Function.hass.data.get(DOMAIN, {}) + self.config_entry = Function.hass.data.get(DOMAIN, {}).get(CONFIG_ENTRY, {}) async def ast_not_implemented(self, arg, *args): """Raise NotImplementedError exception for unimplemented AST types.""" diff --git a/custom_components/pyscript/global_ctx.py b/custom_components/pyscript/global_ctx.py index 4a70631..eae2beb 100644 --- a/custom_components/pyscript/global_ctx.py +++ b/custom_components/pyscript/global_ctx.py @@ -4,7 +4,7 @@ import os from types import ModuleType -from .const import CONF_HASS_IS_GLOBAL, DOMAIN, FOLDER, LOGGER_PATH +from .const import CONF_HASS_IS_GLOBAL, CONFIG_ENTRY, DOMAIN, FOLDER, LOGGER_PATH from .eval import AstEval from .function import Function from .trigger import TrigInfo @@ -26,7 +26,7 @@ def __init__(self, name, global_sym_table=None, manager=None, rel_import_path=No self.auto_start = False self.module = None self.rel_import_path = rel_import_path - config_entry = Function.hass.data.get(DOMAIN, {}) + config_entry = Function.hass.data.get(DOMAIN, {}).get(CONFIG_ENTRY, {}) if config_entry.data.get(CONF_HASS_IS_GLOBAL, False): # # expose hass as a global variable if configured diff --git a/tests/test_config_flow.py b/tests/test_config_flow.py index c2a0080..5df2ec2 100644 --- a/tests/test_config_flow.py +++ b/tests/test_config_flow.py @@ -12,14 +12,14 @@ _LOGGER = logging.getLogger(__name__) -@pytest.fixture(name="pyscript_bypass_setup", autouse=True) +@pytest.fixture(name="pyscript_bypass_setup") def pyscript_bypass_setup_fixture(): """Mock component setup.""" with patch("custom_components.pyscript.async_setup_entry", return_value=True): yield -async def test_user_flow_minimum_fields(hass): +async def test_user_flow_minimum_fields(hass, pyscript_bypass_setup): """Test user config flow with minimum fields.""" # test form shows result = await hass.config_entries.flow.async_init(DOMAIN, context={"source": SOURCE_USER}) @@ -35,7 +35,7 @@ async def test_user_flow_minimum_fields(hass): assert not result["data"][CONF_HASS_IS_GLOBAL] -async def test_user_flow_all_fields(hass): +async def test_user_flow_all_fields(hass, pyscript_bypass_setup): """Test user config flow with all fields.""" # test form shows result = await hass.config_entries.flow.async_init(DOMAIN, context={"source": SOURCE_USER}) @@ -53,7 +53,7 @@ async def test_user_flow_all_fields(hass): assert result["data"][CONF_HASS_IS_GLOBAL] -async def test_user_already_configured(hass): +async def test_user_already_configured(hass, pyscript_bypass_setup): """Test service is already configured during user setup.""" result = await hass.config_entries.flow.async_init( DOMAIN, @@ -82,7 +82,7 @@ async def test_import_flow(hass, pyscript_bypass_setup): assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY -async def test_import_flow_update_allow_all_imports(hass): +async def test_import_flow_update_allow_all_imports(hass, pyscript_bypass_setup): """Test import config flow updates existing entry when `allow_all_imports` has changed.""" result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT}, data=PYSCRIPT_SCHEMA({}) @@ -100,7 +100,7 @@ async def test_import_flow_update_allow_all_imports(hass): assert result["reason"] == "updated_entry" -async def test_import_flow_update_apps_from_none(hass): +async def test_import_flow_update_apps_from_none(hass, pyscript_bypass_setup): """Test import config flow updates existing entry when `apps` has changed from None to something.""" result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT}, data=PYSCRIPT_SCHEMA({}) @@ -116,7 +116,7 @@ async def test_import_flow_update_apps_from_none(hass): assert result["reason"] == "updated_entry" -async def test_import_flow_update_apps_to_none(hass): +async def test_import_flow_update_apps_to_none(hass, pyscript_bypass_setup): """Test import config flow updates existing entry when `apps` has changed from something to None.""" result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT}, data=PYSCRIPT_SCHEMA({"apps": {"test_app": {"param": 1}}}) @@ -130,7 +130,7 @@ async def test_import_flow_update_apps_to_none(hass): assert result["reason"] == "updated_entry" -async def test_import_flow_no_update(hass): +async def test_import_flow_no_update(hass, pyscript_bypass_setup): """Test import config flow doesn't update existing entry when data is same.""" result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT}, data=PYSCRIPT_SCHEMA({}) @@ -146,7 +146,7 @@ async def test_import_flow_no_update(hass): assert result["reason"] == "already_configured" -async def test_import_flow_update_user(hass): +async def test_import_flow_update_user(hass, pyscript_bypass_setup): """Test import config flow update excludes `allow_all_imports` from being updated when updated entry was a user entry.""" result = await hass.config_entries.flow.async_init( DOMAIN, @@ -170,7 +170,7 @@ async def test_import_flow_update_user(hass): } -async def test_import_flow_update_import(hass): +async def test_import_flow_update_import(hass, pyscript_bypass_setup): """Test import config flow update includes `allow_all_imports` in update when updated entry was imported entry.""" result = await hass.config_entries.flow.async_init( DOMAIN, @@ -190,7 +190,7 @@ async def test_import_flow_update_import(hass): assert hass.config_entries.async_entries(DOMAIN)[0].data == {"apps": {"test_app": {"param": 1}}} -async def test_options_flow_import(hass): +async def test_options_flow_import(hass, pyscript_bypass_setup): """Test options flow aborts because configuration needs to be managed via configuration.yaml.""" result = await hass.config_entries.flow.async_init( DOMAIN, @@ -212,7 +212,7 @@ async def test_options_flow_import(hass): assert result["title"] == "" -async def test_options_flow_user_change(hass): +async def test_options_flow_user_change(hass, pyscript_bypass_setup): """Test options flow updates config entry when options change.""" result = await hass.config_entries.flow.async_init( DOMAIN, @@ -240,7 +240,7 @@ async def test_options_flow_user_change(hass): assert entry.data[CONF_HASS_IS_GLOBAL] is False -async def test_options_flow_user_no_change(hass): +async def test_options_flow_user_no_change(hass, pyscript_bypass_setup): """Test options flow aborts when options don't change.""" result = await hass.config_entries.flow.async_init( DOMAIN, @@ -267,3 +267,19 @@ async def test_options_flow_user_no_change(hass): assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result["title"] == "" + + +async def test_config_entry_reload(hass): + """Test that config entry reload does not duplicate listeners.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + data=PYSCRIPT_SCHEMA({CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True}), + ) + await hass.async_block_till_done() + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + entry = result["result"] + listeners = hass.bus.async_listeners() + await hass.config_entries.async_reload(entry.entry_id) + await hass.async_block_till_done() + assert listeners == hass.bus.async_listeners() diff --git a/tests/test_unit_eval.py b/tests/test_unit_eval.py index b0f08c9..38f320b 100644 --- a/tests/test_unit_eval.py +++ b/tests/test_unit_eval.py @@ -1,6 +1,6 @@ """Unit tests for Python interpreter.""" -from custom_components.pyscript.const import CONF_ALLOW_ALL_IMPORTS, DOMAIN +from custom_components.pyscript.const import CONF_ALLOW_ALL_IMPORTS, CONFIG_ENTRY, DOMAIN from custom_components.pyscript.eval import AstEval from custom_components.pyscript.function import Function from custom_components.pyscript.global_ctx import GlobalContext, GlobalContextMgr @@ -884,7 +884,7 @@ async def run_one_test(test_data): async def test_eval(hass): """Test interpreter.""" - hass.data[DOMAIN] = MockConfigEntry(domain=DOMAIN, data={CONF_ALLOW_ALL_IMPORTS: False}) + hass.data[DOMAIN] = {CONFIG_ENTRY: MockConfigEntry(domain=DOMAIN, data={CONF_ALLOW_ALL_IMPORTS: False})} Function.init(hass) State.init(hass) State.register_functions() @@ -1065,7 +1065,7 @@ async def run_one_test_exception(test_data): async def test_eval_exceptions(hass): """Test interpreter exceptions.""" - hass.data[DOMAIN] = MockConfigEntry(domain=DOMAIN, data={CONF_ALLOW_ALL_IMPORTS: False}) + hass.data[DOMAIN] = {CONFIG_ENTRY: MockConfigEntry(domain=DOMAIN, data={CONF_ALLOW_ALL_IMPORTS: False})} Function.init(hass) State.init(hass) State.register_functions()