From d5fc8990291623bff609f58cc42abfcc54f2b422 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 20 Jul 2022 10:40:04 +0000 Subject: [PATCH 1/7] Avoid an extra "can_read" call and use timeout directly. --- redis/asyncio/client.py | 14 ++------------ redis/asyncio/connection.py | 32 +++++++++++++++++++++++++------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index c13054b227..0bd58c3b7a 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -759,18 +759,8 @@ async def parse_response(self, block: bool = True, timeout: float = 0): if not conn.is_connected: await conn.connect() - if not block: - - async def read_with_timeout(): - try: - async with async_timeout.timeout(timeout): - return await conn.read_response() - except asyncio.TimeoutError: - return None - - response = await self._execute(conn, read_with_timeout) - else: - response = await self._execute(conn, conn.read_response) + read_timeout = None if block else timeout + response = await self._execute(conn, conn.read_response, timeout=read_timeout) if conn.health_check_interval and response == self.health_check_response: # ignore the health check message as user might not expect it diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 53b41af7f8..a83a3b735e 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import copy import enum import errno @@ -55,6 +56,19 @@ if HIREDIS_AVAILABLE: import hiredis +if sys.version_info[:2] >= (3, 10): + nullcontext = contextlib.nullcontext() +else: + + class NullContext: + async def __aenter__(self): + pass + + async def __aexit__(self, *args): + pass + + nullcontext = NullContext() + NONBLOCKING_EXCEPTION_ERROR_NUMBERS = { BlockingIOError: errno.EWOULDBLOCK, ssl.SSLWantReadError: 2, @@ -922,19 +936,23 @@ async def can_read_destructive(self): f"Error while reading from {self.host}:{self.port}: {e.args}" ) - async def read_response(self, disable_decoding: bool = False): + async def read_response( + self, + disable_decoding: bool = False, + timeout: Optional[float] = None, + ): """Read the response from a previously sent command""" + read_timeout = timeout if timeout is not None else self.socket_timeout try: - if self.socket_timeout: - async with async_timeout.timeout(self.socket_timeout): - response = await self._parser.read_response( - disable_decoding=disable_decoding - ) - else: + async with async_timeout.timeout(read_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding ) except asyncio.TimeoutError: + if timeout is not None: + # user requested timeout, return None + return None + # it was a self.socket_timeout error. await self.disconnect(nowait=True) raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: From 0da78c8bb019cc6adcd603a2c684282d7cf8b6b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 20 Jul 2022 12:26:40 +0000 Subject: [PATCH 2/7] Remove low-level read timeouts from the Parser, now handled in the Connection --- redis/asyncio/connection.py | 123 ++++++++++-------------------------- 1 file changed, 35 insertions(+), 88 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index a83a3b735e..c0927c3965 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -2,7 +2,6 @@ import contextlib import copy import enum -import errno import inspect import io import os @@ -69,16 +68,6 @@ async def __aexit__(self, *args): nullcontext = NullContext() -NONBLOCKING_EXCEPTION_ERROR_NUMBERS = { - BlockingIOError: errno.EWOULDBLOCK, - ssl.SSLWantReadError: 2, - ssl.SSLWantWriteError: 2, - ssl.SSLError: 2, -} - -NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) - - SYM_STAR = b"*" SYM_DOLLAR = b"$" SYM_CRLF = b"\r\n" @@ -243,11 +232,9 @@ def __init__( self, stream_reader: asyncio.StreamReader, socket_read_size: int, - socket_timeout: Optional[float], ): self._stream: Optional[asyncio.StreamReader] = stream_reader self.socket_read_size = socket_read_size - self.socket_timeout = socket_timeout self._buffer: Optional[io.BytesIO] = io.BytesIO() # number of bytes written to the buffer from the socket self.bytes_written = 0 @@ -258,52 +245,35 @@ def __init__( def length(self): return self.bytes_written - self.bytes_read - async def _read_from_socket( - self, - length: Optional[int] = None, - timeout: Union[float, None, _Sentinel] = SENTINEL, - raise_on_timeout: bool = True, - ) -> bool: + async def _read_from_socket(self, length: Optional[int] = None) -> bool: buf = self._buffer if buf is None or self._stream is None: raise RedisError("Buffer is closed.") buf.seek(self.bytes_written) marker = 0 - timeout = timeout if timeout is not SENTINEL else self.socket_timeout - try: - while True: - async with async_timeout.timeout(timeout): - data = await self._stream.read(self.socket_read_size) - # an empty string indicates the server shutdown the socket - if isinstance(data, bytes) and len(data) == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - buf.write(data) - data_length = len(data) - self.bytes_written += data_length - marker += data_length - - if length is not None and length > marker: - continue - return True - except (socket.timeout, asyncio.TimeoutError): - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") - return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") + while True: + data = await self._stream.read(self.socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + self.bytes_written += data_length + marker += data_length + + if length is not None and length > marker: + continue + return True async def can_read_destructive(self) -> bool: - return bool(self.length) or await self._read_from_socket( - timeout=0, raise_on_timeout=False - ) + if self.length: + return True + try: + async with async_timeout.timeout(0): + return await self._read_from_socket() + except asyncio.TimeoutError: + return False async def read(self, length: int) -> bytes: length = length + 2 # make sure to read the \r\n terminator @@ -386,9 +356,7 @@ def on_connect(self, connection: "Connection"): if self._stream is None: raise RedisError("Buffer is closed.") - self._buffer = SocketBuffer( - self._stream, self._read_size, connection.socket_timeout - ) + self._buffer = SocketBuffer(self._stream, self._read_size) self.encoder = connection.encoder def on_disconnect(self): @@ -458,14 +426,13 @@ async def read_response( class HiredisParser(BaseParser): """Parser class for connections using Hiredis""" - __slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout") + __slots__ = BaseParser.__slots__ + ("_reader",) def __init__(self, socket_read_size: int): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not available.") super().__init__(socket_read_size=socket_read_size) self._reader: Optional[hiredis.Reader] = None - self._socket_timeout: Optional[float] = None def on_connect(self, connection: "Connection"): self._stream = connection._reader @@ -478,7 +445,6 @@ def on_connect(self, connection: "Connection"): kwargs["errors"] = connection.encoder.encoding_errors self._reader = hiredis.Reader(**kwargs) - self._socket_timeout = connection.socket_timeout def on_disconnect(self): self._stream = None @@ -489,39 +455,20 @@ async def can_read_destructive(self): raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) if self._reader.gets(): return True - return await self.read_from_socket(timeout=0, raise_on_timeout=False) - - async def read_from_socket( - self, - timeout: Union[float, None, _Sentinel] = SENTINEL, - raise_on_timeout: bool = True, - ): - timeout = self._socket_timeout if timeout is SENTINEL else timeout try: - if timeout is None: - buffer = await self._stream.read(self._read_size) - else: - async with async_timeout.timeout(timeout): - buffer = await self._stream.read(self._read_size) - if not buffer or not isinstance(buffer, bytes): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - self._reader.feed(buffer) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - except (socket.timeout, asyncio.TimeoutError): - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") from None + async with async_timeout.timeout(0): + return await self.read_from_socket() + except asyncio.TimeoutError: return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") + + async def read_from_socket(self): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True async def read_response( self, disable_decoding: bool = False From 9a97fe36fbc739f3242a6de367f134b2b960e5cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 20 Jul 2022 12:32:52 +0000 Subject: [PATCH 3/7] Allow pubsub.get_message(time=None) to block. --- redis/asyncio/client.py | 6 +++--- redis/client.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 0bd58c3b7a..671a973149 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -872,16 +872,16 @@ async def listen(self) -> AsyncIterator: yield response async def get_message( - self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0 ): """ Get the next message if one is available, otherwise None. If timeout is specified, the system will wait for `timeout` seconds before returning. Timeout should be specified as a floating point - number. + number or None to wait indefinitely. """ - response = await self.parse_response(block=False, timeout=timeout) + response = await self.parse_response(block=(timeout is None), timeout=timeout) if response: return await self.handle_message(response, ignore_subscribe_messages) return None diff --git a/redis/client.py b/redis/client.py index 0662a99ea1..75a0dac226 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1637,13 +1637,13 @@ def listen(self): if response is not None: yield response - def get_message(self, ignore_subscribe_messages=False, timeout=0): + def get_message(self, ignore_subscribe_messages=False, timeout=0.0): """ Get the next message if one is available, otherwise None. If timeout is specified, the system will wait for `timeout` seconds before returning. Timeout should be specified as a floating point - number. + number, or None, to wait indefinitely. """ if not self.subscribed: # Wait for subscription @@ -1659,7 +1659,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0): # so no messages are available return None - response = self.parse_response(block=False, timeout=timeout) + response = self.parse_response(block=(timeout is None), timeout=timeout) if response: return self.handle_message(response, ignore_subscribe_messages) return None From 9abe3eb1062105e48646f488834ad9c7089d9bd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 21 Jul 2022 09:59:12 +0000 Subject: [PATCH 4/7] update Changes --- CHANGES | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGES b/CHANGES index a5b5029681..2ced3d8f96 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Allow `timeout=None` in `PubSub.get_message()` to wait forever * add `nowait` flag to `asyncio.Connection.disconnect()` * Update README.md links * Fix timezone handling for datetime to unixtime conversions From 36fda7bc639ff8dc85b3aade08b72bf945b8ddb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 21 Jul 2022 14:09:32 +0000 Subject: [PATCH 5/7] increase test timeout for robustness --- tests/test_asyncio/test_pubsub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 86584e4715..6dedca9ab5 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -29,7 +29,7 @@ async def run(*args, **kwargs): return wrapper -async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): +async def wait_for_message(pubsub, timeout=0.2, ignore_subscribe_messages=False): now = asyncio.get_event_loop().time() timeout = now + timeout while now < timeout: From 305f848a02046212b4501256ada4ee3a16ebc505 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 25 Jul 2022 12:44:13 +0000 Subject: [PATCH 6/7] expand with statement to avoid invoking null context managers. remove nullcontext --- redis/asyncio/connection.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index c0927c3965..c8834c9286 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1,5 +1,4 @@ import asyncio -import contextlib import copy import enum import inspect @@ -55,19 +54,6 @@ if HIREDIS_AVAILABLE: import hiredis -if sys.version_info[:2] >= (3, 10): - nullcontext = contextlib.nullcontext() -else: - - class NullContext: - async def __aenter__(self): - pass - - async def __aexit__(self, *args): - pass - - nullcontext = NullContext() - SYM_STAR = b"*" SYM_DOLLAR = b"$" SYM_CRLF = b"\r\n" @@ -891,7 +877,12 @@ async def read_response( """Read the response from a previously sent command""" read_timeout = timeout if timeout is not None else self.socket_timeout try: - async with async_timeout.timeout(read_timeout): + if read_timeout is not None: + async with async_timeout.timeout(read_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + else: response = await self._parser.read_response( disable_decoding=disable_decoding ) From 7fa672d1873d1ff61bdabb4771c5271aacd8993a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 29 Sep 2022 13:24:42 +0000 Subject: [PATCH 7/7] Remove unused import --- redis/asyncio/client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 671a973149..0e40ed70f8 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -24,8 +24,6 @@ cast, ) -import async_timeout - from redis.asyncio.connection import ( Connection, ConnectionPool,