Skip to content

Commit 9374d86

Browse files
author
Chris Rossi
authored
Implement Future.cancel() (#204)
This brings our implementation of ``Future`` in parity with the ``Future`` interface defined in the Python 3 standard library, and makes it possible to cancel asynchronous ``grpc`` calls from NDB.
1 parent 98ff809 commit 9374d86

File tree

6 files changed

+205
-17
lines changed

6 files changed

+205
-17
lines changed

packages/google-cloud-ndb/google/cloud/ndb/_remote.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
# In its own module to avoid circular import between _datastore_api and
1818
# tasklets modules.
19+
import grpc
20+
21+
from google.cloud.ndb import exceptions
1922

2023

2124
class RemoteCall:
@@ -36,18 +39,47 @@ class RemoteCall:
3639
def __init__(self, future, info):
3740
self.future = future
3841
self.info = info
42+
self._callbacks = []
43+
44+
future.add_done_callback(self._finish)
3945

4046
def __repr__(self):
4147
return self.info
4248

4349
def exception(self):
4450
"""Calls :meth:`grpc.Future.exception` on attr:`future`."""
45-
return self.future.exception()
51+
# GRPC will actually raise FutureCancelledError.
52+
# We'll translate that to our own Cancelled exception and *return* it,
53+
# which is far more polite for a method that *returns exceptions*.
54+
try:
55+
return self.future.exception()
56+
except grpc.FutureCancelledError:
57+
return exceptions.Cancelled()
4658

4759
def result(self):
4860
"""Calls :meth:`grpc.Future.result` on attr:`future`."""
4961
return self.future.result()
5062

5163
def add_done_callback(self, callback):
52-
"""Calls :meth:`grpc.Future.add_done_callback` on attr:`future`."""
53-
return self.future.add_done_callback(callback)
64+
"""Add a callback function to be run upon task completion. Will run
65+
immediately if task has already finished.
66+
67+
Args:
68+
callback (Callable): The function to execute.
69+
"""
70+
if self.future.done():
71+
callback(self)
72+
else:
73+
self._callbacks.append(callback)
74+
75+
def cancel(self):
76+
"""Calls :meth:`grpc.Future.cancel` on attr:`cancel`."""
77+
return self.future.cancel()
78+
79+
def _finish(self, rpc):
80+
"""Called when remote future is finished.
81+
82+
Used to call our own done callbacks.
83+
"""
84+
for callback in self._callbacks:
85+
callback(self)

packages/google-cloud-ndb/google/cloud/ndb/exceptions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,12 @@ class NoLongerImplementedError(NotImplementedError):
112112

113113
def __init__(self):
114114
super(NoLongerImplementedError, self).__init__("No longer implemented")
115+
116+
117+
class Cancelled(Error):
118+
"""An operation has been cancelled by user request.
119+
120+
Raised when trying to get a result from a future that has been cancelled by
121+
a call to ``Future.cancel`` (possibly on a future that depends on this
122+
future).
123+
"""

packages/google-cloud-ndb/google/cloud/ndb/tasklets.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def main():
5858

5959
from google.cloud.ndb import context as context_module
6060
from google.cloud.ndb import _eventloop
61+
from google.cloud.ndb import exceptions
6162
from google.cloud.ndb import _remote
6263

6364
__all__ = [
@@ -232,20 +233,26 @@ def add_done_callback(self, callback):
232233
self._callbacks.append(callback)
233234

234235
def cancel(self):
235-
"""Cancel the task for this future.
236+
"""Attempt to cancel the task for this future.
236237
237-
Raises:
238-
NotImplementedError: Always, not supported.
238+
If the task has already completed, this call will do nothing.
239+
Otherwise, this will attempt to cancel whatever task this future is
240+
waiting on. There is no specific guarantee the underlying task will be
241+
cancelled.
239242
"""
240-
raise NotImplementedError
243+
if not self.done():
244+
self.set_exception(exceptions.Cancelled())
241245

242246
def cancelled(self):
243-
"""Get whether task for this future has been canceled.
247+
"""Get whether the task for this future has been cancelled.
244248
245249
Returns:
246-
:data:`False`: Always.
250+
:data:`True`: If this future's task has been cancelled, otherwise
251+
:data:`False`.
247252
"""
248-
return False
253+
return self._exception is not None and isinstance(
254+
self._exception, exceptions.Cancelled
255+
)
249256

250257
@staticmethod
251258
def wait_any(futures):
@@ -278,6 +285,7 @@ def __init__(self, generator, context, info="Unknown"):
278285
super(_TaskletFuture, self).__init__(info=info)
279286
self.generator = generator
280287
self.context = context
288+
self.waiting_on = None
281289

282290
def _advance_tasklet(self, send_value=None, error=None):
283291
"""Advance a tasklet one step by sending in a value or error."""
@@ -324,6 +332,8 @@ def done_callback(yielded):
324332
# in Legacy) directly. Doing so, it has been found, can lead to
325333
# exceeding the maximum recursion depth. Queing it up to run on the
326334
# event loop avoids this issue by keeping the call stack shallow.
335+
self.waiting_on = None
336+
327337
error = yielded.exception()
328338
if error:
329339
_eventloop.call_soon(self._advance_tasklet, error=error)
@@ -332,19 +342,30 @@ def done_callback(yielded):
332342

333343
if isinstance(yielded, Future):
334344
yielded.add_done_callback(done_callback)
345+
self.waiting_on = yielded
335346

336347
elif isinstance(yielded, _remote.RemoteCall):
337348
_eventloop.queue_rpc(yielded, done_callback)
349+
self.waiting_on = yielded
338350

339351
elif isinstance(yielded, (list, tuple)):
340352
future = _MultiFuture(yielded)
341353
future.add_done_callback(done_callback)
354+
self.waiting_on = future
342355

343356
else:
344357
raise RuntimeError(
345358
"A tasklet yielded an illegal value: {!r}".format(yielded)
346359
)
347360

361+
def cancel(self):
362+
"""Overrides :meth:`Future.cancel`."""
363+
if self.waiting_on:
364+
self.waiting_on.cancel()
365+
366+
else:
367+
super(_TaskletFuture, self).cancel()
368+
348369

349370
def _get_return_value(stop):
350371
"""Inspect `StopIteration` instance for return value of tasklet.
@@ -399,6 +420,11 @@ def _dependency_done(self, dependency):
399420
result = tuple((future.result() for future in self._dependencies))
400421
self.set_result(result)
401422

423+
def cancel(self):
424+
"""Overrides :meth:`Future.cancel`."""
425+
for dependency in self._dependencies:
426+
dependency.cancel()
427+
402428

403429
def tasklet(wrapped):
404430
"""

packages/google-cloud-ndb/tests/system/test_query.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,31 @@ def make_entities():
9696
assert [entity.foo for entity in results][:5] == [0, 1, 2, 3, 4]
9797

9898

99+
@pytest.mark.usefixtures("client_context")
100+
def test_fetch_and_immediately_cancel(dispose_of):
101+
# Make a lot of entities so the query call won't complete before we get to
102+
# call cancel.
103+
n_entities = 500
104+
105+
class SomeKind(ndb.Model):
106+
foo = ndb.IntegerProperty()
107+
108+
@ndb.toplevel
109+
def make_entities():
110+
entities = [SomeKind(foo=i) for i in range(n_entities)]
111+
keys = yield [entity.put_async() for entity in entities]
112+
raise ndb.Return(keys)
113+
114+
for key in make_entities():
115+
dispose_of(key._key)
116+
117+
query = SomeKind.query()
118+
future = query.fetch_async()
119+
future.cancel()
120+
with pytest.raises(ndb.exceptions.Cancelled):
121+
future.result()
122+
123+
99124
@pytest.mark.usefixtures("client_context")
100125
def test_ancestor_query(ds_entity):
101126
root_id = test_utils.system.unique_resource_id()

packages/google-cloud-ndb/tests/unit/test__remote.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,26 @@
1414

1515
from unittest import mock
1616

17+
import grpc
18+
import pytest
19+
20+
from google.cloud.ndb import exceptions
1721
from google.cloud.ndb import _remote
1822
from google.cloud.ndb import tasklets
1923

2024

2125
class TestRemoteCall:
2226
@staticmethod
2327
def test_constructor():
24-
call = _remote.RemoteCall("future", "info")
25-
assert call.future == "future"
28+
future = tasklets.Future()
29+
call = _remote.RemoteCall(future, "info")
30+
assert call.future is future
2631
assert call.info == "info"
2732

2833
@staticmethod
2934
def test_repr():
30-
call = _remote.RemoteCall(None, "a remote call")
35+
future = tasklets.Future()
36+
call = _remote.RemoteCall(future, "a remote call")
3137
assert repr(call) == "a remote call"
3238

3339
@staticmethod
@@ -38,6 +44,14 @@ def test_exception():
3844
call = _remote.RemoteCall(future, "testing")
3945
assert call.exception() is error
4046

47+
@staticmethod
48+
def test_exception_FutureCancelledError():
49+
error = grpc.FutureCancelledError()
50+
future = tasklets.Future()
51+
future.exception = mock.Mock(side_effect=error)
52+
call = _remote.RemoteCall(future, "testing")
53+
assert isinstance(call.exception(), exceptions.Cancelled)
54+
4155
@staticmethod
4256
def test_result():
4357
future = tasklets.Future()
@@ -52,4 +66,22 @@ def test_add_done_callback():
5266
callback = mock.Mock(spec=())
5367
call.add_done_callback(callback)
5468
future.set_result(None)
55-
callback.assert_called_once_with(future)
69+
callback.assert_called_once_with(call)
70+
71+
@staticmethod
72+
def test_add_done_callback_already_done():
73+
future = tasklets.Future()
74+
future.set_result(None)
75+
call = _remote.RemoteCall(future, "testing")
76+
callback = mock.Mock(spec=())
77+
call.add_done_callback(callback)
78+
callback.assert_called_once_with(call)
79+
80+
@staticmethod
81+
def test_cancel():
82+
future = tasklets.Future()
83+
call = _remote.RemoteCall(future, "testing")
84+
call.cancel()
85+
assert future.cancelled()
86+
with pytest.raises(exceptions.Cancelled):
87+
call.result()

packages/google-cloud-ndb/tests/unit/test_tasklets.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from google.cloud.ndb import context as context_module
2222
from google.cloud.ndb import _eventloop
23+
from google.cloud.ndb import exceptions
2324
from google.cloud.ndb import _remote
2425
from google.cloud.ndb import tasklets
2526

@@ -188,10 +189,38 @@ def side_effects(future):
188189
assert _eventloop.run1.call_count == 3
189190

190191
@staticmethod
192+
@pytest.mark.usefixtures("in_context")
191193
def test_cancel():
192-
future = tasklets.Future()
193-
with pytest.raises(NotImplementedError):
194-
future.cancel()
194+
# Integration test. Actually test that a cancel propagates properly.
195+
rpc = tasklets.Future("Fake RPC")
196+
wrapped_rpc = _remote.RemoteCall(rpc, "Wrapped Fake RPC")
197+
198+
@tasklets.tasklet
199+
def inner_tasklet():
200+
yield wrapped_rpc
201+
202+
@tasklets.tasklet
203+
def outer_tasklet():
204+
yield inner_tasklet()
205+
206+
future = outer_tasklet()
207+
assert not future.cancelled()
208+
future.cancel()
209+
assert rpc.cancelled()
210+
211+
with pytest.raises(exceptions.Cancelled):
212+
future.result()
213+
214+
assert future.cancelled()
215+
216+
@staticmethod
217+
@pytest.mark.usefixtures("in_context")
218+
def test_cancel_already_done():
219+
future = tasklets.Future("testing")
220+
future.set_result(42)
221+
future.cancel() # noop
222+
assert not future.cancelled()
223+
assert future.result() == 42
195224

196225
@staticmethod
197226
def test_cancelled():
@@ -358,6 +387,31 @@ def generator_function(dependencies):
358387
assert future.result() == 11
359388
assert future.context is in_context
360389

390+
@staticmethod
391+
def test_cancel_not_waiting(in_context):
392+
dependency = tasklets.Future()
393+
future = tasklets._TaskletFuture(None, in_context)
394+
future.cancel()
395+
396+
assert not dependency.cancelled()
397+
with pytest.raises(exceptions.Cancelled):
398+
future.result()
399+
400+
@staticmethod
401+
def test_cancel_waiting_on_dependency(in_context):
402+
def generator_function(dependency):
403+
yield dependency
404+
405+
dependency = tasklets.Future()
406+
generator = generator_function(dependency)
407+
future = tasklets._TaskletFuture(generator, in_context)
408+
future._advance_tasklet()
409+
future.cancel()
410+
411+
assert dependency.cancelled()
412+
with pytest.raises(exceptions.Cancelled):
413+
future.result()
414+
361415

362416
class Test_MultiFuture:
363417
@staticmethod
@@ -388,6 +442,16 @@ def test_error():
388442
with pytest.raises(Exception):
389443
future.result()
390444

445+
@staticmethod
446+
def test_cancel():
447+
dependencies = (tasklets.Future(), tasklets.Future())
448+
future = tasklets._MultiFuture(dependencies)
449+
future.cancel()
450+
assert dependencies[0].cancelled()
451+
assert dependencies[1].cancelled()
452+
with pytest.raises(exceptions.Cancelled):
453+
future.result()
454+
391455

392456
class Test__get_return_value:
393457
@staticmethod

0 commit comments

Comments
 (0)