Skip to content

Commit ee37f2c

Browse files
committed
Remove low-level read timeouts from the Parser, now handled in the Connection
1 parent 15f4d78 commit ee37f2c

File tree

1 file changed

+36
-90
lines changed

1 file changed

+36
-90
lines changed

redis/asyncio/connection.py

Lines changed: 36 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,6 @@ async def __aexit__(self, *args):
7070

7171
nullcontext = NullContext()
7272

73-
NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {
74-
BlockingIOError: errno.EWOULDBLOCK,
75-
ssl.SSLWantReadError: 2,
76-
ssl.SSLWantWriteError: 2,
77-
ssl.SSLError: 2,
78-
}
79-
80-
NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())
81-
82-
8373
SYM_STAR = b"*"
8474
SYM_DOLLAR = b"$"
8575
SYM_CRLF = b"\r\n"
@@ -233,11 +223,9 @@ def __init__(
233223
self,
234224
stream_reader: asyncio.StreamReader,
235225
socket_read_size: int,
236-
socket_timeout: Optional[float],
237226
):
238227
self._stream: Optional[asyncio.StreamReader] = stream_reader
239228
self.socket_read_size = socket_read_size
240-
self.socket_timeout = socket_timeout
241229
self._buffer: Optional[io.BytesIO] = io.BytesIO()
242230
# number of bytes written to the buffer from the socket
243231
self.bytes_written = 0
@@ -248,52 +236,35 @@ def __init__(
248236
def length(self):
249237
return self.bytes_written - self.bytes_read
250238

251-
async def _read_from_socket(
252-
self,
253-
length: Optional[int] = None,
254-
timeout: Union[float, None, _Sentinel] = SENTINEL,
255-
raise_on_timeout: bool = True,
256-
) -> bool:
239+
async def _read_from_socket(self, length: Optional[int] = None) -> bool:
257240
buf = self._buffer
258241
if buf is None or self._stream is None:
259242
raise RedisError("Buffer is closed.")
260243
buf.seek(self.bytes_written)
261244
marker = 0
262-
timeout = timeout if timeout is not SENTINEL else self.socket_timeout
263245

264-
try:
265-
while True:
266-
async with async_timeout.timeout(timeout):
267-
data = await self._stream.read(self.socket_read_size)
268-
# an empty string indicates the server shutdown the socket
269-
if isinstance(data, bytes) and len(data) == 0:
270-
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
271-
buf.write(data)
272-
data_length = len(data)
273-
self.bytes_written += data_length
274-
marker += data_length
275-
276-
if length is not None and length > marker:
277-
continue
278-
return True
279-
except (socket.timeout, asyncio.TimeoutError):
280-
if raise_on_timeout:
281-
raise TimeoutError("Timeout reading from socket")
282-
return False
283-
except NONBLOCKING_EXCEPTIONS as ex:
284-
# if we're in nonblocking mode and the recv raises a
285-
# blocking error, simply return False indicating that
286-
# there's no data to be read. otherwise raise the
287-
# original exception.
288-
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
289-
if not raise_on_timeout and ex.errno == allowed:
290-
return False
291-
raise ConnectionError(f"Error while reading from socket: {ex.args}")
246+
while True:
247+
data = await self._stream.read(self.socket_read_size)
248+
# an empty string indicates the server shutdown the socket
249+
if isinstance(data, bytes) and len(data) == 0:
250+
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
251+
buf.write(data)
252+
data_length = len(data)
253+
self.bytes_written += data_length
254+
marker += data_length
255+
256+
if length is not None and length > marker:
257+
continue
258+
return True
292259

293260
async def can_read(self, timeout: float) -> bool:
294-
return bool(self.length) or await self._read_from_socket(
295-
timeout=timeout, raise_on_timeout=False
296-
)
261+
if self.length:
262+
return True
263+
try:
264+
async with async_timeout.timeout(timeout):
265+
return await self._read_from_socket()
266+
except asyncio.TimeoutError:
267+
return False
297268

298269
async def read(self, length: int) -> bytes:
299270
length = length + 2 # make sure to read the \r\n terminator
@@ -376,9 +347,7 @@ def on_connect(self, connection: "Connection"):
376347
if self._stream is None:
377348
raise RedisError("Buffer is closed.")
378349

379-
self._buffer = SocketBuffer(
380-
self._stream, self._read_size, connection.socket_timeout
381-
)
350+
self._buffer = SocketBuffer(self._stream, self._read_size)
382351
self.encoder = connection.encoder
383352

384353
def on_disconnect(self):
@@ -448,7 +417,7 @@ async def read_response(
448417
class HiredisParser(BaseParser):
449418
"""Parser class for connections using Hiredis"""
450419

451-
__slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout")
420+
__slots__ = BaseParser.__slots__ + ("_next_response", "_reader")
452421

453422
_next_response: bool
454423

@@ -457,7 +426,6 @@ def __init__(self, socket_read_size: int):
457426
raise RedisError("Hiredis is not available.")
458427
super().__init__(socket_read_size=socket_read_size)
459428
self._reader: Optional[hiredis.Reader] = None
460-
self._socket_timeout: Optional[float] = None
461429

462430
def on_connect(self, connection: "Connection"):
463431
self._stream = connection._reader
@@ -471,7 +439,6 @@ def on_connect(self, connection: "Connection"):
471439

472440
self._reader = hiredis.Reader(**kwargs)
473441
self._next_response = False
474-
self._socket_timeout = connection.socket_timeout
475442

476443
def on_disconnect(self):
477444
self._stream = None
@@ -485,42 +452,21 @@ async def can_read(self, timeout: float):
485452
if self._next_response is False:
486453
self._next_response = self._reader.gets()
487454
if self._next_response is False:
488-
return await self.read_from_socket(timeout=timeout, raise_on_timeout=False)
455+
try:
456+
with async_timeout.timeout(timeout):
457+
return await self.read_from_socket()
458+
except asyncio.TimeoutError:
459+
return False
489460
return True
490461

491-
async def read_from_socket(
492-
self,
493-
timeout: Union[float, None, _Sentinel] = SENTINEL,
494-
raise_on_timeout: bool = True,
495-
):
496-
timeout = self._socket_timeout if timeout is SENTINEL else timeout
497-
try:
498-
if timeout is None:
499-
buffer = await self._stream.read(self._read_size)
500-
else:
501-
async with async_timeout.timeout(timeout):
502-
buffer = await self._stream.read(self._read_size)
503-
if not buffer or not isinstance(buffer, bytes):
504-
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
505-
self._reader.feed(buffer)
506-
# data was read from the socket and added to the buffer.
507-
# return True to indicate that data was read.
508-
return True
509-
except asyncio.CancelledError:
510-
raise
511-
except (socket.timeout, asyncio.TimeoutError):
512-
if raise_on_timeout:
513-
raise TimeoutError("Timeout reading from socket") from None
514-
return False
515-
except NONBLOCKING_EXCEPTIONS as ex:
516-
# if we're in nonblocking mode and the recv raises a
517-
# blocking error, simply return False indicating that
518-
# there's no data to be read. otherwise raise the
519-
# original exception.
520-
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
521-
if not raise_on_timeout and ex.errno == allowed:
522-
return False
523-
raise ConnectionError(f"Error while reading from socket: {ex.args}")
462+
async def read_from_socket(self):
463+
buffer = await self._stream.read(self._read_size)
464+
if not buffer or not isinstance(buffer, bytes):
465+
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
466+
self._reader.feed(buffer)
467+
# data was read from the socket and added to the buffer.
468+
# return True to indicate that data was read.
469+
return True
524470

525471
async def read_response(
526472
self, disable_decoding: bool = False

0 commit comments

Comments
 (0)