29
29
from __future__ import division
30
30
31
31
from base64 import b64encode
32
- from collections import deque
32
+ from collections import deque , namedtuple
33
33
from io import BytesIO
34
34
import logging
35
35
from os import makedirs , open as os_open , write as os_write , close as os_close , O_CREAT , O_APPEND , O_WRONLY
81
81
log_error = log .error
82
82
83
83
84
+ Address = namedtuple ("Address" , ["host" , "port" ])
85
+ ServerInfo = namedtuple ("ServerInfo" , ["address" , "version" ])
86
+
87
+
84
88
class BufferingSocket (object ):
85
89
86
90
def __init__ (self , connection ):
87
91
self .connection = connection
88
92
self .socket = connection .socket
89
- self .address = self .socket .getpeername ()
93
+ self .address = Address ( * self .socket .getpeername () )
90
94
self .buffer = bytearray ()
91
95
92
96
def fill (self ):
@@ -132,7 +136,7 @@ class ChunkChannel(object):
132
136
133
137
def __init__ (self , sock ):
134
138
self .socket = sock
135
- self .address = sock .getpeername ()
139
+ self .address = Address ( * sock .getpeername () )
136
140
self .raw = BytesIO ()
137
141
self .output_buffer = []
138
142
self .output_size = 0
@@ -206,6 +210,22 @@ def on_ignored(self, metadata=None):
206
210
pass
207
211
208
212
213
+ class InitResponse (Response ):
214
+
215
+ def on_success (self , metadata ):
216
+ super (InitResponse , self ).on_success (metadata )
217
+ connection = self .connection
218
+ address = Address (* connection .socket .getpeername ())
219
+ version = metadata .get ("server" )
220
+ connection .server = ServerInfo (address , version )
221
+
222
+ def on_failure (self , metadata ):
223
+ code = metadata .get ("code" )
224
+ error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else
225
+ ServiceUnavailable )
226
+ raise error (metadata .get ("message" , "INIT failed" ))
227
+
228
+
209
229
class Connection (object ):
210
230
""" Server connection for Bolt protocol v1.
211
231
@@ -222,15 +242,15 @@ class Connection(object):
222
242
223
243
defunct = False
224
244
225
- server_version = None # TODO: remove this when PR#108 is merged
226
-
227
245
#: The pool of which this connection is a member
228
246
pool = None
229
247
248
+ #: Server version details
249
+ server = None
250
+
230
251
def __init__ (self , sock , ** config ):
231
252
self .socket = sock
232
253
self .buffering_socket = BufferingSocket (self )
233
- self .address = sock .getpeername ()
234
254
self .channel = ChunkChannel (sock )
235
255
self .packer = Packer (self .channel )
236
256
self .unpacker = Unpacker ()
@@ -251,19 +271,7 @@ def __init__(self, sock, **config):
251
271
# Pick up the server certificate, if any
252
272
self .der_encoded_server_certificate = config .get ("der_encoded_server_certificate" )
253
273
254
- def on_success (metadata ):
255
- self .server_version = metadata .get ("server" )
256
-
257
- def on_failure (metadata ):
258
- code = metadata .get ("code" )
259
- error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else
260
- ServiceUnavailable )
261
- raise error (metadata .get ("message" , "INIT failed" ))
262
-
263
- response = Response (self )
264
- response .on_success = on_success
265
- response .on_failure = on_failure
266
-
274
+ response = InitResponse (self )
267
275
self .append (INIT , (self .user_agent , self .auth_dict ), response = response )
268
276
self .sync ()
269
277
@@ -323,18 +331,18 @@ def send(self):
323
331
""" Send all queued messages to the server.
324
332
"""
325
333
if self .closed :
326
- raise ServiceUnavailable ("Failed to write to closed connection %r" % (self .address ,))
334
+ raise ServiceUnavailable ("Failed to write to closed connection %r" % (self .server . address ,))
327
335
if self .defunct :
328
- raise ServiceUnavailable ("Failed to write to defunct connection %r" % (self .address ,))
336
+ raise ServiceUnavailable ("Failed to write to defunct connection %r" % (self .server . address ,))
329
337
self .channel .send ()
330
338
331
339
def fetch (self ):
332
340
""" Receive exactly one message from the server.
333
341
"""
334
342
if self .closed :
335
- raise ServiceUnavailable ("Failed to read from closed connection %r" % (self .address ,))
343
+ raise ServiceUnavailable ("Failed to read from closed connection %r" % (self .server . address ,))
336
344
if self .defunct :
337
- raise ServiceUnavailable ("Failed to read from defunct connection %r" % (self .address ,))
345
+ raise ServiceUnavailable ("Failed to read from defunct connection %r" % (self .server . address ,))
338
346
try :
339
347
message_data = self .buffering_socket .read_message ()
340
348
except ProtocolError :
0 commit comments