1
1
from __future__ import annotations
2
+ from typing import TYPE_CHECKING , TypeVar
3
+
4
+ if TYPE_CHECKING :
5
+ from typing import Any , AsyncIterator , Coroutine
2
6
3
7
import asyncio
8
+ from concurrent .futures import wait
4
9
import threading
5
- from typing import (
6
- Any ,
7
- AsyncIterator ,
8
- Coroutine ,
9
- List ,
10
- Optional ,
11
- TypeVar ,
12
- )
10
+
13
11
from typing_extensions import ParamSpec
14
12
15
13
from zarr .config import SyncConfiguration
16
14
15
+ P = ParamSpec ("P" )
16
+ T = TypeVar ("T" )
17
17
18
18
# From https://github.com/fsspec/filesystem_spec/blob/master/fsspec/asyn.py
19
19
20
- iothread : List [ Optional [ threading .Thread ] ] = [None ] # dedicated IO thread
21
- loop : List [ Optional [ asyncio .AbstractEventLoop ] ] = [
20
+ iothread : list [ threading .Thread | None ] = [None ] # dedicated IO thread
21
+ loop : list [ asyncio .AbstractEventLoop | None ] = [
22
22
None
23
23
] # global event loop for any non-async instance
24
- _lock : Optional [ threading .Lock ] = None # global lock placeholder
24
+ _lock : threading .Lock | None = None # global lock placeholder
25
25
get_running_loop = asyncio .get_running_loop
26
26
27
27
28
+ class SyncError (Exception ):
29
+ pass
30
+
31
+
28
32
def _get_lock () -> threading .Lock :
29
33
"""Allocate or return a threading lock.
30
34
@@ -36,16 +40,22 @@ def _get_lock() -> threading.Lock:
36
40
return _lock
37
41
38
42
39
- async def _runner (event : threading .Event , coro : Coroutine , result_box : List [Optional [Any ]]):
43
+ async def _runner (coro : Coroutine [Any , Any , T ]) -> T | BaseException :
44
+ """
45
+ Await a coroutine and return the result of running it. If awaiting the coroutine raises an
46
+ exception, the exception will be returned.
47
+ """
40
48
try :
41
- result_box [ 0 ] = await coro
49
+ return await coro
42
50
except Exception as ex :
43
- result_box [0 ] = ex
44
- finally :
45
- event .set ()
51
+ return ex
46
52
47
53
48
- def sync (coro : Coroutine , loop : Optional [asyncio .AbstractEventLoop ] = None ):
54
+ def sync (
55
+ coro : Coroutine [Any , Any , T ],
56
+ loop : asyncio .AbstractEventLoop | None = None ,
57
+ timeout : float | None = None ,
58
+ ) -> T :
49
59
"""
50
60
Make loop run coroutine until it returns. Runs in other thread
51
61
@@ -57,23 +67,25 @@ def sync(coro: Coroutine, loop: Optional[asyncio.AbstractEventLoop] = None):
57
67
# NB: if the loop is not running *yet*, it is OK to submit work
58
68
# and we will wait for it
59
69
loop = _get_loop ()
60
- if loop is None or loop .is_closed ():
70
+ if not isinstance (loop , asyncio .AbstractEventLoop ):
71
+ raise TypeError (f"loop cannot be of type { type (loop )} " )
72
+ if loop .is_closed ():
61
73
raise RuntimeError ("Loop is not running" )
62
74
try :
63
75
loop0 = asyncio .events .get_running_loop ()
64
76
if loop0 is loop :
65
- raise NotImplementedError ("Calling sync() from within a running loop" )
77
+ raise SyncError ("Calling sync() from within a running loop" )
66
78
except RuntimeError :
67
79
pass
68
- result_box : List [ Optional [ Any ]] = [ None ]
69
- event = threading . Event ( )
70
- asyncio . run_coroutine_threadsafe ( _runner ( event , coro , result_box ), loop )
71
- while True :
72
- # this loops allows thread to get interrupted
73
- if event . wait ( 1 ):
74
- break
75
-
76
- return_result = result_box [ 0 ]
80
+
81
+ future = asyncio . run_coroutine_threadsafe ( _runner ( coro ), loop )
82
+
83
+ finished , unfinished = wait ([ future ], return_when = asyncio . ALL_COMPLETED , timeout = timeout )
84
+ if len ( unfinished ) > 0 :
85
+ raise asyncio . TimeoutError ( f"Coroutine { coro } failed to finish in within { timeout } s" )
86
+ assert len ( finished ) == 1
87
+ return_result = list ( finished )[ 0 ]. result ()
88
+
77
89
if isinstance (return_result , BaseException ):
78
90
raise return_result
79
91
else :
@@ -96,14 +108,8 @@ def _get_loop() -> asyncio.AbstractEventLoop:
96
108
th .daemon = True
97
109
th .start ()
98
110
iothread [0 ] = th
99
-
100
- return new_loop
101
- else :
102
- return loop [0 ]
103
-
104
-
105
- P = ParamSpec ("P" )
106
- T = TypeVar ("T" )
111
+ assert loop [0 ] is not None
112
+ return loop [0 ]
107
113
108
114
109
115
class SyncMixin :
@@ -112,12 +118,14 @@ class SyncMixin:
112
118
def _sync (self , coroutine : Coroutine [Any , Any , T ]) -> T :
113
119
# TODO: refactor this to to take *args and **kwargs and pass those to the method
114
120
# this should allow us to better type the sync wrapper
115
- return sync (coroutine , loop = self ._sync_configuration .asyncio_loop )
116
-
117
- def _sync_iter (self , coroutine : Coroutine [Any , Any , AsyncIterator [T ]]) -> List [T ]:
118
- async def iter_to_list () -> List [T ]:
119
- # TODO: replace with generators so we don't materialize the entire iterator at once
120
- async_iterator = await coroutine
121
+ return sync (
122
+ coroutine ,
123
+ loop = self ._sync_configuration .asyncio_loop ,
124
+ timeout = self ._sync_configuration .timeout ,
125
+ )
126
+
127
+ def _sync_iter (self , async_iterator : AsyncIterator [T ]) -> list [T ]:
128
+ async def iter_to_list () -> list [T ]:
121
129
return [item async for item in async_iterator ]
122
130
123
131
return self ._sync (iter_to_list ())
0 commit comments