Skip to content

Unload scripts and unsubscribe from event listeners during entry unload #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions custom_components/pyscript/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from .const import (
CONF_ALLOW_ALL_IMPORTS,
CONF_HASS_IS_GLOBAL,
CONFIG_ENTRY,
DOMAIN,
FOLDER,
LOGGER_PATH,
SERVICE_JUPYTER_KERNEL_START,
UNSUB_LISTENERS,
)
from .eval import AstEval
from .event import Event
Expand Down Expand Up @@ -87,7 +89,8 @@ async def async_setup_entry(hass, config_entry):
await hass.async_add_executor_job(os.makedirs, pyscript_folder)

hass.data.setdefault(DOMAIN, {})
hass.data[DOMAIN] = config_entry
hass.data[DOMAIN][CONFIG_ENTRY] = config_entry
hass.data[DOMAIN][UNSUB_LISTENERS] = []

State.set_pyscript_config(config_entry.data)

Expand All @@ -112,15 +115,7 @@ async def reload_scripts_handler(call):

State.set_pyscript_config(config_entry.data)

ctx_delete = {}
for global_ctx_name, global_ctx in GlobalContextMgr.items():
idx = global_ctx_name.find(".")
if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps", "modules"}:
continue
global_ctx.stop()
ctx_delete[global_ctx_name] = global_ctx
for global_ctx_name, global_ctx in ctx_delete.items():
await GlobalContextMgr.delete(global_ctx_name)
await unload_scripts()

await load_scripts(hass, config_entry.data)

Expand Down Expand Up @@ -183,7 +178,7 @@ async def state_changed(event):

async def start_triggers(event):
_LOGGER.debug("adding state changed listener and starting triggers")
hass.bus.async_listen(EVENT_STATE_CHANGED, state_changed)
hass.data[DOMAIN][UNSUB_LISTENERS].append(hass.bus.async_listen(EVENT_STATE_CHANGED, state_changed))
for global_ctx_name, global_ctx in GlobalContextMgr.items():
idx = global_ctx_name.find(".")
if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps"}:
Expand All @@ -202,18 +197,44 @@ async def stop_triggers(event):
# tell reaper task to exit (after other tasks are cancelled)
await Function.reaper_stop()

hass.bus.async_listen(EVENT_HOMEASSISTANT_STARTED, start_triggers)
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, stop_triggers)
# Store callbacks to event listeners so we can unsubscribe on unload
hass.data[DOMAIN][UNSUB_LISTENERS].append(
hass.bus.async_listen(EVENT_HOMEASSISTANT_STARTED, start_triggers)
)
hass.data[DOMAIN][UNSUB_LISTENERS].append(hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, stop_triggers))

return True


async def async_unload_entry(hass, config_entry):
"""Unload a config entry."""
# Unload scripts
await unload_scripts()

# tell reaper task to exit (after other tasks are cancelled)
await Function.reaper_stop()

# Unsubscribe from listeners
for unsub_listener in hass.data[DOMAIN][UNSUB_LISTENERS]:
unsub_listener()

hass.data.pop(DOMAIN)
return True


async def unload_scripts():
"""Unload all scripts from GlobalContextMgr."""
ctx_delete = {}
for global_ctx_name, global_ctx in GlobalContextMgr.items():
idx = global_ctx_name.find(".")
if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps", "modules"}:
continue
global_ctx.stop()
ctx_delete[global_ctx_name] = global_ctx
for global_ctx_name, global_ctx in ctx_delete.items():
await GlobalContextMgr.delete(global_ctx_name)


@bind_hass
async def load_scripts(hass, data):
"""Load all python scripts in FOLDER."""
Expand Down
3 changes: 3 additions & 0 deletions custom_components/pyscript/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

DOMAIN = "pyscript"

CONFIG_ENTRY = "config_entry"
UNSUB_LISTENERS = "unsub_listeners"

FOLDER = "pyscript"

CONF_ALLOW_ALL_IMPORTS = "allow_all_imports"
Expand Down
11 changes: 9 additions & 2 deletions custom_components/pyscript/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
from homeassistant.const import SERVICE_RELOAD
from homeassistant.helpers.service import async_set_service_schema

from .const import ALLOWED_IMPORTS, CONF_ALLOW_ALL_IMPORTS, DOMAIN, LOGGER_PATH, SERVICE_JUPYTER_KERNEL_START
from .const import (
ALLOWED_IMPORTS,
CONF_ALLOW_ALL_IMPORTS,
CONFIG_ENTRY,
DOMAIN,
LOGGER_PATH,
SERVICE_JUPYTER_KERNEL_START,
)
from .function import Function
from .state import State

Expand Down Expand Up @@ -745,7 +752,7 @@ def __init__(self, name, global_ctx, logger_name=None):
self.logger_handlers = set()
self.logger = None
self.set_logger_name(logger_name if logger_name is not None else self.name)
self.config_entry = Function.hass.data.get(DOMAIN, {})
self.config_entry = Function.hass.data.get(DOMAIN, {}).get(CONFIG_ENTRY, {})

async def ast_not_implemented(self, arg, *args):
"""Raise NotImplementedError exception for unimplemented AST types."""
Expand Down
4 changes: 2 additions & 2 deletions custom_components/pyscript/global_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from types import ModuleType

from .const import CONF_HASS_IS_GLOBAL, DOMAIN, FOLDER, LOGGER_PATH
from .const import CONF_HASS_IS_GLOBAL, CONFIG_ENTRY, DOMAIN, FOLDER, LOGGER_PATH
from .eval import AstEval
from .function import Function
from .trigger import TrigInfo
Expand All @@ -26,7 +26,7 @@ def __init__(self, name, global_sym_table=None, manager=None, rel_import_path=No
self.auto_start = False
self.module = None
self.rel_import_path = rel_import_path
config_entry = Function.hass.data.get(DOMAIN, {})
config_entry = Function.hass.data.get(DOMAIN, {}).get(CONFIG_ENTRY, {})
if config_entry.data.get(CONF_HASS_IS_GLOBAL, False):
#
# expose hass as a global variable if configured
Expand Down
42 changes: 29 additions & 13 deletions tests/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
_LOGGER = logging.getLogger(__name__)


@pytest.fixture(name="pyscript_bypass_setup", autouse=True)
@pytest.fixture(name="pyscript_bypass_setup")
def pyscript_bypass_setup_fixture():
"""Mock component setup."""
with patch("custom_components.pyscript.async_setup_entry", return_value=True):
yield


async def test_user_flow_minimum_fields(hass):
async def test_user_flow_minimum_fields(hass, pyscript_bypass_setup):
"""Test user config flow with minimum fields."""
# test form shows
result = await hass.config_entries.flow.async_init(DOMAIN, context={"source": SOURCE_USER})
Expand All @@ -35,7 +35,7 @@ async def test_user_flow_minimum_fields(hass):
assert not result["data"][CONF_HASS_IS_GLOBAL]


async def test_user_flow_all_fields(hass):
async def test_user_flow_all_fields(hass, pyscript_bypass_setup):
"""Test user config flow with all fields."""
# test form shows
result = await hass.config_entries.flow.async_init(DOMAIN, context={"source": SOURCE_USER})
Expand All @@ -53,7 +53,7 @@ async def test_user_flow_all_fields(hass):
assert result["data"][CONF_HASS_IS_GLOBAL]


async def test_user_already_configured(hass):
async def test_user_already_configured(hass, pyscript_bypass_setup):
"""Test service is already configured during user setup."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
Expand Down Expand Up @@ -82,7 +82,7 @@ async def test_import_flow(hass, pyscript_bypass_setup):
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY


async def test_import_flow_update_allow_all_imports(hass):
async def test_import_flow_update_allow_all_imports(hass, pyscript_bypass_setup):
"""Test import config flow updates existing entry when `allow_all_imports` has changed."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}, data=PYSCRIPT_SCHEMA({})
Expand All @@ -100,7 +100,7 @@ async def test_import_flow_update_allow_all_imports(hass):
assert result["reason"] == "updated_entry"


async def test_import_flow_update_apps_from_none(hass):
async def test_import_flow_update_apps_from_none(hass, pyscript_bypass_setup):
"""Test import config flow updates existing entry when `apps` has changed from None to something."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}, data=PYSCRIPT_SCHEMA({})
Expand All @@ -116,7 +116,7 @@ async def test_import_flow_update_apps_from_none(hass):
assert result["reason"] == "updated_entry"


async def test_import_flow_update_apps_to_none(hass):
async def test_import_flow_update_apps_to_none(hass, pyscript_bypass_setup):
"""Test import config flow updates existing entry when `apps` has changed from something to None."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}, data=PYSCRIPT_SCHEMA({"apps": {"test_app": {"param": 1}}})
Expand All @@ -130,7 +130,7 @@ async def test_import_flow_update_apps_to_none(hass):
assert result["reason"] == "updated_entry"


async def test_import_flow_no_update(hass):
async def test_import_flow_no_update(hass, pyscript_bypass_setup):
"""Test import config flow doesn't update existing entry when data is same."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}, data=PYSCRIPT_SCHEMA({})
Expand All @@ -146,7 +146,7 @@ async def test_import_flow_no_update(hass):
assert result["reason"] == "already_configured"


async def test_import_flow_update_user(hass):
async def test_import_flow_update_user(hass, pyscript_bypass_setup):
"""Test import config flow update excludes `allow_all_imports` from being updated when updated entry was a user entry."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
Expand All @@ -170,7 +170,7 @@ async def test_import_flow_update_user(hass):
}


async def test_import_flow_update_import(hass):
async def test_import_flow_update_import(hass, pyscript_bypass_setup):
"""Test import config flow update includes `allow_all_imports` in update when updated entry was imported entry."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
Expand All @@ -190,7 +190,7 @@ async def test_import_flow_update_import(hass):
assert hass.config_entries.async_entries(DOMAIN)[0].data == {"apps": {"test_app": {"param": 1}}}


async def test_options_flow_import(hass):
async def test_options_flow_import(hass, pyscript_bypass_setup):
"""Test options flow aborts because configuration needs to be managed via configuration.yaml."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
Expand All @@ -212,7 +212,7 @@ async def test_options_flow_import(hass):
assert result["title"] == ""


async def test_options_flow_user_change(hass):
async def test_options_flow_user_change(hass, pyscript_bypass_setup):
"""Test options flow updates config entry when options change."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
Expand Down Expand Up @@ -240,7 +240,7 @@ async def test_options_flow_user_change(hass):
assert entry.data[CONF_HASS_IS_GLOBAL] is False


async def test_options_flow_user_no_change(hass):
async def test_options_flow_user_no_change(hass, pyscript_bypass_setup):
"""Test options flow aborts when options don't change."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
Expand All @@ -267,3 +267,19 @@ async def test_options_flow_user_no_change(hass):

assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["title"] == ""


async def test_config_entry_reload(hass):
"""Test that config entry reload does not duplicate listeners."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER},
data=PYSCRIPT_SCHEMA({CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True}),
)
await hass.async_block_till_done()
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
entry = result["result"]
listeners = hass.bus.async_listeners()
await hass.config_entries.async_reload(entry.entry_id)
await hass.async_block_till_done()
assert listeners == hass.bus.async_listeners()
6 changes: 3 additions & 3 deletions tests/test_unit_eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Unit tests for Python interpreter."""

from custom_components.pyscript.const import CONF_ALLOW_ALL_IMPORTS, DOMAIN
from custom_components.pyscript.const import CONF_ALLOW_ALL_IMPORTS, CONFIG_ENTRY, DOMAIN
from custom_components.pyscript.eval import AstEval
from custom_components.pyscript.function import Function
from custom_components.pyscript.global_ctx import GlobalContext, GlobalContextMgr
Expand Down Expand Up @@ -884,7 +884,7 @@ async def run_one_test(test_data):

async def test_eval(hass):
"""Test interpreter."""
hass.data[DOMAIN] = MockConfigEntry(domain=DOMAIN, data={CONF_ALLOW_ALL_IMPORTS: False})
hass.data[DOMAIN] = {CONFIG_ENTRY: MockConfigEntry(domain=DOMAIN, data={CONF_ALLOW_ALL_IMPORTS: False})}
Function.init(hass)
State.init(hass)
State.register_functions()
Expand Down Expand Up @@ -1065,7 +1065,7 @@ async def run_one_test_exception(test_data):

async def test_eval_exceptions(hass):
"""Test interpreter exceptions."""
hass.data[DOMAIN] = MockConfigEntry(domain=DOMAIN, data={CONF_ALLOW_ALL_IMPORTS: False})
hass.data[DOMAIN] = {CONFIG_ENTRY: MockConfigEntry(domain=DOMAIN, data={CONF_ALLOW_ALL_IMPORTS: False})}
Function.init(hass)
State.init(hass)
State.register_functions()
Expand Down