Skip to content

[4.3] Fix removing connection twice from pool. #598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
]

import abc
from collections import deque
from collections import (
defaultdict,
deque,
)
from logging import getLogger
from random import choice
from select import select
Expand Down Expand Up @@ -610,7 +613,7 @@ def __init__(self, opener, pool_config, workspace_config):
self.opener = opener
self.pool_config = pool_config
self.workspace_config = workspace_config
self.connections = {}
self.connections = defaultdict(deque)
self.lock = RLock()
self.cond = Condition(self.lock)

Expand All @@ -632,35 +635,44 @@ def _acquire(self, address, timeout):
timeout = self.workspace_config.connection_acquisition_timeout

with self.lock:
try:
connections = self.connections[address]
except KeyError:
connections = self.connections[address] = deque()

def time_remaining():
t = timeout - (perf_counter() - t0)
return t if t > 0 else 0

while True:
# try to find a free connection in pool
for connection in list(connections):
for connection in list(self.connections.get(address, [])):
if (connection.closed() or connection.defunct()
or connection.stale()):
# `close` is a noop on already closed connections.
# This is to make sure that the connection is gracefully
# closed, e.g. if it's just marked as `stale` but still
# alive.
connection.close()
connections.remove(connection)
try:
self.connections.get(address, []).remove(connection)
except ValueError:
# If closure fails (e.g. because the server went
# down), all connections to the same address will
# be removed. Therefore, we silently ignore if the
# connection isn't in the pool anymore.
pass
continue
if not connection.in_use:
connection.in_use = True
return connection
# all connections in pool are in-use
infinite_pool_size = (self.pool_config.max_connection_pool_size < 0 or self.pool_config.max_connection_pool_size == float("inf"))
can_create_new_connection = infinite_pool_size or len(connections) < self.pool_config.max_connection_pool_size
connections = self.connections[address]
max_pool_size = self.pool_config.max_connection_pool_size
infinite_pool_size = (max_pool_size < 0
or max_pool_size == float("inf"))
can_create_new_connection = (
infinite_pool_size
or len(connections) < max_pool_size
)
if can_create_new_connection:
timeout = min(self.pool_config.connection_timeout, time_remaining())
timeout = min(self.pool_config.connection_timeout,
time_remaining())
try:
connection = self.opener(address, timeout)
except ServiceUnavailable:
Expand Down
115 changes: 115 additions & 0 deletions tests/unit/io/test_neo4j_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from unittest.mock import Mock

import pytest

from ..work import FakeConnection

from neo4j import (
READ_ACCESS,
WRITE_ACCESS,
)
from neo4j.addressing import ResolvedAddress
from neo4j.conf import (
PoolConfig,
WorkspaceConfig
)
from neo4j.io import Neo4jPool


ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host")
WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host")


@pytest.fixture()
def opener():
def open_(addr, timeout):
connection = FakeConnection()
connection.addr = addr
connection.timeout = timeout
route_mock = Mock()
route_mock.return_value = [{
"ttl": 1000,
"servers": [
{"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"},
{"addresses": [str(READER_ADDRESS)], "role": "READ"},
{"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"},
],
}]
connection.attach_mock(route_mock, "route")
opener_.connections.append(connection)
return connection

opener_ = Mock()
opener_.connections = []
opener_.side_effect = open_
return opener_


@pytest.mark.parametrize("type_", ("r", "w"))
def test_chooses_right_connection_type(opener, type_):
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS,
30, "test_db", None)
pool.release(cx1)
if type_ == "r":
assert cx1.addr == READER_ADDRESS
else:
assert cx1.addr == WRITER_ADDRESS


def test_reuses_connection(opener):
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx1)
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
assert cx1 is cx2


@pytest.mark.parametrize("break_on_close", (True, False))
def test_closes_stale_connections(opener, break_on_close):
def break_connection():
pool.deactivate(cx1.addr)

if cx_close_mock_side_effect:
cx_close_mock_side_effect()

pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx1)
assert cx1 in pool.connections[cx1.addr]
# simulate connection going stale (e.g. exceeding) and than breaking when
# the pool tries to close the connection
cx1.stale.return_value = True
cx_close_mock = cx1.close
if break_on_close:
cx_close_mock_side_effect = cx_close_mock.side_effect
cx_close_mock.side_effect = break_connection
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx2)
assert cx1.close.called_once()
assert cx2 is not cx1
assert cx2.addr == cx1.addr
assert cx1 not in pool.connections[cx1.addr]
assert cx2 in pool.connections[cx2.addr]
1 change: 1 addition & 0 deletions tests/unit/work/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._fake_connection import FakeConnection
92 changes: 92 additions & 0 deletions tests/unit/work/_fake_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import inspect
from unittest import mock

import pytest

from neo4j import ServerInfo


class FakeConnection(mock.NonCallableMagicMock):
callbacks = []
server_info = ServerInfo("127.0.0.1", (4, 3))

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attach_mock(mock.PropertyMock(return_value=True), "is_reset")
self.attach_mock(mock.Mock(return_value=False), "defunct")
self.attach_mock(mock.Mock(return_value=False), "stale")
self.attach_mock(mock.Mock(return_value=False), "closed")

def close_side_effect():
self.closed.return_value = True

self.attach_mock(mock.Mock(side_effect=close_side_effect), "close")

def fetch_message(self, *args, **kwargs):
if self.callbacks:
cb = self.callbacks.pop(0)
cb()
return super().__getattr__("fetch_message")(*args, **kwargs)

def fetch_all(self, *args, **kwargs):
while self.callbacks:
cb = self.callbacks.pop(0)
cb()
return super().__getattr__("fetch_all")(*args, **kwargs)

def __getattr__(self, name):
parent = super()

def build_message_handler(name):
def func(*args, **kwargs):
def callback():
for cb_name, param_count in (
("on_success", 1),
("on_summary", 0)
):
cb = kwargs.get(cb_name, None)
if callable(cb):
try:
param_count = \
len(inspect.signature(cb).parameters)
except ValueError:
# e.g. built-in method as cb
pass
if param_count == 1:
cb({})
else:
cb()
self.callbacks.append(callback)
return parent.__getattr__(name)(*args, **kwargs)

return func

if name in ("run", "commit", "pull", "rollback", "discard"):
return build_message_handler(name)
return parent.__getattr__(name)


@pytest.fixture
def fake_connection():
return FakeConnection()