|
40 | 40 | from .constants import DEFAULT_PORT, ENCRYPTION_DEFAULT, TRUST_DEFAULT, TRUST_SIGNED_CERTIFICATES, \ |
41 | 41 | TRUST_ON_FIRST_USE, READ_ACCESS, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, \ |
42 | 42 | TRUST_ALL_CERTIFICATES, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES |
43 | | -from .exceptions import CypherError, ProtocolError, ResultError, TransactionError |
| 43 | +from .exceptions import CypherError, ProtocolError, ResultError, TransactionError, \ |
| 44 | + ServiceUnavailable |
44 | 45 | from .ssl_compat import SSL_AVAILABLE, SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED |
45 | 46 | from .summary import ResultSummary |
46 | 47 | from .types import hydrated |
@@ -196,53 +197,88 @@ def session(self, access_mode=None): |
196 | 197 | return Session(self.pool.acquire(self.address)) |
197 | 198 |
|
198 | 199 |
|
199 | | -class RoutingDriver(Driver): |
200 | | - """ A :class:`.RoutingDriver` is created from a `bolt+routing` URI. |
| 200 | +def parse_address(address): |
| 201 | + """ Convert an address string to a tuple. |
201 | 202 | """ |
| 203 | + host, _, port = address.partition(":") |
| 204 | + return host, int(port) |
202 | 205 |
|
203 | | - def __init__(self, address, **config): |
204 | | - self.address = address |
205 | | - self.security_plan = security_plan = SecurityPlan.build(address, **config) |
206 | | - self.encrypted = security_plan.encrypted |
207 | | - if not security_plan.routing_compatible: |
208 | | - # this error message is case-specific as there is only one incompatible |
209 | | - # scenario right now |
210 | | - raise RuntimeError("TRUST_ON_FIRST_USE is not compatible with routing") |
211 | | - Driver.__init__(self, lambda a: connect(a, security_plan.ssl_context, **config)) |
212 | | - self._lock = Lock() |
213 | | - self._expiry_time = None |
214 | | - self._routers = RoundRobinSet([address]) |
215 | | - self._readers = RoundRobinSet() |
216 | | - self._writers = RoundRobinSet() |
217 | | - self.discover() |
| 206 | + |
| 207 | +class Router(object): |
| 208 | + |
| 209 | + timer = monotonic |
| 210 | + |
| 211 | + def __init__(self, pool, initial_address): |
| 212 | + self.pool = pool |
| 213 | + self.lock = Lock() |
| 214 | + self.expiry_time = None |
| 215 | + self.routers = RoundRobinSet([initial_address]) |
| 216 | + self.readers = RoundRobinSet() |
| 217 | + self.writers = RoundRobinSet() |
| 218 | + |
| 219 | + def stale(self): |
| 220 | + expired = self.expiry_time is None or self.expiry_time <= self.timer() |
| 221 | + return expired or len(self.routers) <= 1 or not self.readers or not self.writers |
218 | 222 |
|
219 | 223 | def discover(self): |
220 | | - with self._lock: |
221 | | - for router in list(self._routers): |
| 224 | + with self.lock: |
| 225 | + if not self.routers: |
| 226 | + raise ServiceUnavailable("No routers available") |
| 227 | + for router in list(self.routers): |
222 | 228 | session = Session(self.pool.acquire(router)) |
223 | 229 | try: |
224 | 230 | record = session.run("CALL dbms.cluster.routing.getServers").single() |
| 231 | + except CypherError as error: |
| 232 | + if error.code == "Neo.ClientError.Procedure.ProcedureNotFound": |
| 233 | + raise ServiceUnavailable("Server does not support routing") |
| 234 | + raise |
225 | 235 | except ResultError: |
226 | 236 | raise RuntimeError("TODO") |
227 | | - new_expiry_time = monotonic() + record["ttl"] |
| 237 | + new_expiry_time = self.timer() + record["ttl"] |
228 | 238 | servers = record["servers"] |
229 | 239 | new_routers = [s["addresses"] for s in servers if s["role"] == "ROUTE"][0] |
230 | 240 | new_readers = [s["addresses"] for s in servers if s["role"] == "READ"][0] |
231 | 241 | new_writers = [s["addresses"] for s in servers if s["role"] == "WRITE"][0] |
232 | 242 | if new_routers and new_readers and new_writers: |
233 | | - self._expiry_time = new_expiry_time |
234 | | - self._routers.replace(new_routers) |
235 | | - self._readers.replace(new_readers) |
236 | | - self._writers.replace(new_writers) |
237 | | - else: |
238 | | - raise RuntimeError("TODO") |
| 243 | + self.expiry_time = new_expiry_time |
| 244 | + self.routers.replace(map(parse_address, new_routers)) |
| 245 | + self.readers.replace(map(parse_address, new_readers)) |
| 246 | + self.writers.replace(map(parse_address, new_writers)) |
| 247 | + return |
| 248 | + raise ServiceUnavailable("Unable to establish routing information") |
| 249 | + |
| 250 | + def acquire_read_connection(self): |
| 251 | + if self.stale(): |
| 252 | + self.discover() |
| 253 | + return self.pool.acquire(next(self.readers)) |
| 254 | + |
| 255 | + def acquire_write_connection(self): |
| 256 | + if self.stale(): |
| 257 | + self.discover() |
| 258 | + return self.pool.acquire(next(self.writers)) |
| 259 | + |
| 260 | + |
| 261 | +class RoutingDriver(Driver): |
| 262 | + """ A :class:`.RoutingDriver` is created from a `bolt+routing` URI. |
| 263 | + """ |
| 264 | + |
| 265 | + def __init__(self, address, **config): |
| 266 | + self.security_plan = security_plan = SecurityPlan.build(address, **config) |
| 267 | + self.encrypted = security_plan.encrypted |
| 268 | + if not security_plan.routing_compatible: |
| 269 | + # this error message is case-specific as there is only one incompatible |
| 270 | + # scenario right now |
| 271 | + raise RuntimeError("TRUST_ON_FIRST_USE is not compatible with routing") |
| 272 | + Driver.__init__(self, lambda a: connect(a, security_plan.ssl_context, **config)) |
| 273 | + self.router = Router(self.pool, address) |
| 274 | + self.router.discover() |
239 | 275 |
|
240 | 276 | def session(self, access_mode=None): |
241 | 277 | if access_mode == READ_ACCESS: |
242 | | - address = next(self._readers) |
| 278 | + connection = self.router.acquire_read_connection() |
243 | 279 | else: |
244 | | - address = next(self._writers) |
245 | | - return Session(self.pool.acquire(address)) |
| 280 | + connection = self.router.acquire_write_connection() |
| 281 | + return Session(connection) |
246 | 282 |
|
247 | 283 |
|
248 | 284 | class StatementResult(object): |
|
0 commit comments