Skip to content

Commit cad3718

Browse files
authored
Don't send RESET on READY (clean) connections (#572)
1 parent 6679891 commit cad3718

File tree

5 files changed

+129
-39
lines changed

5 files changed

+129
-39
lines changed

neo4j/io/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class Bolt(abc.ABC):
123123
PROTOCOL_VERSION = None
124124

125125
# flag if connection needs RESET to go back to READY state
126-
_is_reset = True
126+
is_reset = False
127127

128128
# The socket
129129
in_use = False
@@ -460,10 +460,6 @@ def rollback(self, **handlers):
460460
""" Appends a ROLLBACK message to the output queue."""
461461
pass
462462

463-
@property
464-
def is_reset(self):
465-
return self._is_reset
466-
467463
@abc.abstractmethod
468464
def reset(self):
469465
""" Appends a RESET message to the outgoing queue, sends it and consumes

neo4j/io/_bolt3.py

Lines changed: 81 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
2020

21+
from enum import Enum
2122
from logging import getLogger
2223
from ssl import SSLSocket
2324

@@ -52,6 +53,53 @@
5253
log = getLogger("neo4j")
5354

5455

56+
class ServerStates(Enum):
57+
CONNECTED = "CONNECTED"
58+
READY = "READY"
59+
STREAMING = "STREAMING"
60+
TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING"
61+
FAILED = "FAILED"
62+
63+
64+
class ServerStateManager:
65+
_STATE_TRANSITIONS = {
66+
ServerStates.CONNECTED: {
67+
"hello": ServerStates.READY,
68+
},
69+
ServerStates.READY: {
70+
"run": ServerStates.STREAMING,
71+
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
72+
},
73+
ServerStates.STREAMING: {
74+
"pull": ServerStates.READY,
75+
"discard": ServerStates.READY,
76+
"reset": ServerStates.READY,
77+
},
78+
ServerStates.TX_READY_OR_TX_STREAMING: {
79+
"commit": ServerStates.READY,
80+
"rollback": ServerStates.READY,
81+
"reset": ServerStates.READY,
82+
},
83+
ServerStates.FAILED: {
84+
"reset": ServerStates.READY,
85+
}
86+
}
87+
88+
def __init__(self, init_state, on_change=None):
89+
self.state = init_state
90+
self._on_change = on_change
91+
92+
def transition(self, message, metadata):
93+
if metadata.get("has_more"):
94+
return
95+
state_before = self.state
96+
self.state = self._STATE_TRANSITIONS\
97+
.get(self.state, {})\
98+
.get(message, self.state)
99+
if state_before != self.state and callable(self._on_change):
100+
self._on_change(state_before, self.state)
101+
102+
55103
class Bolt3(Bolt):
56104
""" Protocol handler for Bolt 3.
57105
@@ -64,6 +112,25 @@ class Bolt3(Bolt):
64112

65113
supports_multiple_databases = False
66114

115+
def __init__(self, *args, **kwargs):
116+
super().__init__(*args, **kwargs)
117+
self._server_state_manager = ServerStateManager(
118+
ServerStates.CONNECTED, on_change=self._on_server_state_change
119+
)
120+
121+
def _on_server_state_change(self, old_state, new_state):
122+
log.debug("[#%04X] State: %s > %s", self.local_port,
123+
old_state.name, new_state.name)
124+
125+
@property
126+
def is_reset(self):
127+
if self.responses:
128+
# We can't be sure of the server's state as there are still pending
129+
# responses. Unless the last message we sent was RESET. In that case
130+
# the server state will always be READY when we're done.
131+
return self.responses[-1].message == "reset"
132+
return self._server_state_manager.state == ServerStates.READY
133+
67134
@property
68135
def encrypted(self):
69136
return isinstance(self.socket, SSLSocket)
@@ -92,7 +159,8 @@ def hello(self):
92159
logged_headers["credentials"] = "*******"
93160
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
94161
self._append(b"\x01", (headers,),
95-
response=InitResponse(self, on_success=self.server_info.update))
162+
response=InitResponse(self, "hello",
163+
on_success=self.server_info.update))
96164
self.send_all()
97165
self.fetch_all()
98166
check_supported_server_product(self.server_info.agent)
@@ -155,21 +223,20 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
155223
fields = (query, parameters, extra)
156224
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
157225
if query.upper() == u"COMMIT":
158-
self._append(b"\x10", fields, CommitResponse(self, **handlers))
226+
self._append(b"\x10", fields, CommitResponse(self, "run",
227+
**handlers))
159228
else:
160-
self._append(b"\x10", fields, Response(self, **handlers))
161-
self._is_reset = False
229+
self._append(b"\x10", fields, Response(self, "run", **handlers))
162230

163231
def discard(self, n=-1, qid=-1, **handlers):
164232
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
165233
log.debug("[#%04X] C: DISCARD_ALL", self.local_port)
166-
self._append(b"\x2F", (), Response(self, **handlers))
234+
self._append(b"\x2F", (), Response(self, "discard", **handlers))
167235

168236
def pull(self, n=-1, qid=-1, **handlers):
169237
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
170238
log.debug("[#%04X] C: PULL_ALL", self.local_port)
171-
self._append(b"\x3F", (), Response(self, **handlers))
172-
self._is_reset = False
239+
self._append(b"\x3F", (), Response(self, "pull", **handlers))
173240

174241
def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
175242
if db is not None:
@@ -193,16 +260,15 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None,
193260
except TypeError:
194261
raise TypeError("Timeout must be specified as a number of seconds")
195262
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
196-
self._append(b"\x11", (extra,), Response(self, **handlers))
197-
self._is_reset = False
263+
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
198264

199265
def commit(self, **handlers):
200266
log.debug("[#%04X] C: COMMIT", self.local_port)
201-
self._append(b"\x12", (), CommitResponse(self, **handlers))
267+
self._append(b"\x12", (), CommitResponse(self, "commit", **handlers))
202268

203269
def rollback(self, **handlers):
204270
log.debug("[#%04X] C: ROLLBACK", self.local_port)
205-
self._append(b"\x13", (), Response(self, **handlers))
271+
self._append(b"\x13", (), Response(self, "rollback", **handlers))
206272

207273
def reset(self):
208274
""" Add a RESET message to the outgoing queue, send
@@ -213,10 +279,9 @@ def fail(metadata):
213279
raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address)
214280

215281
log.debug("[#%04X] C: RESET", self.local_port)
216-
self._append(b"\x0F", response=Response(self, on_failure=fail))
282+
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
217283
self.send_all()
218284
self.fetch_all()
219-
self._is_reset = True
220285

221286
def fetch_message(self):
222287
""" Receive at most one message from the server, if available.
@@ -249,12 +314,15 @@ def fetch_message(self):
249314
response.complete = True
250315
if summary_signature == b"\x70":
251316
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
317+
self._server_state_manager.transition(response.message,
318+
summary_metadata)
252319
response.on_success(summary_metadata or {})
253320
elif summary_signature == b"\x7E":
254321
log.debug("[#%04X] S: IGNORED", self.local_port)
255322
response.on_ignored(summary_metadata or {})
256323
elif summary_signature == b"\x7F":
257324
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
325+
self._server_state_manager.state = ServerStates.FAILED
258326
try:
259327
response.on_failure(summary_metadata or {})
260328
except (ServiceUnavailable, DatabaseUnavailable):

neo4j/io/_bolt4.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
2020

21+
from enum import Enum
2122
from logging import getLogger
2223
from ssl import SSLSocket
2324

@@ -37,7 +38,6 @@
3738
Neo4jError,
3839
NotALeader,
3940
ServiceUnavailable,
40-
SessionExpired,
4141
)
4242
from neo4j.io import (
4343
Bolt,
@@ -48,6 +48,10 @@
4848
InitResponse,
4949
Response,
5050
)
51+
from neo4j.io._bolt3 import (
52+
ServerStateManager,
53+
ServerStates,
54+
)
5155

5256

5357
log = getLogger("neo4j")
@@ -65,6 +69,25 @@ class Bolt4x0(Bolt):
6569

6670
supports_multiple_databases = True
6771

72+
def __init__(self, *args, **kwargs):
73+
super().__init__(*args, **kwargs)
74+
self._server_state_manager = ServerStateManager(
75+
ServerStates.CONNECTED, on_change=self._on_server_state_change
76+
)
77+
78+
def _on_server_state_change(self, old_state, new_state):
79+
log.debug("[#%04X] State: %s > %s", self.local_port,
80+
old_state.name, new_state.name)
81+
82+
@property
83+
def is_reset(self):
84+
if self.responses:
85+
# We can't be sure of the server's state as there are still pending
86+
# responses. Unless the last message we sent was RESET. In that case
87+
# the server state will always be READY when we're done.
88+
return self.responses[-1].message == "reset"
89+
return self._server_state_manager.state == ServerStates.READY
90+
6891
@property
6992
def encrypted(self):
7093
return isinstance(self.socket, SSLSocket)
@@ -93,7 +116,8 @@ def hello(self):
93116
logged_headers["credentials"] = "*******"
94117
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
95118
self._append(b"\x01", (headers,),
96-
response=InitResponse(self, on_success=self.server_info.update))
119+
response=InitResponse(self, "hello",
120+
on_success=self.server_info.update))
97121
self.send_all()
98122
self.fetch_all()
99123
check_supported_server_product(self.server_info.agent)
@@ -162,25 +186,24 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
162186
fields = (query, parameters, extra)
163187
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
164188
if query.upper() == u"COMMIT":
165-
self._append(b"\x10", fields, CommitResponse(self, **handlers))
189+
self._append(b"\x10", fields, CommitResponse(self, "run",
190+
**handlers))
166191
else:
167-
self._append(b"\x10", fields, Response(self, **handlers))
168-
self._is_reset = False
192+
self._append(b"\x10", fields, Response(self, "run", **handlers))
169193

170194
def discard(self, n=-1, qid=-1, **handlers):
171195
extra = {"n": n}
172196
if qid != -1:
173197
extra["qid"] = qid
174198
log.debug("[#%04X] C: DISCARD %r", self.local_port, extra)
175-
self._append(b"\x2F", (extra,), Response(self, **handlers))
199+
self._append(b"\x2F", (extra,), Response(self, "discard", **handlers))
176200

177201
def pull(self, n=-1, qid=-1, **handlers):
178202
extra = {"n": n}
179203
if qid != -1:
180204
extra["qid"] = qid
181205
log.debug("[#%04X] C: PULL %r", self.local_port, extra)
182-
self._append(b"\x3F", (extra,), Response(self, **handlers))
183-
self._is_reset = False
206+
self._append(b"\x3F", (extra,), Response(self, "pull", **handlers))
184207

185208
def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
186209
db=None, **handlers):
@@ -205,16 +228,15 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
205228
except TypeError:
206229
raise TypeError("Timeout must be specified as a number of seconds")
207230
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
208-
self._append(b"\x11", (extra,), Response(self, **handlers))
209-
self._is_reset = False
231+
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
210232

211233
def commit(self, **handlers):
212234
log.debug("[#%04X] C: COMMIT", self.local_port)
213-
self._append(b"\x12", (), CommitResponse(self, **handlers))
235+
self._append(b"\x12", (), CommitResponse(self, "commit", **handlers))
214236

215237
def rollback(self, **handlers):
216238
log.debug("[#%04X] C: ROLLBACK", self.local_port)
217-
self._append(b"\x13", (), Response(self, **handlers))
239+
self._append(b"\x13", (), Response(self, "rollback", **handlers))
218240

219241
def reset(self):
220242
""" Add a RESET message to the outgoing queue, send
@@ -225,10 +247,9 @@ def fail(metadata):
225247
raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address)
226248

227249
log.debug("[#%04X] C: RESET", self.local_port)
228-
self._append(b"\x0F", response=Response(self, on_failure=fail))
250+
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
229251
self.send_all()
230252
self.fetch_all()
231-
self._is_reset = True
232253

233254
def fetch_message(self):
234255
""" Receive at most one message from the server, if available.
@@ -261,12 +282,15 @@ def fetch_message(self):
261282
response.complete = True
262283
if summary_signature == b"\x70":
263284
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
285+
self._server_state_manager.transition(response.message,
286+
summary_metadata)
264287
response.on_success(summary_metadata or {})
265288
elif summary_signature == b"\x7E":
266289
log.debug("[#%04X] S: IGNORED", self.local_port)
267290
response.on_ignored(summary_metadata or {})
268291
elif summary_signature == b"\x7F":
269292
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
293+
self._server_state_manager.state = ServerStates.FAILED
270294
try:
271295
response.on_failure(summary_metadata or {})
272296
except (ServiceUnavailable, DatabaseUnavailable):
@@ -372,7 +396,9 @@ def fail(md):
372396
else:
373397
bookmarks = list(bookmarks)
374398
self._append(b"\x66", (routing_context, bookmarks, database),
375-
response=Response(self, on_success=metadata.update, on_failure=fail))
399+
response=Response(self, "route",
400+
on_success=metadata.update,
401+
on_failure=fail))
376402
self.send_all()
377403
self.fetch_all()
378404
return [metadata.get("rt")]
@@ -400,7 +426,8 @@ def on_success(metadata):
400426
logged_headers["credentials"] = "*******"
401427
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
402428
self._append(b"\x01", (headers,),
403-
response=InitResponse(self, on_success=on_success))
429+
response=InitResponse(self, "hello",
430+
on_success=on_success))
404431
self.send_all()
405432
self.fetch_all()
406433
check_supported_server_product(self.server_info.agent)

neo4j/io/_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,10 @@ class Response:
144144
more detail messages followed by one summary message).
145145
"""
146146

147-
def __init__(self, connection, **handlers):
147+
def __init__(self, connection, message, **handlers):
148148
self.connection = connection
149149
self.handlers = handlers
150+
self.message = message
150151
self.complete = False
151152

152153
def on_records(self, records):

testkitbackend/test_config.json

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,12 @@
2929
"stub.session_run_parameters.test_session_run_parameters.TestSessionRunParameters.test_empty_query":
3030
"Driver rejects empty queries before sending it to the server",
3131
"tls.tlsversions.TestTlsVersions.test_1_1":
32-
"TLSv1.1 and below are disabled in the driver",
33-
"stub.disconnects.test_disconnects.TestDisconnects.test_fail_on_reset":
34-
"Driver silently ignores all errors on releasing connections back into the pool."
32+
"TLSv1.1 and below are disabled in the driver"
3533
},
3634
"features": {
3735
"AuthorizationExpiredTreatment": true,
3836
"Optimization:ImplicitDefaultArguments": true,
39-
"Optimization:MinimalResets": "Driver resets some clean connections when put back into pool",
37+
"Optimization:MinimalResets": true,
4038
"Optimization:ConnectionReuse": true,
4139
"Optimization:PullPipelining": true,
4240
"ConfHint:connection.recv_timeout_seconds": true,

0 commit comments

Comments
 (0)