diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index 28e4adcac..9c5bcb0fa 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -32,22 +32,27 @@ from select import select from socket import socket, SOL_SOCKET, SO_KEEPALIVE, SHUT_RDWR, error as SocketError, timeout as SocketTimeout, AF_INET, AF_INET6 from struct import pack as struct_pack, unpack as struct_unpack -from threading import RLock +from threading import RLock, Condition from neo4j.addressing import SocketAddress, is_ip_address from neo4j.bolt.cert import KNOWN_HOSTS from neo4j.bolt.response import InitResponse, AckFailureResponse, ResetResponse from neo4j.compat.ssl import SSL_AVAILABLE, HAS_SNI, SSLError -from neo4j.exceptions import ProtocolError, SecurityError, ServiceUnavailable +from neo4j.exceptions import ClientError, ProtocolError, SecurityError, ServiceUnavailable from neo4j.meta import version from neo4j.packstream import Packer, Unpacker from neo4j.util import import_best as _import_best +from time import clock ChunkedInputBuffer = _import_best("neo4j.bolt._io", "neo4j.bolt.io").ChunkedInputBuffer ChunkedOutputBuffer = _import_best("neo4j.bolt._io", "neo4j.bolt.io").ChunkedOutputBuffer +INFINITE = -1 +DEFAULT_MAX_CONNECTION_LIFETIME = INFINITE +DEFAULT_MAX_CONNECTION_POOL_SIZE = INFINITE DEFAULT_CONNECTION_TIMEOUT = 5.0 +DEFAULT_CONNECTION_ACQUISITION_TIMEOUT = 60 DEFAULT_PORT = 7687 DEFAULT_USER_AGENT = "neo4j-python/%s" % version @@ -178,6 +183,8 @@ def __init__(self, address, sock, error_handler, **config): self.packer = Packer(self.output_buffer) self.unpacker = Unpacker() self.responses = deque() + self._max_connection_lifetime = config.get("max_connection_lifetime", DEFAULT_MAX_CONNECTION_LIFETIME) + self._creation_timestamp = clock() # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", DEFAULT_USER_AGENT) @@ -201,6 +208,7 @@ def __init__(self, address, sock, error_handler, **config): # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get("der_encoded_server_certificate") + def Init(self): response = InitResponse(self) self.append(INIT, (self.user_agent, self.auth_dict), response=response) self.sync() @@ -360,6 +368,9 @@ def _unpack(self): more = False return details, summary_signature, summary_metadata + def timedout(self): + return 0 <= self._max_connection_lifetime <= clock() - self._creation_timestamp + def sync(self): """ Send and fetch all outstanding messages. @@ -396,11 +407,14 @@ class ConnectionPool(object): _closed = False - def __init__(self, connector, connection_error_handler): + def __init__(self, connector, connection_error_handler, **config): self.connector = connector self.connection_error_handler = connection_error_handler self.connections = {} self.lock = RLock() + self.cond = Condition(self.lock) + self._max_connection_pool_size = config.get("max_connection_pool_size", DEFAULT_MAX_CONNECTION_POOL_SIZE) + self._connection_acquisition_timeout = config.get("connection_acquisition_timeout", DEFAULT_CONNECTION_ACQUISITION_TIMEOUT) def __enter__(self): return self @@ -424,23 +438,42 @@ def acquire_direct(self, address): connections = self.connections[address] except KeyError: connections = self.connections[address] = deque() - for connection in list(connections): - if connection.closed() or connection.defunct(): - connections.remove(connection) - continue - if not connection.in_use: - connection.in_use = True - return connection - try: - connection = self.connector(address, self.connection_error_handler) - except ServiceUnavailable: - self.remove(address) - raise - else: - connection.pool = self - connection.in_use = True - connections.append(connection) - return connection + + connection_acquisition_start_timestamp = clock() + while True: + # try to find a free connection in pool + for connection in list(connections): + if connection.closed() or connection.defunct() or connection.timedout(): + connections.remove(connection) + continue + if not connection.in_use: + connection.in_use = True + return connection + # all connections in pool are in-use + can_create_new_connection = self._max_connection_pool_size == INFINITE or len(connections) < self._max_connection_pool_size + if can_create_new_connection: + try: + connection = self.connector(address, self.connection_error_handler) + except ServiceUnavailable: + self.remove(address) + raise + else: + connection.pool = self + connection.in_use = True + connections.append(connection) + return connection + + # failed to obtain a connection from pool because the pool is full and no free connection in the pool + span_timeout = self._connection_acquisition_timeout - (clock() - connection_acquisition_start_timestamp) + if span_timeout > 0: + self.cond.wait(span_timeout) + # if timed out, then we throw error. This time computation is needed, as with python 2.7, we cannot + # tell if the condition is notified or timed out when we come to this line + if self._connection_acquisition_timeout <= (clock() - connection_acquisition_start_timestamp): + raise ClientError("Failed to obtain a connection from pool within {!r}s".format( + self._connection_acquisition_timeout)) + else: + raise ClientError("Failed to obtain a connection from pool within {!r}s".format(self._connection_acquisition_timeout)) def acquire(self, access_mode=None): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -454,6 +487,7 @@ def release(self, connection): """ with self.lock: connection.in_use = False + self.cond.notify_all() def in_use_connection_count(self, address): """ Count the number of connections currently in use to a given @@ -600,8 +634,10 @@ def connect(address, ssl_context=None, error_handler=None, **config): s.shutdown(SHUT_RDWR) s.close() elif agreed_version == 1: - return Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate, + connection = Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate, error_handler=error_handler, **config) + connection.Init() + return connection elif agreed_version == 0x48545450: log_error("S: [CLOSE]") s.close() diff --git a/neo4j/v1/__init__.py b/neo4j/v1/__init__.py index fa13808af..a2eacc335 100644 --- a/neo4j/v1/__init__.py +++ b/neo4j/v1/__init__.py @@ -18,9 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from neo4j.exceptions import * - from .api import * from .direct import * from .exceptions import * diff --git a/neo4j/v1/api.py b/neo4j/v1/api.py index 4e063da36..e6d2358f6 100644 --- a/neo4j/v1/api.py +++ b/neo4j/v1/api.py @@ -25,8 +25,8 @@ from time import time, sleep from warnings import warn -from neo4j.bolt import ProtocolError, ServiceUnavailable -from neo4j.compat import unicode, urlparse +from neo4j.exceptions import ProtocolError, ServiceUnavailable +from neo4j.compat import urlparse from neo4j.exceptions import CypherError, TransientError from .exceptions import DriverError, SessionError, SessionExpired, TransactionError diff --git a/neo4j/v1/direct.py b/neo4j/v1/direct.py index c63419ee7..2a8f6e6f4 100644 --- a/neo4j/v1/direct.py +++ b/neo4j/v1/direct.py @@ -20,7 +20,7 @@ from neo4j.addressing import SocketAddress, resolve -from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect, ConnectionErrorHandler +from neo4j.bolt.connection import DEFAULT_PORT, ConnectionPool, connect, ConnectionErrorHandler from neo4j.exceptions import ServiceUnavailable from neo4j.v1.api import Driver from neo4j.v1.security import SecurityPlan @@ -37,8 +37,8 @@ def __init__(self): class DirectConnectionPool(ConnectionPool): - def __init__(self, connector, address): - super(DirectConnectionPool, self).__init__(connector, DirectConnectionErrorHandler()) + def __init__(self, connector, address, **config): + super(DirectConnectionPool, self).__init__(connector, DirectConnectionErrorHandler(), **config) self.address = address def acquire(self, access_mode=None): @@ -73,7 +73,7 @@ def __init__(self, uri, **config): def connector(address, error_handler): return connect(address, security_plan.ssl_context, error_handler, **config) - pool = DirectConnectionPool(connector, self.address) + pool = DirectConnectionPool(connector, self.address, **config) pool.release(pool.acquire()) Driver.__init__(self, pool, **config) diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 4bfe688c0..7eb845d6a 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -34,7 +34,7 @@ LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0 LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1 -LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED +DEFAULT_LOAD_BALANCING_STRATEGY = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED class OrderedSet(MutableSet): @@ -166,7 +166,7 @@ class LoadBalancingStrategy(object): @classmethod def build(cls, connection_pool, **config): - load_balancing_strategy = config.get("load_balancing_strategy", LOAD_BALANCING_STRATEGY_DEFAULT) + load_balancing_strategy = config.get("load_balancing_strategy", DEFAULT_LOAD_BALANCING_STRATEGY) if load_balancing_strategy == LOAD_BALANCING_STRATEGY_LEAST_CONNECTED: return LeastConnectedLoadBalancingStrategy(connection_pool) elif load_balancing_strategy == LOAD_BALANCING_STRATEGY_ROUND_ROBIN: @@ -265,7 +265,7 @@ class RoutingConnectionPool(ConnectionPool): """ def __init__(self, connector, initial_address, routing_context, *routers, **config): - super(RoutingConnectionPool, self).__init__(connector, RoutingConnectionErrorHandler(self)) + super(RoutingConnectionPool, self).__init__(connector, RoutingConnectionErrorHandler(self), **config) self.initial_address = initial_address self.routing_context = routing_context self.routing_table = RoutingTable(routers) diff --git a/test/integration/test_driver.py b/test/integration/test_driver.py index f4466d127..4afc1d916 100644 --- a/test/integration/test_driver.py +++ b/test/integration/test_driver.py @@ -19,8 +19,8 @@ # limitations under the License. -from neo4j.v1 import GraphDatabase, ProtocolError, ServiceUnavailable - +from neo4j.v1 import GraphDatabase, ServiceUnavailable +from neo4j.exceptions import ProtocolError from test.integration.tools import IntegrationTestCase diff --git a/test/integration/test_security.py b/test/integration/test_security.py index 48ed29da1..9751a1b67 100644 --- a/test/integration/test_security.py +++ b/test/integration/test_security.py @@ -23,7 +23,8 @@ from ssl import SSLSocket from unittest import skipUnless -from neo4j.v1 import GraphDatabase, SSL_AVAILABLE, TRUST_ON_FIRST_USE, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES, AuthError +from neo4j.v1 import GraphDatabase, SSL_AVAILABLE, TRUST_ON_FIRST_USE, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES +from neo4j.exceptions import AuthError from test.integration.tools import IntegrationTestCase diff --git a/test/integration/test_session.py b/test/integration/test_session.py index e991c2e5d..6291f611b 100644 --- a/test/integration/test_session.py +++ b/test/integration/test_session.py @@ -24,7 +24,8 @@ from neo4j.v1 import \ READ_ACCESS, WRITE_ACCESS, \ CypherError, SessionError, TransactionError, \ - Node, Relationship, Path, CypherSyntaxError + Node, Relationship, Path +from neo4j.exceptions import CypherSyntaxError from test.integration.tools import DirectIntegrationTestCase diff --git a/test/integration/tools.py b/test/integration/tools.py index cd60a1cb7..604bafd83 100644 --- a/test/integration/tools.py +++ b/test/integration/tools.py @@ -32,7 +32,8 @@ from boltkit.controller import WindowsController, UnixController -from neo4j.v1 import GraphDatabase, AuthError +from neo4j.v1 import GraphDatabase +from neo4j.exceptions import AuthError from neo4j.util import ServerVersion from test.env import NEO4J_SERVER_PACKAGE, NEO4J_USER, NEO4J_PASSWORD diff --git a/test/performance/tools.py b/test/performance/tools.py index b216ff232..fbeca9ce7 100644 --- a/test/performance/tools.py +++ b/test/performance/tools.py @@ -34,7 +34,8 @@ from boltkit.controller import WindowsController, UnixController -from neo4j.v1 import GraphDatabase, AuthError +from neo4j.v1 import GraphDatabase +from neo4j.exceptions import AuthError from neo4j.util import ServerVersion from test.env import NEO4J_SERVER_PACKAGE, NEO4J_USER, NEO4J_PASSWORD diff --git a/test/stub/test_routingdriver.py b/test/stub/test_routingdriver.py index ae12fa8ab..ab2163ccc 100644 --- a/test/stub/test_routingdriver.py +++ b/test/stub/test_routingdriver.py @@ -21,7 +21,8 @@ from neo4j.v1 import GraphDatabase, READ_ACCESS, WRITE_ACCESS, SessionExpired, \ RoutingDriver, RoutingConnectionPool, LeastConnectedLoadBalancingStrategy, LOAD_BALANCING_STRATEGY_ROUND_ROBIN, \ - RoundRobinLoadBalancingStrategy, TransientError, ClientError + RoundRobinLoadBalancingStrategy, TransientError +from neo4j.exceptions import ClientError from neo4j.bolt import ProtocolError, ServiceUnavailable from test.stub.tools import StubTestCase, StubCluster diff --git a/test/integration/test_connection.py b/test/unit/test_connection.py similarity index 51% rename from test/integration/test_connection.py rename to test/unit/test_connection.py index 703f97df0..f0226e8e1 100644 --- a/test/integration/test_connection.py +++ b/test/unit/test_connection.py @@ -17,13 +17,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function +from unittest import TestCase +from threading import Thread, Event +from neo4j.v1 import DirectConnectionErrorHandler, ServiceUnavailable +from neo4j.bolt import Connection, ConnectionPool +from neo4j.exceptions import ClientError +class FakeSocket(object): + def __init__(self, address): + self.address = address -from socket import create_connection + def getpeername(self): + return self.address -from neo4j.v1 import ConnectionPool, ServiceUnavailable, DirectConnectionErrorHandler + def sendall(self, data): + return -from test.integration.tools import IntegrationTestCase + def close(self): + return class QuickConnection(object): @@ -44,22 +56,46 @@ def closed(self): def defunct(self): return False + def timedout(self): + return False + def connector(address, _): - return QuickConnection(create_connection(address)) + return QuickConnection(FakeSocket(address)) -class ConnectionPoolTestCase(IntegrationTestCase): +class ConnectionTestCase(TestCase): + def test_conn_timedout(self): + address = ("127.0.0.1", 7687) + connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(), max_connection_lifetime=0) + self.assertEqual(connection.timedout(), True) + + def test_conn_not_timedout_if_not_enabled(self): + address = ("127.0.0.1", 7687) + connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(), + max_connection_lifetime=-1) + self.assertEqual(connection.timedout(), False) + + def test_conn_not_timedout(self): + address = ("127.0.0.1", 7687) + connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(), + max_connection_lifetime=999999999) + self.assertEqual(connection.timedout(), False) + + +class ConnectionPoolTestCase(TestCase): def setUp(self): self.pool = ConnectionPool(connector, DirectConnectionErrorHandler()) def tearDown(self): self.pool.close() - def assert_pool_size(self, address, expected_active, expected_inactive): + def assert_pool_size(self, address, expected_active, expected_inactive, pool=None): + if pool is None: + pool = self.pool try: - connections = self.pool.connections[address] + connections = pool.connections[address] except KeyError: assert 0 == expected_active assert 0 == expected_inactive @@ -108,7 +144,7 @@ def test_releasing_twice(self): self.assert_pool_size(address, 0, 1) def test_cannot_acquire_after_close(self): - with ConnectionPool(lambda a: QuickConnection(create_connection(a)), DirectConnectionErrorHandler()) as pool: + with ConnectionPool(lambda a: QuickConnection(FakeSocket(a)), DirectConnectionErrorHandler()) as pool: pool.close() with self.assertRaises(ServiceUnavailable): _ = pool.acquire_direct("X") @@ -120,3 +156,43 @@ def test_in_use_count(self): self.assertEqual(self.pool.in_use_connection_count(address), 1) self.pool.release(connection) self.assertEqual(self.pool.in_use_connection_count(address), 0) + + def test_max_conn_pool_size(self): + with ConnectionPool(connector, DirectConnectionErrorHandler, + max_connection_pool_size=1, connection_acquisition_timeout=0) as pool: + address = ("127.0.0.1", 7687) + pool.acquire_direct(address) + self.assertEqual(pool.in_use_connection_count(address), 1) + with self.assertRaises(ClientError): + pool.acquire_direct(address) + self.assertEqual(pool.in_use_connection_count(address), 1) + + def test_multithread(self): + with ConnectionPool(connector, DirectConnectionErrorHandler, + max_connection_pool_size=5, connection_acquisition_timeout=10) as pool: + address = ("127.0.0.1", 7687) + releasing_event = Event() + + # We start 10 threads to compete connections from pool with size of 5 + threads = [] + for i in range(10): + t = Thread(target=acquire_release_conn, args=(pool, address, releasing_event)) + t.start() + threads.append(t) + + # The pool size should be 5, all are in-use + self.assert_pool_size(address, 5, 0, pool) + # Now we allow thread to release connections they obtained from pool + releasing_event.set() + + # wait for all threads to release connections back to pool + for t in threads: + t.join() + # The pool size is still 5, but all are free + self.assert_pool_size(address, 0, 5, pool) + + +def acquire_release_conn(pool, address, releasing_event): + conn = pool.acquire_direct(address) + releasing_event.wait() + pool.release(conn) \ No newline at end of file