Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class Bolt(abc.ABC):
PROTOCOL_VERSION = None

# flag if connection needs RESET to go back to READY state
_is_reset = True
is_reset = False

# The socket
in_use = False
Expand Down Expand Up @@ -460,10 +460,6 @@ def rollback(self, **handlers):
""" Appends a ROLLBACK message to the output queue."""
pass

@property
def is_reset(self):
return self._is_reset

@abc.abstractmethod
def reset(self):
""" Appends a RESET message to the outgoing queue, sends it and consumes
Expand Down
78 changes: 68 additions & 10 deletions neo4j/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from logging import getLogger
from ssl import SSLSocket

Expand Down Expand Up @@ -52,6 +53,38 @@
log = getLogger("neo4j")


class ServerStates(Enum):
CONNECTED = "CONNECTED"
READY = "READY"
STREAMING = "STREAMING"
TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING"
FAILED = "FAILED"


STATE_TRANSITIONS = {
ServerStates.CONNECTED: {
"hello": ServerStates.READY,
},
ServerStates.READY: {
"run": ServerStates.STREAMING,
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
},
ServerStates.STREAMING: {
"pull": ServerStates.READY,
"discard": ServerStates.READY,
"reset": ServerStates.READY,
},
ServerStates.TX_READY_OR_TX_STREAMING: {
"commit": ServerStates.READY,
"rollback": ServerStates.READY,
"reset": ServerStates.READY,
},
ServerStates.FAILED: {
"reset": ServerStates.READY,
}
}


class Bolt3(Bolt):
""" Protocol handler for Bolt 3.

Expand All @@ -64,6 +97,16 @@ class Bolt3(Bolt):

supports_multiple_databases = False

_server_state = ServerStates.CONNECTED

@property
def is_reset(self):
if self.responses:
# we can't be sure of the server's state as there are still pending
# responses.
return False
return self._server_state == ServerStates.READY

@property
def encrypted(self):
return isinstance(self.socket, SSLSocket)
Expand Down Expand Up @@ -92,7 +135,8 @@ def hello(self):
logged_headers["credentials"] = "*******"
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
self._append(b"\x01", (headers,),
response=InitResponse(self, on_success=self.server_info.update))
response=InitResponse(self, "hello",
on_success=self.server_info.update))
self.send_all()
self.fetch_all()
check_supported_server_product(self.server_info.agent)
Expand Down Expand Up @@ -155,20 +199,20 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
fields = (query, parameters, extra)
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
if query.upper() == u"COMMIT":
self._append(b"\x10", fields, CommitResponse(self, **handlers))
self._append(b"\x10", fields, CommitResponse(self, "run",
**handlers))
else:
self._append(b"\x10", fields, Response(self, **handlers))
self._is_reset = False
self._append(b"\x10", fields, Response(self, "run", **handlers))

def discard(self, n=-1, qid=-1, **handlers):
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
log.debug("[#%04X] C: DISCARD_ALL", self.local_port)
self._append(b"\x2F", (), Response(self, **handlers))
self._append(b"\x2F", (), Response(self, "discard", **handlers))

def pull(self, n=-1, qid=-1, **handlers):
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
log.debug("[#%04X] C: PULL_ALL", self.local_port)
self._append(b"\x3F", (), Response(self, **handlers))
self._append(b"\x3F", (), Response(self, "pull", **handlers))
self._is_reset = False

def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
Expand All @@ -193,16 +237,16 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None,
except TypeError:
raise TypeError("Timeout must be specified as a number of seconds")
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
self._append(b"\x11", (extra,), Response(self, **handlers))
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
self._is_reset = False

def commit(self, **handlers):
log.debug("[#%04X] C: COMMIT", self.local_port)
self._append(b"\x12", (), CommitResponse(self, **handlers))
self._append(b"\x12", (), CommitResponse(self, "commit", **handlers))

def rollback(self, **handlers):
log.debug("[#%04X] C: ROLLBACK", self.local_port)
self._append(b"\x13", (), Response(self, **handlers))
self._append(b"\x13", (), Response(self, "rollback", **handlers))

def reset(self):
""" Add a RESET message to the outgoing queue, send
Expand All @@ -213,11 +257,22 @@ def fail(metadata):
raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address)

log.debug("[#%04X] C: RESET", self.local_port)
self._append(b"\x0F", response=Response(self, on_failure=fail))
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
self.send_all()
self.fetch_all()
self._is_reset = True

def _update_server_state_on_success(self, metadata, message):
if metadata.get("has_more"):
return
state_before = self._server_state
self._server_state = STATE_TRANSITIONS\
.get(self._server_state, {})\
.get(message, self._server_state)
if state_before != self._server_state:
log.debug("[#%04X] State: %s", self.local_port,
self._server_state.name)

def fetch_message(self):
""" Receive at most one message from the server, if available.

Expand Down Expand Up @@ -249,12 +304,15 @@ def fetch_message(self):
response.complete = True
if summary_signature == b"\x70":
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
self._update_server_state_on_success(summary_metadata,
response.message)
response.on_success(summary_metadata or {})
elif summary_signature == b"\x7E":
log.debug("[#%04X] S: IGNORED", self.local_port)
response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state = ServerStates.FAILED
try:
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand Down
57 changes: 45 additions & 12 deletions neo4j/io/_bolt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from logging import getLogger
from ssl import SSLSocket

Expand All @@ -37,7 +38,6 @@
Neo4jError,
NotALeader,
ServiceUnavailable,
SessionExpired,
)
from neo4j.io import (
Bolt,
Expand All @@ -48,6 +48,10 @@
InitResponse,
Response,
)
from neo4j.io._bolt3 import (
ServerStates,
STATE_TRANSITIONS,
)


log = getLogger("neo4j")
Expand All @@ -65,6 +69,16 @@ class Bolt4x0(Bolt):

supports_multiple_databases = True

_server_state = ServerStates.CONNECTED

@property
def is_reset(self):
if self.responses:
# we can't be sure of the server's state as there are still pending
# responses.
return False
return self._server_state == ServerStates.READY

@property
def encrypted(self):
return isinstance(self.socket, SSLSocket)
Expand Down Expand Up @@ -93,7 +107,8 @@ def hello(self):
logged_headers["credentials"] = "*******"
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
self._append(b"\x01", (headers,),
response=InitResponse(self, on_success=self.server_info.update))
response=InitResponse(self, "hello",
on_success=self.server_info.update))
self.send_all()
self.fetch_all()
check_supported_server_product(self.server_info.agent)
Expand Down Expand Up @@ -162,24 +177,25 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
fields = (query, parameters, extra)
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
if query.upper() == u"COMMIT":
self._append(b"\x10", fields, CommitResponse(self, **handlers))
self._append(b"\x10", fields, CommitResponse(self, "run",
**handlers))
else:
self._append(b"\x10", fields, Response(self, **handlers))
self._append(b"\x10", fields, Response(self, "run", **handlers))
self._is_reset = False

def discard(self, n=-1, qid=-1, **handlers):
extra = {"n": n}
if qid != -1:
extra["qid"] = qid
log.debug("[#%04X] C: DISCARD %r", self.local_port, extra)
self._append(b"\x2F", (extra,), Response(self, **handlers))
self._append(b"\x2F", (extra,), Response(self, "discard", **handlers))

def pull(self, n=-1, qid=-1, **handlers):
extra = {"n": n}
if qid != -1:
extra["qid"] = qid
log.debug("[#%04X] C: PULL %r", self.local_port, extra)
self._append(b"\x3F", (extra,), Response(self, **handlers))
self._append(b"\x3F", (extra,), Response(self, "pull", **handlers))
self._is_reset = False

def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
Expand All @@ -205,16 +221,16 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
except TypeError:
raise TypeError("Timeout must be specified as a number of seconds")
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
self._append(b"\x11", (extra,), Response(self, **handlers))
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
self._is_reset = False

def commit(self, **handlers):
log.debug("[#%04X] C: COMMIT", self.local_port)
self._append(b"\x12", (), CommitResponse(self, **handlers))
self._append(b"\x12", (), CommitResponse(self, "commit", **handlers))

def rollback(self, **handlers):
log.debug("[#%04X] C: ROLLBACK", self.local_port)
self._append(b"\x13", (), Response(self, **handlers))
self._append(b"\x13", (), Response(self, "rollback", **handlers))

def reset(self):
""" Add a RESET message to the outgoing queue, send
Expand All @@ -225,11 +241,22 @@ def fail(metadata):
raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address)

log.debug("[#%04X] C: RESET", self.local_port)
self._append(b"\x0F", response=Response(self, on_failure=fail))
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
self.send_all()
self.fetch_all()
self._is_reset = True

def _update_server_state_on_success(self, metadata, message):
if metadata.get("has_more"):
return
state_before = self._server_state
self._server_state = STATE_TRANSITIONS\
.get(self._server_state, {})\
.get(message, self._server_state)
if state_before != self._server_state:
log.debug("[#%04X] [%s]", self.local_port,
self._server_state.name)

def fetch_message(self):
""" Receive at most one message from the server, if available.

Expand Down Expand Up @@ -261,12 +288,15 @@ def fetch_message(self):
response.complete = True
if summary_signature == b"\x70":
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
self._update_server_state_on_success(summary_metadata,
response.message)
response.on_success(summary_metadata or {})
elif summary_signature == b"\x7E":
log.debug("[#%04X] S: IGNORED", self.local_port)
response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state = ServerStates.FAILED
try:
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand Down Expand Up @@ -372,7 +402,9 @@ def fail(md):
else:
bookmarks = list(bookmarks)
self._append(b"\x66", (routing_context, bookmarks, database),
response=Response(self, on_success=metadata.update, on_failure=fail))
response=Response(self, "route",
on_success=metadata.update,
on_failure=fail))
self.send_all()
self.fetch_all()
return [metadata.get("rt")]
Expand Down Expand Up @@ -400,7 +432,8 @@ def on_success(metadata):
logged_headers["credentials"] = "*******"
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
self._append(b"\x01", (headers,),
response=InitResponse(self, on_success=on_success))
response=InitResponse(self, "hello",
on_success=on_success))
self.send_all()
self.fetch_all()
check_supported_server_product(self.server_info.agent)
Expand Down
3 changes: 2 additions & 1 deletion neo4j/io/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ class Response:
more detail messages followed by one summary message).
"""

def __init__(self, connection, **handlers):
def __init__(self, connection, message, **handlers):
self.connection = connection
self.handlers = handlers
self.message = message
self.complete = False

def on_records(self, records):
Expand Down
2 changes: 1 addition & 1 deletion testkitbackend/test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"features": {
"AuthorizationExpiredTreatment": true,
"Optimization:ImplicitDefaultArguments": true,
"Optimization:MinimalResets": "Driver resets some clean connections when put back into pool",
"Optimization:MinimalResets": true,
"Optimization:ConnectionReuse": true,
"Optimization:PullPipelining": true,
"ConfHint:connection.recv_timeout_seconds": true,
Expand Down