Skip to content

Commit 6cfd0c6

Browse files
authored
Fix removing connection twice from pool. (#598)
When trying to close a stale connection the driver count realize that the connection is dead on trying to send GOODBYE. This would cause the connection to make sure that all connections to the same address would get removed from the pool as well. Since this removal only happens as a side effect of `connection.close()` and does not always happen, the driver would still try to remove the (now already removed) connection form the pool after closure. Fixes: `ValueError: deque.remove(x): x not in deque` Backport of #598
1 parent 74e5739 commit 6cfd0c6

File tree

4 files changed

+232
-12
lines changed

4 files changed

+232
-12
lines changed

neo4j/io/__init__.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
]
3535

3636
import abc
37-
from collections import deque
37+
from collections import (
38+
defaultdict,
39+
deque,
40+
)
3841
from logging import getLogger
3942
from random import choice
4043
from select import select
@@ -610,7 +613,7 @@ def __init__(self, opener, pool_config, workspace_config):
610613
self.opener = opener
611614
self.pool_config = pool_config
612615
self.workspace_config = workspace_config
613-
self.connections = {}
616+
self.connections = defaultdict(deque)
614617
self.lock = RLock()
615618
self.cond = Condition(self.lock)
616619

@@ -632,35 +635,44 @@ def _acquire(self, address, timeout):
632635
timeout = self.workspace_config.connection_acquisition_timeout
633636

634637
with self.lock:
635-
try:
636-
connections = self.connections[address]
637-
except KeyError:
638-
connections = self.connections[address] = deque()
639-
640638
def time_remaining():
641639
t = timeout - (perf_counter() - t0)
642640
return t if t > 0 else 0
643641

644642
while True:
645643
# try to find a free connection in pool
646-
for connection in list(connections):
644+
for connection in list(self.connections.get(address, [])):
647645
if (connection.closed() or connection.defunct()
648646
or connection.stale()):
649647
# `close` is a noop on already closed connections.
650648
# This is to make sure that the connection is gracefully
651649
# closed, e.g. if it's just marked as `stale` but still
652650
# alive.
653651
connection.close()
654-
connections.remove(connection)
652+
try:
653+
self.connections.get(address, []).remove(connection)
654+
except ValueError:
655+
# If closure fails (e.g. because the server went
656+
# down), all connections to the same address will
657+
# be removed. Therefore, we silently ignore if the
658+
# connection isn't in the pool anymore.
659+
pass
655660
continue
656661
if not connection.in_use:
657662
connection.in_use = True
658663
return connection
659664
# all connections in pool are in-use
660-
infinite_pool_size = (self.pool_config.max_connection_pool_size < 0 or self.pool_config.max_connection_pool_size == float("inf"))
661-
can_create_new_connection = infinite_pool_size or len(connections) < self.pool_config.max_connection_pool_size
665+
connections = self.connections[address]
666+
max_pool_size = self.pool_config.max_connection_pool_size
667+
infinite_pool_size = (max_pool_size < 0
668+
or max_pool_size == float("inf"))
669+
can_create_new_connection = (
670+
infinite_pool_size
671+
or len(connections) < max_pool_size
672+
)
662673
if can_create_new_connection:
663-
timeout = min(self.pool_config.connection_timeout, time_remaining())
674+
timeout = min(self.pool_config.connection_timeout,
675+
time_remaining())
664676
try:
665677
connection = self.opener(address, timeout)
666678
except ServiceUnavailable:

tests/unit/io/test_neo4j_pool.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
# Copyright (c) "Neo4j"
5+
# Neo4j Sweden AB [http://neo4j.com]
6+
#
7+
# This file is part of Neo4j.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
import inspect
22+
from unittest.mock import Mock
23+
24+
import pytest
25+
26+
from ..work import FakeConnection
27+
28+
from neo4j import (
29+
READ_ACCESS,
30+
WRITE_ACCESS,
31+
)
32+
from neo4j.addressing import ResolvedAddress
33+
from neo4j.conf import (
34+
PoolConfig,
35+
WorkspaceConfig
36+
)
37+
from neo4j.io import Neo4jPool
38+
39+
40+
ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
41+
READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host")
42+
WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host")
43+
44+
45+
@pytest.fixture()
46+
def opener():
47+
def open_(addr, timeout):
48+
connection = FakeConnection()
49+
connection.addr = addr
50+
connection.timeout = timeout
51+
route_mock = Mock()
52+
route_mock.return_value = [{
53+
"ttl": 1000,
54+
"servers": [
55+
{"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"},
56+
{"addresses": [str(READER_ADDRESS)], "role": "READ"},
57+
{"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"},
58+
],
59+
}]
60+
connection.attach_mock(route_mock, "route")
61+
opener_.connections.append(connection)
62+
return connection
63+
64+
opener_ = Mock()
65+
opener_.connections = []
66+
opener_.side_effect = open_
67+
return opener_
68+
69+
70+
@pytest.mark.parametrize("type_", ("r", "w"))
71+
def test_chooses_right_connection_type(opener, type_):
72+
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
73+
cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS,
74+
30, "test_db", None)
75+
pool.release(cx1)
76+
if type_ == "r":
77+
assert cx1.addr == READER_ADDRESS
78+
else:
79+
assert cx1.addr == WRITER_ADDRESS
80+
81+
82+
def test_reuses_connection(opener):
83+
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
84+
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
85+
pool.release(cx1)
86+
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
87+
assert cx1 is cx2
88+
89+
90+
@pytest.mark.parametrize("break_on_close", (True, False))
91+
def test_closes_stale_connections(opener, break_on_close):
92+
def break_connection():
93+
pool.deactivate(cx1.addr)
94+
95+
if cx_close_mock_side_effect:
96+
cx_close_mock_side_effect()
97+
98+
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
99+
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
100+
pool.release(cx1)
101+
assert cx1 in pool.connections[cx1.addr]
102+
# simulate connection going stale (e.g. exceeding) and than breaking when
103+
# the pool tries to close the connection
104+
cx1.stale.return_value = True
105+
cx_close_mock = cx1.close
106+
if break_on_close:
107+
cx_close_mock_side_effect = cx_close_mock.side_effect
108+
cx_close_mock.side_effect = break_connection
109+
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
110+
pool.release(cx2)
111+
assert cx1.close.called_once()
112+
assert cx2 is not cx1
113+
assert cx2.addr == cx1.addr
114+
assert cx1 not in pool.connections[cx1.addr]
115+
assert cx2 in pool.connections[cx2.addr]

tests/unit/work/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._fake_connection import FakeConnection

tests/unit/work/_fake_connection.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
# Copyright (c) "Neo4j"
5+
# Neo4j Sweden AB [http://neo4j.com]
6+
#
7+
# This file is part of Neo4j.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
22+
import inspect
23+
from unittest import mock
24+
25+
import pytest
26+
27+
from neo4j import ServerInfo
28+
29+
30+
class FakeConnection(mock.NonCallableMagicMock):
31+
callbacks = []
32+
server_info = ServerInfo("127.0.0.1", (4, 3))
33+
34+
def __init__(self, *args, **kwargs):
35+
super().__init__(*args, **kwargs)
36+
self.attach_mock(mock.PropertyMock(return_value=True), "is_reset")
37+
self.attach_mock(mock.Mock(return_value=False), "defunct")
38+
self.attach_mock(mock.Mock(return_value=False), "stale")
39+
self.attach_mock(mock.Mock(return_value=False), "closed")
40+
41+
def close_side_effect():
42+
self.closed.return_value = True
43+
44+
self.attach_mock(mock.Mock(side_effect=close_side_effect), "close")
45+
46+
def fetch_message(self, *args, **kwargs):
47+
if self.callbacks:
48+
cb = self.callbacks.pop(0)
49+
cb()
50+
return super().__getattr__("fetch_message")(*args, **kwargs)
51+
52+
def fetch_all(self, *args, **kwargs):
53+
while self.callbacks:
54+
cb = self.callbacks.pop(0)
55+
cb()
56+
return super().__getattr__("fetch_all")(*args, **kwargs)
57+
58+
def __getattr__(self, name):
59+
parent = super()
60+
61+
def build_message_handler(name):
62+
def func(*args, **kwargs):
63+
def callback():
64+
for cb_name, param_count in (
65+
("on_success", 1),
66+
("on_summary", 0)
67+
):
68+
cb = kwargs.get(cb_name, None)
69+
if callable(cb):
70+
try:
71+
param_count = \
72+
len(inspect.signature(cb).parameters)
73+
except ValueError:
74+
# e.g. built-in method as cb
75+
pass
76+
if param_count == 1:
77+
cb({})
78+
else:
79+
cb()
80+
self.callbacks.append(callback)
81+
return parent.__getattr__(name)(*args, **kwargs)
82+
83+
return func
84+
85+
if name in ("run", "commit", "pull", "rollback", "discard"):
86+
return build_message_handler(name)
87+
return parent.__getattr__(name)
88+
89+
90+
@pytest.fixture
91+
def fake_connection():
92+
return FakeConnection()

0 commit comments

Comments
 (0)