Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 9d11fee

Browse files
authored
Improve exception handling for concurrent execution (#12109)
* fix incorrect unwrapFirstError import this was being imported from the wrong place * Refactor `concurrently_execute` to use `yieldable_gather_results` * Improve exception handling in `yieldable_gather_results` Try to avoid swallowing so many stack traces. * mark unwrapFirstError deprecated * changelog
1 parent 952efd0 commit 9d11fee

File tree

5 files changed

+151
-27
lines changed

5 files changed

+151
-27
lines changed

changelog.d/12109.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve exception handling for concurrent execution.

synapse/handlers/message.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@
5555
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
5656
from synapse.storage.state import StateFilter
5757
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
58-
from synapse.util import json_decoder, json_encoder, log_failure
59-
from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
58+
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
59+
from synapse.util.async_helpers import Linearizer, gather_results
6060
from synapse.util.caches.expiringcache import ExpiringCache
6161
from synapse.util.metrics import measure_func
6262
from synapse.visibility import filter_events_for_client

synapse/util/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def _handle_frozendict(obj: Any) -> Dict[Any, Any]:
8181

8282

8383
def unwrapFirstError(failure: Failure) -> Failure:
84-
# defer.gatherResults and DeferredLists wrap failures.
84+
# Deprecated: you probably just want to catch defer.FirstError and reraise
85+
# the subFailure's value, which will do a better job of preserving stacktraces.
86+
# (actually, you probably want to use yieldable_gather_results anyway)
8587
failure.trap(defer.FirstError)
8688
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
8789

synapse/util/async_helpers.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Hashable,
3030
Iterable,
3131
Iterator,
32+
List,
3233
Optional,
3334
Set,
3435
Tuple,
@@ -51,7 +52,7 @@
5152
make_deferred_yieldable,
5253
run_in_background,
5354
)
54-
from synapse.util import Clock, unwrapFirstError
55+
from synapse.util import Clock
5556

5657
logger = logging.getLogger(__name__)
5758

@@ -193,9 +194,9 @@ def __repr__(self) -> str:
193194
T = TypeVar("T")
194195

195196

196-
def concurrently_execute(
197+
async def concurrently_execute(
197198
func: Callable[[T], Any], args: Iterable[T], limit: int
198-
) -> defer.Deferred:
199+
) -> None:
199200
"""Executes the function with each argument concurrently while limiting
200201
the number of concurrent executions.
201202
@@ -221,20 +222,14 @@ async def _concurrently_execute_inner(value: T) -> None:
221222
# We use `itertools.islice` to handle the case where the number of args is
222223
# less than the limit, avoiding needlessly spawning unnecessary background
223224
# tasks.
224-
return make_deferred_yieldable(
225-
defer.gatherResults(
226-
[
227-
run_in_background(_concurrently_execute_inner, value)
228-
for value in itertools.islice(it, limit)
229-
],
230-
consumeErrors=True,
231-
)
232-
).addErrback(unwrapFirstError)
225+
await yieldable_gather_results(
226+
_concurrently_execute_inner, (value for value in itertools.islice(it, limit))
227+
)
233228

234229

235-
def yieldable_gather_results(
236-
func: Callable, iter: Iterable, *args: Any, **kwargs: Any
237-
) -> defer.Deferred:
230+
async def yieldable_gather_results(
231+
func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
232+
) -> List[T]:
238233
"""Executes the function with each argument concurrently.
239234
240235
Args:
@@ -245,15 +240,30 @@ def yieldable_gather_results(
245240
**kwargs: Keyword arguments to be passed to each call to func
246241
247242
Returns
248-
Deferred[list]: Resolved when all functions have been invoked, or errors if
249-
one of the function calls fails.
243+
A list containing the results of the function
250244
"""
251-
return make_deferred_yieldable(
252-
defer.gatherResults(
253-
[run_in_background(func, item, *args, **kwargs) for item in iter],
254-
consumeErrors=True,
245+
try:
246+
return await make_deferred_yieldable(
247+
defer.gatherResults(
248+
[run_in_background(func, item, *args, **kwargs) for item in iter],
249+
consumeErrors=True,
250+
)
255251
)
256-
).addErrback(unwrapFirstError)
252+
except defer.FirstError as dfe:
253+
# unwrap the error from defer.gatherResults.
254+
255+
# The raised exception's traceback only includes func() etc if
256+
# the 'await' happens before the exception is thrown - ie if the failure
257+
# happens *asynchronously* - otherwise Twisted throws away the traceback as it
258+
# could be large.
259+
#
260+
# We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
261+
# we could throw Twisted into the fires of Mordor.
262+
263+
# suppress exception chaining, because the FirstError doesn't tell us anything
264+
# very interesting.
265+
assert isinstance(dfe.subFailure.value, BaseException)
266+
raise dfe.subFailure.value from None
257267

258268

259269
T1 = TypeVar("T1")

tests/util/test_async_helpers.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,24 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import traceback
15+
1416
from twisted.internet import defer
15-
from twisted.internet.defer import CancelledError, Deferred
17+
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
1618
from twisted.internet.task import Clock
19+
from twisted.python.failure import Failure
1720

1821
from synapse.logging.context import (
1922
SENTINEL_CONTEXT,
2023
LoggingContext,
2124
PreserveLoggingContext,
2225
current_context,
2326
)
24-
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
27+
from synapse.util.async_helpers import (
28+
ObservableDeferred,
29+
concurrently_execute,
30+
timeout_deferred,
31+
)
2532

2633
from tests.unittest import TestCase
2734

@@ -171,3 +178,107 @@ def errback(res, deferred_name):
171178
)
172179
self.failureResultOf(timing_out_d, defer.TimeoutError)
173180
self.assertIs(current_context(), context_one)
181+
182+
183+
class _TestException(Exception):
184+
pass
185+
186+
187+
class ConcurrentlyExecuteTest(TestCase):
188+
def test_limits_runners(self):
189+
"""If we have more tasks than runners, we should get the limit of runners"""
190+
started = 0
191+
waiters = []
192+
processed = []
193+
194+
async def callback(v):
195+
# when we first enter, bump the start count
196+
nonlocal started
197+
started += 1
198+
199+
# record the fact we got an item
200+
processed.append(v)
201+
202+
# wait for the goahead before returning
203+
d2 = Deferred()
204+
waiters.append(d2)
205+
await d2
206+
207+
# set it going
208+
d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
209+
210+
# check we got exactly 3 processes
211+
self.assertEqual(started, 3)
212+
self.assertEqual(len(waiters), 3)
213+
214+
# let one finish
215+
waiters.pop().callback(0)
216+
217+
# ... which should start another
218+
self.assertEqual(started, 4)
219+
self.assertEqual(len(waiters), 3)
220+
221+
# we still shouldn't be done
222+
self.assertNoResult(d2)
223+
224+
# finish the job
225+
while waiters:
226+
waiters.pop().callback(0)
227+
228+
# check everything got done
229+
self.assertEqual(started, 5)
230+
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
231+
self.successResultOf(d2)
232+
233+
def test_preserves_stacktraces(self):
234+
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
235+
d1 = Deferred()
236+
237+
async def callback(v):
238+
# alas, this doesn't work at all without an await here
239+
await d1
240+
raise _TestException("bah")
241+
242+
async def caller():
243+
try:
244+
await concurrently_execute(callback, [1], 2)
245+
except _TestException as e:
246+
tb = traceback.extract_tb(e.__traceback__)
247+
# we expect to see "caller", "concurrently_execute" and "callback".
248+
self.assertEqual(tb[0].name, "caller")
249+
self.assertEqual(tb[1].name, "concurrently_execute")
250+
self.assertEqual(tb[-1].name, "callback")
251+
else:
252+
self.fail("No exception thrown")
253+
254+
d2 = ensureDeferred(caller())
255+
d1.callback(0)
256+
self.successResultOf(d2)
257+
258+
def test_preserves_stacktraces_on_preformed_failure(self):
259+
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
260+
d1 = Deferred()
261+
f = Failure(_TestException("bah"))
262+
263+
async def callback(v):
264+
# alas, this doesn't work at all without an await here
265+
await d1
266+
await defer.fail(f)
267+
268+
async def caller():
269+
try:
270+
await concurrently_execute(callback, [1], 2)
271+
except _TestException as e:
272+
tb = traceback.extract_tb(e.__traceback__)
273+
# we expect to see "caller", "concurrently_execute", "callback",
274+
# and some magic from inside ensureDeferred that happens when .fail
275+
# is called.
276+
self.assertEqual(tb[0].name, "caller")
277+
self.assertEqual(tb[1].name, "concurrently_execute")
278+
self.assertEqual(tb[-2].name, "callback")
279+
else:
280+
self.fail("No exception thrown")
281+
282+
d2 = ensureDeferred(caller())
283+
d1.callback(0)
284+
self.successResultOf(d2)

0 commit comments

Comments
 (0)