Skip to content

Commit b0883b7

Browse files
Simplify async timeouts and allowing timeout=None in PubSub.get_message() to wait forever (#2295)
* Avoid an extra "can_read" call and use timeout directly. * Remove low-level read timeouts from the Parser, now handled in the Connection * Allow pubsub.get_message(time=None) to block. * update Changes * increase test timeout for robustness * expand with statement to avoid invoking null context managers. remove nullcontext * Remove unused import
1 parent cdbc662 commit b0883b7

File tree

5 files changed

+57
-112
lines changed

5 files changed

+57
-112
lines changed

CHANGES

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Allow `timeout=None` in `PubSub.get_message()` to wait forever
12
* add `nowait` flag to `asyncio.Connection.disconnect()`
23
* Update README.md links
34
* Fix timezone handling for datetime to unixtime conversions

redis/asyncio/client.py

+5-17
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
cast,
2525
)
2626

27-
import async_timeout
28-
2927
from redis.asyncio.connection import (
3028
Connection,
3129
ConnectionPool,
@@ -759,18 +757,8 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
759757
if not conn.is_connected:
760758
await conn.connect()
761759

762-
if not block:
763-
764-
async def read_with_timeout():
765-
try:
766-
async with async_timeout.timeout(timeout):
767-
return await conn.read_response()
768-
except asyncio.TimeoutError:
769-
return None
770-
771-
response = await self._execute(conn, read_with_timeout)
772-
else:
773-
response = await self._execute(conn, conn.read_response)
760+
read_timeout = None if block else timeout
761+
response = await self._execute(conn, conn.read_response, timeout=read_timeout)
774762

775763
if conn.health_check_interval and response == self.health_check_response:
776764
# ignore the health check message as user might not expect it
@@ -882,16 +870,16 @@ async def listen(self) -> AsyncIterator:
882870
yield response
883871

884872
async def get_message(
885-
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
873+
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
886874
):
887875
"""
888876
Get the next message if one is available, otherwise None.
889877
890878
If timeout is specified, the system will wait for `timeout` seconds
891879
before returning. Timeout should be specified as a floating point
892-
number.
880+
number or None to wait indefinitely.
893881
"""
894-
response = await self.parse_response(block=False, timeout=timeout)
882+
response = await self.parse_response(block=(timeout is None), timeout=timeout)
895883
if response:
896884
return await self.handle_message(response, ignore_subscribe_messages)
897885
return None

redis/asyncio/connection.py

+47-91
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import copy
33
import enum
4-
import errno
54
import inspect
65
import io
76
import os
@@ -55,16 +54,6 @@
5554
if HIREDIS_AVAILABLE:
5655
import hiredis
5756

58-
NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {
59-
BlockingIOError: errno.EWOULDBLOCK,
60-
ssl.SSLWantReadError: 2,
61-
ssl.SSLWantWriteError: 2,
62-
ssl.SSLError: 2,
63-
}
64-
65-
NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())
66-
67-
6857
SYM_STAR = b"*"
6958
SYM_DOLLAR = b"$"
7059
SYM_CRLF = b"\r\n"
@@ -229,11 +218,9 @@ def __init__(
229218
self,
230219
stream_reader: asyncio.StreamReader,
231220
socket_read_size: int,
232-
socket_timeout: Optional[float],
233221
):
234222
self._stream: Optional[asyncio.StreamReader] = stream_reader
235223
self.socket_read_size = socket_read_size
236-
self.socket_timeout = socket_timeout
237224
self._buffer: Optional[io.BytesIO] = io.BytesIO()
238225
# number of bytes written to the buffer from the socket
239226
self.bytes_written = 0
@@ -244,52 +231,35 @@ def __init__(
244231
def length(self):
245232
return self.bytes_written - self.bytes_read
246233

247-
async def _read_from_socket(
248-
self,
249-
length: Optional[int] = None,
250-
timeout: Union[float, None, _Sentinel] = SENTINEL,
251-
raise_on_timeout: bool = True,
252-
) -> bool:
234+
async def _read_from_socket(self, length: Optional[int] = None) -> bool:
253235
buf = self._buffer
254236
if buf is None or self._stream is None:
255237
raise RedisError("Buffer is closed.")
256238
buf.seek(self.bytes_written)
257239
marker = 0
258-
timeout = timeout if timeout is not SENTINEL else self.socket_timeout
259240

260-
try:
261-
while True:
262-
async with async_timeout.timeout(timeout):
263-
data = await self._stream.read(self.socket_read_size)
264-
# an empty string indicates the server shutdown the socket
265-
if isinstance(data, bytes) and len(data) == 0:
266-
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
267-
buf.write(data)
268-
data_length = len(data)
269-
self.bytes_written += data_length
270-
marker += data_length
271-
272-
if length is not None and length > marker:
273-
continue
274-
return True
275-
except (socket.timeout, asyncio.TimeoutError):
276-
if raise_on_timeout:
277-
raise TimeoutError("Timeout reading from socket")
278-
return False
279-
except NONBLOCKING_EXCEPTIONS as ex:
280-
# if we're in nonblocking mode and the recv raises a
281-
# blocking error, simply return False indicating that
282-
# there's no data to be read. otherwise raise the
283-
# original exception.
284-
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
285-
if not raise_on_timeout and ex.errno == allowed:
286-
return False
287-
raise ConnectionError(f"Error while reading from socket: {ex.args}")
241+
while True:
242+
data = await self._stream.read(self.socket_read_size)
243+
# an empty string indicates the server shutdown the socket
244+
if isinstance(data, bytes) and len(data) == 0:
245+
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
246+
buf.write(data)
247+
data_length = len(data)
248+
self.bytes_written += data_length
249+
marker += data_length
250+
251+
if length is not None and length > marker:
252+
continue
253+
return True
288254

289255
async def can_read_destructive(self) -> bool:
290-
return bool(self.length) or await self._read_from_socket(
291-
timeout=0, raise_on_timeout=False
292-
)
256+
if self.length:
257+
return True
258+
try:
259+
async with async_timeout.timeout(0):
260+
return await self._read_from_socket()
261+
except asyncio.TimeoutError:
262+
return False
293263

294264
async def read(self, length: int) -> bytes:
295265
length = length + 2 # make sure to read the \r\n terminator
@@ -372,9 +342,7 @@ def on_connect(self, connection: "Connection"):
372342
if self._stream is None:
373343
raise RedisError("Buffer is closed.")
374344

375-
self._buffer = SocketBuffer(
376-
self._stream, self._read_size, connection.socket_timeout
377-
)
345+
self._buffer = SocketBuffer(self._stream, self._read_size)
378346
self.encoder = connection.encoder
379347

380348
def on_disconnect(self):
@@ -444,14 +412,13 @@ async def read_response(
444412
class HiredisParser(BaseParser):
445413
"""Parser class for connections using Hiredis"""
446414

447-
__slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout")
415+
__slots__ = BaseParser.__slots__ + ("_reader",)
448416

449417
def __init__(self, socket_read_size: int):
450418
if not HIREDIS_AVAILABLE:
451419
raise RedisError("Hiredis is not available.")
452420
super().__init__(socket_read_size=socket_read_size)
453421
self._reader: Optional[hiredis.Reader] = None
454-
self._socket_timeout: Optional[float] = None
455422

456423
def on_connect(self, connection: "Connection"):
457424
self._stream = connection._reader
@@ -464,7 +431,6 @@ def on_connect(self, connection: "Connection"):
464431
kwargs["errors"] = connection.encoder.encoding_errors
465432

466433
self._reader = hiredis.Reader(**kwargs)
467-
self._socket_timeout = connection.socket_timeout
468434

469435
def on_disconnect(self):
470436
self._stream = None
@@ -475,39 +441,20 @@ async def can_read_destructive(self):
475441
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
476442
if self._reader.gets():
477443
return True
478-
return await self.read_from_socket(timeout=0, raise_on_timeout=False)
479-
480-
async def read_from_socket(
481-
self,
482-
timeout: Union[float, None, _Sentinel] = SENTINEL,
483-
raise_on_timeout: bool = True,
484-
):
485-
timeout = self._socket_timeout if timeout is SENTINEL else timeout
486444
try:
487-
if timeout is None:
488-
buffer = await self._stream.read(self._read_size)
489-
else:
490-
async with async_timeout.timeout(timeout):
491-
buffer = await self._stream.read(self._read_size)
492-
if not buffer or not isinstance(buffer, bytes):
493-
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
494-
self._reader.feed(buffer)
495-
# data was read from the socket and added to the buffer.
496-
# return True to indicate that data was read.
497-
return True
498-
except (socket.timeout, asyncio.TimeoutError):
499-
if raise_on_timeout:
500-
raise TimeoutError("Timeout reading from socket") from None
445+
async with async_timeout.timeout(0):
446+
return await self.read_from_socket()
447+
except asyncio.TimeoutError:
501448
return False
502-
except NONBLOCKING_EXCEPTIONS as ex:
503-
# if we're in nonblocking mode and the recv raises a
504-
# blocking error, simply return False indicating that
505-
# there's no data to be read. otherwise raise the
506-
# original exception.
507-
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
508-
if not raise_on_timeout and ex.errno == allowed:
509-
return False
510-
raise ConnectionError(f"Error while reading from socket: {ex.args}")
449+
450+
async def read_from_socket(self):
451+
buffer = await self._stream.read(self._read_size)
452+
if not buffer or not isinstance(buffer, bytes):
453+
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
454+
self._reader.feed(buffer)
455+
# data was read from the socket and added to the buffer.
456+
# return True to indicate that data was read.
457+
return True
511458

512459
async def read_response(
513460
self, disable_decoding: bool = False
@@ -922,11 +869,16 @@ async def can_read_destructive(self):
922869
f"Error while reading from {self.host}:{self.port}: {e.args}"
923870
)
924871

925-
async def read_response(self, disable_decoding: bool = False):
872+
async def read_response(
873+
self,
874+
disable_decoding: bool = False,
875+
timeout: Optional[float] = None,
876+
):
926877
"""Read the response from a previously sent command"""
878+
read_timeout = timeout if timeout is not None else self.socket_timeout
927879
try:
928-
if self.socket_timeout:
929-
async with async_timeout.timeout(self.socket_timeout):
880+
if read_timeout is not None:
881+
async with async_timeout.timeout(read_timeout):
930882
response = await self._parser.read_response(
931883
disable_decoding=disable_decoding
932884
)
@@ -935,6 +887,10 @@ async def read_response(self, disable_decoding: bool = False):
935887
disable_decoding=disable_decoding
936888
)
937889
except asyncio.TimeoutError:
890+
if timeout is not None:
891+
# user requested timeout, return None
892+
return None
893+
# it was a self.socket_timeout error.
938894
await self.disconnect(nowait=True)
939895
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
940896
except OSError as e:

redis/client.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1637,13 +1637,13 @@ def listen(self):
16371637
if response is not None:
16381638
yield response
16391639

1640-
def get_message(self, ignore_subscribe_messages=False, timeout=0):
1640+
def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
16411641
"""
16421642
Get the next message if one is available, otherwise None.
16431643
16441644
If timeout is specified, the system will wait for `timeout` seconds
16451645
before returning. Timeout should be specified as a floating point
1646-
number.
1646+
number, or None, to wait indefinitely.
16471647
"""
16481648
if not self.subscribed:
16491649
# Wait for subscription
@@ -1659,7 +1659,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0):
16591659
# so no messages are available
16601660
return None
16611661

1662-
response = self.parse_response(block=False, timeout=timeout)
1662+
response = self.parse_response(block=(timeout is None), timeout=timeout)
16631663
if response:
16641664
return self.handle_message(response, ignore_subscribe_messages)
16651665
return None

tests/test_asyncio/test_pubsub.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ async def run(*args, **kwargs):
2929
return wrapper
3030

3131

32-
async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
32+
async def wait_for_message(pubsub, timeout=0.2, ignore_subscribe_messages=False):
3333
now = asyncio.get_event_loop().time()
3434
timeout = now + timeout
3535
while now < timeout:

0 commit comments

Comments
 (0)