Skip to content

Commit 3cb094a

Browse files
author
Zhen
committed
Added RoutingSession to call get routing table
1 parent fe7554e commit 3cb094a

File tree

2 files changed

+24
-20
lines changed

2 files changed

+24
-20
lines changed

neo4j/v1/routing.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect
2727
from neo4j.compat.collections import MutableSet, OrderedDict
2828
from neo4j.exceptions import CypherError
29-
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS
29+
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters
3030
from neo4j.v1.exceptions import SessionExpired
3131
from neo4j.v1.security import SecurityPlan
3232
from neo4j.v1.session import BoltSession
@@ -150,14 +150,28 @@ def update(self, new_routing_table):
150150
self.ttl = new_routing_table.ttl
151151

152152

153-
class RoutingConnectionPool(ConnectionPool):
154-
""" Connection pool with routing table.
155-
"""
153+
class RoutingSession(BoltSession):
156154

157155
call_get_servers = "CALL dbms.cluster.routing.getServers"
158156
get_routing_table_param = "context"
159157
call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({%s})" % get_routing_table_param
160158

159+
def routing_info_procedure(self, routing_context):
160+
if ServerVersion.from_str(self._connection.server.version).at_least_version(3, 2):
161+
return self.call_get_routing_table, {self.get_routing_table_param: routing_context}
162+
else:
163+
return self.call_get_servers, {}
164+
165+
def __run__(self, ignored, routing_context):
166+
# the statement is ignored as it will be get routing table procedure call.
167+
statement, parameters = self.routing_info_procedure(routing_context)
168+
return self._run(fix_statement(statement), fix_parameters(parameters))
169+
170+
171+
class RoutingConnectionPool(ConnectionPool):
172+
""" Connection pool with routing table.
173+
"""
174+
161175
def __init__(self, connector, initial_address, routing_context, *routers):
162176
super(RoutingConnectionPool, self).__init__(connector)
163177
self.initial_address = initial_address
@@ -166,12 +180,6 @@ def __init__(self, connector, initial_address, routing_context, *routers):
166180
self.missing_writer = False
167181
self.refresh_lock = Lock()
168182

169-
def routing_info_procedure(self, connection):
170-
if ServerVersion.from_str(connection.server.version).at_least_version(3, 2):
171-
return self.call_get_routing_table, {self.get_routing_table_param: self.routing_context}
172-
else:
173-
return self.call_get_servers, {}
174-
175183
def fetch_routing_info(self, address):
176184
""" Fetch raw routing info from a given router address.
177185
@@ -182,15 +190,8 @@ def fetch_routing_info(self, address):
182190
if routing support is broken
183191
"""
184192
try:
185-
connections = [None]
186-
187-
def connector(_):
188-
connection = self.acquire_direct(address)
189-
connections[0] = connection
190-
return connection
191-
192-
with BoltSession(lambda _: connector) as session:
193-
return list(session.run(*self.routing_info_procedure(connections[0])))
193+
with RoutingSession(lambda _: self.acquire_direct(address)) as session:
194+
return list(session.run("ignored", self.routing_context))
194195
except CypherError as error:
195196
if error.code == "Neo.ClientError.Procedure.ProcedureNotFound":
196197
raise ServiceUnavailable("Server {!r} does not support routing".format(address))

neo4j/v1/session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class BoltSession(Session):
3434
:param bookmark:
3535
"""
3636

37-
def __run__(self, statement, parameters):
37+
def _run(self, statement, parameters):
3838
assert isinstance(statement, unicode)
3939
assert isinstance(parameters, dict)
4040

@@ -52,6 +52,9 @@ def __run__(self, statement, parameters):
5252

5353
return result
5454

55+
def __run__(self, statement, parameters):
56+
return self._run(statement, parameters)
57+
5558
def __begin__(self):
5659
return self.__run__(u"BEGIN", {"bookmark": self._bookmark} if self._bookmark else {})
5760

0 commit comments

Comments
 (0)