Skip to content

Simplify async timeouts and allowing timeout=None in PubSub.get_message() to wait forever #2295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Allow `timeout=None` in `PubSub.get_message()` to wait forever
* add `nowait` flag to `asyncio.Connection.disconnect()`
* Update README.md links
* Fix timezone handling for datetime to unixtime conversions
Expand Down
22 changes: 5 additions & 17 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
cast,
)

import async_timeout

from redis.asyncio.connection import (
Connection,
ConnectionPool,
Expand Down Expand Up @@ -759,18 +757,8 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
if not conn.is_connected:
await conn.connect()

if not block:

async def read_with_timeout():
try:
async with async_timeout.timeout(timeout):
return await conn.read_response()
except asyncio.TimeoutError:
return None

response = await self._execute(conn, read_with_timeout)
else:
response = await self._execute(conn, conn.read_response)
read_timeout = None if block else timeout
response = await self._execute(conn, conn.read_response, timeout=read_timeout)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand Down Expand Up @@ -882,16 +870,16 @@ async def listen(self) -> AsyncIterator:
yield response

async def get_message(
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
):
"""
Get the next message if one is available, otherwise None.

If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number.
number or None to wait indefinitely.
"""
response = await self.parse_response(block=False, timeout=timeout)
response = await self.parse_response(block=(timeout is None), timeout=timeout)
if response:
return await self.handle_message(response, ignore_subscribe_messages)
return None
Expand Down
138 changes: 47 additions & 91 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import copy
import enum
import errno
import inspect
import io
import os
Expand Down Expand Up @@ -55,16 +54,6 @@
if HIREDIS_AVAILABLE:
import hiredis

NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {
BlockingIOError: errno.EWOULDBLOCK,
ssl.SSLWantReadError: 2,
ssl.SSLWantWriteError: 2,
ssl.SSLError: 2,
}

NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())


SYM_STAR = b"*"
SYM_DOLLAR = b"$"
SYM_CRLF = b"\r\n"
Expand Down Expand Up @@ -229,11 +218,9 @@ def __init__(
self,
stream_reader: asyncio.StreamReader,
socket_read_size: int,
socket_timeout: Optional[float],
):
self._stream: Optional[asyncio.StreamReader] = stream_reader
self.socket_read_size = socket_read_size
self.socket_timeout = socket_timeout
self._buffer: Optional[io.BytesIO] = io.BytesIO()
# number of bytes written to the buffer from the socket
self.bytes_written = 0
Expand All @@ -244,52 +231,35 @@ def __init__(
def length(self):
return self.bytes_written - self.bytes_read

async def _read_from_socket(
self,
length: Optional[int] = None,
timeout: Union[float, None, _Sentinel] = SENTINEL,
raise_on_timeout: bool = True,
) -> bool:
async def _read_from_socket(self, length: Optional[int] = None) -> bool:
buf = self._buffer
if buf is None or self._stream is None:
raise RedisError("Buffer is closed.")
buf.seek(self.bytes_written)
marker = 0
timeout = timeout if timeout is not SENTINEL else self.socket_timeout

try:
while True:
async with async_timeout.timeout(timeout):
data = await self._stream.read(self.socket_read_size)
# an empty string indicates the server shutdown the socket
if isinstance(data, bytes) and len(data) == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
self.bytes_written += data_length
marker += data_length

if length is not None and length > marker:
continue
return True
except (socket.timeout, asyncio.TimeoutError):
if raise_on_timeout:
raise TimeoutError("Timeout reading from socket")
return False
except NONBLOCKING_EXCEPTIONS as ex:
# if we're in nonblocking mode and the recv raises a
# blocking error, simply return False indicating that
# there's no data to be read. otherwise raise the
# original exception.
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
if not raise_on_timeout and ex.errno == allowed:
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")
while True:
data = await self._stream.read(self.socket_read_size)
# an empty string indicates the server shutdown the socket
if isinstance(data, bytes) and len(data) == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
self.bytes_written += data_length
marker += data_length

if length is not None and length > marker:
continue
return True

async def can_read_destructive(self) -> bool:
return bool(self.length) or await self._read_from_socket(
timeout=0, raise_on_timeout=False
)
if self.length:
return True
try:
async with async_timeout.timeout(0):
return await self._read_from_socket()
except asyncio.TimeoutError:
return False

async def read(self, length: int) -> bytes:
length = length + 2 # make sure to read the \r\n terminator
Expand Down Expand Up @@ -372,9 +342,7 @@ def on_connect(self, connection: "Connection"):
if self._stream is None:
raise RedisError("Buffer is closed.")

self._buffer = SocketBuffer(
self._stream, self._read_size, connection.socket_timeout
)
self._buffer = SocketBuffer(self._stream, self._read_size)
self.encoder = connection.encoder

def on_disconnect(self):
Expand Down Expand Up @@ -444,14 +412,13 @@ async def read_response(
class HiredisParser(BaseParser):
"""Parser class for connections using Hiredis"""

__slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout")
__slots__ = BaseParser.__slots__ + ("_reader",)

def __init__(self, socket_read_size: int):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not available.")
super().__init__(socket_read_size=socket_read_size)
self._reader: Optional[hiredis.Reader] = None
self._socket_timeout: Optional[float] = None

def on_connect(self, connection: "Connection"):
self._stream = connection._reader
Expand All @@ -464,7 +431,6 @@ def on_connect(self, connection: "Connection"):
kwargs["errors"] = connection.encoder.encoding_errors

self._reader = hiredis.Reader(**kwargs)
self._socket_timeout = connection.socket_timeout

def on_disconnect(self):
self._stream = None
Expand All @@ -475,39 +441,20 @@ async def can_read_destructive(self):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._reader.gets():
return True
return await self.read_from_socket(timeout=0, raise_on_timeout=False)

async def read_from_socket(
self,
timeout: Union[float, None, _Sentinel] = SENTINEL,
raise_on_timeout: bool = True,
):
timeout = self._socket_timeout if timeout is SENTINEL else timeout
try:
if timeout is None:
buffer = await self._stream.read(self._read_size)
else:
async with async_timeout.timeout(timeout):
buffer = await self._stream.read(self._read_size)
if not buffer or not isinstance(buffer, bytes):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
self._reader.feed(buffer)
# data was read from the socket and added to the buffer.
# return True to indicate that data was read.
return True
except (socket.timeout, asyncio.TimeoutError):
if raise_on_timeout:
raise TimeoutError("Timeout reading from socket") from None
async with async_timeout.timeout(0):
return await self.read_from_socket()
except asyncio.TimeoutError:
return False
except NONBLOCKING_EXCEPTIONS as ex:
# if we're in nonblocking mode and the recv raises a
# blocking error, simply return False indicating that
# there's no data to be read. otherwise raise the
# original exception.
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
if not raise_on_timeout and ex.errno == allowed:
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")

async def read_from_socket(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dvora-h can you dig into this function?

buffer = await self._stream.read(self._read_size)
if not buffer or not isinstance(buffer, bytes):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
self._reader.feed(buffer)
# data was read from the socket and added to the buffer.
# return True to indicate that data was read.
return True

async def read_response(
self, disable_decoding: bool = False
Expand Down Expand Up @@ -922,11 +869,16 @@ async def can_read_destructive(self):
f"Error while reading from {self.host}:{self.port}: {e.args}"
)

async def read_response(self, disable_decoding: bool = False):
async def read_response(
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
):
"""Read the response from a previously sent command"""
read_timeout = timeout if timeout is not None else self.socket_timeout
try:
if self.socket_timeout:
async with async_timeout.timeout(self.socket_timeout):
if read_timeout is not None:
async with async_timeout.timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
Expand All @@ -935,6 +887,10 @@ async def read_response(self, disable_decoding: bool = False):
disable_decoding=disable_decoding
)
except asyncio.TimeoutError:
if timeout is not None:
# user requested timeout, return None
return None
# it was a self.socket_timeout error.
await self.disconnect(nowait=True)
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
except OSError as e:
Expand Down
6 changes: 3 additions & 3 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,13 +1637,13 @@ def listen(self):
if response is not None:
yield response

def get_message(self, ignore_subscribe_messages=False, timeout=0):
def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
"""
Get the next message if one is available, otherwise None.

If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number.
number, or None, to wait indefinitely.
"""
if not self.subscribed:
# Wait for subscription
Expand All @@ -1659,7 +1659,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0):
# so no messages are available
return None

response = self.parse_response(block=False, timeout=timeout)
response = self.parse_response(block=(timeout is None), timeout=timeout)
if response:
return self.handle_message(response, ignore_subscribe_messages)
return None
Expand Down
2 changes: 1 addition & 1 deletion tests/test_asyncio/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def run(*args, **kwargs):
return wrapper


async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
async def wait_for_message(pubsub, timeout=0.2, ignore_subscribe_messages=False):
now = asyncio.get_event_loop().time()
timeout = now + timeout
while now < timeout:
Expand Down