18
18
# See the License for the specific language governing permissions and
19
19
# limitations under the License.
20
20
21
+ from enum import Enum
21
22
from logging import getLogger
22
23
from ssl import SSLSocket
23
24
52
53
log = getLogger ("neo4j" )
53
54
54
55
56
+ class ServerStates (Enum ):
57
+ CONNECTED = "CONNECTED"
58
+ READY = "READY"
59
+ STREAMING = "STREAMING"
60
+ TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING"
61
+ FAILED = "FAILED"
62
+
63
+
64
+ class ServerStateManager :
65
+ _STATE_TRANSITIONS = {
66
+ ServerStates .CONNECTED : {
67
+ "hello" : ServerStates .READY ,
68
+ },
69
+ ServerStates .READY : {
70
+ "run" : ServerStates .STREAMING ,
71
+ "begin" : ServerStates .TX_READY_OR_TX_STREAMING ,
72
+ },
73
+ ServerStates .STREAMING : {
74
+ "pull" : ServerStates .READY ,
75
+ "discard" : ServerStates .READY ,
76
+ "reset" : ServerStates .READY ,
77
+ },
78
+ ServerStates .TX_READY_OR_TX_STREAMING : {
79
+ "commit" : ServerStates .READY ,
80
+ "rollback" : ServerStates .READY ,
81
+ "reset" : ServerStates .READY ,
82
+ },
83
+ ServerStates .FAILED : {
84
+ "reset" : ServerStates .READY ,
85
+ }
86
+ }
87
+
88
+ def __init__ (self , init_state , on_change = None ):
89
+ self .state = init_state
90
+ self ._on_change = on_change
91
+
92
+ def transition (self , message , metadata ):
93
+ if metadata .get ("has_more" ):
94
+ return
95
+ state_before = self .state
96
+ self .state = self ._STATE_TRANSITIONS \
97
+ .get (self .state , {})\
98
+ .get (message , self .state )
99
+ if state_before != self .state and callable (self ._on_change ):
100
+ self ._on_change (state_before , self .state )
101
+
102
+
55
103
class Bolt3 (Bolt ):
56
104
""" Protocol handler for Bolt 3.
57
105
@@ -64,6 +112,25 @@ class Bolt3(Bolt):
64
112
65
113
supports_multiple_databases = False
66
114
115
+ def __init__ (self , * args , ** kwargs ):
116
+ super ().__init__ (* args , ** kwargs )
117
+ self ._server_state_manager = ServerStateManager (
118
+ ServerStates .CONNECTED , on_change = self ._on_server_state_change
119
+ )
120
+
121
+ def _on_server_state_change (self , old_state , new_state ):
122
+ log .debug ("[#%04X] State: %s > %s" , self .local_port ,
123
+ old_state .name , new_state .name )
124
+
125
+ @property
126
+ def is_reset (self ):
127
+ if self .responses :
128
+ # We can't be sure of the server's state as there are still pending
129
+ # responses. Unless the last message we sent was RESET. In that case
130
+ # the server state will always be READY when we're done.
131
+ return self .responses [- 1 ].message == "reset"
132
+ return self ._server_state_manager .state == ServerStates .READY
133
+
67
134
@property
68
135
def encrypted (self ):
69
136
return isinstance (self .socket , SSLSocket )
@@ -92,7 +159,8 @@ def hello(self):
92
159
logged_headers ["credentials" ] = "*******"
93
160
log .debug ("[#%04X] C: HELLO %r" , self .local_port , logged_headers )
94
161
self ._append (b"\x01 " , (headers ,),
95
- response = InitResponse (self , on_success = self .server_info .update ))
162
+ response = InitResponse (self , "hello" ,
163
+ on_success = self .server_info .update ))
96
164
self .send_all ()
97
165
self .fetch_all ()
98
166
check_supported_server_product (self .server_info .agent )
@@ -155,21 +223,20 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
155
223
fields = (query , parameters , extra )
156
224
log .debug ("[#%04X] C: RUN %s" , self .local_port , " " .join (map (repr , fields )))
157
225
if query .upper () == u"COMMIT" :
158
- self ._append (b"\x10 " , fields , CommitResponse (self , ** handlers ))
226
+ self ._append (b"\x10 " , fields , CommitResponse (self , "run" ,
227
+ ** handlers ))
159
228
else :
160
- self ._append (b"\x10 " , fields , Response (self , ** handlers ))
161
- self ._is_reset = False
229
+ self ._append (b"\x10 " , fields , Response (self , "run" , ** handlers ))
162
230
163
231
def discard (self , n = - 1 , qid = - 1 , ** handlers ):
164
232
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
165
233
log .debug ("[#%04X] C: DISCARD_ALL" , self .local_port )
166
- self ._append (b"\x2F " , (), Response (self , ** handlers ))
234
+ self ._append (b"\x2F " , (), Response (self , "discard" , ** handlers ))
167
235
168
236
def pull (self , n = - 1 , qid = - 1 , ** handlers ):
169
237
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
170
238
log .debug ("[#%04X] C: PULL_ALL" , self .local_port )
171
- self ._append (b"\x3F " , (), Response (self , ** handlers ))
172
- self ._is_reset = False
239
+ self ._append (b"\x3F " , (), Response (self , "pull" , ** handlers ))
173
240
174
241
def begin (self , mode = None , bookmarks = None , metadata = None , timeout = None , db = None , ** handlers ):
175
242
if db is not None :
@@ -193,16 +260,15 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None,
193
260
except TypeError :
194
261
raise TypeError ("Timeout must be specified as a number of seconds" )
195
262
log .debug ("[#%04X] C: BEGIN %r" , self .local_port , extra )
196
- self ._append (b"\x11 " , (extra ,), Response (self , ** handlers ))
197
- self ._is_reset = False
263
+ self ._append (b"\x11 " , (extra ,), Response (self , "begin" , ** handlers ))
198
264
199
265
def commit (self , ** handlers ):
200
266
log .debug ("[#%04X] C: COMMIT" , self .local_port )
201
- self ._append (b"\x12 " , (), CommitResponse (self , ** handlers ))
267
+ self ._append (b"\x12 " , (), CommitResponse (self , "commit" , ** handlers ))
202
268
203
269
def rollback (self , ** handlers ):
204
270
log .debug ("[#%04X] C: ROLLBACK" , self .local_port )
205
- self ._append (b"\x13 " , (), Response (self , ** handlers ))
271
+ self ._append (b"\x13 " , (), Response (self , "rollback" , ** handlers ))
206
272
207
273
def reset (self ):
208
274
""" Add a RESET message to the outgoing queue, send
@@ -213,10 +279,9 @@ def fail(metadata):
213
279
raise BoltProtocolError ("RESET failed %r" % metadata , address = self .unresolved_address )
214
280
215
281
log .debug ("[#%04X] C: RESET" , self .local_port )
216
- self ._append (b"\x0F " , response = Response (self , on_failure = fail ))
282
+ self ._append (b"\x0F " , response = Response (self , "reset" , on_failure = fail ))
217
283
self .send_all ()
218
284
self .fetch_all ()
219
- self ._is_reset = True
220
285
221
286
def fetch_message (self ):
222
287
""" Receive at most one message from the server, if available.
@@ -249,12 +314,15 @@ def fetch_message(self):
249
314
response .complete = True
250
315
if summary_signature == b"\x70 " :
251
316
log .debug ("[#%04X] S: SUCCESS %r" , self .local_port , summary_metadata )
317
+ self ._server_state_manager .transition (response .message ,
318
+ summary_metadata )
252
319
response .on_success (summary_metadata or {})
253
320
elif summary_signature == b"\x7E " :
254
321
log .debug ("[#%04X] S: IGNORED" , self .local_port )
255
322
response .on_ignored (summary_metadata or {})
256
323
elif summary_signature == b"\x7F " :
257
324
log .debug ("[#%04X] S: FAILURE %r" , self .local_port , summary_metadata )
325
+ self ._server_state_manager .state = ServerStates .FAILED
258
326
try :
259
327
response .on_failure (summary_metadata or {})
260
328
except (ServiceUnavailable , DatabaseUnavailable ):
0 commit comments