Skip to content

Commit a6e68bc

Browse files
committed
refactor async tree benchmark to work with TaskGroup or gather depending on flags and feature availability
1 parent 6afbaab commit a6e68bc

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

async_tree.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ def parse_args():
8181
default=False,
8282
help="Print the results (runtime and number of Tasks created).",
8383
)
84+
parser.add_argument(
85+
"-g",
86+
"--gather",
87+
action="store_true",
88+
default=False,
89+
help="Use gather (if not specified, use TaskGroup if available, otherwise use gather).",
90+
)
8491
parser.add_argument(
8592
"-e",
8693
"--eager",
@@ -96,11 +103,17 @@ def __init__(
96103
self,
97104
memoizable_percentage=DEFAULT_MEMOIZABLE_PERCENTAGE,
98105
cpu_probability=DEFAULT_CPU_PROBABILITY,
106+
use_gather=None,
107+
use_eager_factory=None,
99108
):
100109
self.suspense_count = 0
101110
self.task_count = 0
102111
self.memoizable_percentage = memoizable_percentage
103112
self.cpu_probability = cpu_probability
113+
has_taskgroups = hasattr(asyncio, "TaskGroup")
114+
self.use_gather = use_gather or (not has_taskgroups)
115+
has_eager_factory = hasattr(asyncio, "create_eager_task_factory")
116+
self.use_eager_factory = use_eager_factory and has_eager_factory
104117
self.cache = {}
105118
# set to deterministic random, so that the results are reproducible
106119
random.seed(0)
@@ -119,14 +132,19 @@ async def recurse(self, recurse_level):
119132
await self.suspense_func()
120133
return
121134

122-
await asyncio.gather(
123-
*[self.recurse(recurse_level - 1) for _ in range(NUM_RECURSE_BRANCHES)]
124-
)
135+
if self.use_gather:
136+
await asyncio.gather(
137+
*[self.recurse(recurse_level - 1) for _ in range(NUM_RECURSE_BRANCHES)]
138+
)
139+
else:
140+
async with asyncio.TaskGroup() as tg:
141+
for _ in range(NUM_RECURSE_BRANCHES):
142+
tg.create_task(self.recurse(recurse_level - 1))
125143

126144
async def run_benchmark(self):
127145
await self.recurse(NUM_RECURSE_LEVELS)
128146

129-
def run(self, use_eager_factory):
147+
def run(self):
130148

131149
def counting_task_constructor(coro, *, loop=None, name=None, context=None, yield_result=None):
132150
if yield_result is None:
@@ -142,7 +160,7 @@ def counting_task_factory(loop, coro, *, name=None, context=None, yield_result=N
142160
self.run_benchmark(),
143161
task_factory=(
144162
asyncio.create_eager_task_factory(counting_task_constructor)
145-
if use_eager_factory else counting_task_factory
163+
if self.use_eager_factory else counting_task_factory
146164
),
147165
)
148166

@@ -192,14 +210,19 @@ async def suspense_func(self):
192210
"cpu_io_mixed": CpuIoMixedAsyncTree,
193211
}
194212
async_tree_class = trees[scenario]
195-
async_tree = async_tree_class(args.memoizable_percentage, args.cpu_probability)
213+
async_tree = async_tree_class(
214+
args.memoizable_percentage, args.cpu_probability, args.gather, args.eager)
196215

197216
start_time = time.perf_counter()
198-
async_tree.run(args.eager)
217+
async_tree.run()
199218
end_time = time.perf_counter()
200219

201220
if args.print:
221+
eager_or_tg = "gather" if async_tree.use_gather else "TaskGroup"
222+
task_factory = "eager" if async_tree.use_eager_factory else "standard"
202223
print(f"Scenario: {scenario}")
224+
print(f"Method: {eager_or_tg}")
225+
print(f"Task factory: {task_factory}")
203226
print(f"Time: {end_time - start_time} s")
204227
print(f"Tasks created: {async_tree.task_count}")
205228
print(f"Suspense called: {async_tree.suspense_count}")

0 commit comments

Comments
 (0)