Skip to content

Commit e5a746e

Browse files
committed
fixed shutdown trigger for case where it calls task.unique(); see #117
1 parent 8e84abe commit e5a746e

File tree

6 files changed

+82
-30
lines changed

6 files changed

+82
-30
lines changed

custom_components/pyscript/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,11 @@ async def jupyter_kernel_start(call):
166166
global_ctx_name, global_sym_table={"__name__": global_ctx_name}, manager=GlobalContextMgr
167167
)
168168
global_ctx.set_auto_start(True)
169-
170169
GlobalContextMgr.set(global_ctx_name, global_ctx)
171170

172171
ast_ctx = AstEval(global_ctx_name, global_ctx)
173172
Function.install_ast_funcs(ast_ctx)
174-
kernel = Kernel(call.data, ast_ctx, global_ctx_name)
173+
kernel = Kernel(call.data, ast_ctx, global_ctx, global_ctx_name)
175174
await kernel.session_start()
176175
hass.states.async_set(call.data["state_var"], json.dumps(kernel.get_ports()))
177176

@@ -215,7 +214,9 @@ async def hass_started(event):
215214
async def hass_stop(event):
216215
_LOGGER.debug("stopping global contexts")
217216
await unload_scripts(unload_all=True)
218-
# tell reaper task to exit (after other tasks are cancelled)
217+
# sync with waiter, and then tell waiter and reaper tasks to exit
218+
await Function.waiter_sync()
219+
await Function.waiter_stop()
219220
await Function.reaper_stop()
220221

221222
# Store callbacks to event listeners so we can unsubscribe on unload
@@ -255,7 +256,7 @@ async def unload_scripts(global_ctx_only=None, unload_all=False):
255256
ctx_delete[global_ctx_name] = global_ctx
256257
for global_ctx_name, global_ctx in ctx_delete.items():
257258
GlobalContextMgr.delete(global_ctx_name)
258-
await Function.reaper_sync()
259+
await Function.waiter_sync()
259260

260261

261262
@bind_hass
@@ -501,7 +502,7 @@ def import_recurse(ctx_name, visited, ctx2imports):
501502
if global_ctx_name not in ctx2files or not ctx2files[global_ctx_name].autoload:
502503
_LOGGER.info("Unloaded %s", global_ctx.get_file_path())
503504
GlobalContextMgr.delete(global_ctx_name)
504-
await Function.reaper_sync()
505+
await Function.waiter_sync()
505506

506507
#
507508
# now load the requested files, and files that depend on loaded files

custom_components/pyscript/function.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,15 @@ class Function:
4949

5050
#
5151
# task id of the task that cancels and waits for other tasks,
52-
# and also awaits on coros
5352
#
5453
task_reaper = None
54+
task_reaper_q = None
55+
56+
#
57+
# task id of the task that awaits for coros (used by shutdown triggers)
58+
#
59+
task_waiter = None
60+
task_waiter_q = None
5561

5662
def __init__(self):
5763
"""Warn on Function instantiation."""
@@ -97,10 +103,6 @@ async def task_reaper(reaper_q):
97103
await cmd[1]
98104
except asyncio.CancelledError:
99105
pass
100-
elif cmd[0] == "await":
101-
await cmd[1]
102-
elif cmd[0] == "sync":
103-
await cmd[1].put(0)
104106
else:
105107
_LOGGER.error("task_reaper: unknown command %s", cmd[0])
106108
except asyncio.CancelledError:
@@ -112,36 +114,71 @@ async def task_reaper(reaper_q):
112114
cls.task_reaper_q = asyncio.Queue(0)
113115
cls.task_reaper = Function.create_task(task_reaper(cls.task_reaper_q))
114116

117+
#
118+
# start a task which creates tasks to run coros, and then syncs on their completion;
119+
# this is used by the shutdown trigger
120+
#
121+
async def task_waiter(waiter_q):
122+
aws = []
123+
while True:
124+
try:
125+
cmd = await waiter_q.get()
126+
if cmd[0] == "exit":
127+
return
128+
if cmd[0] == "await":
129+
aws.append(cls.create_task(cmd[1]))
130+
elif cmd[0] == "sync":
131+
if len(aws) > 0:
132+
await asyncio.gather(*aws)
133+
aws = []
134+
await cmd[1].put(0)
135+
else:
136+
_LOGGER.error("task_waiter: unknown command %s", cmd[0])
137+
except asyncio.CancelledError:
138+
raise
139+
except Exception:
140+
_LOGGER.error("task_waiter: got exception %s", traceback.format_exc(-1))
141+
142+
if not cls.task_waiter:
143+
cls.task_waiter_q = asyncio.Queue(0)
144+
cls.task_waiter = Function.create_task(task_waiter(cls.task_waiter_q))
145+
_LOGGER.debug("task_waiter: started")
146+
147+
@classmethod
148+
def reaper_cancel(cls, task):
149+
"""Send a task to be canceled by the reaper."""
150+
cls.task_reaper_q.put_nowait(["cancel", task])
151+
115152
@classmethod
116153
async def reaper_stop(cls):
117-
"""Tell the reaper task to exit by sending a special task None."""
154+
"""Tell the reaper task to exit."""
118155
if cls.task_reaper:
119156
cls.task_reaper_q.put_nowait(["exit"])
120157
await cls.task_reaper
121158
cls.task_reaper = None
122159
cls.task_reaper_q = None
123160

124161
@classmethod
125-
def reaper_cancel(cls, task):
126-
"""Send a task to be canceled by the reaper."""
127-
cls.task_reaper_q.put_nowait(["cancel", task])
128-
129-
@classmethod
130-
def reaper_await(cls, coro):
131-
"""Send a coro to be awaited by the reaper."""
132-
cls.task_reaper_q.put_nowait(["await", coro])
162+
def waiter_await(cls, coro):
163+
"""Send a coro to be awaited by the waiter task."""
164+
cls.task_waiter_q.put_nowait(["await", coro])
133165

134166
@classmethod
135-
async def reaper_sync(cls):
136-
"""Wait until the reaper queue is empty."""
137-
sync_q = asyncio.Queue(0)
138-
sync_q.put_nowait(["sync", sync_q])
139-
await sync_q.get()
167+
async def waiter_sync(cls):
168+
"""Wait until the waiter queue is empty."""
169+
if cls.task_waiter:
170+
sync_q = asyncio.Queue(0)
171+
cls.task_waiter_q.put_nowait(["sync", sync_q])
172+
await sync_q.get()
140173

141174
@classmethod
142-
def reaper_exit(cls):
143-
"""Send an exit request to the reaper."""
144-
cls.task_reaper_q.put_nowait(["exit"])
175+
async def waiter_stop(cls):
176+
"""Tell the waiter task to exit."""
177+
if cls.task_waiter:
178+
cls.task_waiter_q.put_nowait(["exit"])
179+
await cls.task_waiter
180+
cls.task_waiter = None
181+
cls.task_waiter_q = None
145182

146183
@classmethod
147184
async def async_sleep(cls, duration):

custom_components/pyscript/jupyter_kernel.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,10 @@ def close(self):
189189
class Kernel:
190190
"""Define a Jupyter Kernel class."""
191191

192-
def __init__(self, config, ast_ctx, global_ctx_name):
192+
def __init__(self, config, ast_ctx, global_ctx, global_ctx_name):
193193
"""Initialize a Kernel object, one instance per session."""
194194
self.config = config.copy()
195+
self.global_ctx = global_ctx
195196
self.global_ctx_name = global_ctx_name
196197
self.ast_ctx = ast_ctx
197198

@@ -346,11 +347,15 @@ async def shell_handler(self, shell_socket, wire_msg):
346347
if code.startswith("_rwho_ls = %who_ls"):
347348
code = "print([])"
348349

350+
self.global_ctx.set_auto_start(False)
349351
self.ast_ctx.parse(code)
350352
exc = self.ast_ctx.get_exception_obj()
351353
if exc is None:
352354
result = await self.ast_ctx.eval()
353355
exc = self.ast_ctx.get_exception_obj()
356+
await Function.waiter_sync()
357+
self.global_ctx.set_auto_start(True)
358+
self.global_ctx.start()
354359
if exc:
355360
traceback_mesg = self.ast_ctx.get_exception_long().split("\n")
356361

custom_components/pyscript/trigger.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,11 +832,12 @@ def stop(self):
832832
Mqtt.notify_del(self.mqtt_trigger[0], self.notify_q)
833833
if self.task:
834834
Function.reaper_cancel(self.task)
835+
self.task = None
835836
if self.run_on_shutdown:
836837
notify_type = "shutdown"
837838
notify_info = {"trigger_type": "time", "trigger_time": "shutdown"}
838839
action_future = self.call_action(notify_type, notify_info, run_task=False)
839-
Function.reaper_await(action_future)
840+
Function.waiter_await(action_future)
840841

841842
def start(self):
842843
"""Start this trigger task."""

docs/new_features.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,6 @@ The new features since 1.1.0 in master include:
3636
Breaking changes since 1.1.0 include:
3737

3838
Bug fixes since 1.1.0 include:
39+
40+
- Fixed shutdown trigger for case where it calls ``task.unique()`` (#117).
41+
- Added error message for invalid ``@time_active`` argument, by @dlashua (#118).

tests/test_function.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1134,10 +1134,11 @@ async def test_service_call_params(hass):
11341134
assert call.call_args[1] == {"context": Context(id="test"), "blocking": False}
11351135

11361136
# Stop all tasks to avoid conflicts with other tests
1137+
await Function.waiter_stop()
11371138
await Function.reaper_stop()
11381139

11391140

1140-
async def test_serive_call_blocking(hass, caplog):
1141+
async def test_service_call_blocking(hass, caplog):
11411142
"""Test that service calls with blocking=True actually block."""
11421143
notify_q = asyncio.Queue(0)
11431144

@@ -1171,6 +1172,9 @@ def func_startup():
11711172
service.call("pyscript", "long_sleep", blocking=True, limit=1e-6)
11721173
pyscript.done = [seq_num, pyscript.var1]
11731174
1175+
seq_num += 1
1176+
pyscript.done = [seq_num]
1177+
11741178
@service
11751179
def long_sleep():
11761180
task.sleep(10000)
@@ -1187,3 +1191,4 @@ def service1():
11871191
assert literal_eval(await wait_until_done(notify_q)) == [2, "3"]
11881192
assert literal_eval(await wait_until_done(notify_q)) == [3, "4"]
11891193
assert literal_eval(await wait_until_done(notify_q)) == [4, "5"]
1194+
assert literal_eval(await wait_until_done(notify_q)) == [5]

0 commit comments

Comments
 (0)