@@ -61,28 +61,43 @@ class ServerStates(Enum):
6161 FAILED = "FAILED"
6262
6363
64- STATE_TRANSITIONS = {
65- ServerStates .CONNECTED : {
66- "hello" : ServerStates .READY ,
67- },
68- ServerStates .READY : {
69- "run" : ServerStates .STREAMING ,
70- "begin" : ServerStates .TX_READY_OR_TX_STREAMING ,
71- },
72- ServerStates .STREAMING : {
73- "pull" : ServerStates .READY ,
74- "discard" : ServerStates .READY ,
75- "reset" : ServerStates .READY ,
76- },
77- ServerStates .TX_READY_OR_TX_STREAMING : {
78- "commit" : ServerStates .READY ,
79- "rollback" : ServerStates .READY ,
80- "reset" : ServerStates .READY ,
81- },
82- ServerStates .FAILED : {
83- "reset" : ServerStates .READY ,
64+ class ServerStateManager :
65+ _STATE_TRANSITIONS = {
66+ ServerStates .CONNECTED : {
67+ "hello" : ServerStates .READY ,
68+ },
69+ ServerStates .READY : {
70+ "run" : ServerStates .STREAMING ,
71+ "begin" : ServerStates .TX_READY_OR_TX_STREAMING ,
72+ },
73+ ServerStates .STREAMING : {
74+ "pull" : ServerStates .READY ,
75+ "discard" : ServerStates .READY ,
76+ "reset" : ServerStates .READY ,
77+ },
78+ ServerStates .TX_READY_OR_TX_STREAMING : {
79+ "commit" : ServerStates .READY ,
80+ "rollback" : ServerStates .READY ,
81+ "reset" : ServerStates .READY ,
82+ },
83+ ServerStates .FAILED : {
84+ "reset" : ServerStates .READY ,
85+ }
8486 }
85- }
87+
88+ def __init__ (self , init_state , on_change = None ):
89+ self .state = init_state
90+ self ._on_change = on_change
91+
92+ def transition (self , metadata , message ):
93+ if metadata .get ("has_more" ):
94+ return
95+ state_before = self .state
96+ self .state = self ._STATE_TRANSITIONS \
97+ .get (self .state , {})\
98+ .get (message , self .state )
99+ if state_before != self .state and callable (self ._on_change ):
100+ self ._on_change (state_before , self .state )
86101
87102
88103class Bolt3 (Bolt ):
@@ -97,15 +112,23 @@ class Bolt3(Bolt):
97112
98113 supports_multiple_databases = False
99114
100- _server_state = ServerStates .CONNECTED
115+ def __init__ (self , * args , ** kwargs ):
116+ super ().__init__ (* args , ** kwargs )
117+ self ._server_state_manager = ServerStateManager (
118+ ServerStates .CONNECTED , on_change = self ._on_server_state_change
119+ )
120+
121+ def _on_server_state_change (self , old_state , new_state ):
122+ log .debug ("[#%04X] State: %s > %s" , self .local_port ,
123+ old_state .name , new_state .name )
101124
102125 @property
103126 def is_reset (self ):
104127 if self .responses :
105128 # we can't be sure of the server's state as there are still pending
106129 # responses.
107130 return False
108- return self ._server_state == ServerStates .READY
131+ return self ._server_state_manager . state == ServerStates .READY
109132
110133 @property
111134 def encrypted (self ):
@@ -213,7 +236,6 @@ def pull(self, n=-1, qid=-1, **handlers):
213236 # Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
214237 log .debug ("[#%04X] C: PULL_ALL" , self .local_port )
215238 self ._append (b"\x3F " , (), Response (self , "pull" , ** handlers ))
216- self ._is_reset = False
217239
218240 def begin (self , mode = None , bookmarks = None , metadata = None , timeout = None , db = None , ** handlers ):
219241 if db is not None :
@@ -238,7 +260,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None,
238260 raise TypeError ("Timeout must be specified as a number of seconds" )
239261 log .debug ("[#%04X] C: BEGIN %r" , self .local_port , extra )
240262 self ._append (b"\x11 " , (extra ,), Response (self , "begin" , ** handlers ))
241- self ._is_reset = False
242263
243264 def commit (self , ** handlers ):
244265 log .debug ("[#%04X] C: COMMIT" , self .local_port )
@@ -260,18 +281,6 @@ def fail(metadata):
260281 self ._append (b"\x0F " , response = Response (self , "reset" , on_failure = fail ))
261282 self .send_all ()
262283 self .fetch_all ()
263- self ._is_reset = True
264-
265- def _update_server_state_on_success (self , metadata , message ):
266- if metadata .get ("has_more" ):
267- return
268- state_before = self ._server_state
269- self ._server_state = STATE_TRANSITIONS \
270- .get (self ._server_state , {})\
271- .get (message , self ._server_state )
272- if state_before != self ._server_state :
273- log .debug ("[#%04X] State: %s" , self .local_port ,
274- self ._server_state .name )
275284
276285 def fetch_message (self ):
277286 """ Receive at most one message from the server, if available.
@@ -304,15 +313,15 @@ def fetch_message(self):
304313 response .complete = True
305314 if summary_signature == b"\x70 " :
306315 log .debug ("[#%04X] S: SUCCESS %r" , self .local_port , summary_metadata )
307- self ._update_server_state_on_success ( summary_metadata ,
308- response . message )
316+ self ._server_state_manager . transition ( response . message ,
317+ summary_metadata )
309318 response .on_success (summary_metadata or {})
310319 elif summary_signature == b"\x7E " :
311320 log .debug ("[#%04X] S: IGNORED" , self .local_port )
312321 response .on_ignored (summary_metadata or {})
313322 elif summary_signature == b"\x7F " :
314323 log .debug ("[#%04X] S: FAILURE %r" , self .local_port , summary_metadata )
315- self ._server_state = ServerStates .FAILED
324+ self ._server_state_manager . state = ServerStates .FAILED
316325 try :
317326 response .on_failure (summary_metadata or {})
318327 except (ServiceUnavailable , DatabaseUnavailable ):
0 commit comments