Skip to content

Commit d851216

Browse files
authored
bpo-46752: Slight improvements to TaskGroup API (GH-31398)
* Remove task group names (for now) We're not sure that they are needed, and once in the code we would never be able to get rid of them. Yury wrote: > Ideally, there should be a way for someone to build a "trace" > of taskgroups/task leading to the current running task. > We could do that using contextvars, but I'm not sure we should > do that in 3.11. * Pass optional name on to task in create_task() * Remove a bunch of unused stuff
1 parent 2a38e1a commit d851216

File tree

2 files changed

+17
-24
lines changed

2 files changed

+17
-24
lines changed

Lib/asyncio/taskgroups.py

+4-18
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33

44
__all__ = ["TaskGroup"]
55

6-
import itertools
7-
import textwrap
8-
import traceback
9-
import types
106
import weakref
117

128
from . import events
@@ -15,12 +11,7 @@
1511

1612
class TaskGroup:
1713

18-
def __init__(self, *, name=None):
19-
if name is None:
20-
self._name = f'tg-{_name_counter()}'
21-
else:
22-
self._name = str(name)
23-
14+
def __init__(self):
2415
self._entered = False
2516
self._exiting = False
2617
self._aborting = False
@@ -33,11 +24,8 @@ def __init__(self, *, name=None):
3324
self._base_error = None
3425
self._on_completed_fut = None
3526

36-
def get_name(self):
37-
return self._name
38-
3927
def __repr__(self):
40-
msg = f'<TaskGroup {self._name!r}'
28+
msg = f'<TaskGroup'
4129
if self._tasks:
4230
msg += f' tasks:{len(self._tasks)}'
4331
if self._unfinished_tasks:
@@ -152,12 +140,13 @@ async def __aexit__(self, et, exc, tb):
152140
me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
153141
raise me from None
154142

155-
def create_task(self, coro):
143+
def create_task(self, coro, *, name=None):
156144
if not self._entered:
157145
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
158146
if self._exiting and self._unfinished_tasks == 0:
159147
raise RuntimeError(f"TaskGroup {self!r} is finished")
160148
task = self._loop.create_task(coro)
149+
tasks._set_task_name(task, name)
161150
task.add_done_callback(self._on_task_done)
162151
self._unfinished_tasks += 1
163152
self._tasks.add(task)
@@ -230,6 +219,3 @@ def _on_task_done(self, task):
230219
# # after TaskGroup is finished.
231220
self._parent_cancel_requested = True
232221
self._parent_task.cancel()
233-
234-
235-
_name_counter = itertools.count(1).__next__

Lib/test/test_asyncio/test_taskgroups.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,10 @@ async def crash_after(t):
368368
raise ValueError(t)
369369

370370
async def runner():
371-
async with taskgroups.TaskGroup(name='g1') as g1:
371+
async with taskgroups.TaskGroup() as g1:
372372
g1.create_task(crash_after(0.1))
373373

374-
async with taskgroups.TaskGroup(name='g2') as g2:
374+
async with taskgroups.TaskGroup() as g2:
375375
g2.create_task(crash_after(0.2))
376376

377377
r = asyncio.create_task(runner())
@@ -387,10 +387,10 @@ async def crash_after(t):
387387
raise ValueError(t)
388388

389389
async def runner():
390-
async with taskgroups.TaskGroup(name='g1') as g1:
390+
async with taskgroups.TaskGroup() as g1:
391391
g1.create_task(crash_after(10))
392392

393-
async with taskgroups.TaskGroup(name='g2') as g2:
393+
async with taskgroups.TaskGroup() as g2:
394394
g2.create_task(crash_after(0.1))
395395

396396
r = asyncio.create_task(runner())
@@ -407,7 +407,7 @@ async def crash_soon():
407407
1 / 0
408408

409409
async def runner():
410-
async with taskgroups.TaskGroup(name='g1') as g1:
410+
async with taskgroups.TaskGroup() as g1:
411411
g1.create_task(crash_soon())
412412
try:
413413
await asyncio.sleep(10)
@@ -430,7 +430,7 @@ async def crash_soon():
430430
1 / 0
431431

432432
async def nested_runner():
433-
async with taskgroups.TaskGroup(name='g1') as g1:
433+
async with taskgroups.TaskGroup() as g1:
434434
g1.create_task(crash_soon())
435435
try:
436436
await asyncio.sleep(10)
@@ -692,3 +692,10 @@ async def runner():
692692

693693
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
694694
self.assertGreaterEqual(nhydras, 10)
695+
696+
async def test_taskgroup_task_name(self):
697+
async def coro():
698+
await asyncio.sleep(0)
699+
async with taskgroups.TaskGroup() as g:
700+
t = g.create_task(coro(), name="yolo")
701+
self.assertEqual(t.get_name(), "yolo")

0 commit comments

Comments
 (0)