Skip to content

Commit 21ea6d5

Browse files
bdracoaaugustin
authored andcommitted
Use asyncio.timeout instead of asyncio.wait_for.
asyncio.wait_for creates a task whereas asyncio.timeout doesn't. Fallback to a vendored version of async_timeout on Python < 3.11. async.timeout will become the underlying implementation for async.wait_for in Python 3.12: python/cpython#98518
1 parent af91737 commit 21ea6d5

File tree

5 files changed

+247
-23
lines changed

5 files changed

+247
-23
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ branch = true
3939
omit = [
4040
# */websockets matches src/websockets and .tox/**/site-packages/websockets
4141
"*/websockets/__main__.py",
42+
"*/websockets/legacy/async_timeout.py",
4243
"*/websockets/legacy/compatibility.py",
4344
"tests/maxi_cov.py",
4445
]
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
2+
# Licensed under the Apache License, Version 2.0.
3+
4+
import asyncio
5+
import enum
6+
import sys
7+
import warnings
8+
from types import TracebackType
9+
from typing import Optional, Type
10+
11+
12+
if sys.version_info >= (3, 8):
13+
from typing import final
14+
else:
15+
from typing_extensions import final
16+
17+
18+
__version__ = "4.0.2"
19+
20+
21+
__all__ = ("timeout", "timeout_at", "Timeout")
22+
23+
24+
def timeout(delay: Optional[float]) -> "Timeout":
25+
"""timeout context manager.
26+
27+
Useful in cases when you want to apply timeout logic around block
28+
of code or in cases when asyncio.wait_for is not suitable. For example:
29+
30+
>>> async with timeout(0.001):
31+
... async with aiohttp.get('https://github.com') as r:
32+
... await r.text()
33+
34+
35+
delay - value in seconds or None to disable timeout logic
36+
"""
37+
loop = asyncio.get_running_loop()
38+
if delay is not None:
39+
deadline = loop.time() + delay # type: Optional[float]
40+
else:
41+
deadline = None
42+
return Timeout(deadline, loop)
43+
44+
45+
def timeout_at(deadline: Optional[float]) -> "Timeout":
46+
"""Schedule the timeout at absolute time.
47+
48+
deadline argument points on the time in the same clock system
49+
as loop.time().
50+
51+
Please note: it is not POSIX time but a time with
52+
undefined starting base, e.g. the time of the system power on.
53+
54+
>>> async with timeout_at(loop.time() + 10):
55+
... async with aiohttp.get('https://github.com') as r:
56+
... await r.text()
57+
58+
59+
"""
60+
loop = asyncio.get_running_loop()
61+
return Timeout(deadline, loop)
62+
63+
64+
class _State(enum.Enum):
65+
INIT = "INIT"
66+
ENTER = "ENTER"
67+
TIMEOUT = "TIMEOUT"
68+
EXIT = "EXIT"
69+
70+
71+
@final
72+
class Timeout:
73+
# Internal class, please don't instantiate it directly
74+
# Use timeout() and timeout_at() public factories instead.
75+
#
76+
# Implementation note: `async with timeout()` is preferred
77+
# over `with timeout()`.
78+
# While technically the Timeout class implementation
79+
# doesn't need to be async at all,
80+
# the `async with` statement explicitly points that
81+
# the context manager should be used from async function context.
82+
#
83+
# This design allows to avoid many silly misusages.
84+
#
85+
# TimeoutError is raised immediately when scheduled
86+
# if the deadline is passed.
87+
# The purpose is to time out as soon as possible
88+
# without waiting for the next await expression.
89+
90+
__slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
91+
92+
def __init__(
93+
self, deadline: Optional[float], loop: asyncio.AbstractEventLoop
94+
) -> None:
95+
self._loop = loop
96+
self._state = _State.INIT
97+
98+
self._timeout_handler = None # type: Optional[asyncio.Handle]
99+
if deadline is None:
100+
self._deadline = None # type: Optional[float]
101+
else:
102+
self.update(deadline)
103+
104+
def __enter__(self) -> "Timeout":
105+
warnings.warn(
106+
"with timeout() is deprecated, use async with timeout() instead",
107+
DeprecationWarning,
108+
stacklevel=2,
109+
)
110+
self._do_enter()
111+
return self
112+
113+
def __exit__(
114+
self,
115+
exc_type: Optional[Type[BaseException]],
116+
exc_val: Optional[BaseException],
117+
exc_tb: Optional[TracebackType],
118+
) -> Optional[bool]:
119+
self._do_exit(exc_type)
120+
return None
121+
122+
async def __aenter__(self) -> "Timeout":
123+
self._do_enter()
124+
return self
125+
126+
async def __aexit__(
127+
self,
128+
exc_type: Optional[Type[BaseException]],
129+
exc_val: Optional[BaseException],
130+
exc_tb: Optional[TracebackType],
131+
) -> Optional[bool]:
132+
self._do_exit(exc_type)
133+
return None
134+
135+
@property
136+
def expired(self) -> bool:
137+
"""Is timeout expired during execution?"""
138+
return self._state == _State.TIMEOUT
139+
140+
@property
141+
def deadline(self) -> Optional[float]:
142+
return self._deadline
143+
144+
def reject(self) -> None:
145+
"""Reject scheduled timeout if any."""
146+
# cancel is maybe better name but
147+
# task.cancel() raises CancelledError in asyncio world.
148+
if self._state not in (_State.INIT, _State.ENTER):
149+
raise RuntimeError(f"invalid state {self._state.value}")
150+
self._reject()
151+
152+
def _reject(self) -> None:
153+
if self._timeout_handler is not None:
154+
self._timeout_handler.cancel()
155+
self._timeout_handler = None
156+
157+
def shift(self, delay: float) -> None:
158+
"""Advance timeout on delay seconds.
159+
160+
The delay can be negative.
161+
162+
Raise RuntimeError if shift is called when deadline is not scheduled
163+
"""
164+
deadline = self._deadline
165+
if deadline is None:
166+
raise RuntimeError("cannot shift timeout if deadline is not scheduled")
167+
self.update(deadline + delay)
168+
169+
def update(self, deadline: float) -> None:
170+
"""Set deadline to absolute value.
171+
172+
deadline argument points on the time in the same clock system
173+
as loop.time().
174+
175+
If new deadline is in the past the timeout is raised immediately.
176+
177+
Please note: it is not POSIX time but a time with
178+
undefined starting base, e.g. the time of the system power on.
179+
"""
180+
if self._state == _State.EXIT:
181+
raise RuntimeError("cannot reschedule after exit from context manager")
182+
if self._state == _State.TIMEOUT:
183+
raise RuntimeError("cannot reschedule expired timeout")
184+
if self._timeout_handler is not None:
185+
self._timeout_handler.cancel()
186+
self._deadline = deadline
187+
if self._state != _State.INIT:
188+
self._reschedule()
189+
190+
def _reschedule(self) -> None:
191+
assert self._state == _State.ENTER
192+
deadline = self._deadline
193+
if deadline is None:
194+
return
195+
196+
now = self._loop.time()
197+
if self._timeout_handler is not None:
198+
self._timeout_handler.cancel()
199+
200+
task = asyncio.current_task()
201+
if deadline <= now:
202+
self._timeout_handler = self._loop.call_soon(self._on_timeout, task)
203+
else:
204+
self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task)
205+
206+
def _do_enter(self) -> None:
207+
if self._state != _State.INIT:
208+
raise RuntimeError(f"invalid state {self._state.value}")
209+
self._state = _State.ENTER
210+
self._reschedule()
211+
212+
def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
213+
if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT:
214+
self._timeout_handler = None
215+
raise asyncio.TimeoutError
216+
# timeout has not expired
217+
self._state = _State.EXIT
218+
self._reject()
219+
return None
220+
221+
def _on_timeout(self, task: "asyncio.Task[None]") -> None:
222+
task.cancel()
223+
self._state = _State.TIMEOUT
224+
# drop the reference early
225+
self._timeout_handler = None

src/websockets/legacy/compatibility.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from typing import Any, Dict
66

77

8+
__all__ = ["asyncio_timeout", "loop_if_py_lt_38"]
9+
10+
811
if sys.version_info[:2] >= (3, 8):
912

1013
def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]:
@@ -22,3 +25,9 @@ def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]:
2225
2326
"""
2427
return {"loop": loop}
28+
29+
30+
if sys.version_info[:2] >= (3, 11):
31+
from asyncio import timeout as asyncio_timeout # noqa: F401
32+
else:
33+
from .async_timeout import timeout as asyncio_timeout # noqa: F401

src/websockets/legacy/protocol.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454
from ..protocol import State
5555
from ..typing import Data, LoggerLike, Subprotocol
56-
from .compatibility import loop_if_py_lt_38
56+
from .compatibility import asyncio_timeout, loop_if_py_lt_38
5757
from .framing import Frame
5858

5959

@@ -763,18 +763,15 @@ async def close(self, code: int = 1000, reason: str = "") -> None:
763763
764764
"""
765765
try:
766-
await asyncio.wait_for(
767-
self.write_close_frame(Close(code, reason)),
768-
self.close_timeout,
769-
**loop_if_py_lt_38(self.loop),
770-
)
766+
async with asyncio_timeout(self.close_timeout):
767+
await self.write_close_frame(Close(code, reason))
771768
except asyncio.TimeoutError:
772769
# If the close frame cannot be sent because the send buffers
773770
# are full, the closing handshake won't complete anyway.
774771
# Fail the connection to shut down faster.
775772
self.fail_connection()
776773

777-
# If no close frame is received within the timeout, wait_for() cancels
774+
# If no close frame is received within the timeout, asyncio_timeout() cancels
778775
# the data transfer task and raises TimeoutError.
779776

780777
# If close() is called multiple times concurrently and one of these
@@ -784,11 +781,8 @@ async def close(self, code: int = 1000, reason: str = "") -> None:
784781
try:
785782
# If close() is canceled during the wait, self.transfer_data_task
786783
# is canceled before the timeout elapses.
787-
await asyncio.wait_for(
788-
self.transfer_data_task,
789-
self.close_timeout,
790-
**loop_if_py_lt_38(self.loop),
791-
)
784+
async with asyncio_timeout(self.close_timeout):
785+
await self.transfer_data_task
792786
except (asyncio.TimeoutError, asyncio.CancelledError):
793787
pass
794788

@@ -1270,11 +1264,8 @@ async def keepalive_ping(self) -> None:
12701264

12711265
if self.ping_timeout is not None:
12721266
try:
1273-
await asyncio.wait_for(
1274-
pong_waiter,
1275-
self.ping_timeout,
1276-
**loop_if_py_lt_38(self.loop),
1277-
)
1267+
async with asyncio_timeout(self.ping_timeout):
1268+
await pong_waiter
12781269
self.logger.debug("% received keepalive pong")
12791270
except asyncio.TimeoutError:
12801271
if self.debug:
@@ -1392,11 +1383,8 @@ async def wait_for_connection_lost(self) -> bool:
13921383
"""
13931384
if not self.connection_lost_waiter.done():
13941385
try:
1395-
await asyncio.wait_for(
1396-
asyncio.shield(self.connection_lost_waiter),
1397-
self.close_timeout,
1398-
**loop_if_py_lt_38(self.loop),
1399-
)
1386+
async with asyncio_timeout(self.close_timeout):
1387+
await asyncio.shield(self.connection_lost_waiter)
14001388
except asyncio.TimeoutError:
14011389
pass
14021390
# Re-check self.connection_lost_waiter.done() synchronously because

tests/legacy/test_protocol.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ def test_answer_ping_does_not_crash_if_connection_closed(self):
932932
self.receive_frame(Frame(True, OP_PING, b"test"))
933933
self.receive_eof()
934934

935+
self.run_loop_once()
935936
with self.assertNoLogs():
936937
self.loop.run_until_complete(self.protocol.close())
937938

@@ -1373,7 +1374,7 @@ def test_simultaneous_close(self):
13731374
# Receive the incoming close frame right after self.protocol.close()
13741375
# starts executing. This reproduces the error described in:
13751376
# https://github.com/aaugustin/websockets/issues/339
1376-
self.loop.call_soon(self.receive_frame, self.remote_close)
1377+
self.receive_frame(self.remote_close)
13771378
self.loop.call_soon(self.receive_eof_if_client)
13781379

13791380
self.loop.run_until_complete(self.protocol.close(reason="local"))

0 commit comments

Comments
 (0)