Skip to content

Commit 713f949

Browse files
authored
Merge pull request #85 from raman325/service_blocking
Add support for service blocking
2 parents beaf1b5 + 1dfab0f commit 713f949

File tree

6 files changed

+133
-21
lines changed

6 files changed

+133
-21
lines changed

custom_components/pyscript/function.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,18 @@ def service_has_service(cls, domain, name):
190190
async def service_call(cls, domain, name, **kwargs):
191191
"""Implement service.call()."""
192192
curr_task = asyncio.current_task()
193-
if "context" in kwargs and isinstance(kwargs["context"], Context):
194-
context = kwargs["context"]
195-
del kwargs["context"]
196-
else:
197-
context = cls.task2context.get(curr_task, None)
198-
199-
await cls.hass.services.async_call(domain, name, kwargs, context=context)
193+
hass_args = {}
194+
for keyword, typ, default in [
195+
("context", [Context], cls.task2context.get(curr_task, None)),
196+
("blocking", [bool], None),
197+
("limit", [float, int], None),
198+
]:
199+
if keyword in kwargs and type(kwargs[keyword]) in typ:
200+
hass_args[keyword] = kwargs.pop(keyword)
201+
elif default:
202+
hass_args[keyword] = default
203+
204+
await cls.hass.services.async_call(domain, name, kwargs, **hass_args)
200205

201206
@classmethod
202207
async def service_completions(cls, root):
@@ -255,15 +260,21 @@ def get(cls, name):
255260
def service_call_factory(domain, service):
256261
async def service_call(*args, **kwargs):
257262
curr_task = asyncio.current_task()
258-
if "context" in kwargs and isinstance(kwargs["context"], Context):
259-
context = kwargs["context"]
260-
del kwargs["context"]
261-
else:
262-
context = cls.task2context.get(curr_task, None)
263+
hass_args = {}
264+
for keyword, typ, default in [
265+
("context", [Context], cls.task2context.get(curr_task, None)),
266+
("blocking", [bool], None),
267+
("limit", [float, int], None),
268+
]:
269+
if keyword in kwargs and type(kwargs[keyword]) in typ:
270+
hass_args[keyword] = kwargs.pop(keyword)
271+
elif default:
272+
hass_args[keyword] = default
263273

264274
if len(args) != 0:
265275
raise (TypeError, f"service {domain}.{service} takes no positional arguments")
266-
await cls.hass.services.async_call(domain, service, kwargs, context=context)
276+
277+
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
267278

268279
return service_call
269280

custom_components/pyscript/state.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,16 @@ async def get(cls, var_name):
258258
def service_call_factory(domain, service, entity_id, params):
259259
async def service_call(*args, **kwargs):
260260
curr_task = asyncio.current_task()
261-
if "context" in kwargs and isinstance(kwargs["context"], Context):
262-
context = kwargs["context"]
263-
del kwargs["context"]
264-
else:
265-
context = Function.task2context.get(curr_task, None)
261+
hass_args = {}
262+
for keyword, typ, default in [
263+
("context", [Context], Function.task2context.get(curr_task, None)),
264+
("blocking", [bool], None),
265+
("limit", [float, int], None),
266+
]:
267+
if keyword in kwargs and type(kwargs[keyword]) in typ:
268+
hass_args[keyword] = kwargs.pop(keyword)
269+
elif default:
270+
hass_args[keyword] = default
266271

267272
kwargs["entity_id"] = entity_id
268273
if len(args) == 1 and len(params) == 1:
@@ -273,7 +278,7 @@ async def service_call(*args, **kwargs):
273278
kwargs[param_name] = args[0]
274279
elif len(args) != 0:
275280
raise TypeError(f"service {domain}.{service} takes no positional arguments")
276-
await cls.hass.services.async_call(domain, service, kwargs, context=context)
281+
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
277282

278283
return service_call
279284

docs/new_features.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ The new features since 0.32 in master include:
5454
messages get displayed in VSCode. One benign but unresolved bug with VSCode is that when you connect
5555
to the pyscript kernel, VSCode starts a second pyscript Jupyter kernel, before shutting that second one
5656
down.
57+
- Service calls now accept ``blocking`` and ``limit`` parameters. The default behavior for a service call is
58+
to run it in the background, but using ``blocking=True`` will force a task to wait up to ``limit`` seconds
59+
for the service call to finish executing before continuing. Contributed by @raman325 (#85)
5760

5861
The bug fixes since 0.32 in master include:
5962

docs/reference.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ called by:
210210
211211
service.call("myservice", "flash_light", light_name="front", light_color="red")
212212
213+
When making a service call, either using the ``service.call`` function or the service name as the
214+
function, you can optionally pass the keyword argument ``blocking=True`` if you would like to wait
215+
for the service to finish execution before continuing execution in your function. You can also
216+
specify a timeout for a blocking service call using the ``limit=<number_of_seconds>`` parameters.
217+
213218
Firing events
214219
-------------
215220

@@ -691,8 +696,11 @@ or ``float()``). Attributes keep their native type.
691696
Service Calls
692697
^^^^^^^^^^^^^
693698

694-
``service.call(domain, name, **kwargs)``
695-
calls the service ``domain.name`` with the given keyword arguments as parameters.
699+
``service.call(domain, name, blocking=False, limit=10, **kwargs)``
700+
calls the service ``domain.name`` with the given keyword arguments as parameters. If ``blocking``
701+
is ``True``, pyscript will wait for the service to finish executing before continuing the current
702+
routine, or will wait a maximum of the number of seconds specified in the `limit` keyword
703+
argument.
696704
``service.has_service(domain, name)``
697705
returns whether the service ``domain.name`` exists.
698706

tests/test_function.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytest_homeassistant_custom_component.async_mock import MagicMock, Mock, mock_open, patch
1313

1414
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_STATE_CHANGED
15+
from homeassistant.core import Context
1516
from homeassistant.setup import async_setup_component
1617

1718

@@ -804,3 +805,44 @@ def set_add(entity_id=None, val1=None, val2=None):
804805
assert literal_eval(await wait_until_done(notify_q)) == [4, "pyscript.var1", "32"]
805806
assert literal_eval(await wait_until_done(notify_q)) == [5, "pyscript.var1", "50", "HomeAssistant"]
806807
assert "TypeError: service pyscript.set_add takes no positional arguments" in caplog.text
808+
809+
810+
async def test_service_call_params(hass):
811+
"""Test that hass params get set properly on service calls."""
812+
with patch.object(hass.services, "async_call") as call, patch.object(
813+
Function, "service_has_service", return_value=True
814+
):
815+
Function.init(hass)
816+
await Function.service_call(
817+
"test", "test", context=Context(id="test"), blocking=True, limit=1, other_service_data="test"
818+
)
819+
assert call.called
820+
assert call.call_args[0] == ("test", "test", {"other_service_data": "test"})
821+
assert call.call_args[1] == {"context": Context(id="test"), "blocking": True, "limit": 1}
822+
call.reset_mock()
823+
824+
await Function.service_call(
825+
"test", "test", context=Context(id="test"), blocking=False, other_service_data="test"
826+
)
827+
assert call.called
828+
assert call.call_args[0] == ("test", "test", {"other_service_data": "test"})
829+
assert call.call_args[1] == {"context": Context(id="test"), "blocking": False}
830+
call.reset_mock()
831+
832+
await Function.get("test.test")(
833+
context=Context(id="test"), blocking=True, limit=1, other_service_data="test"
834+
)
835+
assert call.called
836+
assert call.call_args[0] == ("test", "test", {"other_service_data": "test"})
837+
assert call.call_args[1] == {"context": Context(id="test"), "blocking": True, "limit": 1}
838+
call.reset_mock()
839+
840+
await Function.get("test.test")(
841+
context=Context(id="test"), blocking=False, other_service_data="test"
842+
)
843+
assert call.called
844+
assert call.call_args[0] == ("test", "test", {"other_service_data": "test"})
845+
assert call.call_args[1] == {"context": Context(id="test"), "blocking": False}
846+
847+
# Stop all tasks to avoid conflicts with other tests
848+
await Function.reaper_stop()

tests/test_state.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Test pyscripts test module."""
2+
from custom_components.pyscript.state import State
3+
from pytest_homeassistant_custom_component.async_mock import patch
4+
5+
from homeassistant.core import Context
6+
from homeassistant.helpers.state import State as HassState
7+
8+
9+
async def test_service_call(hass):
10+
"""Test calling a service using the entity_id as a property."""
11+
with patch(
12+
"custom_components.pyscript.state.async_get_all_descriptions",
13+
return_value={
14+
"test": {
15+
"test": {"description": None, "fields": {"entity_id": "blah", "other_service_data": "blah"}}
16+
}
17+
},
18+
), patch.object(hass.states, "get", return_value=HassState("test.entity", "True")), patch.object(
19+
hass.services, "async_call"
20+
) as call:
21+
State.init(hass)
22+
await State.get_service_params()
23+
24+
func = await State.get("test.entity.test")
25+
await func(context=Context(id="test"), blocking=True, limit=1, other_service_data="test")
26+
assert call.called
27+
assert call.call_args[0] == (
28+
"test",
29+
"test",
30+
{"other_service_data": "test", "entity_id": "test.entity"},
31+
)
32+
assert call.call_args[1] == {"context": Context(id="test"), "blocking": True, "limit": 1}
33+
call.reset_mock()
34+
35+
func = await State.get("test.entity.test")
36+
await func(context=Context(id="test"), blocking=False, other_service_data="test")
37+
assert call.called
38+
assert call.call_args[0] == (
39+
"test",
40+
"test",
41+
{"other_service_data": "test", "entity_id": "test.entity"},
42+
)
43+
assert call.call_args[1] == {"context": Context(id="test"), "blocking": False}

0 commit comments

Comments
 (0)