Skip to content

Commit 9fe8366

Browse files
Catch Exception and not BaseException in the Connection (#2104)
* Add failing unittests for passing BaseException through * Resolve failing unittest * Remove redundant checks for asyncio.CancelledError
1 parent fbf68dd commit 9fe8366

File tree

4 files changed

+121
-7
lines changed

4 files changed

+121
-7
lines changed

redis/asyncio/connection.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,6 @@ async def read_from_socket(
502502
# data was read from the socket and added to the buffer.
503503
# return True to indicate that data was read.
504504
return True
505-
except asyncio.CancelledError:
506-
raise
507505
except (socket.timeout, asyncio.TimeoutError):
508506
if raise_on_timeout:
509507
raise TimeoutError("Timeout reading from socket") from None
@@ -721,7 +719,7 @@ async def connect(self):
721719
lambda: self._connect(), lambda error: self.disconnect()
722720
)
723721
except asyncio.CancelledError:
724-
raise
722+
raise # in 3.7 and earlier, this is an Exception, not BaseException
725723
except (socket.timeout, asyncio.TimeoutError):
726724
raise TimeoutError("Timeout connecting to server")
727725
except OSError as e:
@@ -916,7 +914,7 @@ async def send_packed_command(
916914
raise ConnectionError(
917915
f"Error {err_no} while writing to socket. {errmsg}."
918916
) from e
919-
except BaseException:
917+
except Exception:
920918
await self.disconnect()
921919
raise
922920

@@ -958,7 +956,7 @@ async def read_response(self, disable_decoding: bool = False):
958956
raise ConnectionError(
959957
f"Error while reading from {self.host}:{self.port} : {e.args}"
960958
)
961-
except BaseException:
959+
except Exception:
962960
await self.disconnect()
963961
raise
964962

redis/connection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def send_packed_command(self, command, check_health=True):
766766
errno = e.args[0]
767767
errmsg = e.args[1]
768768
raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
769-
except BaseException:
769+
except Exception:
770770
self.disconnect()
771771
raise
772772

@@ -804,7 +804,7 @@ def read_response(self, disable_decoding=False):
804804
except OSError as e:
805805
self.disconnect()
806806
raise ConnectionError(f"Error while reading from {hosterr}" f" : {e.args}")
807-
except BaseException:
807+
except Exception:
808808
self.disconnect()
809809
raise
810810

tests/test_asyncio/test_pubsub.py

+74
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import asyncio
22
import functools
33
import socket
4+
import sys
45
from typing import Optional
6+
from unittest.mock import patch
57

68
import async_timeout
79
import pytest
@@ -914,3 +916,75 @@ async def loop_step_listen(self):
914916
return True
915917
except asyncio.TimeoutError:
916918
return False
919+
920+
921+
@pytest.mark.onlynoncluster
922+
class TestBaseException:
923+
@pytest.mark.skipif(
924+
sys.version_info < (3, 8), reason="requires python 3.8 or higher"
925+
)
926+
async def test_outer_timeout(self, r: redis.Redis):
927+
"""
928+
Using asyncio_timeout manually outside the inner method timeouts works.
929+
This works on Python versions 3.8 and greater, at which time asyncio.
930+
CancelledError became a BaseException instead of an Exception before.
931+
"""
932+
pubsub = r.pubsub()
933+
await pubsub.subscribe("foo")
934+
assert pubsub.connection.is_connected
935+
936+
async def get_msg_or_timeout(timeout=0.1):
937+
async with async_timeout.timeout(timeout):
938+
# blocking method to return messages
939+
while True:
940+
response = await pubsub.parse_response(block=True)
941+
message = await pubsub.handle_message(
942+
response, ignore_subscribe_messages=False
943+
)
944+
if message is not None:
945+
return message
946+
947+
# get subscribe message
948+
msg = await get_msg_or_timeout(10)
949+
assert msg is not None
950+
# timeout waiting for another message which never arrives
951+
assert pubsub.connection.is_connected
952+
with pytest.raises(asyncio.TimeoutError):
953+
await get_msg_or_timeout()
954+
# the timeout on the read should not cause disconnect
955+
assert pubsub.connection.is_connected
956+
957+
async def test_base_exception(self, r: redis.Redis):
958+
"""
959+
Manually trigger a BaseException inside the parser's .read_response method
960+
and verify that it isn't caught
961+
"""
962+
pubsub = r.pubsub()
963+
await pubsub.subscribe("foo")
964+
assert pubsub.connection.is_connected
965+
966+
async def get_msg():
967+
# blocking method to return messages
968+
while True:
969+
response = await pubsub.parse_response(block=True)
970+
message = await pubsub.handle_message(
971+
response, ignore_subscribe_messages=False
972+
)
973+
if message is not None:
974+
return message
975+
976+
# get subscribe message
977+
msg = await get_msg()
978+
assert msg is not None
979+
# timeout waiting for another message which never arrives
980+
assert pubsub.connection.is_connected
981+
with patch("redis.asyncio.connection.PythonParser.read_response") as mock1:
982+
mock1.side_effect = BaseException("boom")
983+
with patch("redis.asyncio.connection.HiredisParser.read_response") as mock2:
984+
mock2.side_effect = BaseException("boom")
985+
986+
with pytest.raises(BaseException):
987+
await get_msg()
988+
989+
# the timeout on the read should not cause disconnect
990+
assert pubsub.connection.is_connected

tests/test_pubsub.py

+42
Original file line numberDiff line numberDiff line change
@@ -735,3 +735,45 @@ def loop_step_listen(self):
735735
for message in self.pubsub.listen():
736736
self.messages.put(message)
737737
return True
738+
739+
740+
@pytest.mark.onlynoncluster
741+
class TestBaseException:
742+
def test_base_exception(self, r: redis.Redis):
743+
"""
744+
Manually trigger a BaseException inside the parser's .read_response method
745+
and verify that it isn't caught
746+
"""
747+
pubsub = r.pubsub()
748+
pubsub.subscribe("foo")
749+
750+
def is_connected():
751+
return pubsub.connection._sock is not None
752+
753+
assert is_connected()
754+
755+
def get_msg():
756+
# blocking method to return messages
757+
while True:
758+
response = pubsub.parse_response(block=True)
759+
message = pubsub.handle_message(
760+
response, ignore_subscribe_messages=False
761+
)
762+
if message is not None:
763+
return message
764+
765+
# get subscribe message
766+
msg = get_msg()
767+
assert msg is not None
768+
# timeout waiting for another message which never arrives
769+
assert is_connected()
770+
with patch("redis.connection.PythonParser.read_response") as mock1:
771+
mock1.side_effect = BaseException("boom")
772+
with patch("redis.connection.HiredisParser.read_response") as mock2:
773+
mock2.side_effect = BaseException("boom")
774+
775+
with pytest.raises(BaseException):
776+
get_msg()
777+
778+
# the timeout on the read should not cause disconnect
779+
assert is_connected()

0 commit comments

Comments
 (0)