26
26
from neo4j .bolt import ConnectionPool , ServiceUnavailable , ProtocolError , DEFAULT_PORT , connect
27
27
from neo4j .compat .collections import MutableSet , OrderedDict
28
28
from neo4j .exceptions import CypherError
29
- from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS
29
+ from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS , fix_statement , fix_parameters
30
30
from neo4j .v1 .exceptions import SessionExpired
31
31
from neo4j .v1 .security import SecurityPlan
32
32
from neo4j .v1 .session import BoltSession
@@ -150,14 +150,28 @@ def update(self, new_routing_table):
150
150
self .ttl = new_routing_table .ttl
151
151
152
152
153
- class RoutingConnectionPool (ConnectionPool ):
154
- """ Connection pool with routing table.
155
- """
153
+ class RoutingSession (BoltSession ):
156
154
157
155
call_get_servers = "CALL dbms.cluster.routing.getServers"
158
156
get_routing_table_param = "context"
159
157
call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({%s})" % get_routing_table_param
160
158
159
+ def routing_info_procedure (self , routing_context ):
160
+ if ServerVersion .from_str (self ._connection .server .version ).at_least_version (3 , 2 ):
161
+ return self .call_get_routing_table , {self .get_routing_table_param : routing_context }
162
+ else :
163
+ return self .call_get_servers , {}
164
+
165
+ def __run__ (self , ignored , routing_context ):
166
+ # the statement is ignored as it will be get routing table procedure call.
167
+ statement , parameters = self .routing_info_procedure (routing_context )
168
+ return self ._run (fix_statement (statement ), fix_parameters (parameters ))
169
+
170
+
171
+ class RoutingConnectionPool (ConnectionPool ):
172
+ """ Connection pool with routing table.
173
+ """
174
+
161
175
def __init__ (self , connector , initial_address , routing_context , * routers ):
162
176
super (RoutingConnectionPool , self ).__init__ (connector )
163
177
self .initial_address = initial_address
@@ -166,12 +180,6 @@ def __init__(self, connector, initial_address, routing_context, *routers):
166
180
self .missing_writer = False
167
181
self .refresh_lock = Lock ()
168
182
169
- def routing_info_procedure (self , connection ):
170
- if ServerVersion .from_str (connection .server .version ).at_least_version (3 , 2 ):
171
- return self .call_get_routing_table , {self .get_routing_table_param : self .routing_context }
172
- else :
173
- return self .call_get_servers , {}
174
-
175
183
def fetch_routing_info (self , address ):
176
184
""" Fetch raw routing info from a given router address.
177
185
@@ -182,15 +190,8 @@ def fetch_routing_info(self, address):
182
190
if routing support is broken
183
191
"""
184
192
try :
185
- connections = [None ]
186
-
187
- def connector (_ ):
188
- connection = self .acquire_direct (address )
189
- connections [0 ] = connection
190
- return connection
191
-
192
- with BoltSession (lambda _ : connector ) as session :
193
- return list (session .run (* self .routing_info_procedure (connections [0 ])))
193
+ with RoutingSession (lambda _ : self .acquire_direct (address )) as session :
194
+ return list (session .run ("ignored" , self .routing_context ))
194
195
except CypherError as error :
195
196
if error .code == "Neo.ClientError.Procedure.ProcedureNotFound" :
196
197
raise ServiceUnavailable ("Server {!r} does not support routing" .format (address ))
0 commit comments