diff --git a/neo4j/_async/io/_bolt.py b/neo4j/_async/io/_bolt.py index 7302e64e8..a83e58a62 100644 --- a/neo4j/_async/io/_bolt.py +++ b/neo4j/_async/io/_bolt.py @@ -373,7 +373,11 @@ def time_remaining(): await connection.hello() finally: connection.socket.set_deadline(None) - except Exception as e: + except ( + Exception, + # Python 3.8+: CancelledError is a subclass of BaseException + asyncio.CancelledError, + ) as e: log.debug("[#%04X] C: %r", connection.local_port, e) connection.kill() raise diff --git a/neo4j/_sync/io/_bolt.py b/neo4j/_sync/io/_bolt.py index f652968b9..93e8207ac 100644 --- a/neo4j/_sync/io/_bolt.py +++ b/neo4j/_sync/io/_bolt.py @@ -373,7 +373,11 @@ def time_remaining(): connection.hello() finally: connection.socket.set_deadline(None) - except Exception as e: + except ( + Exception, + # Python 3.8+: CancelledError is a subclass of BaseException + asyncio.CancelledError, + ) as e: log.debug("[#%04X] C: %r", connection.local_port, e) connection.kill() raise diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 3bda07ea5..771c4d73d 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -16,9 +16,14 @@ # limitations under the License. +import asyncio + import pytest from neo4j._async.io import AsyncBolt +from neo4j._async_compat.network import AsyncBoltSocket + +from ...._async_compat import AsyncTestDecorators # python -m pytest tests/unit/io/test_class_bolt.py -s -v @@ -74,3 +79,27 @@ def test_magic_preamble(): preamble = 0x6060B017 preamble_bytes = preamble.to_bytes(4, byteorder="big") assert AsyncBolt.MAGIC_PREAMBLE == preamble_bytes + + +@AsyncTestDecorators.mark_async_only_test +async def test_cancel_hello_in_open(mocker): + address = ("localhost", 7687) + socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) + + socket_cls_mock = mocker.patch("neo4j._async.io._bolt.AsyncBoltSocket", + autospec=True) + socket_cls_mock.connect.return_value = ( + socket_mock, (5, 0), None, None + ) + socket_mock.getpeername.return_value = address + bolt_cls_mock = mocker.patch("neo4j._async.io._bolt5.AsyncBolt5x0", + autospec=True) + bolt_mock = bolt_cls_mock.return_value + bolt_mock.socket = socket_mock + bolt_mock.hello.side_effect = asyncio.CancelledError() + bolt_mock.local_port = 1234 + + with pytest.raises(asyncio.CancelledError): + await AsyncBolt.open(address) + + bolt_mock.kill.assert_called_once_with() diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index 03377cf3d..dc323021b 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -16,10 +16,15 @@ # limitations under the License. +import asyncio + import pytest +from neo4j._async_compat.network import BoltSocket from neo4j._sync.io import Bolt +from ...._async_compat import TestDecorators + # python -m pytest tests/unit/io/test_class_bolt.py -s -v @@ -74,3 +79,27 @@ def test_magic_preamble(): preamble = 0x6060B017 preamble_bytes = preamble.to_bytes(4, byteorder="big") assert Bolt.MAGIC_PREAMBLE == preamble_bytes + + +@TestDecorators.mark_async_only_test +def test_cancel_hello_in_open(mocker): + address = ("localhost", 7687) + socket_mock = mocker.Mock(spec=BoltSocket) + + socket_cls_mock = mocker.patch("neo4j._sync.io._bolt.BoltSocket", + autospec=True) + socket_cls_mock.connect.return_value = ( + socket_mock, (5, 0), None, None + ) + socket_mock.getpeername.return_value = address + bolt_cls_mock = mocker.patch("neo4j._sync.io._bolt5.Bolt5x0", + autospec=True) + bolt_mock = bolt_cls_mock.return_value + bolt_mock.socket = socket_mock + bolt_mock.hello.side_effect = asyncio.CancelledError() + bolt_mock.local_port = 1234 + + with pytest.raises(asyncio.CancelledError): + Bolt.open(address) + + bolt_mock.kill.assert_called_once_with()