Skip to content

Commit 0f6d749

Browse files
author
Zhen Li
authored
Merge pull request #106 from neo4j/1.1-remove-on-conn-fail
Remove from pool on connection failure
2 parents 96fcb55 + f405905 commit 0f6d749

File tree

6 files changed

+70
-10
lines changed

6 files changed

+70
-10
lines changed

neo4j/v1/bolt.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@
8383

8484
class BufferingSocket(object):
8585

86-
def __init__(self, socket):
87-
self.address = socket.getpeername()
88-
self.socket = socket
86+
def __init__(self, connection):
87+
self.connection = connection
88+
self.socket = connection.socket
89+
self.address = self.socket.getpeername()
8990
self.buffer = bytearray()
9091

9192
def fill(self):
@@ -96,6 +97,10 @@ def fill(self):
9697
self.buffer[len(self.buffer):] = received
9798
else:
9899
if ready_to_read is not None:
100+
# If this connection fails, remove this address from the
101+
# connection pool to which this connection belongs.
102+
if self.connection.pool:
103+
self.connection.pool.remove(self.address)
99104
raise ServiceUnavailable("Failed to read from connection %r" % (self.address,))
100105

101106
def read_message(self):
@@ -211,9 +216,12 @@ class Connection(object):
211216
.. note:: logs at INFO level
212217
"""
213218

219+
#: The pool of which this connection is a member
220+
pool = None
221+
214222
def __init__(self, sock, **config):
215223
self.socket = sock
216-
self.buffering_socket = BufferingSocket(sock)
224+
self.buffering_socket = BufferingSocket(self)
217225
self.address = sock.getpeername()
218226
self.channel = ChunkChannel(sock)
219227
self.packer = Packer(self.channel)
@@ -411,6 +419,7 @@ def acquire(self, address):
411419
connection.in_use = True
412420
return connection
413421
connection = self.connector(address)
422+
connection.pool = self
414423
connection.in_use = True
415424
connections.append(connection)
416425
return connection

neo4j/v1/routing.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,22 +263,40 @@ def refresh_routing_table(self):
263263
def acquire_for_read(self):
264264
""" Acquire a connection to a read server.
265265
"""
266-
self.refresh_routing_table()
267-
return self.acquire(next(self.routing_table.readers))
266+
while True:
267+
address = None
268+
while address is None:
269+
self.refresh_routing_table()
270+
address = next(self.routing_table.readers)
271+
try:
272+
connection = self.acquire(address)
273+
except ServiceUnavailable:
274+
self.remove(address)
275+
else:
276+
return connection
268277

269278
def acquire_for_write(self):
270279
""" Acquire a connection to a write server.
271280
"""
272-
self.refresh_routing_table()
273-
return self.acquire(next(self.routing_table.writers))
281+
while True:
282+
address = None
283+
while address is None:
284+
self.refresh_routing_table()
285+
address = next(self.routing_table.writers)
286+
try:
287+
connection = self.acquire(address)
288+
except ServiceUnavailable:
289+
self.remove(address)
290+
else:
291+
return connection
274292

275293
def remove(self, address):
276294
""" Remove an address from the connection pool, if present, closing
277295
all connections to that address. Also remove from the routing table.
278296
"""
279-
super(RoutingConnectionPool, self).remove(address)
280297
# We use `discard` instead of `remove` here since the former
281298
# will not fail if the address has already been removed.
282299
self.routing_table.routers.discard(address)
283300
self.routing_table.readers.discard(address)
284301
self.routing_table.writers.discard(address)
302+
super(RoutingConnectionPool, self).remove(address)

test/resources/fail_on_init.script

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
4+
S: <EXIT>
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
4+
C: RUN "CALL dbms.cluster.routing.getServers" {}
5+
PULL_ALL
6+
S: SUCCESS {"fields": ["ttl", "servers"]}
7+
RECORD [300, [{"role":"ROUTE","addresses":["127.0.0.1:9001","127.0.0.1:9002","127.0.0.1:9003"]},{"role":"READ","addresses":["127.0.0.1:9004","127.0.0.1:9005"]},{"role":"WRITE","addresses":["127.0.0.1:9006","127.0.0.1:9007"]}]]
8+
SUCCESS {}

test/test_routing.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,17 @@ def test_connected_to_reader(self):
577577
connection = pool.acquire_for_read()
578578
assert connection.address in pool.routing_table.readers
579579

580+
def test_should_retry_if_first_reader_fails(self):
581+
with StubCluster({9001: "router.script",
582+
9004: "fail_on_init.script",
583+
9005: "empty.script"}):
584+
address = ("127.0.0.1", 9001)
585+
with RoutingConnectionPool(connector, address) as pool:
586+
assert not pool.routing_table.is_fresh()
587+
_ = pool.acquire_for_read()
588+
assert ("127.0.0.1", 9004) not in pool.routing_table.readers
589+
assert ("127.0.0.1", 9005) in pool.routing_table.readers
590+
580591

581592
class RoutingConnectionPoolAcquireForWriteTestCase(ServerTestCase):
582593

@@ -596,6 +607,17 @@ def test_connected_to_writer(self):
596607
connection = pool.acquire_for_write()
597608
assert connection.address in pool.routing_table.writers
598609

610+
def test_should_retry_if_first_writer_fails(self):
611+
with StubCluster({9001: "router_with_multiple_writers.script",
612+
9006: "fail_on_init.script",
613+
9007: "empty.script"}):
614+
address = ("127.0.0.1", 9001)
615+
with RoutingConnectionPool(connector, address) as pool:
616+
assert not pool.routing_table.is_fresh()
617+
_ = pool.acquire_for_write()
618+
assert ("127.0.0.1", 9006) not in pool.routing_table.writers
619+
assert ("127.0.0.1", 9007) in pool.routing_table.writers
620+
599621

600622
class RoutingConnectionPoolRemoveTestCase(ServerTestCase):
601623

test/util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ class ServerTestCase(TestCase):
8585

8686
known_hosts = KNOWN_HOSTS
8787
known_hosts_backup = known_hosts + ".backup"
88-
servers = []
8988

9089
def setUp(self):
9190
if isfile(self.known_hosts):

0 commit comments

Comments
 (0)