Skip to content

Commit 2b54a4e

Browse files
[3.12] gh-124309: Modernize the staggered_race implementation to support e… (#124574)
gh-124309: Modernize the `staggered_race` implementation to support eager task factories (#124390) Co-authored-by: Thomas Grainger <[email protected]> Co-authored-by: Jelle Zijlstra <[email protected]> Co-authored-by: Carol Willing <[email protected]> Co-authored-by: Kumar Aditya <[email protected]> (cherry picked from commit de929f3) Co-authored-by: Peter Bierma <[email protected]>
1 parent 48359c5 commit 2b54a4e

File tree

5 files changed

+193
-73
lines changed

5 files changed

+193
-73
lines changed

Lib/asyncio/base_events.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,7 @@ async def create_connection(
11101110
(functools.partial(self._connect_sock,
11111111
exceptions, addrinfo, laddr_infos)
11121112
for addrinfo in infos),
1113-
happy_eyeballs_delay, loop=self)
1113+
happy_eyeballs_delay)
11141114

11151115
if sock is None:
11161116
exceptions = [exc for sub in exceptions for exc in sub]

Lib/asyncio/staggered.py

+18-72
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,15 @@
33
__all__ = 'staggered_race',
44

55
import contextlib
6-
import typing
76

8-
from . import events
9-
from . import exceptions as exceptions_mod
107
from . import locks
118
from . import tasks
9+
from . import taskgroups
1210

11+
class _Done(Exception):
12+
pass
1313

14-
async def staggered_race(
15-
coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]],
16-
delay: typing.Optional[float],
17-
*,
18-
loop: events.AbstractEventLoop = None,
19-
) -> typing.Tuple[
20-
typing.Any,
21-
typing.Optional[int],
22-
typing.List[typing.Optional[Exception]]
23-
]:
14+
async def staggered_race(coro_fns, delay):
2415
"""Run coroutines with staggered start times and take the first to finish.
2516
2617
This method takes an iterable of coroutine functions. The first one is
@@ -52,8 +43,6 @@ async def staggered_race(
5243
delay: amount of time, in seconds, between starting coroutines. If
5344
``None``, the coroutines will run sequentially.
5445
55-
loop: the event loop to use.
56-
5746
Returns:
5847
tuple *(winner_result, winner_index, exceptions)* where
5948
@@ -72,37 +61,11 @@ async def staggered_race(
7261
7362
"""
7463
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
75-
loop = loop or events.get_running_loop()
76-
enum_coro_fns = enumerate(coro_fns)
7764
winner_result = None
7865
winner_index = None
7966
exceptions = []
80-
running_tasks = []
81-
82-
async def run_one_coro(
83-
previous_failed: typing.Optional[locks.Event]) -> None:
84-
# Wait for the previous task to finish, or for delay seconds
85-
if previous_failed is not None:
86-
with contextlib.suppress(exceptions_mod.TimeoutError):
87-
# Use asyncio.wait_for() instead of asyncio.wait() here, so
88-
# that if we get cancelled at this point, Event.wait() is also
89-
# cancelled, otherwise there will be a "Task destroyed but it is
90-
# pending" later.
91-
await tasks.wait_for(previous_failed.wait(), delay)
92-
# Get the next coroutine to run
93-
try:
94-
this_index, coro_fn = next(enum_coro_fns)
95-
except StopIteration:
96-
return
97-
# Start task that will run the next coroutine
98-
this_failed = locks.Event()
99-
next_task = loop.create_task(run_one_coro(this_failed))
100-
running_tasks.append(next_task)
101-
assert len(running_tasks) == this_index + 2
102-
# Prepare place to put this coroutine's exceptions if not won
103-
exceptions.append(None)
104-
assert len(exceptions) == this_index + 1
10567

68+
async def run_one_coro(this_index, coro_fn, this_failed):
10669
try:
10770
result = await coro_fn()
10871
except (SystemExit, KeyboardInterrupt):
@@ -116,34 +79,17 @@ async def run_one_coro(
11679
assert winner_index is None
11780
winner_index = this_index
11881
winner_result = result
119-
# Cancel all other tasks. We take care to not cancel the current
120-
# task as well. If we do so, then since there is no `await` after
121-
# here and CancelledError are usually thrown at one, we will
122-
# encounter a curious corner case where the current task will end
123-
# up as done() == True, cancelled() == False, exception() ==
124-
# asyncio.CancelledError. This behavior is specified in
125-
# https://bugs.python.org/issue30048
126-
for i, t in enumerate(running_tasks):
127-
if i != this_index:
128-
t.cancel()
129-
130-
first_task = loop.create_task(run_one_coro(None))
131-
running_tasks.append(first_task)
82+
raise _Done
83+
13284
try:
133-
# Wait for a growing list of tasks to all finish: poor man's version of
134-
# curio's TaskGroup or trio's nursery
135-
done_count = 0
136-
while done_count != len(running_tasks):
137-
done, _ = await tasks.wait(running_tasks)
138-
done_count = len(done)
139-
# If run_one_coro raises an unhandled exception, it's probably a
140-
# programming error, and I want to see it.
141-
if __debug__:
142-
for d in done:
143-
if d.done() and not d.cancelled() and d.exception():
144-
raise d.exception()
145-
return winner_result, winner_index, exceptions
146-
finally:
147-
# Make sure no tasks are left running if we leave this function
148-
for t in running_tasks:
149-
t.cancel()
85+
async with taskgroups.TaskGroup() as tg:
86+
for this_index, coro_fn in enumerate(coro_fns):
87+
this_failed = locks.Event()
88+
exceptions.append(None)
89+
tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
90+
with contextlib.suppress(TimeoutError):
91+
await tasks.wait_for(this_failed.wait(), delay)
92+
except* _Done:
93+
pass
94+
95+
return winner_result, winner_index, exceptions

Lib/test/test_asyncio/test_eager_task_factory.py

+47
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,53 @@ async def run():
218218

219219
self.run_coro(run())
220220

221+
def test_staggered_race_with_eager_tasks(self):
222+
# See https://github.com/python/cpython/issues/124309
223+
224+
async def fail():
225+
await asyncio.sleep(0)
226+
raise ValueError("no good")
227+
228+
async def run():
229+
winner, index, excs = await asyncio.staggered.staggered_race(
230+
[
231+
lambda: asyncio.sleep(2, result="sleep2"),
232+
lambda: asyncio.sleep(1, result="sleep1"),
233+
lambda: fail()
234+
],
235+
delay=0.25
236+
)
237+
self.assertEqual(winner, 'sleep1')
238+
self.assertEqual(index, 1)
239+
self.assertIsNone(excs[index])
240+
self.assertIsInstance(excs[0], asyncio.CancelledError)
241+
self.assertIsInstance(excs[2], ValueError)
242+
243+
self.run_coro(run())
244+
245+
def test_staggered_race_with_eager_tasks_no_delay(self):
246+
# See https://github.com/python/cpython/issues/124309
247+
async def fail():
248+
raise ValueError("no good")
249+
250+
async def run():
251+
winner, index, excs = await asyncio.staggered.staggered_race(
252+
[
253+
lambda: fail(),
254+
lambda: asyncio.sleep(1, result="sleep1"),
255+
lambda: asyncio.sleep(0, result="sleep0"),
256+
],
257+
delay=None
258+
)
259+
self.assertEqual(winner, 'sleep1')
260+
self.assertEqual(index, 1)
261+
self.assertIsNone(excs[index])
262+
self.assertIsInstance(excs[0], ValueError)
263+
self.assertEqual(len(excs), 2)
264+
265+
self.run_coro(run())
266+
267+
221268

222269
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
223270
Task = tasks._PyTask
+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import asyncio
2+
import unittest
3+
from asyncio.staggered import staggered_race
4+
5+
from test import support
6+
7+
support.requires_working_socket(module=True)
8+
9+
10+
def tearDownModule():
11+
asyncio.set_event_loop_policy(None)
12+
13+
14+
class StaggeredTests(unittest.IsolatedAsyncioTestCase):
15+
async def test_empty(self):
16+
winner, index, excs = await staggered_race(
17+
[],
18+
delay=None,
19+
)
20+
21+
self.assertIs(winner, None)
22+
self.assertIs(index, None)
23+
self.assertEqual(excs, [])
24+
25+
async def test_one_successful(self):
26+
async def coro(index):
27+
return f'Res: {index}'
28+
29+
winner, index, excs = await staggered_race(
30+
[
31+
lambda: coro(0),
32+
lambda: coro(1),
33+
],
34+
delay=None,
35+
)
36+
37+
self.assertEqual(winner, 'Res: 0')
38+
self.assertEqual(index, 0)
39+
self.assertEqual(excs, [None])
40+
41+
async def test_first_error_second_successful(self):
42+
async def coro(index):
43+
if index == 0:
44+
raise ValueError(index)
45+
return f'Res: {index}'
46+
47+
winner, index, excs = await staggered_race(
48+
[
49+
lambda: coro(0),
50+
lambda: coro(1),
51+
],
52+
delay=None,
53+
)
54+
55+
self.assertEqual(winner, 'Res: 1')
56+
self.assertEqual(index, 1)
57+
self.assertEqual(len(excs), 2)
58+
self.assertIsInstance(excs[0], ValueError)
59+
self.assertIs(excs[1], None)
60+
61+
async def test_first_timeout_second_successful(self):
62+
async def coro(index):
63+
if index == 0:
64+
await asyncio.sleep(10) # much bigger than delay
65+
return f'Res: {index}'
66+
67+
winner, index, excs = await staggered_race(
68+
[
69+
lambda: coro(0),
70+
lambda: coro(1),
71+
],
72+
delay=0.1,
73+
)
74+
75+
self.assertEqual(winner, 'Res: 1')
76+
self.assertEqual(index, 1)
77+
self.assertEqual(len(excs), 2)
78+
self.assertIsInstance(excs[0], asyncio.CancelledError)
79+
self.assertIs(excs[1], None)
80+
81+
async def test_none_successful(self):
82+
async def coro(index):
83+
raise ValueError(index)
84+
85+
for delay in [None, 0, 0.1, 1]:
86+
with self.subTest(delay=delay):
87+
winner, index, excs = await staggered_race(
88+
[
89+
lambda: coro(0),
90+
lambda: coro(1),
91+
],
92+
delay=delay,
93+
)
94+
95+
self.assertIs(winner, None)
96+
self.assertIs(index, None)
97+
self.assertEqual(len(excs), 2)
98+
self.assertIsInstance(excs[0], ValueError)
99+
self.assertIsInstance(excs[1], ValueError)
100+
101+
async def test_long_delay_early_failure(self):
102+
async def coro(index):
103+
await asyncio.sleep(0) # Dummy coroutine for the 1 case
104+
if index == 0:
105+
await asyncio.sleep(0.1) # Dummy coroutine
106+
raise ValueError(index)
107+
108+
return f'Res: {index}'
109+
110+
winner, index, excs = await staggered_race(
111+
[
112+
lambda: coro(0),
113+
lambda: coro(1),
114+
],
115+
delay=10,
116+
)
117+
118+
self.assertEqual(winner, 'Res: 1')
119+
self.assertEqual(index, 1)
120+
self.assertEqual(len(excs), 2)
121+
self.assertIsInstance(excs[0], ValueError)
122+
self.assertIsNone(excs[1])
123+
124+
125+
if __name__ == "__main__":
126+
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.

0 commit comments

Comments
 (0)