@@ -81,6 +81,13 @@ def parse_args():
81
81
default = False ,
82
82
help = "Print the results (runtime and number of Tasks created)." ,
83
83
)
84
+ parser .add_argument (
85
+ "-g" ,
86
+ "--gather" ,
87
+ action = "store_true" ,
88
+ default = False ,
89
+ help = "Use gather (if not specified, use TaskGroup if available, otherwise use gather)." ,
90
+ )
84
91
parser .add_argument (
85
92
"-e" ,
86
93
"--eager" ,
@@ -96,11 +103,17 @@ def __init__(
96
103
self ,
97
104
memoizable_percentage = DEFAULT_MEMOIZABLE_PERCENTAGE ,
98
105
cpu_probability = DEFAULT_CPU_PROBABILITY ,
106
+ use_gather = None ,
107
+ use_eager_factory = None ,
99
108
):
100
109
self .suspense_count = 0
101
110
self .task_count = 0
102
111
self .memoizable_percentage = memoizable_percentage
103
112
self .cpu_probability = cpu_probability
113
+ has_taskgroups = hasattr (asyncio , "TaskGroup" )
114
+ self .use_gather = use_gather or (not has_taskgroups )
115
+ has_eager_factory = hasattr (asyncio , "create_eager_task_factory" )
116
+ self .use_eager_factory = use_eager_factory and has_eager_factory
104
117
self .cache = {}
105
118
# set to deterministic random, so that the results are reproducible
106
119
random .seed (0 )
@@ -119,14 +132,19 @@ async def recurse(self, recurse_level):
119
132
await self .suspense_func ()
120
133
return
121
134
122
- await asyncio .gather (
123
- * [self .recurse (recurse_level - 1 ) for _ in range (NUM_RECURSE_BRANCHES )]
124
- )
135
+ if self .use_gather :
136
+ await asyncio .gather (
137
+ * [self .recurse (recurse_level - 1 ) for _ in range (NUM_RECURSE_BRANCHES )]
138
+ )
139
+ else :
140
+ async with asyncio .TaskGroup () as tg :
141
+ for _ in range (NUM_RECURSE_BRANCHES ):
142
+ tg .create_task (self .recurse (recurse_level - 1 ))
125
143
126
144
async def run_benchmark (self ):
127
145
await self .recurse (NUM_RECURSE_LEVELS )
128
146
129
- def run (self , use_eager_factory ):
147
+ def run (self ):
130
148
131
149
def counting_task_constructor (coro , * , loop = None , name = None , context = None , yield_result = None ):
132
150
if yield_result is None :
@@ -142,7 +160,7 @@ def counting_task_factory(loop, coro, *, name=None, context=None, yield_result=N
142
160
self .run_benchmark (),
143
161
task_factory = (
144
162
asyncio .create_eager_task_factory (counting_task_constructor )
145
- if use_eager_factory else counting_task_factory
163
+ if self . use_eager_factory else counting_task_factory
146
164
),
147
165
)
148
166
@@ -192,14 +210,19 @@ async def suspense_func(self):
192
210
"cpu_io_mixed" : CpuIoMixedAsyncTree ,
193
211
}
194
212
async_tree_class = trees [scenario ]
195
- async_tree = async_tree_class (args .memoizable_percentage , args .cpu_probability )
213
+ async_tree = async_tree_class (
214
+ args .memoizable_percentage , args .cpu_probability , args .gather , args .eager )
196
215
197
216
start_time = time .perf_counter ()
198
- async_tree .run (args . eager )
217
+ async_tree .run ()
199
218
end_time = time .perf_counter ()
200
219
201
220
if args .print :
221
+ eager_or_tg = "gather" if async_tree .use_gather else "TaskGroup"
222
+ task_factory = "eager" if async_tree .use_eager_factory else "standard"
202
223
print (f"Scenario: { scenario } " )
224
+ print (f"Method: { eager_or_tg } " )
225
+ print (f"Task factory: { task_factory } " )
203
226
print (f"Time: { end_time - start_time } s" )
204
227
print (f"Tasks created: { async_tree .task_count } " )
205
228
print (f"Suspense called: { async_tree .suspense_count } " )
0 commit comments