Skip to content

[4.4] Don't close stale connections while in use #631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
defaultdict,
deque,
)
import logging
from logging import getLogger
from random import choice
import selectors
Expand Down Expand Up @@ -652,11 +653,19 @@ def time_remaining():
# try to find a free connection in pool
for connection in list(self.connections.get(address, [])):
if (connection.closed() or connection.defunct()
or connection.stale()):
or (connection.stale() and not connection.in_use)):
# `close` is a noop on already closed connections.
# This is to make sure that the connection is gracefully
# closed, e.g. if it's just marked as `stale` but still
# alive.
if log.isEnabledFor(logging.DEBUG):
log.debug(
"[#%04X] C: <POOL> removing old connection "
"(closed=%s, defunct=%s, stale=%s, in_use=%s)",
connection.local_port,
connection.closed(), connection.defunct(),
connection.stale(), connection.in_use
)
connection.close()
try:
self.connections.get(address, []).remove(connection)
Expand Down
39 changes: 35 additions & 4 deletions tests/unit/io/test_neo4j_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_chooses_right_connection_type(opener, type_):
cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS,
30, "test_db", None)
pool.release(cx1)
if type_ == "r":
if type_ == "r":
assert cx1.addr == READER_ADDRESS
else:
assert cx1.addr == WRITER_ADDRESS
Expand All @@ -147,7 +147,7 @@ def break_connection():
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx1)
assert cx1 in pool.connections[cx1.addr]
# simulate connection going stale (e.g. exceeding) and than breaking when
# simulate connection going stale (e.g. exceeding) and then breaking when
# the pool tries to close the connection
cx1.stale.return_value = True
cx_close_mock = cx1.close
Expand All @@ -156,13 +156,44 @@ def break_connection():
cx_close_mock.side_effect = break_connection
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx2)
assert cx1.close.called_once()
if break_on_close:
cx1.close.assert_called()
else:
cx1.close.assert_called_once()
assert cx2 is not cx1
assert cx2.addr == cx1.addr
assert cx1 not in pool.connections[cx1.addr]
assert cx2 in pool.connections[cx2.addr]


def test_does_not_close_stale_connections_in_use(opener):
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
assert cx1 in pool.connections[cx1.addr]
# simulate connection going stale (e.g. exceeding) while being in use
cx1.stale.return_value = True
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx2)
cx1.close.assert_not_called()
assert cx2 is not cx1
assert cx2.addr == cx1.addr
assert cx1 in pool.connections[cx1.addr]
assert cx2 in pool.connections[cx2.addr]

pool.release(cx1)
# now that cx1 is back in the pool and still stale,
# it should be closed when trying to acquire the next connection
cx1.close.assert_not_called()

cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx3)
cx1.close.assert_called_once()
assert cx2 is cx3
assert cx3.addr == cx1.addr
assert cx1 not in pool.connections[cx1.addr]
assert cx3 in pool.connections[cx2.addr]


def test_release_resets_connections(opener):
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
Expand Down