Skip to content

Commit ffcc17c

Browse files
authored
Merge pull request #176 from lutovich/1.5-forget-addresses-on-errors
Routing driver forgets addresses on some errors
2 parents f243ee6 + 9edcf42 commit ffcc17c

13 files changed

+260
-39
lines changed

neo4j/bolt/connection.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,26 @@ def supports_bytes(self):
119119
return self.version_info() >= (3, 2)
120120

121121

122+
class ConnectionErrorHandler(object):
123+
""" A handler for send and receive errors.
124+
"""
125+
126+
def __init__(self, handlers_by_error_class=None):
127+
if handlers_by_error_class is None:
128+
handlers_by_error_class = {}
129+
130+
self.handlers_by_error_class = handlers_by_error_class
131+
self.known_errors = tuple(handlers_by_error_class.keys())
132+
133+
def handle(self, error, address):
134+
try:
135+
error_class = error.__class__
136+
handler = self.handlers_by_error_class[error_class]
137+
handler(address)
138+
except KeyError:
139+
pass
140+
141+
122142
class Connection(object):
123143
""" Server connection for Bolt protocol v1.
124144
@@ -148,8 +168,10 @@ class Connection(object):
148168

149169
_last_run_statement = None
150170

151-
def __init__(self, sock, **config):
171+
def __init__(self, address, sock, error_handler, **config):
172+
self.address = address
152173
self.socket = sock
174+
self.error_handler = error_handler
153175
self.server = ServerInfo(SocketAddress.from_socket(sock))
154176
self.input_buffer = ChunkedInputBuffer()
155177
self.output_buffer = ChunkedOutputBuffer()
@@ -237,6 +259,13 @@ def reset(self):
237259
self.sync()
238260

239261
def send(self):
262+
try:
263+
self._send()
264+
except self.error_handler.known_errors as error:
265+
self.error_handler.handle(error, self.address)
266+
raise error
267+
268+
def _send(self):
240269
""" Send all queued messages to the server.
241270
"""
242271
data = self.output_buffer.view()
@@ -250,6 +279,13 @@ def send(self):
250279
self.output_buffer.clear()
251280

252281
def fetch(self):
282+
try:
283+
return self._fetch()
284+
except self.error_handler.known_errors as error:
285+
self.error_handler.handle(error, self.address)
286+
raise error
287+
288+
def _fetch(self):
253289
""" Receive at least one message from the server, if available.
254290
255291
:return: 2-tuple of number of detail messages and number of summary messages fetched
@@ -360,8 +396,9 @@ class ConnectionPool(object):
360396

361397
_closed = False
362398

363-
def __init__(self, connector):
399+
def __init__(self, connector, connection_error_handler):
364400
self.connector = connector
401+
self.connection_error_handler = connection_error_handler
365402
self.connections = {}
366403
self.lock = RLock()
367404

@@ -395,7 +432,7 @@ def acquire_direct(self, address):
395432
connection.in_use = True
396433
return connection
397434
try:
398-
connection = self.connector(address)
435+
connection = self.connector(address, self.connection_error_handler)
399436
except ServiceUnavailable:
400437
self.remove(address)
401438
raise
@@ -457,7 +494,7 @@ def closed(self):
457494
return self._closed
458495

459496

460-
def connect(address, ssl_context=None, **config):
497+
def connect(address, ssl_context=None, error_handler=None, **config):
461498
""" Connect and perform a handshake and return a valid Connection object, assuming
462499
a protocol version can be agreed.
463500
"""
@@ -563,7 +600,8 @@ def connect(address, ssl_context=None, **config):
563600
s.shutdown(SHUT_RDWR)
564601
s.close()
565602
elif agreed_version == 1:
566-
return Connection(s, der_encoded_server_certificate=der_encoded_server_certificate, **config)
603+
return Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate,
604+
error_handler=error_handler, **config)
567605
elif agreed_version == 0x48545450:
568606
log_error("S: [CLOSE]")
569607
s.close()

neo4j/exceptions.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,9 @@ def hydrate(cls, message=None, code=None, **metadata):
6565
classification = "DatabaseError"
6666
category = "General"
6767
title = "UnknownError"
68-
if classification == "ClientError":
69-
try:
70-
error_class = client_errors[code]
71-
except KeyError:
72-
error_class = ClientError
73-
elif classification == "DatabaseError":
74-
error_class = DatabaseError
75-
elif classification == "TransientError":
76-
error_class = TransientError
77-
else:
78-
error_class = cls
68+
69+
error_class = cls._extract_error_class(classification, code)
70+
7971
inst = error_class(message)
8072
inst.message = message
8173
inst.code = code
@@ -85,6 +77,26 @@ def hydrate(cls, message=None, code=None, **metadata):
8577
inst.metadata = metadata
8678
return inst
8779

80+
@classmethod
81+
def _extract_error_class(cls, classification, code):
82+
if classification == "ClientError":
83+
try:
84+
return client_errors[code]
85+
except KeyError:
86+
return ClientError
87+
88+
elif classification == "TransientError":
89+
try:
90+
return transient_errors[code]
91+
except KeyError:
92+
return TransientError
93+
94+
elif classification == "DatabaseError":
95+
return DatabaseError
96+
97+
else:
98+
return cls
99+
88100

89101
class ClientError(CypherError):
90102
""" The Client sent a bad request - changing the request might yield a successful outcome.
@@ -101,6 +113,11 @@ class TransientError(CypherError):
101113
"""
102114

103115

116+
class DatabaseUnavailableError(TransientError):
117+
"""
118+
"""
119+
120+
104121
class ConstraintError(ClientError):
105122
"""
106123
"""
@@ -116,11 +133,21 @@ class CypherTypeError(ClientError):
116133
"""
117134

118135

136+
class NotALeaderError(ClientError):
137+
"""
138+
"""
139+
140+
119141
class Forbidden(ClientError, SecurityError):
120142
"""
121143
"""
122144

123145

146+
class ForbiddenOnReadOnlyDatabaseError(Forbidden):
147+
"""
148+
"""
149+
150+
124151
class AuthError(ClientError, SecurityError):
125152
""" Raised when authentication failure occurs.
126153
"""
@@ -144,7 +171,7 @@ class AuthError(ClientError, SecurityError):
144171
"Neo.ClientError.Statement.TypeError": CypherTypeError,
145172

146173
# Forbidden
147-
"Neo.ClientError.General.ForbiddenOnReadOnlyDatabase": Forbidden,
174+
"Neo.ClientError.General.ForbiddenOnReadOnlyDatabase": ForbiddenOnReadOnlyDatabaseError,
148175
"Neo.ClientError.General.ReadOnly": Forbidden,
149176
"Neo.ClientError.Schema.ForbiddenOnConstraintIndex": Forbidden,
150177
"Neo.ClientError.Schema.IndexBelongsToConstraint": Forbidden,
@@ -155,4 +182,12 @@ class AuthError(ClientError, SecurityError):
155182
"Neo.ClientError.Security.AuthorizationFailed": AuthError,
156183
"Neo.ClientError.Security.Unauthorized": AuthError,
157184

185+
# NotALeaderError
186+
"Neo.ClientError.Cluster.NotALeader": NotALeaderError
187+
}
188+
189+
transient_errors = {
190+
191+
# DatabaseUnavailableError
192+
"Neo.TransientError.General.DatabaseUnavailable": DatabaseUnavailableError
158193
}

neo4j/v1/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def _disconnect(self, sync):
277277
if sync:
278278
try:
279279
self._connection.sync()
280-
except ServiceUnavailable:
280+
except (SessionError, ServiceUnavailable):
281281
pass
282282
if self._connection:
283283
self._connection.in_use = False

neo4j/v1/direct.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,25 @@
2020

2121

2222
from neo4j.addressing import SocketAddress, resolve
23-
from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect
23+
from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect, ConnectionErrorHandler
2424
from neo4j.exceptions import ServiceUnavailable
2525
from neo4j.v1.api import Driver
2626
from neo4j.v1.security import SecurityPlan
2727
from neo4j.v1.session import BoltSession
2828

2929

30+
class DirectConnectionErrorHandler(ConnectionErrorHandler):
31+
""" Handler for errors in direct driver connections.
32+
"""
33+
34+
def __init__(self):
35+
super(DirectConnectionErrorHandler, self).__init__({}) # does not need to handle errors
36+
37+
3038
class DirectConnectionPool(ConnectionPool):
3139

3240
def __init__(self, connector, address):
33-
super(DirectConnectionPool, self).__init__(connector)
41+
super(DirectConnectionPool, self).__init__(connector, DirectConnectionErrorHandler())
3442
self.address = address
3543

3644
def acquire(self, access_mode=None):
@@ -61,7 +69,11 @@ def __init__(self, uri, **config):
6169
self.address = SocketAddress.from_uri(uri, DEFAULT_PORT)
6270
self.security_plan = security_plan = SecurityPlan.build(**config)
6371
self.encrypted = security_plan.encrypted
64-
pool = DirectConnectionPool(lambda a: connect(a, security_plan.ssl_context, **config), self.address)
72+
73+
def connector(address, error_handler):
74+
return connect(address, security_plan.ssl_context, error_handler, **config)
75+
76+
pool = DirectConnectionPool(connector, self.address)
6577
pool.release(pool.acquire())
6678
Driver.__init__(self, pool, **config)
6779

neo4j/v1/routing.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,15 @@
2323
from time import clock
2424

2525
from neo4j.addressing import SocketAddress, resolve
26-
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect
26+
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect, ConnectionErrorHandler
2727
from neo4j.compat.collections import MutableSet, OrderedDict
28-
from neo4j.exceptions import CypherError
28+
from neo4j.exceptions import CypherError, DatabaseUnavailableError, NotALeaderError, ForbiddenOnReadOnlyDatabaseError
2929
from neo4j.util import ServerVersion
3030
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters
3131
from neo4j.v1.exceptions import SessionExpired
3232
from neo4j.v1.security import SecurityPlan
3333
from neo4j.v1.session import BoltSession
3434

35-
3635
LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0
3736
LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1
3837
LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED
@@ -247,12 +246,26 @@ def _select(self, offset, addresses):
247246
return least_connected_address
248247

249248

249+
class RoutingConnectionErrorHandler(ConnectionErrorHandler):
250+
""" Handler for errors in routing driver connections.
251+
"""
252+
253+
def __init__(self, pool):
254+
super(RoutingConnectionErrorHandler, self).__init__({
255+
SessionExpired: lambda address: pool.remove(address),
256+
ServiceUnavailable: lambda address: pool.remove(address),
257+
DatabaseUnavailableError: lambda address: pool.remove(address),
258+
NotALeaderError: lambda address: pool.remove_writer(address),
259+
ForbiddenOnReadOnlyDatabaseError: lambda address: pool.remove_writer(address)
260+
})
261+
262+
250263
class RoutingConnectionPool(ConnectionPool):
251264
""" Connection pool with routing table.
252265
"""
253266

254267
def __init__(self, connector, initial_address, routing_context, *routers, **config):
255-
super(RoutingConnectionPool, self).__init__(connector)
268+
super(RoutingConnectionPool, self).__init__(connector, RoutingConnectionErrorHandler(self))
256269
self.initial_address = initial_address
257270
self.routing_context = routing_context
258271
self.routing_table = RoutingTable(routers)
@@ -416,6 +429,11 @@ def remove(self, address):
416429
self.routing_table.writers.discard(address)
417430
super(RoutingConnectionPool, self).remove(address)
418431

432+
def remove_writer(self, address):
433+
""" Remove a writer address from the routing table, if present.
434+
"""
435+
self.routing_table.writers.discard(address)
436+
419437

420438
class RoutingDriver(Driver):
421439
""" A :class:`.RoutingDriver` is created from a ``bolt+routing`` URI. The
@@ -433,8 +451,8 @@ def __init__(self, uri, **config):
433451
# scenario right now
434452
raise ValueError("TRUST_ON_FIRST_USE is not compatible with routing")
435453

436-
def connector(a):
437-
return connect(a, security_plan.ssl_context, **config)
454+
def connector(address, error_handler):
455+
return connect(address, security_plan.ssl_context, error_handler, **config)
438456

439457
pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address), **config)
440458
try:

test/integration/test_connection.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from socket import create_connection
2323

24-
from neo4j.v1 import ConnectionPool, ServiceUnavailable
24+
from neo4j.v1 import ConnectionPool, ServiceUnavailable, DirectConnectionErrorHandler
2525

2626
from test.integration.tools import IntegrationTestCase
2727

@@ -45,10 +45,14 @@ def defunct(self):
4545
return False
4646

4747

48+
def connector(address, _):
49+
return QuickConnection(create_connection(address))
50+
51+
4852
class ConnectionPoolTestCase(IntegrationTestCase):
4953

5054
def setUp(self):
51-
self.pool = ConnectionPool(lambda a: QuickConnection(create_connection(a)))
55+
self.pool = ConnectionPool(connector, DirectConnectionErrorHandler())
5256

5357
def tearDown(self):
5458
self.pool.close()
@@ -104,7 +108,7 @@ def test_releasing_twice(self):
104108
self.assert_pool_size(address, 0, 1)
105109

106110
def test_cannot_acquire_after_close(self):
107-
with ConnectionPool(lambda a: QuickConnection(create_connection(a))) as pool:
111+
with ConnectionPool(lambda a: QuickConnection(create_connection(a)), DirectConnectionErrorHandler()) as pool:
108112
pool.close()
109113
with self.assertRaises(ServiceUnavailable):
110114
_ = pool.acquire_direct("X")
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
!: AUTO PULL_ALL
4+
!: AUTO ACK_FAILURE
5+
!: AUTO RUN "ROLLBACK" {}
6+
!: AUTO RUN "BEGIN" {}
7+
!: AUTO RUN "COMMIT" {}
8+
9+
C: RUN "RETURN 1" {}
10+
C: PULL_ALL
11+
S: FAILURE {"code": "Neo.TransientError.General.DatabaseUnavailable", "message": "Database is busy doing store copy"}
12+
S: IGNORED
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
!: AUTO PULL_ALL
4+
!: AUTO ACK_FAILURE
5+
!: AUTO RUN "ROLLBACK" {}
6+
!: AUTO RUN "BEGIN" {}
7+
!: AUTO RUN "COMMIT" {}
8+
9+
C: RUN "CREATE (n {name:'Bob'})" {}
10+
C: PULL_ALL
11+
S: FAILURE {"code": "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase", "message": "Unable to write"}
12+
S: IGNORED

0 commit comments

Comments
 (0)