Skip to content

Commit 02c7b86

Browse files
fix tests
1 parent d5796e6 commit 02c7b86

File tree

3 files changed

+114
-8
lines changed

3 files changed

+114
-8
lines changed

Lib/test/test_asyncio/test_eager_task_factory.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,33 @@ async def run():
267267
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
268268
Task = tasks._PyTask
269269

270+
def setUp(self):
271+
self._current_task = asyncio.current_task
272+
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
273+
return super().setUp()
274+
275+
def tearDown(self):
276+
asyncio.current_task = asyncio.tasks.current_task = self._current_task
277+
return super().tearDown()
278+
279+
270280

271281
@unittest.skipUnless(hasattr(tasks, '_CTask'),
272282
'requires the C _asyncio module')
273283
class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
274284
Task = getattr(tasks, '_CTask', None)
275285

286+
def setUp(self):
287+
self._current_task = asyncio.current_task
288+
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
289+
return super().setUp()
290+
291+
def tearDown(self):
292+
asyncio.current_task = asyncio.tasks.current_task = self._current_task
293+
return super().tearDown()
294+
295+
296+
@unittest.skip("skip")
276297
def test_issue105987(self):
277298
code = """if 1:
278299
from _asyncio import _swap_current_task
@@ -400,31 +421,83 @@ class BaseEagerTaskFactoryTests(BaseTaskCountingTests):
400421

401422

402423
class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
403-
Task = asyncio.Task
424+
Task = asyncio.tasks._CTask
425+
426+
def setUp(self):
427+
self._current_task = asyncio.current_task
428+
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
429+
return super().setUp()
404430

431+
def tearDown(self):
432+
asyncio.current_task = asyncio.tasks.current_task = self._current_task
433+
return super().tearDown()
405434

406435
class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
407-
Task = asyncio.Task
436+
Task = asyncio.tasks._CTask
437+
438+
def setUp(self):
439+
self._current_task = asyncio.current_task
440+
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
441+
return super().setUp()
442+
443+
def tearDown(self):
444+
asyncio.current_task = asyncio.tasks.current_task = self._current_task
445+
return super().tearDown()
408446

409447

410448
class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
411449
Task = tasks._PyTask
412450

451+
def setUp(self):
452+
self._current_task = asyncio.current_task
453+
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
454+
return super().setUp()
455+
456+
def tearDown(self):
457+
asyncio.current_task = asyncio.tasks.current_task = self._current_task
458+
return super().tearDown()
459+
413460

414461
class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
415462
Task = tasks._PyTask
416463

464+
def setUp(self):
465+
self._current_task = asyncio.current_task
466+
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
467+
return super().setUp()
468+
469+
def tearDown(self):
470+
asyncio.current_task = asyncio.tasks.current_task = self._current_task
471+
return super().tearDown()
417472

418473
@unittest.skipUnless(hasattr(tasks, '_CTask'),
419474
'requires the C _asyncio module')
420475
class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
421476
Task = getattr(tasks, '_CTask', None)
422477

478+
def setUp(self):
479+
self._current_task = asyncio.current_task
480+
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
481+
return super().setUp()
482+
483+
def tearDown(self):
484+
asyncio.current_task = asyncio.tasks.current_task = self._current_task
485+
return super().tearDown()
486+
423487

424488
@unittest.skipUnless(hasattr(tasks, '_CTask'),
425489
'requires the C _asyncio module')
426490
class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
427491
Task = getattr(tasks, '_CTask', None)
428492

493+
def setUp(self):
494+
self._current_task = asyncio.current_task
495+
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
496+
return super().setUp()
497+
498+
def tearDown(self):
499+
asyncio.current_task = asyncio.tasks.current_task = self._current_task
500+
return super().tearDown()
501+
429502
if __name__ == '__main__':
430503
unittest.main()

Lib/test/test_asyncio/test_free_threading.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,11 @@ async def coro():
4040
self.assertEqual(task.get_loop(), loop)
4141
self.assertFalse(task.done())
4242

43-
current = self.current_task()
44-
self.assertEqual(current.get_loop(), loop)
43+
current = asyncio.current_task()
44+
try:
45+
self.assertEqual(current.get_loop(), loop)
46+
except AttributeError:
47+
breakpoint()
4548
self.assertSetEqual(all_tasks, tasks | {current})
4649
future.set_result(None)
4750

@@ -101,8 +104,12 @@ def test_task_different_thread_finalized(self) -> None:
101104
async def func():
102105
nonlocal task
103106
task = asyncio.current_task()
104-
105-
thread = Thread(target=lambda: asyncio.run(func()))
107+
def runner():
108+
with asyncio.Runner() as runner:
109+
loop = runner.get_loop()
110+
loop.set_task_factory(self.factory)
111+
runner.run(func())
112+
thread = Thread(target=runner)
106113
thread.start()
107114
thread.join()
108115
wr = weakref.ref(task)
@@ -164,7 +171,15 @@ async def main():
164171

165172
class TestPyFreeThreading(TestFreeThreading, TestCase):
166173
all_tasks = staticmethod(asyncio.tasks._py_all_tasks)
167-
current_task = staticmethod(asyncio.tasks._py_current_task)
174+
175+
def setUp(self):
176+
self._old_current_task = asyncio.current_task
177+
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
178+
return super().setUp()
179+
180+
def tearDown(self):
181+
asyncio.current_task = asyncio.tasks.current_task = self._old_current_task
182+
return super().tearDown()
168183

169184
def factory(self, loop, coro, **kwargs):
170185
return asyncio.tasks._PyTask(coro, loop=loop, **kwargs)
@@ -173,7 +188,16 @@ def factory(self, loop, coro, **kwargs):
173188
@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
174189
class TestCFreeThreading(TestFreeThreading, TestCase):
175190
all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None))
176-
current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None))
191+
192+
def setUp(self):
193+
self._old_current_task = asyncio.current_task
194+
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
195+
return super().setUp()
196+
197+
def tearDown(self):
198+
asyncio.current_task = asyncio.tasks.current_task = self._old_current_task
199+
return super().tearDown()
200+
177201

178202
def factory(self, loop, coro, **kwargs):
179203
return asyncio.tasks._CTask(coro, loop=loop, **kwargs)

Lib/test/test_asyncio/test_graph.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ def setUp(self):
369369
futures.future_discard_from_awaited_by = futures._c_future_discard_from_awaited_by
370370
asyncio.future_discard_from_awaited_by = futures.future_discard_from_awaited_by
371371

372+
self._current_task = asyncio.current_task
373+
asyncio.current_task = asyncio.tasks.current_task = tasks._c_current_task
372374

373375
def tearDown(self):
374376
futures = asyncio.futures
@@ -390,6 +392,8 @@ def tearDown(self):
390392
futures.Future = self._Future
391393
del self._Future
392394

395+
asyncio.current_task = asyncio.tasks.current_task = self._current_task
396+
393397

394398
@unittest.skipIf(
395399
not hasattr(asyncio.futures, "_py_future_add_to_awaited_by"),
@@ -414,6 +418,9 @@ def setUp(self):
414418
futures.future_discard_from_awaited_by = futures._py_future_discard_from_awaited_by
415419
asyncio.future_discard_from_awaited_by = futures.future_discard_from_awaited_by
416420

421+
self._current_task = asyncio.current_task
422+
asyncio.current_task = asyncio.tasks.current_task = tasks._py_current_task
423+
417424

418425
def tearDown(self):
419426
futures = asyncio.futures
@@ -434,3 +441,5 @@ def tearDown(self):
434441
asyncio.Future = self._Future
435442
futures.Future = self._Future
436443
del self._Future
444+
445+
asyncio.current_task = asyncio.tasks.current_task = self._current_task

0 commit comments

Comments
 (0)