@@ -493,24 +493,34 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
493493 ):
494494 raise error
495495
496- # COMMAND EXECUTION AND PROTOCOL PARSING
497- async def execute_command (self , * args , ** options ):
498- """Execute a command and return a parsed response"""
499- await self .initialize ()
500- pool = self .connection_pool
501- command_name = args [0 ]
502- conn = self .connection or await pool .get_connection (command_name , ** options )
503-
496+ async def _try_send_command_parse_response (self , conn , * args , ** options ):
504497 try :
505498 return await conn .retry .call_with_retry (
506499 lambda : self ._send_command_parse_response (
507- conn , command_name , * args , ** options
500+ conn , args [ 0 ] , * args , ** options
508501 ),
509502 lambda error : self ._disconnect_raise (conn , error ),
510503 )
504+ except asyncio .CancelledError :
505+ await conn .disconnect (nowait = True )
506+ raise
511507 finally :
508+ if self .single_connection_client :
509+ self ._single_conn_lock .release ()
512510 if not self .connection :
513- await pool .release (conn )
511+ await self .connection_pool .release (conn )
512+
513+ # COMMAND EXECUTION AND PROTOCOL PARSING
514+ async def execute_command (self , * args , ** options ):
515+ """Execute a command and return a parsed response"""
516+ await self .initialize ()
517+ pool = self .connection_pool
518+ command_name = args [0 ]
519+ conn = self .connection or await pool .get_connection (command_name , ** options )
520+
521+ return await asyncio .shield (
522+ self ._try_send_command_parse_response (conn , * args , ** options )
523+ )
514524
515525 async def parse_response (
516526 self , connection : Connection , command_name : Union [str , bytes ], ** options
@@ -749,10 +759,18 @@ async def _disconnect_raise_connect(self, conn, error):
749759 is not a TimeoutError. Otherwise, try to reconnect
750760 """
751761 await conn .disconnect ()
762+
752763 if not (conn .retry_on_timeout and isinstance (error , TimeoutError )):
753764 raise error
754765 await conn .connect ()
755766
767+ async def _try_execute (self , conn , command , * arg , ** kwargs ):
768+ try :
769+ return await command (* arg , ** kwargs )
770+ except asyncio .CancelledError :
771+ await conn .disconnect ()
772+ raise
773+
756774 async def _execute (self , conn , command , * args , ** kwargs ):
757775 """
758776 Connect manually upon disconnection. If the Redis server is down,
@@ -761,9 +779,11 @@ async def _execute(self, conn, command, *args, **kwargs):
761779 called by the # connection to resubscribe us to any channels and
762780 patterns we were previously listening to
763781 """
764- return await conn .retry .call_with_retry (
765- lambda : command (* args , ** kwargs ),
766- lambda error : self ._disconnect_raise_connect (conn , error ),
782+ return await asyncio .shield (
783+ conn .retry .call_with_retry (
784+ lambda : self ._try_execute (conn , command , * args , ** kwargs ),
785+ lambda error : self ._disconnect_raise_connect (conn , error ),
786+ )
767787 )
768788
769789 async def parse_response (self , block : bool = True , timeout : float = 0 ):
@@ -1165,6 +1185,18 @@ async def _disconnect_reset_raise(self, conn, error):
11651185 await self .reset ()
11661186 raise
11671187
1188+ async def _try_send_command_parse_response (self , conn , * args , ** options ):
1189+ try :
1190+ return await conn .retry .call_with_retry (
1191+ lambda : self ._send_command_parse_response (
1192+ conn , args [0 ], * args , ** options
1193+ ),
1194+ lambda error : self ._disconnect_reset_raise (conn , error ),
1195+ )
1196+ except asyncio .CancelledError :
1197+ await conn .disconnect ()
1198+ raise
1199+
11681200 async def immediate_execute_command (self , * args , ** options ):
11691201 """
11701202 Execute a command immediately, but don't auto-retry on a
@@ -1180,13 +1212,13 @@ async def immediate_execute_command(self, *args, **options):
11801212 command_name , self .shard_hint
11811213 )
11821214 self .connection = conn
1183-
1184- return await conn . retry . call_with_retry (
1185- lambda : self ._send_command_parse_response (
1186- conn , command_name , * args , ** options
1187- ),
1188- lambda error : self . _disconnect_reset_raise ( conn , error ),
1189- )
1215+ try :
1216+ return await asyncio . shield (
1217+ self ._try_send_command_parse_response ( conn , * args , ** options )
1218+ )
1219+ except asyncio . CancelledError :
1220+ await conn . disconnect ()
1221+ raise
11901222
11911223 def pipeline_execute_command (self , * args , ** options ):
11921224 """
@@ -1353,6 +1385,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
13531385 await self .reset ()
13541386 raise
13551387
1388+ async def _try_execute (self , conn , execute , stack , raise_on_error ):
1389+ try :
1390+ return await conn .retry .call_with_retry (
1391+ lambda : execute (conn , stack , raise_on_error ),
1392+ lambda error : self ._disconnect_raise_reset (conn , error ),
1393+ )
1394+ except asyncio .CancelledError :
1395+ # not supposed to be possible, yet here we are
1396+ await conn .disconnect (nowait = True )
1397+ raise
1398+ finally :
1399+ await self .reset ()
1400+
13561401 async def execute (self , raise_on_error : bool = True ):
13571402 """Execute all the commands in the current pipeline"""
13581403 stack = self .command_stack
@@ -1375,15 +1420,10 @@ async def execute(self, raise_on_error: bool = True):
13751420
13761421 try :
13771422 return await asyncio .shield (
1378- conn .retry .call_with_retry (
1379- lambda : execute (conn , stack , raise_on_error ),
1380- lambda error : self ._disconnect_raise_reset (conn , error ),
1381- )
1423+ self ._try_execute (conn , execute , stack , raise_on_error )
13821424 )
1383- except asyncio .CancelledError :
1384- # not supposed to be possible, yet here we are
1385- await conn .disconnect (nowait = True )
1386- raise
1425+ except RuntimeError :
1426+ await self .reset ()
13871427 finally :
13881428 await self .reset ()
13891429
0 commit comments