Skip to content

Routing driver forgets addresses on some errors #176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions neo4j/bolt/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,26 @@ def supports_bytes(self):
return self.version_info() >= (3, 2)


class ConnectionErrorHandler(object):
""" A handler for send and receive errors.
"""

def __init__(self, handlers_by_error_class=None):
if handlers_by_error_class is None:
handlers_by_error_class = {}

self.handlers_by_error_class = handlers_by_error_class
self.known_errors = tuple(handlers_by_error_class.keys())

def handle(self, error, address):
try:
error_class = error.__class__
handler = self.handlers_by_error_class[error_class]
handler(address)
except KeyError:
pass


class Connection(object):
""" Server connection for Bolt protocol v1.

Expand Down Expand Up @@ -148,8 +168,10 @@ class Connection(object):

_last_run_statement = None

def __init__(self, sock, **config):
def __init__(self, address, sock, error_handler, **config):
self.address = address
self.socket = sock
self.error_handler = error_handler
self.server = ServerInfo(SocketAddress.from_socket(sock))
self.input_buffer = ChunkedInputBuffer()
self.output_buffer = ChunkedOutputBuffer()
Expand Down Expand Up @@ -237,6 +259,13 @@ def reset(self):
self.sync()

def send(self):
try:
self._send()
except self.error_handler.known_errors as error:
self.error_handler.handle(error, self.address)
raise error

def _send(self):
""" Send all queued messages to the server.
"""
data = self.output_buffer.view()
Expand All @@ -250,6 +279,13 @@ def send(self):
self.output_buffer.clear()

def fetch(self):
try:
return self._fetch()
except self.error_handler.known_errors as error:
self.error_handler.handle(error, self.address)
raise error

def _fetch(self):
""" Receive at least one message from the server, if available.

:return: 2-tuple of number of detail messages and number of summary messages fetched
Expand Down Expand Up @@ -360,8 +396,9 @@ class ConnectionPool(object):

_closed = False

def __init__(self, connector):
def __init__(self, connector, connection_error_handler):
self.connector = connector
self.connection_error_handler = connection_error_handler
self.connections = {}
self.lock = RLock()

Expand Down Expand Up @@ -395,7 +432,7 @@ def acquire_direct(self, address):
connection.in_use = True
return connection
try:
connection = self.connector(address)
connection = self.connector(address, self.connection_error_handler)
except ServiceUnavailable:
self.remove(address)
raise
Expand Down Expand Up @@ -457,7 +494,7 @@ def closed(self):
return self._closed


def connect(address, ssl_context=None, **config):
def connect(address, ssl_context=None, error_handler=None, **config):
""" Connect and perform a handshake and return a valid Connection object, assuming
a protocol version can be agreed.
"""
Expand Down Expand Up @@ -563,7 +600,8 @@ def connect(address, ssl_context=None, **config):
s.shutdown(SHUT_RDWR)
s.close()
elif agreed_version == 1:
return Connection(s, der_encoded_server_certificate=der_encoded_server_certificate, **config)
return Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate,
error_handler=error_handler, **config)
elif agreed_version == 0x48545450:
log_error("S: [CLOSE]")
s.close()
Expand Down
59 changes: 47 additions & 12 deletions neo4j/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,9 @@ def hydrate(cls, message=None, code=None, **metadata):
classification = "DatabaseError"
category = "General"
title = "UnknownError"
if classification == "ClientError":
try:
error_class = client_errors[code]
except KeyError:
error_class = ClientError
elif classification == "DatabaseError":
error_class = DatabaseError
elif classification == "TransientError":
error_class = TransientError
else:
error_class = cls

error_class = cls._extract_error_class(classification, code)

inst = error_class(message)
inst.message = message
inst.code = code
Expand All @@ -85,6 +77,26 @@ def hydrate(cls, message=None, code=None, **metadata):
inst.metadata = metadata
return inst

@classmethod
def _extract_error_class(cls, classification, code):
if classification == "ClientError":
try:
return client_errors[code]
except KeyError:
return ClientError

elif classification == "TransientError":
try:
return transient_errors[code]
except KeyError:
return TransientError

elif classification == "DatabaseError":
return DatabaseError

else:
return cls


class ClientError(CypherError):
""" The Client sent a bad request - changing the request might yield a successful outcome.
Expand All @@ -101,6 +113,11 @@ class TransientError(CypherError):
"""


class DatabaseUnavailableError(TransientError):
"""
"""


class ConstraintError(ClientError):
"""
"""
Expand All @@ -116,11 +133,21 @@ class CypherTypeError(ClientError):
"""


class NotALeaderError(ClientError):
"""
"""


class Forbidden(ClientError, SecurityError):
"""
"""


class ForbiddenOnReadOnlyDatabaseError(Forbidden):
"""
"""


class AuthError(ClientError, SecurityError):
""" Raised when authentication failure occurs.
"""
Expand All @@ -144,7 +171,7 @@ class AuthError(ClientError, SecurityError):
"Neo.ClientError.Statement.TypeError": CypherTypeError,

# Forbidden
"Neo.ClientError.General.ForbiddenOnReadOnlyDatabase": Forbidden,
"Neo.ClientError.General.ForbiddenOnReadOnlyDatabase": ForbiddenOnReadOnlyDatabaseError,
"Neo.ClientError.General.ReadOnly": Forbidden,
"Neo.ClientError.Schema.ForbiddenOnConstraintIndex": Forbidden,
"Neo.ClientError.Schema.IndexBelongsToConstraint": Forbidden,
Expand All @@ -155,4 +182,12 @@ class AuthError(ClientError, SecurityError):
"Neo.ClientError.Security.AuthorizationFailed": AuthError,
"Neo.ClientError.Security.Unauthorized": AuthError,

# NotALeaderError
"Neo.ClientError.Cluster.NotALeader": NotALeaderError
}

transient_errors = {

# DatabaseUnavailableError
"Neo.TransientError.General.DatabaseUnavailable": DatabaseUnavailableError
}
2 changes: 1 addition & 1 deletion neo4j/v1/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _disconnect(self, sync):
if sync:
try:
self._connection.sync()
except ServiceUnavailable:
except (SessionError, ServiceUnavailable):
pass
if self._connection:
self._connection.in_use = False
Expand Down
18 changes: 15 additions & 3 deletions neo4j/v1/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,25 @@


from neo4j.addressing import SocketAddress, resolve
from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect
from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect, ConnectionErrorHandler
from neo4j.exceptions import ServiceUnavailable
from neo4j.v1.api import Driver
from neo4j.v1.security import SecurityPlan
from neo4j.v1.session import BoltSession


class DirectConnectionErrorHandler(ConnectionErrorHandler):
""" Handler for errors in direct driver connections.
"""

def __init__(self):
super(DirectConnectionErrorHandler, self).__init__({}) # does not need to handle errors


class DirectConnectionPool(ConnectionPool):

def __init__(self, connector, address):
super(DirectConnectionPool, self).__init__(connector)
super(DirectConnectionPool, self).__init__(connector, DirectConnectionErrorHandler())
self.address = address

def acquire(self, access_mode=None):
Expand Down Expand Up @@ -61,7 +69,11 @@ def __init__(self, uri, **config):
self.address = SocketAddress.from_uri(uri, DEFAULT_PORT)
self.security_plan = security_plan = SecurityPlan.build(**config)
self.encrypted = security_plan.encrypted
pool = DirectConnectionPool(lambda a: connect(a, security_plan.ssl_context, **config), self.address)

def connector(address, error_handler):
return connect(address, security_plan.ssl_context, error_handler, **config)

pool = DirectConnectionPool(connector, self.address)
pool.release(pool.acquire())
Driver.__init__(self, pool, **config)

Expand Down
30 changes: 24 additions & 6 deletions neo4j/v1/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@
from time import clock

from neo4j.addressing import SocketAddress, resolve
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect, ConnectionErrorHandler
from neo4j.compat.collections import MutableSet, OrderedDict
from neo4j.exceptions import CypherError
from neo4j.exceptions import CypherError, DatabaseUnavailableError, NotALeaderError, ForbiddenOnReadOnlyDatabaseError
from neo4j.util import ServerVersion
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters
from neo4j.v1.exceptions import SessionExpired
from neo4j.v1.security import SecurityPlan
from neo4j.v1.session import BoltSession


LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0
LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1
LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED
Expand Down Expand Up @@ -247,12 +246,26 @@ def _select(self, offset, addresses):
return least_connected_address


class RoutingConnectionErrorHandler(ConnectionErrorHandler):
""" Handler for errors in routing driver connections.
"""

def __init__(self, pool):
super(RoutingConnectionErrorHandler, self).__init__({
SessionExpired: lambda address: pool.remove(address),
ServiceUnavailable: lambda address: pool.remove(address),
DatabaseUnavailableError: lambda address: pool.remove(address),
NotALeaderError: lambda address: pool.remove_writer(address),
ForbiddenOnReadOnlyDatabaseError: lambda address: pool.remove_writer(address)
})


class RoutingConnectionPool(ConnectionPool):
""" Connection pool with routing table.
"""

def __init__(self, connector, initial_address, routing_context, *routers, **config):
super(RoutingConnectionPool, self).__init__(connector)
super(RoutingConnectionPool, self).__init__(connector, RoutingConnectionErrorHandler(self))
self.initial_address = initial_address
self.routing_context = routing_context
self.routing_table = RoutingTable(routers)
Expand Down Expand Up @@ -416,6 +429,11 @@ def remove(self, address):
self.routing_table.writers.discard(address)
super(RoutingConnectionPool, self).remove(address)

def remove_writer(self, address):
""" Remove a writer address from the routing table, if present.
"""
self.routing_table.writers.discard(address)


class RoutingDriver(Driver):
""" A :class:`.RoutingDriver` is created from a ``bolt+routing`` URI. The
Expand All @@ -433,8 +451,8 @@ def __init__(self, uri, **config):
# scenario right now
raise ValueError("TRUST_ON_FIRST_USE is not compatible with routing")

def connector(a):
return connect(a, security_plan.ssl_context, **config)
def connector(address, error_handler):
return connect(address, security_plan.ssl_context, error_handler, **config)

pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address), **config)
try:
Expand Down
10 changes: 7 additions & 3 deletions test/integration/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from socket import create_connection

from neo4j.v1 import ConnectionPool, ServiceUnavailable
from neo4j.v1 import ConnectionPool, ServiceUnavailable, DirectConnectionErrorHandler

from test.integration.tools import IntegrationTestCase

Expand All @@ -45,10 +45,14 @@ def defunct(self):
return False


def connector(address, _):
return QuickConnection(create_connection(address))


class ConnectionPoolTestCase(IntegrationTestCase):

def setUp(self):
self.pool = ConnectionPool(lambda a: QuickConnection(create_connection(a)))
self.pool = ConnectionPool(connector, DirectConnectionErrorHandler())

def tearDown(self):
self.pool.close()
Expand Down Expand Up @@ -104,7 +108,7 @@ def test_releasing_twice(self):
self.assert_pool_size(address, 0, 1)

def test_cannot_acquire_after_close(self):
with ConnectionPool(lambda a: QuickConnection(create_connection(a))) as pool:
with ConnectionPool(lambda a: QuickConnection(create_connection(a)), DirectConnectionErrorHandler()) as pool:
pool.close()
with self.assertRaises(ServiceUnavailable):
_ = pool.acquire_direct("X")
Expand Down
12 changes: 12 additions & 0 deletions test/stub/scripts/database_unavailable.script
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
!: AUTO INIT
!: AUTO RESET
!: AUTO PULL_ALL
!: AUTO ACK_FAILURE
!: AUTO RUN "ROLLBACK" {}
!: AUTO RUN "BEGIN" {}
!: AUTO RUN "COMMIT" {}

C: RUN "RETURN 1" {}
C: PULL_ALL
S: FAILURE {"code": "Neo.TransientError.General.DatabaseUnavailable", "message": "Database is busy doing store copy"}
S: IGNORED
12 changes: 12 additions & 0 deletions test/stub/scripts/forbidden_on_read_only_database.script
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
!: AUTO INIT
!: AUTO RESET
!: AUTO PULL_ALL
!: AUTO ACK_FAILURE
!: AUTO RUN "ROLLBACK" {}
!: AUTO RUN "BEGIN" {}
!: AUTO RUN "COMMIT" {}

C: RUN "CREATE (n {name:'Bob'})" {}
C: PULL_ALL
S: FAILURE {"code": "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase", "message": "Unable to write"}
S: IGNORED
Loading