Skip to content

Commit 5113cd3

Browse files
bdracoaaugustin
authored andcommitted
Use asyncio.timeout instead of asyncio.wait_for.
asyncio.wait_for creates a task whereas asyncio.timeout doesn't. Fallback to a vendored version of async_timeout on Python < 3.11. async.timeout will become the underlying implementation for async.wait_for in Python 3.12: python/cpython#98518
1 parent 5c7a442 commit 5113cd3

File tree

4 files changed

+246
-23
lines changed

4 files changed

+246
-23
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ branch = true
3939
omit = [
4040
# */websockets matches src/websockets and .tox/**/site-packages/websockets
4141
"*/websockets/__main__.py",
42+
"*/websockets/legacy/async_timeout.py",
4243
"*/websockets/legacy/compatibility.py",
4344
"tests/maxi_cov.py",
4445
]
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
2+
# Licensed under the Apache License, Version 2.0.
3+
4+
import asyncio
5+
import enum
6+
import sys
7+
import warnings
8+
from types import TracebackType
9+
from typing import Optional, Type
10+
11+
12+
if sys.version_info >= (3, 8):
13+
from typing import final
14+
else:
15+
from typing_extensions import final
16+
17+
18+
__version__ = "4.0.2"
19+
20+
21+
__all__ = ("timeout", "timeout_at", "Timeout")
22+
23+
24+
def timeout(delay: Optional[float]) -> "Timeout":
25+
"""timeout context manager.
26+
27+
Useful in cases when you want to apply timeout logic around block
28+
of code or in cases when asyncio.wait_for is not suitable. For example:
29+
30+
>>> async with timeout(0.001):
31+
... async with aiohttp.get('https://github.com') as r:
32+
... await r.text()
33+
34+
35+
delay - value in seconds or None to disable timeout logic
36+
"""
37+
loop = asyncio.get_running_loop()
38+
if delay is not None:
39+
deadline = loop.time() + delay # type: Optional[float]
40+
else:
41+
deadline = None
42+
return Timeout(deadline, loop)
43+
44+
45+
def timeout_at(deadline: Optional[float]) -> "Timeout":
46+
"""Schedule the timeout at absolute time.
47+
48+
deadline argument points on the time in the same clock system
49+
as loop.time().
50+
51+
Please note: it is not POSIX time but a time with
52+
undefined starting base, e.g. the time of the system power on.
53+
54+
>>> async with timeout_at(loop.time() + 10):
55+
... async with aiohttp.get('https://github.com') as r:
56+
... await r.text()
57+
58+
59+
"""
60+
loop = asyncio.get_running_loop()
61+
return Timeout(deadline, loop)
62+
63+
64+
class _State(enum.Enum):
65+
INIT = "INIT"
66+
ENTER = "ENTER"
67+
TIMEOUT = "TIMEOUT"
68+
EXIT = "EXIT"
69+
70+
71+
@final
72+
class Timeout:
73+
# Internal class, please don't instantiate it directly
74+
# Use timeout() and timeout_at() public factories instead.
75+
#
76+
# Implementation note: `async with timeout()` is preferred
77+
# over `with timeout()`.
78+
# While technically the Timeout class implementation
79+
# doesn't need to be async at all,
80+
# the `async with` statement explicitly points that
81+
# the context manager should be used from async function context.
82+
#
83+
# This design allows to avoid many silly misusages.
84+
#
85+
# TimeoutError is raised immediately when scheduled
86+
# if the deadline is passed.
87+
# The purpose is to time out as soon as possible
88+
# without waiting for the next await expression.
89+
90+
__slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
91+
92+
def __init__(
93+
self, deadline: Optional[float], loop: asyncio.AbstractEventLoop
94+
) -> None:
95+
self._loop = loop
96+
self._state = _State.INIT
97+
98+
self._timeout_handler = None # type: Optional[asyncio.Handle]
99+
if deadline is None:
100+
self._deadline = None # type: Optional[float]
101+
else:
102+
self.update(deadline)
103+
104+
def __enter__(self) -> "Timeout":
105+
warnings.warn(
106+
"with timeout() is deprecated, use async with timeout() instead",
107+
DeprecationWarning,
108+
stacklevel=2,
109+
)
110+
self._do_enter()
111+
return self
112+
113+
def __exit__(
114+
self,
115+
exc_type: Optional[Type[BaseException]],
116+
exc_val: Optional[BaseException],
117+
exc_tb: Optional[TracebackType],
118+
) -> Optional[bool]:
119+
self._do_exit(exc_type)
120+
return None
121+
122+
async def __aenter__(self) -> "Timeout":
123+
self._do_enter()
124+
return self
125+
126+
async def __aexit__(
127+
self,
128+
exc_type: Optional[Type[BaseException]],
129+
exc_val: Optional[BaseException],
130+
exc_tb: Optional[TracebackType],
131+
) -> Optional[bool]:
132+
self._do_exit(exc_type)
133+
return None
134+
135+
@property
136+
def expired(self) -> bool:
137+
"""Is timeout expired during execution?"""
138+
return self._state == _State.TIMEOUT
139+
140+
@property
141+
def deadline(self) -> Optional[float]:
142+
return self._deadline
143+
144+
def reject(self) -> None:
145+
"""Reject scheduled timeout if any."""
146+
# cancel is maybe better name but
147+
# task.cancel() raises CancelledError in asyncio world.
148+
if self._state not in (_State.INIT, _State.ENTER):
149+
raise RuntimeError(f"invalid state {self._state.value}")
150+
self._reject()
151+
152+
def _reject(self) -> None:
153+
if self._timeout_handler is not None:
154+
self._timeout_handler.cancel()
155+
self._timeout_handler = None
156+
157+
def shift(self, delay: float) -> None:
158+
"""Advance timeout on delay seconds.
159+
160+
The delay can be negative.
161+
162+
Raise RuntimeError if shift is called when deadline is not scheduled
163+
"""
164+
deadline = self._deadline
165+
if deadline is None:
166+
raise RuntimeError("cannot shift timeout if deadline is not scheduled")
167+
self.update(deadline + delay)
168+
169+
def update(self, deadline: float) -> None:
170+
"""Set deadline to absolute value.
171+
172+
deadline argument points on the time in the same clock system
173+
as loop.time().
174+
175+
If new deadline is in the past the timeout is raised immediately.
176+
177+
Please note: it is not POSIX time but a time with
178+
undefined starting base, e.g. the time of the system power on.
179+
"""
180+
if self._state == _State.EXIT:
181+
raise RuntimeError("cannot reschedule after exit from context manager")
182+
if self._state == _State.TIMEOUT:
183+
raise RuntimeError("cannot reschedule expired timeout")
184+
if self._timeout_handler is not None:
185+
self._timeout_handler.cancel()
186+
self._deadline = deadline
187+
if self._state != _State.INIT:
188+
self._reschedule()
189+
190+
def _reschedule(self) -> None:
191+
assert self._state == _State.ENTER
192+
deadline = self._deadline
193+
if deadline is None:
194+
return
195+
196+
now = self._loop.time()
197+
if self._timeout_handler is not None:
198+
self._timeout_handler.cancel()
199+
200+
task = asyncio.current_task()
201+
if deadline <= now:
202+
self._timeout_handler = self._loop.call_soon(self._on_timeout, task)
203+
else:
204+
self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task)
205+
206+
def _do_enter(self) -> None:
207+
if self._state != _State.INIT:
208+
raise RuntimeError(f"invalid state {self._state.value}")
209+
self._state = _State.ENTER
210+
self._reschedule()
211+
212+
def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
213+
if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT:
214+
self._timeout_handler = None
215+
raise asyncio.TimeoutError
216+
# timeout has not expired
217+
self._state = _State.EXIT
218+
self._reject()
219+
return None
220+
221+
def _on_timeout(self, task: "asyncio.Task[None]") -> None:
222+
task.cancel()
223+
self._state = _State.TIMEOUT
224+
# drop the reference early
225+
self._timeout_handler = None

src/websockets/legacy/compatibility.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from typing import Any, Dict
66

77

8+
__all__ = ["asyncio_timeout", "loop_if_py_lt_38"]
9+
10+
811
if sys.version_info[:2] >= (3, 8):
912

1013
def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]:
@@ -22,3 +25,9 @@ def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]:
2225
2326
"""
2427
return {"loop": loop}
28+
29+
30+
if sys.version_info[:2] >= (3, 11):
31+
from asyncio import timeout as asyncio_timeout # noqa: F401
32+
else:
33+
from .async_timeout import timeout as asyncio_timeout # noqa: F401

src/websockets/legacy/protocol.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454
from ..protocol import State
5555
from ..typing import Data, LoggerLike, Subprotocol
56-
from .compatibility import loop_if_py_lt_38
56+
from .compatibility import asyncio_timeout, loop_if_py_lt_38
5757
from .framing import Frame
5858

5959

@@ -761,19 +761,16 @@ async def close(self, code: int = 1000, reason: str = "") -> None:
761761
762762
"""
763763
try:
764-
await asyncio.wait_for(
765-
self.write_close_frame(Close(code, reason)),
766-
self.close_timeout,
767-
**loop_if_py_lt_38(self.loop),
768-
)
764+
async with asyncio_timeout(self.close_timeout):
765+
await self.write_close_frame(Close(code, reason))
769766
except asyncio.TimeoutError:
770767
# If the close frame cannot be sent because the send buffers
771768
# are full, the closing handshake won't complete anyway.
772769
# Fail the connection to shut down faster.
773770
self.fail_connection()
774771

775-
# If no close frame is received within the timeout, wait_for() cancels
776-
# the data transfer task and raises TimeoutError.
772+
# If no close frame is received within the timeout, asyncio_timeout()
773+
# cancels the data transfer task and raises TimeoutError.
777774

778775
# If close() is called multiple times concurrently and one of these
779776
# calls hits the timeout, the data transfer task will be canceled.
@@ -782,11 +779,8 @@ async def close(self, code: int = 1000, reason: str = "") -> None:
782779
try:
783780
# If close() is canceled during the wait, self.transfer_data_task
784781
# is canceled before the timeout elapses.
785-
await asyncio.wait_for(
786-
self.transfer_data_task,
787-
self.close_timeout,
788-
**loop_if_py_lt_38(self.loop),
789-
)
782+
async with asyncio_timeout(self.close_timeout):
783+
await self.transfer_data_task
790784
except (asyncio.TimeoutError, asyncio.CancelledError):
791785
pass
792786

@@ -1268,11 +1262,8 @@ async def keepalive_ping(self) -> None:
12681262

12691263
if self.ping_timeout is not None:
12701264
try:
1271-
await asyncio.wait_for(
1272-
pong_waiter,
1273-
self.ping_timeout,
1274-
**loop_if_py_lt_38(self.loop),
1275-
)
1265+
async with asyncio_timeout(self.ping_timeout):
1266+
await pong_waiter
12761267
self.logger.debug("% received keepalive pong")
12771268
except asyncio.TimeoutError:
12781269
if self.debug:
@@ -1384,11 +1375,8 @@ async def wait_for_connection_lost(self) -> bool:
13841375
"""
13851376
if not self.connection_lost_waiter.done():
13861377
try:
1387-
await asyncio.wait_for(
1388-
asyncio.shield(self.connection_lost_waiter),
1389-
self.close_timeout,
1390-
**loop_if_py_lt_38(self.loop),
1391-
)
1378+
async with asyncio_timeout(self.close_timeout):
1379+
await asyncio.shield(self.connection_lost_waiter)
13921380
except asyncio.TimeoutError:
13931381
pass
13941382
# Re-check self.connection_lost_waiter.done() synchronously because

0 commit comments

Comments
 (0)