Skip to content

Commit 3b020eb

Browse files
authored
Merge pull request #109 from neo4j/1.1-server-version
Server address and version
2 parents a39c181 + 6848bb1 commit 3b020eb

File tree

5 files changed

+39
-31
lines changed

5 files changed

+39
-31
lines changed

neo4j/v1/bolt.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from __future__ import division
3030

3131
from base64 import b64encode
32-
from collections import deque
32+
from collections import deque, namedtuple
3333
from io import BytesIO
3434
import logging
3535
from os import makedirs, open as os_open, write as os_write, close as os_close, O_CREAT, O_APPEND, O_WRONLY
@@ -81,12 +81,16 @@
8181
log_error = log.error
8282

8383

84+
Address = namedtuple("Address", ["host", "port"])
85+
ServerInfo = namedtuple("ServerInfo", ["address", "version"])
86+
87+
8488
class BufferingSocket(object):
8589

8690
def __init__(self, connection):
8791
self.connection = connection
8892
self.socket = connection.socket
89-
self.address = self.socket.getpeername()
93+
self.address = Address(*self.socket.getpeername())
9094
self.buffer = bytearray()
9195

9296
def fill(self):
@@ -132,7 +136,7 @@ class ChunkChannel(object):
132136

133137
def __init__(self, sock):
134138
self.socket = sock
135-
self.address = sock.getpeername()
139+
self.address = Address(*sock.getpeername())
136140
self.raw = BytesIO()
137141
self.output_buffer = []
138142
self.output_size = 0
@@ -206,6 +210,22 @@ def on_ignored(self, metadata=None):
206210
pass
207211

208212

213+
class InitResponse(Response):
214+
215+
def on_success(self, metadata):
216+
super(InitResponse, self).on_success(metadata)
217+
connection = self.connection
218+
address = Address(*connection.socket.getpeername())
219+
version = metadata.get("server")
220+
connection.server = ServerInfo(address, version)
221+
222+
def on_failure(self, metadata):
223+
code = metadata.get("code")
224+
error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else
225+
ServiceUnavailable)
226+
raise error(metadata.get("message", "INIT failed"))
227+
228+
209229
class Connection(object):
210230
""" Server connection for Bolt protocol v1.
211231
@@ -222,15 +242,15 @@ class Connection(object):
222242

223243
defunct = False
224244

225-
server_version = None # TODO: remove this when PR#108 is merged
226-
227245
#: The pool of which this connection is a member
228246
pool = None
229247

248+
#: Server version details
249+
server = None
250+
230251
def __init__(self, sock, **config):
231252
self.socket = sock
232253
self.buffering_socket = BufferingSocket(self)
233-
self.address = sock.getpeername()
234254
self.channel = ChunkChannel(sock)
235255
self.packer = Packer(self.channel)
236256
self.unpacker = Unpacker()
@@ -251,19 +271,7 @@ def __init__(self, sock, **config):
251271
# Pick up the server certificate, if any
252272
self.der_encoded_server_certificate = config.get("der_encoded_server_certificate")
253273

254-
def on_success(metadata):
255-
self.server_version = metadata.get("server")
256-
257-
def on_failure(metadata):
258-
code = metadata.get("code")
259-
error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else
260-
ServiceUnavailable)
261-
raise error(metadata.get("message", "INIT failed"))
262-
263-
response = Response(self)
264-
response.on_success = on_success
265-
response.on_failure = on_failure
266-
274+
response = InitResponse(self)
267275
self.append(INIT, (self.user_agent, self.auth_dict), response=response)
268276
self.sync()
269277

@@ -323,18 +331,18 @@ def send(self):
323331
""" Send all queued messages to the server.
324332
"""
325333
if self.closed:
326-
raise ServiceUnavailable("Failed to write to closed connection %r" % (self.address,))
334+
raise ServiceUnavailable("Failed to write to closed connection %r" % (self.server.address,))
327335
if self.defunct:
328-
raise ServiceUnavailable("Failed to write to defunct connection %r" % (self.address,))
336+
raise ServiceUnavailable("Failed to write to defunct connection %r" % (self.server.address,))
329337
self.channel.send()
330338

331339
def fetch(self):
332340
""" Receive exactly one message from the server.
333341
"""
334342
if self.closed:
335-
raise ServiceUnavailable("Failed to read from closed connection %r" % (self.address,))
343+
raise ServiceUnavailable("Failed to read from closed connection %r" % (self.server.address,))
336344
if self.defunct:
337-
raise ServiceUnavailable("Failed to read from defunct connection %r" % (self.address,))
345+
raise ServiceUnavailable("Failed to read from defunct connection %r" % (self.server.address,))
338346
try:
339347
message_data = self.buffering_socket.read_message()
340348
except ProtocolError:

neo4j/v1/routing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from threading import Lock
2323
from time import clock
2424

25-
from .bolt import ConnectionPool
25+
from .bolt import Address, ConnectionPool
2626
from .compat.collections import MutableSet, OrderedDict
2727
from .exceptions import CypherError, ProtocolError, ServiceUnavailable
2828

@@ -94,7 +94,7 @@ def parse_address(cls, address):
9494
""" Convert an address string to a tuple.
9595
"""
9696
host, _, port = address.partition(":")
97-
return host, int(port)
97+
return Address(host, int(port))
9898

9999
@classmethod
100100
def parse_routing_info(cls, records):

test/test_driver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_should_be_able_to_read(self):
159159
result = session.run("RETURN $x", {"x": 1})
160160
for record in result:
161161
assert record["x"] == 1
162-
assert session.connection.address == ('127.0.0.1', 9004)
162+
assert session.connection.server.address == ('127.0.0.1', 9004)
163163

164164
def test_should_be_able_to_write(self):
165165
with StubCluster({9001: "router.script", 9006: "create_a.script"}):
@@ -168,7 +168,7 @@ def test_should_be_able_to_write(self):
168168
with driver.session(WRITE_ACCESS) as session:
169169
result = session.run("CREATE (a $x)", {"x": {"name": "Alice"}})
170170
assert not list(result)
171-
assert session.connection.address == ('127.0.0.1', 9006)
171+
assert session.connection.server.address == ('127.0.0.1', 9006)
172172

173173
def test_should_be_able_to_write_as_default(self):
174174
with StubCluster({9001: "router.script", 9006: "create_a.script"}):
@@ -177,7 +177,7 @@ def test_should_be_able_to_write_as_default(self):
177177
with driver.session() as session:
178178
result = session.run("CREATE (a $x)", {"x": {"name": "Alice"}})
179179
assert not list(result)
180-
assert session.connection.address == ('127.0.0.1', 9006)
180+
assert session.connection.server.address == ('127.0.0.1', 9006)
181181

182182
def test_routing_disconnect_on_run(self):
183183
with StubCluster({9001: "router.script", 9004: "disconnect_on_run.script"}):

test/test_routing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def test_connected_to_reader(self):
575575
with RoutingConnectionPool(connector, address) as pool:
576576
assert not pool.routing_table.is_fresh()
577577
connection = pool.acquire_for_read()
578-
assert connection.address in pool.routing_table.readers
578+
assert connection.server.address in pool.routing_table.readers
579579

580580
def test_should_retry_if_first_reader_fails(self):
581581
with StubCluster({9001: "router.script",
@@ -605,7 +605,7 @@ def test_connected_to_writer(self):
605605
with RoutingConnectionPool(connector, address) as pool:
606606
assert not pool.routing_table.is_fresh()
607607
connection = pool.acquire_for_write()
608-
assert connection.address in pool.routing_table.writers
608+
assert connection.server.address in pool.routing_table.writers
609609

610610
def test_should_retry_if_first_writer_fails(self):
611611
with StubCluster({9001: "router_with_multiple_writers.script",

test/test_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
def get_server_version():
3838
driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=False)
3939
with driver.session() as session:
40-
full_version = session.connection.server_version
40+
full_version = session.connection.server.version
4141
if full_version is None:
4242
return "Neo4j", (3, 0), ()
4343
product, _, tagged_version = full_version.partition("/")

0 commit comments

Comments
 (0)