diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 40644e72a..09882f9b4 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -726,9 +726,9 @@ def release(self, *connections): """ with self.lock: for connection in connections: - if not (connection.is_reset - or connection.defunct() - or connection.closed()): + if not (connection.defunct() + or connection.closed() + or connection.is_reset): try: connection.reset() except (Neo4jError, DriverError, BoltError) as e: diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 4ed722d86..f0e49dce5 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -124,11 +124,12 @@ def _on_server_state_change(self, old_state, new_state): @property def is_reset(self): - if self.responses: - # We can't be sure of the server's state as there are still pending - # responses. Unless the last message we sent was RESET. In that case - # the server state will always be READY when we're done. - return self.responses[-1].message == "reset" + # We can't be sure of the server's state if there are still pending + # responses. Unless the last message we sent was RESET. In that case + # the server state will always be READY when we're done. + if (self.responses and self.responses[-1] + and self.responses[-1].message == "reset"): + return True return self._server_state_manager.state == ServerStates.READY @property diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 086f7baf2..4a9c47ab5 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -82,11 +82,12 @@ def _on_server_state_change(self, old_state, new_state): @property def is_reset(self): - if self.responses: - # We can't be sure of the server's state as there are still pending - # responses. Unless the last message we sent was RESET. In that case - # the server state will always be READY when we're done. - return self.responses[-1].message == "reset" + # We can't be sure of the server's state if there are still pending + # responses. Unless the last message we sent was RESET. In that case + # the server state will always be READY when we're done. + if (self.responses and self.responses[-1] + and self.responses[-1].message == "reset"): + return True return self._server_state_manager.state == ServerStates.READY @property diff --git a/neo4j/work/simple.py b/neo4j/work/simple.py index 190b39630..e8c4ec8a9 100644 --- a/neo4j/work/simple.py +++ b/neo4j/work/simple.py @@ -91,7 +91,7 @@ def __init__(self, pool, session_config): def __del__(self): try: self.close() - except OSError: + except (OSError, ServiceUnavailable, SessionExpired): pass def __enter__(self): diff --git a/neo4j/work/transaction.py b/neo4j/work/transaction.py index 746767480..8b7739877 100644 --- a/neo4j/work/transaction.py +++ b/neo4j/work/transaction.py @@ -172,7 +172,9 @@ def rollback(self): metadata = {} try: - if not self._connection.is_reset: + if not (self._connection.defunct() + or self._connection.closed() + or self._connection.is_reset): # DISCARD pending records then do a rollback. self._consume_results() self._connection.rollback(on_success=metadata.update) diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/io/test_neo4j_pool.py index 458118057..5dbab9b4e 100644 --- a/tests/unit/io/test_neo4j_pool.py +++ b/tests/unit/io/test_neo4j_pool.py @@ -161,3 +161,37 @@ def break_connection(): assert cx2.addr == cx1.addr assert cx1 not in pool.connections[cx1.addr] assert cx2 in pool.connections[cx2.addr] + + +def test_release_resets_connections(opener): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1.is_reset_mock.return_value = False + cx1.is_reset_mock.reset_mock() + pool.release(cx1) + cx1.is_reset_mock.assert_called_once() + cx1.reset.assert_called_once() + + +def test_release_does_not_resets_closed_connections(opener): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1.closed.return_value = True + cx1.closed.reset_mock() + cx1.is_reset_mock.reset_mock() + pool.release(cx1) + cx1.closed.assert_called_once() + cx1.is_reset_mock.asset_not_called() + cx1.reset.asset_not_called() + + +def test_release_does_not_resets_defunct_connections(opener): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1.defunct.return_value = True + cx1.defunct.reset_mock() + cx1.is_reset_mock.reset_mock() + pool.release(cx1) + cx1.defunct.assert_called_once() + cx1.is_reset_mock.asset_not_called() + cx1.reset.asset_not_called() diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/work/_fake_connection.py index 25b272fea..fef0b580a 100644 --- a/tests/unit/work/_fake_connection.py +++ b/tests/unit/work/_fake_connection.py @@ -33,7 +33,7 @@ class FakeConnection(mock.NonCallableMagicMock): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.attach_mock(mock.PropertyMock(return_value=True), "is_reset") + self.attach_mock(mock.Mock(return_value=True), "is_reset_mock") self.attach_mock(mock.Mock(return_value=False), "defunct") self.attach_mock(mock.Mock(return_value=False), "stale") self.attach_mock(mock.Mock(return_value=False), "closed") @@ -43,6 +43,13 @@ def close_side_effect(): self.attach_mock(mock.Mock(side_effect=close_side_effect), "close") + @property + def is_reset(self): + if self.closed.return_value or self.defunct.return_value: + raise AssertionError("is_reset should not be called on a closed or " + "defunct connection.") + return self.is_reset_mock() + def fetch_message(self, *args, **kwargs): if self.callbacks: cb = self.callbacks.pop(0) @@ -78,13 +85,13 @@ def callback(): else: cb() self.callbacks.append(callback) - return parent.__getattr__(name)(*args, **kwargs) return func + method_mock = parent.__getattr__(name) if name in ("run", "commit", "pull", "rollback", "discard"): - return build_message_handler(name) - return parent.__getattr__(name) + method_mock.side_effect = build_message_handler(name) + return method_mock @pytest.fixture diff --git a/tests/unit/work/test_transaction.py b/tests/unit/work/test_transaction.py index dd8dbac4c..06e755662 100644 --- a/tests/unit/work/test_transaction.py +++ b/tests/unit/work/test_transaction.py @@ -19,7 +19,10 @@ # limitations under the License. from uuid import uuid4 -from unittest.mock import MagicMock +from unittest.mock import ( + MagicMock, + NonCallableMagicMock, +) import pytest @@ -129,3 +132,59 @@ def test_transaction_run_takes_no_query_object(fake_connection): tx = Transaction(fake_connection, 2, on_closed, on_error) with pytest.raises(ValueError): tx.run(Query("RETURN 1")) + + +def test_transaction_rollbacks_on_open_connections(fake_connection): + tx = Transaction(fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + fake_connection.is_reset_mock.return_value = False + fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + fake_connection.is_reset_mock.assert_called_once() + fake_connection.reset.assert_not_called() + fake_connection.rollback.assert_called_once() + + +def test_transaction_no_rollback_on_reset_connections(fake_connection): + tx = Transaction(fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + fake_connection.is_reset_mock.return_value = True + fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + fake_connection.is_reset_mock.assert_called_once() + fake_connection.reset.asset_not_called() + fake_connection.rollback.asset_not_called() + + +def test_transaction_no_rollback_on_closed_connections(fake_connection): + tx = Transaction(fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + fake_connection.closed.return_value = True + fake_connection.closed.reset_mock() + fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + fake_connection.closed.assert_called_once() + fake_connection.is_reset_mock.asset_not_called() + fake_connection.reset.asset_not_called() + fake_connection.rollback.asset_not_called() + + +def test_transaction_no_rollback_on_defunct_connections(fake_connection): + tx = Transaction(fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + fake_connection.defunct.return_value = True + fake_connection.defunct.reset_mock() + fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + fake_connection.defunct.assert_called_once() + fake_connection.is_reset_mock.asset_not_called() + fake_connection.reset.asset_not_called() + fake_connection.rollback.asset_not_called()