diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index d35a603ec..28e4adcac 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -119,6 +119,26 @@ def supports_bytes(self): return self.version_info() >= (3, 2) +class ConnectionErrorHandler(object): + """ A handler for send and receive errors. + """ + + def __init__(self, handlers_by_error_class=None): + if handlers_by_error_class is None: + handlers_by_error_class = {} + + self.handlers_by_error_class = handlers_by_error_class + self.known_errors = tuple(handlers_by_error_class.keys()) + + def handle(self, error, address): + try: + error_class = error.__class__ + handler = self.handlers_by_error_class[error_class] + handler(address) + except KeyError: + pass + + class Connection(object): """ Server connection for Bolt protocol v1. @@ -148,8 +168,10 @@ class Connection(object): _last_run_statement = None - def __init__(self, sock, **config): + def __init__(self, address, sock, error_handler, **config): + self.address = address self.socket = sock + self.error_handler = error_handler self.server = ServerInfo(SocketAddress.from_socket(sock)) self.input_buffer = ChunkedInputBuffer() self.output_buffer = ChunkedOutputBuffer() @@ -237,6 +259,13 @@ def reset(self): self.sync() def send(self): + try: + self._send() + except self.error_handler.known_errors as error: + self.error_handler.handle(error, self.address) + raise error + + def _send(self): """ Send all queued messages to the server. """ data = self.output_buffer.view() @@ -250,6 +279,13 @@ def send(self): self.output_buffer.clear() def fetch(self): + try: + return self._fetch() + except self.error_handler.known_errors as error: + self.error_handler.handle(error, self.address) + raise error + + def _fetch(self): """ Receive at least one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched @@ -360,8 +396,9 @@ class ConnectionPool(object): _closed = False - def __init__(self, connector): + def __init__(self, connector, connection_error_handler): self.connector = connector + self.connection_error_handler = connection_error_handler self.connections = {} self.lock = RLock() @@ -395,7 +432,7 @@ def acquire_direct(self, address): connection.in_use = True return connection try: - connection = self.connector(address) + connection = self.connector(address, self.connection_error_handler) except ServiceUnavailable: self.remove(address) raise @@ -457,7 +494,7 @@ def closed(self): return self._closed -def connect(address, ssl_context=None, **config): +def connect(address, ssl_context=None, error_handler=None, **config): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ @@ -563,7 +600,8 @@ def connect(address, ssl_context=None, **config): s.shutdown(SHUT_RDWR) s.close() elif agreed_version == 1: - return Connection(s, der_encoded_server_certificate=der_encoded_server_certificate, **config) + return Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate, + error_handler=error_handler, **config) elif agreed_version == 0x48545450: log_error("S: [CLOSE]") s.close() diff --git a/neo4j/exceptions.py b/neo4j/exceptions.py index 1afae6967..069b01761 100644 --- a/neo4j/exceptions.py +++ b/neo4j/exceptions.py @@ -65,17 +65,9 @@ def hydrate(cls, message=None, code=None, **metadata): classification = "DatabaseError" category = "General" title = "UnknownError" - if classification == "ClientError": - try: - error_class = client_errors[code] - except KeyError: - error_class = ClientError - elif classification == "DatabaseError": - error_class = DatabaseError - elif classification == "TransientError": - error_class = TransientError - else: - error_class = cls + + error_class = cls._extract_error_class(classification, code) + inst = error_class(message) inst.message = message inst.code = code @@ -85,6 +77,26 @@ def hydrate(cls, message=None, code=None, **metadata): inst.metadata = metadata return inst + @classmethod + def _extract_error_class(cls, classification, code): + if classification == "ClientError": + try: + return client_errors[code] + except KeyError: + return ClientError + + elif classification == "TransientError": + try: + return transient_errors[code] + except KeyError: + return TransientError + + elif classification == "DatabaseError": + return DatabaseError + + else: + return cls + class ClientError(CypherError): """ The Client sent a bad request - changing the request might yield a successful outcome. @@ -101,6 +113,11 @@ class TransientError(CypherError): """ +class DatabaseUnavailableError(TransientError): + """ + """ + + class ConstraintError(ClientError): """ """ @@ -116,11 +133,21 @@ class CypherTypeError(ClientError): """ +class NotALeaderError(ClientError): + """ + """ + + class Forbidden(ClientError, SecurityError): """ """ +class ForbiddenOnReadOnlyDatabaseError(Forbidden): + """ + """ + + class AuthError(ClientError, SecurityError): """ Raised when authentication failure occurs. """ @@ -144,7 +171,7 @@ class AuthError(ClientError, SecurityError): "Neo.ClientError.Statement.TypeError": CypherTypeError, # Forbidden - "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase": Forbidden, + "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase": ForbiddenOnReadOnlyDatabaseError, "Neo.ClientError.General.ReadOnly": Forbidden, "Neo.ClientError.Schema.ForbiddenOnConstraintIndex": Forbidden, "Neo.ClientError.Schema.IndexBelongsToConstraint": Forbidden, @@ -155,4 +182,12 @@ class AuthError(ClientError, SecurityError): "Neo.ClientError.Security.AuthorizationFailed": AuthError, "Neo.ClientError.Security.Unauthorized": AuthError, + # NotALeaderError + "Neo.ClientError.Cluster.NotALeader": NotALeaderError +} + +transient_errors = { + + # DatabaseUnavailableError + "Neo.TransientError.General.DatabaseUnavailable": DatabaseUnavailableError } diff --git a/neo4j/v1/api.py b/neo4j/v1/api.py index 65b28a809..4e063da36 100644 --- a/neo4j/v1/api.py +++ b/neo4j/v1/api.py @@ -277,7 +277,7 @@ def _disconnect(self, sync): if sync: try: self._connection.sync() - except ServiceUnavailable: + except (SessionError, ServiceUnavailable): pass if self._connection: self._connection.in_use = False diff --git a/neo4j/v1/direct.py b/neo4j/v1/direct.py index a5910db7a..c63419ee7 100644 --- a/neo4j/v1/direct.py +++ b/neo4j/v1/direct.py @@ -20,17 +20,25 @@ from neo4j.addressing import SocketAddress, resolve -from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect +from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect, ConnectionErrorHandler from neo4j.exceptions import ServiceUnavailable from neo4j.v1.api import Driver from neo4j.v1.security import SecurityPlan from neo4j.v1.session import BoltSession +class DirectConnectionErrorHandler(ConnectionErrorHandler): + """ Handler for errors in direct driver connections. + """ + + def __init__(self): + super(DirectConnectionErrorHandler, self).__init__({}) # does not need to handle errors + + class DirectConnectionPool(ConnectionPool): def __init__(self, connector, address): - super(DirectConnectionPool, self).__init__(connector) + super(DirectConnectionPool, self).__init__(connector, DirectConnectionErrorHandler()) self.address = address def acquire(self, access_mode=None): @@ -61,7 +69,11 @@ def __init__(self, uri, **config): self.address = SocketAddress.from_uri(uri, DEFAULT_PORT) self.security_plan = security_plan = SecurityPlan.build(**config) self.encrypted = security_plan.encrypted - pool = DirectConnectionPool(lambda a: connect(a, security_plan.ssl_context, **config), self.address) + + def connector(address, error_handler): + return connect(address, security_plan.ssl_context, error_handler, **config) + + pool = DirectConnectionPool(connector, self.address) pool.release(pool.acquire()) Driver.__init__(self, pool, **config) diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 7d3ffbda0..4bfe688c0 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -23,16 +23,15 @@ from time import clock from neo4j.addressing import SocketAddress, resolve -from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect +from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect, ConnectionErrorHandler from neo4j.compat.collections import MutableSet, OrderedDict -from neo4j.exceptions import CypherError +from neo4j.exceptions import CypherError, DatabaseUnavailableError, NotALeaderError, ForbiddenOnReadOnlyDatabaseError from neo4j.util import ServerVersion from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters from neo4j.v1.exceptions import SessionExpired from neo4j.v1.security import SecurityPlan from neo4j.v1.session import BoltSession - LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0 LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1 LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED @@ -247,12 +246,26 @@ def _select(self, offset, addresses): return least_connected_address +class RoutingConnectionErrorHandler(ConnectionErrorHandler): + """ Handler for errors in routing driver connections. + """ + + def __init__(self, pool): + super(RoutingConnectionErrorHandler, self).__init__({ + SessionExpired: lambda address: pool.remove(address), + ServiceUnavailable: lambda address: pool.remove(address), + DatabaseUnavailableError: lambda address: pool.remove(address), + NotALeaderError: lambda address: pool.remove_writer(address), + ForbiddenOnReadOnlyDatabaseError: lambda address: pool.remove_writer(address) + }) + + class RoutingConnectionPool(ConnectionPool): """ Connection pool with routing table. """ def __init__(self, connector, initial_address, routing_context, *routers, **config): - super(RoutingConnectionPool, self).__init__(connector) + super(RoutingConnectionPool, self).__init__(connector, RoutingConnectionErrorHandler(self)) self.initial_address = initial_address self.routing_context = routing_context self.routing_table = RoutingTable(routers) @@ -416,6 +429,11 @@ def remove(self, address): self.routing_table.writers.discard(address) super(RoutingConnectionPool, self).remove(address) + def remove_writer(self, address): + """ Remove a writer address from the routing table, if present. + """ + self.routing_table.writers.discard(address) + class RoutingDriver(Driver): """ A :class:`.RoutingDriver` is created from a ``bolt+routing`` URI. The @@ -433,8 +451,8 @@ def __init__(self, uri, **config): # scenario right now raise ValueError("TRUST_ON_FIRST_USE is not compatible with routing") - def connector(a): - return connect(a, security_plan.ssl_context, **config) + def connector(address, error_handler): + return connect(address, security_plan.ssl_context, error_handler, **config) pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address), **config) try: diff --git a/test/integration/test_connection.py b/test/integration/test_connection.py index 1a28c2ff9..703f97df0 100644 --- a/test/integration/test_connection.py +++ b/test/integration/test_connection.py @@ -21,7 +21,7 @@ from socket import create_connection -from neo4j.v1 import ConnectionPool, ServiceUnavailable +from neo4j.v1 import ConnectionPool, ServiceUnavailable, DirectConnectionErrorHandler from test.integration.tools import IntegrationTestCase @@ -45,10 +45,14 @@ def defunct(self): return False +def connector(address, _): + return QuickConnection(create_connection(address)) + + class ConnectionPoolTestCase(IntegrationTestCase): def setUp(self): - self.pool = ConnectionPool(lambda a: QuickConnection(create_connection(a))) + self.pool = ConnectionPool(connector, DirectConnectionErrorHandler()) def tearDown(self): self.pool.close() @@ -104,7 +108,7 @@ def test_releasing_twice(self): self.assert_pool_size(address, 0, 1) def test_cannot_acquire_after_close(self): - with ConnectionPool(lambda a: QuickConnection(create_connection(a))) as pool: + with ConnectionPool(lambda a: QuickConnection(create_connection(a)), DirectConnectionErrorHandler()) as pool: pool.close() with self.assertRaises(ServiceUnavailable): _ = pool.acquire_direct("X") diff --git a/test/stub/scripts/database_unavailable.script b/test/stub/scripts/database_unavailable.script new file mode 100644 index 000000000..c482f17f8 --- /dev/null +++ b/test/stub/scripts/database_unavailable.script @@ -0,0 +1,12 @@ +!: AUTO INIT +!: AUTO RESET +!: AUTO PULL_ALL +!: AUTO ACK_FAILURE +!: AUTO RUN "ROLLBACK" {} +!: AUTO RUN "BEGIN" {} +!: AUTO RUN "COMMIT" {} + +C: RUN "RETURN 1" {} +C: PULL_ALL +S: FAILURE {"code": "Neo.TransientError.General.DatabaseUnavailable", "message": "Database is busy doing store copy"} +S: IGNORED diff --git a/test/stub/scripts/forbidden_on_read_only_database.script b/test/stub/scripts/forbidden_on_read_only_database.script new file mode 100644 index 000000000..3385b0974 --- /dev/null +++ b/test/stub/scripts/forbidden_on_read_only_database.script @@ -0,0 +1,12 @@ +!: AUTO INIT +!: AUTO RESET +!: AUTO PULL_ALL +!: AUTO ACK_FAILURE +!: AUTO RUN "ROLLBACK" {} +!: AUTO RUN "BEGIN" {} +!: AUTO RUN "COMMIT" {} + +C: RUN "CREATE (n {name:'Bob'})" {} +C: PULL_ALL +S: FAILURE {"code": "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase", "message": "Unable to write"} +S: IGNORED diff --git a/test/stub/scripts/not_a_leader.script b/test/stub/scripts/not_a_leader.script new file mode 100644 index 000000000..8716466c3 --- /dev/null +++ b/test/stub/scripts/not_a_leader.script @@ -0,0 +1,12 @@ +!: AUTO INIT +!: AUTO RESET +!: AUTO PULL_ALL +!: AUTO ACK_FAILURE +!: AUTO RUN "ROLLBACK" {} +!: AUTO RUN "BEGIN" {} +!: AUTO RUN "COMMIT" {} + +C: RUN "CREATE (n {name:'Bob'})" {} +C: PULL_ALL +S: FAILURE {"code": "Neo.ClientError.Cluster.NotALeader", "message": "Leader switched has happened"} +S: IGNORED diff --git a/test/stub/scripts/rude_reader.script b/test/stub/scripts/rude_reader.script new file mode 100644 index 000000000..1b1f7d48c --- /dev/null +++ b/test/stub/scripts/rude_reader.script @@ -0,0 +1,7 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "RETURN 1" {} + PULL_ALL +S: + diff --git a/test/stub/test_routing.py b/test/stub/test_routing.py index 2c9a23cea..05f2bfca2 100644 --- a/test/stub/test_routing.py +++ b/test/stub/test_routing.py @@ -50,8 +50,8 @@ UNREACHABLE_ADDRESS = ("127.0.0.1", 8080) -def connector(address): - return connect(address, auth=basic_auth("neotest", "neotest")) +def connector(address, error_handler): + return connect(address, error_handler=error_handler, auth=basic_auth("neotest", "neotest")) def RoutingPool(*routers): diff --git a/test/stub/test_routingdriver.py b/test/stub/test_routingdriver.py index 0edefe666..ae12fa8ab 100644 --- a/test/stub/test_routingdriver.py +++ b/test/stub/test_routingdriver.py @@ -21,7 +21,7 @@ from neo4j.v1 import GraphDatabase, READ_ACCESS, WRITE_ACCESS, SessionExpired, \ RoutingDriver, RoutingConnectionPool, LeastConnectedLoadBalancingStrategy, LOAD_BALANCING_STRATEGY_ROUND_ROBIN, \ - RoundRobinLoadBalancingStrategy + RoundRobinLoadBalancingStrategy, TransientError, ClientError from neo4j.bolt import ProtocolError, ServiceUnavailable from test.stub.tools import StubTestCase, StubCluster @@ -236,8 +236,79 @@ def test_can_select_round_robin_load_balancing_strategy(self): self.assertIsInstance(driver._pool.load_balancing_strategy, RoundRobinLoadBalancingStrategy) def test_no_other_load_balancing_strategies_are_available(self): - with StubCluster({9001: "router.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with self.assertRaises(ValueError): + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False, load_balancing_strategy=-1): + pass + + def test_forgets_address_on_not_a_leader_error(self): + with StubCluster({9001: "router.script", 9006: "not_a_leader.script"}): uri = "bolt+routing://127.0.0.1:9001" - with self.assertRaises(ValueError): - with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False, load_balancing_strategy=-1): - pass + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(WRITE_ACCESS) as session: + with self.assertRaises(ClientError): + _ = session.run("CREATE (n {name:'Bob'})") + + pool = driver._pool + table = pool.routing_table + + # address might still have connections in the pool, failed instance just can't serve writes + assert ('127.0.0.1', 9006) in pool.connections + assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} + assert table.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)} + # writer 127.0.0.1:9006 should've been forgotten because of an error + assert len(table.writers) == 0 + + def test_forgets_address_on_forbidden_on_read_only_database_error(self): + with StubCluster({9001: "router.script", 9006: "forbidden_on_read_only_database.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(WRITE_ACCESS) as session: + with self.assertRaises(ClientError): + _ = session.run("CREATE (n {name:'Bob'})") + + pool = driver._pool + table = pool.routing_table + + # address might still have connections in the pool, failed instance just can't serve writes + assert ('127.0.0.1', 9006) in pool.connections + assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} + assert table.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)} + # writer 127.0.0.1:9006 should've been forgotten because of an error + assert len(table.writers) == 0 + + def test_forgets_address_on_service_unavailable_error(self): + with StubCluster({9001: "router.script", 9004: "rude_reader.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(READ_ACCESS) as session: + with self.assertRaises(SessionExpired): + _ = session.run("RETURN 1") + + pool = driver._pool + table = pool.routing_table + + # address should not have connections in the pool, it has failed + assert ('127.0.0.1', 9004) not in pool.connections + assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} + # reader 127.0.0.1:9004 should've been forgotten because of an error + assert table.readers == {('127.0.0.1', 9005)} + assert table.writers == {('127.0.0.1', 9006)} + + def test_forgets_address_on_database_unavailable_error(self): + with StubCluster({9001: "router.script", 9004: "database_unavailable.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(READ_ACCESS) as session: + with self.assertRaises(TransientError): + _ = session.run("RETURN 1") + + pool = driver._pool + table = pool.routing_table + + # address should not have connections in the pool, it has failed + assert ('127.0.0.1', 9004) not in pool.connections + assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} + # reader 127.0.0.1:9004 should've been forgotten because of an error + assert table.readers == {('127.0.0.1', 9005)} + assert table.writers == {('127.0.0.1', 9006)} diff --git a/test/unit/test_routing.py b/test/unit/test_routing.py index 92ca5507d..a7d12c4df 100644 --- a/test/unit/test_routing.py +++ b/test/unit/test_routing.py @@ -52,8 +52,8 @@ } -def connector(address): - return connect(address, auth=basic_auth("neotest", "neotest")) +def connector(address, error_handler): + return connect(address, error_handler=error_handler, auth=basic_auth("neotest", "neotest")) class RoundRobinSetTestCase(TestCase):