1
1
import asyncio
2
2
import copy
3
3
import enum
4
- import errno
5
4
import inspect
6
5
import io
7
6
import os
55
54
if HIREDIS_AVAILABLE :
56
55
import hiredis
57
56
58
- NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {
59
- BlockingIOError : errno .EWOULDBLOCK ,
60
- ssl .SSLWantReadError : 2 ,
61
- ssl .SSLWantWriteError : 2 ,
62
- ssl .SSLError : 2 ,
63
- }
64
-
65
- NONBLOCKING_EXCEPTIONS = tuple (NONBLOCKING_EXCEPTION_ERROR_NUMBERS .keys ())
66
-
67
-
68
57
SYM_STAR = b"*"
69
58
SYM_DOLLAR = b"$"
70
59
SYM_CRLF = b"\r \n "
@@ -229,11 +218,9 @@ def __init__(
229
218
self ,
230
219
stream_reader : asyncio .StreamReader ,
231
220
socket_read_size : int ,
232
- socket_timeout : Optional [float ],
233
221
):
234
222
self ._stream : Optional [asyncio .StreamReader ] = stream_reader
235
223
self .socket_read_size = socket_read_size
236
- self .socket_timeout = socket_timeout
237
224
self ._buffer : Optional [io .BytesIO ] = io .BytesIO ()
238
225
# number of bytes written to the buffer from the socket
239
226
self .bytes_written = 0
@@ -244,52 +231,35 @@ def __init__(
244
231
def length (self ):
245
232
return self .bytes_written - self .bytes_read
246
233
247
- async def _read_from_socket (
248
- self ,
249
- length : Optional [int ] = None ,
250
- timeout : Union [float , None , _Sentinel ] = SENTINEL ,
251
- raise_on_timeout : bool = True ,
252
- ) -> bool :
234
+ async def _read_from_socket (self , length : Optional [int ] = None ) -> bool :
253
235
buf = self ._buffer
254
236
if buf is None or self ._stream is None :
255
237
raise RedisError ("Buffer is closed." )
256
238
buf .seek (self .bytes_written )
257
239
marker = 0
258
- timeout = timeout if timeout is not SENTINEL else self .socket_timeout
259
240
260
- try :
261
- while True :
262
- async with async_timeout .timeout (timeout ):
263
- data = await self ._stream .read (self .socket_read_size )
264
- # an empty string indicates the server shutdown the socket
265
- if isinstance (data , bytes ) and len (data ) == 0 :
266
- raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
267
- buf .write (data )
268
- data_length = len (data )
269
- self .bytes_written += data_length
270
- marker += data_length
271
-
272
- if length is not None and length > marker :
273
- continue
274
- return True
275
- except (socket .timeout , asyncio .TimeoutError ):
276
- if raise_on_timeout :
277
- raise TimeoutError ("Timeout reading from socket" )
278
- return False
279
- except NONBLOCKING_EXCEPTIONS as ex :
280
- # if we're in nonblocking mode and the recv raises a
281
- # blocking error, simply return False indicating that
282
- # there's no data to be read. otherwise raise the
283
- # original exception.
284
- allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS .get (ex .__class__ , - 1 )
285
- if not raise_on_timeout and ex .errno == allowed :
286
- return False
287
- raise ConnectionError (f"Error while reading from socket: { ex .args } " )
241
+ while True :
242
+ data = await self ._stream .read (self .socket_read_size )
243
+ # an empty string indicates the server shutdown the socket
244
+ if isinstance (data , bytes ) and len (data ) == 0 :
245
+ raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
246
+ buf .write (data )
247
+ data_length = len (data )
248
+ self .bytes_written += data_length
249
+ marker += data_length
250
+
251
+ if length is not None and length > marker :
252
+ continue
253
+ return True
288
254
289
255
async def can_read_destructive (self ) -> bool :
290
- return bool (self .length ) or await self ._read_from_socket (
291
- timeout = 0 , raise_on_timeout = False
292
- )
256
+ if self .length :
257
+ return True
258
+ try :
259
+ async with async_timeout .timeout (0 ):
260
+ return await self ._read_from_socket ()
261
+ except asyncio .TimeoutError :
262
+ return False
293
263
294
264
async def read (self , length : int ) -> bytes :
295
265
length = length + 2 # make sure to read the \r\n terminator
@@ -372,9 +342,7 @@ def on_connect(self, connection: "Connection"):
372
342
if self ._stream is None :
373
343
raise RedisError ("Buffer is closed." )
374
344
375
- self ._buffer = SocketBuffer (
376
- self ._stream , self ._read_size , connection .socket_timeout
377
- )
345
+ self ._buffer = SocketBuffer (self ._stream , self ._read_size )
378
346
self .encoder = connection .encoder
379
347
380
348
def on_disconnect (self ):
@@ -444,14 +412,13 @@ async def read_response(
444
412
class HiredisParser (BaseParser ):
445
413
"""Parser class for connections using Hiredis"""
446
414
447
- __slots__ = BaseParser .__slots__ + ("_reader" , "_socket_timeout" )
415
+ __slots__ = BaseParser .__slots__ + ("_reader" ,)
448
416
449
417
def __init__ (self , socket_read_size : int ):
450
418
if not HIREDIS_AVAILABLE :
451
419
raise RedisError ("Hiredis is not available." )
452
420
super ().__init__ (socket_read_size = socket_read_size )
453
421
self ._reader : Optional [hiredis .Reader ] = None
454
- self ._socket_timeout : Optional [float ] = None
455
422
456
423
def on_connect (self , connection : "Connection" ):
457
424
self ._stream = connection ._reader
@@ -464,7 +431,6 @@ def on_connect(self, connection: "Connection"):
464
431
kwargs ["errors" ] = connection .encoder .encoding_errors
465
432
466
433
self ._reader = hiredis .Reader (** kwargs )
467
- self ._socket_timeout = connection .socket_timeout
468
434
469
435
def on_disconnect (self ):
470
436
self ._stream = None
@@ -475,39 +441,20 @@ async def can_read_destructive(self):
475
441
raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
476
442
if self ._reader .gets ():
477
443
return True
478
- return await self .read_from_socket (timeout = 0 , raise_on_timeout = False )
479
-
480
- async def read_from_socket (
481
- self ,
482
- timeout : Union [float , None , _Sentinel ] = SENTINEL ,
483
- raise_on_timeout : bool = True ,
484
- ):
485
- timeout = self ._socket_timeout if timeout is SENTINEL else timeout
486
444
try :
487
- if timeout is None :
488
- buffer = await self ._stream .read (self ._read_size )
489
- else :
490
- async with async_timeout .timeout (timeout ):
491
- buffer = await self ._stream .read (self ._read_size )
492
- if not buffer or not isinstance (buffer , bytes ):
493
- raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR ) from None
494
- self ._reader .feed (buffer )
495
- # data was read from the socket and added to the buffer.
496
- # return True to indicate that data was read.
497
- return True
498
- except (socket .timeout , asyncio .TimeoutError ):
499
- if raise_on_timeout :
500
- raise TimeoutError ("Timeout reading from socket" ) from None
445
+ async with async_timeout .timeout (0 ):
446
+ return await self .read_from_socket ()
447
+ except asyncio .TimeoutError :
501
448
return False
502
- except NONBLOCKING_EXCEPTIONS as ex :
503
- # if we're in nonblocking mode and the recv raises a
504
- # blocking error, simply return False indicating that
505
- # there's no data to be read. otherwise raise the
506
- # original exception.
507
- allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS . get ( ex . __class__ , - 1 )
508
- if not raise_on_timeout and ex . errno == allowed :
509
- return False
510
- raise ConnectionError ( f"Error while reading from socket: { ex . args } " )
449
+
450
+ async def read_from_socket ( self ):
451
+ buffer = await self . _stream . read ( self . _read_size )
452
+ if not buffer or not isinstance ( buffer , bytes ):
453
+ raise ConnectionError ( SERVER_CLOSED_CONNECTION_ERROR ) from None
454
+ self . _reader . feed ( buffer )
455
+ # data was read from the socket and added to the buffer.
456
+ # return True to indicate that data was read.
457
+ return True
511
458
512
459
async def read_response (
513
460
self , disable_decoding : bool = False
@@ -922,11 +869,16 @@ async def can_read_destructive(self):
922
869
f"Error while reading from { self .host } :{ self .port } : { e .args } "
923
870
)
924
871
925
- async def read_response (self , disable_decoding : bool = False ):
872
+ async def read_response (
873
+ self ,
874
+ disable_decoding : bool = False ,
875
+ timeout : Optional [float ] = None ,
876
+ ):
926
877
"""Read the response from a previously sent command"""
878
+ read_timeout = timeout if timeout is not None else self .socket_timeout
927
879
try :
928
- if self . socket_timeout :
929
- async with async_timeout .timeout (self . socket_timeout ):
880
+ if read_timeout is not None :
881
+ async with async_timeout .timeout (read_timeout ):
930
882
response = await self ._parser .read_response (
931
883
disable_decoding = disable_decoding
932
884
)
@@ -935,6 +887,10 @@ async def read_response(self, disable_decoding: bool = False):
935
887
disable_decoding = disable_decoding
936
888
)
937
889
except asyncio .TimeoutError :
890
+ if timeout is not None :
891
+ # user requested timeout, return None
892
+ return None
893
+ # it was a self.socket_timeout error.
938
894
await self .disconnect (nowait = True )
939
895
raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
940
896
except OSError as e :
0 commit comments