From 1e822f30e605f9df2c53a8d24a84afdf1fd5a772 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Mon, 20 Dec 2021 11:24:44 +0100 Subject: [PATCH 1/5] Prepare async: move files --- neo4j/{__init__.py => _async/driver.py} | 0 neo4j/{io/__init__.py => _async/io/_bolt.py} | 0 neo4j/{ => _async}/io/_bolt3.py | 0 neo4j/{ => _async}/io/_bolt4.py | 0 neo4j/{ => _async}/io/_common.py | 0 neo4j/{ => _async}/work/result.py | 0 .../simple.py => _async/work/session.py} | 0 neo4j/{ => _async}/work/transaction.py | 0 .../__init__.py => _async/work/workspace.py} | 0 neo4j/_driver.py | 281 ------------------ neo4j/work/pipelining.py | 136 --------- tests/unit/{ => mixed}/io/test_direct.py | 0 12 files changed, 417 deletions(-) rename neo4j/{__init__.py => _async/driver.py} (100%) rename neo4j/{io/__init__.py => _async/io/_bolt.py} (100%) rename neo4j/{ => _async}/io/_bolt3.py (100%) rename neo4j/{ => _async}/io/_bolt4.py (100%) rename neo4j/{ => _async}/io/_common.py (100%) rename neo4j/{ => _async}/work/result.py (100%) rename neo4j/{work/simple.py => _async/work/session.py} (100%) rename neo4j/{ => _async}/work/transaction.py (100%) rename neo4j/{work/__init__.py => _async/work/workspace.py} (100%) delete mode 100644 neo4j/_driver.py delete mode 100644 neo4j/work/pipelining.py rename tests/unit/{ => mixed}/io/test_direct.py (100%) diff --git a/neo4j/__init__.py b/neo4j/_async/driver.py similarity index 100% rename from neo4j/__init__.py rename to neo4j/_async/driver.py diff --git a/neo4j/io/__init__.py b/neo4j/_async/io/_bolt.py similarity index 100% rename from neo4j/io/__init__.py rename to neo4j/_async/io/_bolt.py diff --git a/neo4j/io/_bolt3.py b/neo4j/_async/io/_bolt3.py similarity index 100% rename from neo4j/io/_bolt3.py rename to neo4j/_async/io/_bolt3.py diff --git a/neo4j/io/_bolt4.py b/neo4j/_async/io/_bolt4.py similarity index 100% rename from neo4j/io/_bolt4.py rename to neo4j/_async/io/_bolt4.py diff --git a/neo4j/io/_common.py b/neo4j/_async/io/_common.py similarity index 100% rename from neo4j/io/_common.py rename to neo4j/_async/io/_common.py diff --git a/neo4j/work/result.py b/neo4j/_async/work/result.py similarity index 100% rename from neo4j/work/result.py rename to neo4j/_async/work/result.py diff --git a/neo4j/work/simple.py b/neo4j/_async/work/session.py similarity index 100% rename from neo4j/work/simple.py rename to neo4j/_async/work/session.py diff --git a/neo4j/work/transaction.py b/neo4j/_async/work/transaction.py similarity index 100% rename from neo4j/work/transaction.py rename to neo4j/_async/work/transaction.py diff --git a/neo4j/work/__init__.py b/neo4j/_async/work/workspace.py similarity index 100% rename from neo4j/work/__init__.py rename to neo4j/_async/work/workspace.py diff --git a/neo4j/_driver.py b/neo4j/_driver.py deleted file mode 100644 index c8aa688b9..000000000 --- a/neo4j/_driver.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from .addressing import Address -from .api import READ_ACCESS -from .conf import ( - Config, - PoolConfig, - SessionConfig, - WorkspaceConfig, -) -from .meta import experimental -from .work.simple import Session - - -class Direct: - - default_host = "localhost" - default_port = 7687 - - default_target = ":" - - def __init__(self, address): - self._address = address - - @property - def address(self): - return self._address - - @classmethod - def parse_target(cls, target): - """ Parse a target string to produce an address. - """ - if not target: - target = cls.default_target - address = Address.parse(target, default_host=cls.default_host, - default_port=cls.default_port) - return address - - -class Routing: - - default_host = "localhost" - default_port = 7687 - - default_targets = ": :17601 :17687" - - def __init__(self, initial_addresses): - self._initial_addresses = initial_addresses - - @property - def initial_addresses(self): - return self._initial_addresses - - @classmethod - def parse_targets(cls, *targets): - """ Parse a sequence of target strings to produce an address - list. - """ - targets = " ".join(targets) - if not targets: - targets = cls.default_targets - addresses = Address.parse_list(targets, default_host=cls.default_host, default_port=cls.default_port) - return addresses - - -class Driver: - """ Base class for all types of :class:`neo4j.Driver`, instances of which are - used as the primary access point to Neo4j. - """ - - #: Connection pool - _pool = None - - def __init__(self, pool): - assert pool is not None - self._pool = pool - - def __del__(self): - self.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - @property - def encrypted(self): - return bool(self._pool.pool_config.encrypted) - - def session(self, **config): - """Create a session, see :ref:`session-construction-ref` - - :param config: session configuration key-word arguments, see :ref:`session-configuration-ref` for available key-word arguments. - - :returns: new :class:`neo4j.Session` object - """ - raise NotImplementedError - - @experimental("The pipeline API is experimental and may be removed or changed in a future release") - def pipeline(self, **config): - """ Create a pipeline. - """ - raise NotImplementedError - - def close(self): - """ Shut down, closing any open connections in the pool. - """ - self._pool.close() - - @experimental("The configuration may change in the future.") - def verify_connectivity(self, **config): - """ This verifies if the driver can connect to a remote server or a cluster - by establishing a network connection with the remote and possibly exchanging - a few data before closing the connection. It throws exception if fails to connect. - - Use the exception to further understand the cause of the connectivity problem. - - Note: Even if this method throws an exception, the driver still need to be closed via close() to free up all resources. - """ - raise NotImplementedError - - @experimental("Feature support query, based on Bolt Protocol Version and Neo4j Server Version will change in the future.") - def supports_multi_db(self): - """ Check if the server or cluster supports multi-databases. - - :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false. - :rtype: bool - """ - with self.session() as session: - session._connect(READ_ACCESS) - return session._connection.supports_multiple_databases - - -class BoltDriver(Direct, Driver): - """ A :class:`.BoltDriver` is created from a ``bolt`` URI and addresses - a single database machine. This may be a standalone server or could be a - specific member of a cluster. - - Connections established by a :class:`.BoltDriver` are always made to the - exact host and port detailed in the URI. - """ - - @classmethod - def open(cls, target, *, auth=None, **config): - """ - :param target: - :param auth: - :param config: The values that can be specified are found in :class: `neo4j.PoolConfig` and :class: `neo4j.WorkspaceConfig` - - :return: - :rtype: :class: `neo4j.BoltDriver` - """ - from neo4j.io import BoltPool - address = cls.parse_target(target) - pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = BoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config) - return cls(pool, default_workspace_config) - - def __init__(self, pool, default_workspace_config): - Direct.__init__(self, pool.address) - Driver.__init__(self, pool) - self._default_workspace_config = default_workspace_config - - def session(self, **config): - """ - :param config: The values that can be specified are found in :class: `neo4j.SessionConfig` - - :return: - :rtype: :class: `neo4j.Session` - """ - from neo4j.work.simple import Session - session_config = SessionConfig(self._default_workspace_config, config) - SessionConfig.consume(config) # Consume the config - return Session(self._pool, session_config) - - def pipeline(self, **config): - from neo4j.work.pipelining import ( - Pipeline, - PipelineConfig, - ) - pipeline_config = PipelineConfig(self._default_workspace_config, config) - PipelineConfig.consume(config) # Consume the config - return Pipeline(self._pool, pipeline_config) - - @experimental("The configuration may change in the future.") - def verify_connectivity(self, **config): - server_agent = None - config["fetch_size"] = -1 - with self.session(**config) as session: - result = session.run("RETURN 1 AS x") - value = result.single().value() - summary = result.consume() - server_agent = summary.server.agent - return server_agent - - -class Neo4jDriver(Routing, Driver): - """ A :class:`.Neo4jDriver` is created from a ``neo4j`` URI. The - routing behaviour works in tandem with Neo4j's `Causal Clustering - `_ - feature by directing read and write behaviour to appropriate - cluster members. - """ - - @classmethod - def open(cls, *targets, auth=None, routing_context=None, **config): - from neo4j.io import Neo4jPool - addresses = cls.parse_targets(*targets) - pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = Neo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) - return cls(pool, default_workspace_config) - - def __init__(self, pool, default_workspace_config): - Routing.__init__(self, pool.get_default_database_initial_router_addresses()) - Driver.__init__(self, pool) - self._default_workspace_config = default_workspace_config - - def session(self, **config): - session_config = SessionConfig(self._default_workspace_config, config) - SessionConfig.consume(config) # Consume the config - return Session(self._pool, session_config) - - def pipeline(self, **config): - from neo4j.work.pipelining import ( - Pipeline, - PipelineConfig, - ) - pipeline_config = PipelineConfig(self._default_workspace_config, config) - PipelineConfig.consume(config) # Consume the config - return Pipeline(self._pool, pipeline_config) - - @experimental("The configuration may change in the future.") - def verify_connectivity(self, **config): - """ - :raise ServiceUnavailable: raised if the server does not support routing or if routing support is broken. - """ - # TODO: Improve and update Stub Test Server to be able to test. - return self._verify_routing_connectivity() - - def _verify_routing_connectivity(self): - from neo4j.exceptions import ( - Neo4jError, - ServiceUnavailable, - SessionExpired, - ) - - table = self._pool.get_routing_table_for_default_database() - routing_info = {} - for ix in list(table.routers): - try: - routing_info[ix] = self._pool.fetch_routing_info( - address=table.routers[0], - database=self._default_workspace_config.database, - imp_user=self._default_workspace_config.impersonated_user, - bookmarks=None, - timeout=self._default_workspace_config - .connection_acquisition_timeout - ) - except (ServiceUnavailable, SessionExpired, Neo4jError): - routing_info[ix] = None - for key, val in routing_info.items(): - if val is not None: - return routing_info - raise ServiceUnavailable("Could not connect to any routing servers.") diff --git a/neo4j/work/pipelining.py b/neo4j/work/pipelining.py deleted file mode 100644 index ccca84202..000000000 --- a/neo4j/work/pipelining.py +++ /dev/null @@ -1,136 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from collections import deque -from threading import Thread, Lock -from time import sleep - -from neo4j.work import Workspace -from neo4j.conf import WorkspaceConfig -from neo4j.api import ( - WRITE_ACCESS, -) - -class PipelineConfig(WorkspaceConfig): - - #: - flush_every = 8192 # bytes - - -class Pipeline(Workspace): - - def __init__(self, pool, config): - assert isinstance(config, PipelineConfig) - super(Pipeline, self).__init__(pool, config) - self._connect(WRITE_ACCESS) - self._flush_every = config.flush_every - self._data = deque() - self._pull_lock = Lock() - - def push(self, statement, parameters=None): - self._connection.run(statement, parameters) - self._connection.pull(on_records=self._data.extend) - output_buffer_size = len(self._connection.outbox.view()) - if output_buffer_size >= self._flush_every: - self._connection.send_all() - - def _results_generator(self): - results_returned_count = 0 - try: - summary = 0 - while summary == 0: - _, summary = self._connection.fetch_message() - summary = 0 - while summary == 0: - detail, summary = self._connection.fetch_message() - for n in range(detail): - response = self._data.popleft() - results_returned_count += 1 - yield response - finally: - self._pull_lock.release() - - def pull(self): - """Returns a generator containing the results of the next query in the pipeline""" - # n.b. pull is now somewhat misleadingly named because it doesn't do anything - # the connection isn't touched until you try and iterate the generator we return - lock_acquired = self._pull_lock.acquire(blocking=False) - if not lock_acquired: - raise PullOrderException() - return self._results_generator() - - -class PullOrderException(Exception): - """Raise when calling pull if a previous pull result has not been fully consumed""" - - -class Pusher(Thread): - - def __init__(self, pipeline): - super(Pusher, self).__init__() - self.pipeline = pipeline - self.running = True - self.count = 0 - - def run(self): - while self.running: - self.pipeline.push("RETURN $x", {"x": self.count}) - self.count += 1 - - -class Puller(Thread): - - def __init__(self, pipeline): - super(Puller, self).__init__() - self.pipeline = pipeline - self.running = True - self.count = 0 - - def run(self): - while self.running: - for _ in self.pipeline.pull(): - pass # consume and discard records - self.count += 1 - - -def main(): - from neo4j import Driver - # from neo4j.bolt.diagnostics import watch - # watch("neobolt") - with Driver("bolt://", auth=("neo4j", "password")) as dx: - p = dx.pipeline(flush_every=1024) - pusher = Pusher(p) - puller = Puller(p) - try: - pusher.start() - puller.start() - while True: - print("sent %d, received %d, backlog %d" % (pusher.count, puller.count, pusher.count - puller.count)) - sleep(1) - except KeyboardInterrupt: - pusher.running = False - pusher.join() - puller.running = False - puller.join() - - -if __name__ == "__main__": - main() diff --git a/tests/unit/io/test_direct.py b/tests/unit/mixed/io/test_direct.py similarity index 100% rename from tests/unit/io/test_direct.py rename to tests/unit/mixed/io/test_direct.py From fe6ccedb0ec8555d1a87b0c45cda55f829c821d0 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Mon, 20 Dec 2021 11:24:44 +0100 Subject: [PATCH 2/5] Prepare async: move files --- neo4j/{ => _async}/io/_bolt3.py | 0 neo4j/{ => _async}/io/_bolt4.py | 0 neo4j/{ => _async}/io/_common.py | 0 neo4j/{io/__init__.py => _async/io/_pool.py} | 0 neo4j/{ => _async}/work/result.py | 0 .../simple.py => _async/work/session.py} | 0 neo4j/{ => _async}/work/summary.py | 0 neo4j/{ => _async}/work/transaction.py | 0 neo4j/_driver.py | 281 ------------------ neo4j/io/README.rst | 1 - neo4j/work/pipelining.py | 136 --------- neo4j/work/{__init__.py => query.py} | 0 testkitbackend/{ => _async}/backend.py | 0 testkitbackend/{ => _async}/requests.py | 0 tests/unit/{data => async_/io}/__init__.py | 0 tests/unit/{ => async_}/io/conftest.py | 0 tests/unit/{ => async_}/io/test__common.py | 0 tests/unit/{ => async_}/io/test_class_bolt.py | 0 .../unit/{ => async_}/io/test_class_bolt3.py | 0 .../{ => async_}/io/test_class_bolt4x0.py | 0 .../{ => async_}/io/test_class_bolt4x1.py | 0 .../{ => async_}/io/test_class_bolt4x2.py | 0 .../{ => async_}/io/test_class_bolt4x3.py | 0 .../{ => async_}/io/test_class_bolt4x4.py | 0 tests/unit/{ => async_}/io/test_direct.py | 0 tests/unit/{ => async_}/io/test_neo4j_pool.py | 0 tests/unit/{ => async_}/work/__init__.py | 0 .../{ => async_}/work/_fake_connection.py | 0 tests/unit/{ => async_}/work/test_result.py | 0 tests/unit/{ => async_}/work/test_session.py | 0 .../{ => async_}/work/test_transaction.py | 0 tests/unit/{io => common/data}/__init__.py | 0 tests/unit/{ => common}/data/test_packing.py | 0 tests/unit/{ => common}/io/test_routing.py | 0 tests/unit/{ => common}/spatial/__init__.py | 0 .../spatial/test_cartesian_point.py | 0 tests/unit/{ => common}/spatial/test_point.py | 0 .../{ => common}/spatial/test_wgs84_point.py | 0 tests/unit/{ => common}/test_addressing.py | 0 tests/unit/{ => common}/test_api.py | 0 tests/unit/{ => common}/test_conf.py | 0 tests/unit/{ => common}/test_data.py | 0 tests/unit/{ => common}/test_driver.py | 0 tests/unit/{ => common}/test_exceptions.py | 0 tests/unit/{ => common}/test_import_neo4j.py | 0 tests/unit/{ => common}/test_record.py | 0 tests/unit/{ => common}/test_security.py | 0 tests/unit/{ => common}/test_types.py | 0 tests/unit/{ => common}/time/__init__.py | 0 tests/unit/{ => common}/time/test_clock.py | 0 .../unit/{ => common}/time/test_clocktime.py | 0 tests/unit/{ => common}/time/test_date.py | 0 tests/unit/{ => common}/time/test_datetime.py | 0 tests/unit/{ => common}/time/test_duration.py | 0 .../unit/{ => common}/time/test_hydration.py | 0 tests/unit/{ => common}/time/test_time.py | 0 56 files changed, 418 deletions(-) rename neo4j/{ => _async}/io/_bolt3.py (100%) rename neo4j/{ => _async}/io/_bolt4.py (100%) rename neo4j/{ => _async}/io/_common.py (100%) rename neo4j/{io/__init__.py => _async/io/_pool.py} (100%) rename neo4j/{ => _async}/work/result.py (100%) rename neo4j/{work/simple.py => _async/work/session.py} (100%) rename neo4j/{ => _async}/work/summary.py (100%) rename neo4j/{ => _async}/work/transaction.py (100%) delete mode 100644 neo4j/_driver.py delete mode 100644 neo4j/io/README.rst delete mode 100644 neo4j/work/pipelining.py rename neo4j/work/{__init__.py => query.py} (100%) rename testkitbackend/{ => _async}/backend.py (100%) rename testkitbackend/{ => _async}/requests.py (100%) rename tests/unit/{data => async_/io}/__init__.py (100%) rename tests/unit/{ => async_}/io/conftest.py (100%) rename tests/unit/{ => async_}/io/test__common.py (100%) rename tests/unit/{ => async_}/io/test_class_bolt.py (100%) rename tests/unit/{ => async_}/io/test_class_bolt3.py (100%) rename tests/unit/{ => async_}/io/test_class_bolt4x0.py (100%) rename tests/unit/{ => async_}/io/test_class_bolt4x1.py (100%) rename tests/unit/{ => async_}/io/test_class_bolt4x2.py (100%) rename tests/unit/{ => async_}/io/test_class_bolt4x3.py (100%) rename tests/unit/{ => async_}/io/test_class_bolt4x4.py (100%) rename tests/unit/{ => async_}/io/test_direct.py (100%) rename tests/unit/{ => async_}/io/test_neo4j_pool.py (100%) rename tests/unit/{ => async_}/work/__init__.py (100%) rename tests/unit/{ => async_}/work/_fake_connection.py (100%) rename tests/unit/{ => async_}/work/test_result.py (100%) rename tests/unit/{ => async_}/work/test_session.py (100%) rename tests/unit/{ => async_}/work/test_transaction.py (100%) rename tests/unit/{io => common/data}/__init__.py (100%) rename tests/unit/{ => common}/data/test_packing.py (100%) rename tests/unit/{ => common}/io/test_routing.py (100%) rename tests/unit/{ => common}/spatial/__init__.py (100%) rename tests/unit/{ => common}/spatial/test_cartesian_point.py (100%) rename tests/unit/{ => common}/spatial/test_point.py (100%) rename tests/unit/{ => common}/spatial/test_wgs84_point.py (100%) rename tests/unit/{ => common}/test_addressing.py (100%) rename tests/unit/{ => common}/test_api.py (100%) rename tests/unit/{ => common}/test_conf.py (100%) rename tests/unit/{ => common}/test_data.py (100%) rename tests/unit/{ => common}/test_driver.py (100%) rename tests/unit/{ => common}/test_exceptions.py (100%) rename tests/unit/{ => common}/test_import_neo4j.py (100%) rename tests/unit/{ => common}/test_record.py (100%) rename tests/unit/{ => common}/test_security.py (100%) rename tests/unit/{ => common}/test_types.py (100%) rename tests/unit/{ => common}/time/__init__.py (100%) rename tests/unit/{ => common}/time/test_clock.py (100%) rename tests/unit/{ => common}/time/test_clocktime.py (100%) rename tests/unit/{ => common}/time/test_date.py (100%) rename tests/unit/{ => common}/time/test_datetime.py (100%) rename tests/unit/{ => common}/time/test_duration.py (100%) rename tests/unit/{ => common}/time/test_hydration.py (100%) rename tests/unit/{ => common}/time/test_time.py (100%) diff --git a/neo4j/io/_bolt3.py b/neo4j/_async/io/_bolt3.py similarity index 100% rename from neo4j/io/_bolt3.py rename to neo4j/_async/io/_bolt3.py diff --git a/neo4j/io/_bolt4.py b/neo4j/_async/io/_bolt4.py similarity index 100% rename from neo4j/io/_bolt4.py rename to neo4j/_async/io/_bolt4.py diff --git a/neo4j/io/_common.py b/neo4j/_async/io/_common.py similarity index 100% rename from neo4j/io/_common.py rename to neo4j/_async/io/_common.py diff --git a/neo4j/io/__init__.py b/neo4j/_async/io/_pool.py similarity index 100% rename from neo4j/io/__init__.py rename to neo4j/_async/io/_pool.py diff --git a/neo4j/work/result.py b/neo4j/_async/work/result.py similarity index 100% rename from neo4j/work/result.py rename to neo4j/_async/work/result.py diff --git a/neo4j/work/simple.py b/neo4j/_async/work/session.py similarity index 100% rename from neo4j/work/simple.py rename to neo4j/_async/work/session.py diff --git a/neo4j/work/summary.py b/neo4j/_async/work/summary.py similarity index 100% rename from neo4j/work/summary.py rename to neo4j/_async/work/summary.py diff --git a/neo4j/work/transaction.py b/neo4j/_async/work/transaction.py similarity index 100% rename from neo4j/work/transaction.py rename to neo4j/_async/work/transaction.py diff --git a/neo4j/_driver.py b/neo4j/_driver.py deleted file mode 100644 index c8aa688b9..000000000 --- a/neo4j/_driver.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from .addressing import Address -from .api import READ_ACCESS -from .conf import ( - Config, - PoolConfig, - SessionConfig, - WorkspaceConfig, -) -from .meta import experimental -from .work.simple import Session - - -class Direct: - - default_host = "localhost" - default_port = 7687 - - default_target = ":" - - def __init__(self, address): - self._address = address - - @property - def address(self): - return self._address - - @classmethod - def parse_target(cls, target): - """ Parse a target string to produce an address. - """ - if not target: - target = cls.default_target - address = Address.parse(target, default_host=cls.default_host, - default_port=cls.default_port) - return address - - -class Routing: - - default_host = "localhost" - default_port = 7687 - - default_targets = ": :17601 :17687" - - def __init__(self, initial_addresses): - self._initial_addresses = initial_addresses - - @property - def initial_addresses(self): - return self._initial_addresses - - @classmethod - def parse_targets(cls, *targets): - """ Parse a sequence of target strings to produce an address - list. - """ - targets = " ".join(targets) - if not targets: - targets = cls.default_targets - addresses = Address.parse_list(targets, default_host=cls.default_host, default_port=cls.default_port) - return addresses - - -class Driver: - """ Base class for all types of :class:`neo4j.Driver`, instances of which are - used as the primary access point to Neo4j. - """ - - #: Connection pool - _pool = None - - def __init__(self, pool): - assert pool is not None - self._pool = pool - - def __del__(self): - self.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - @property - def encrypted(self): - return bool(self._pool.pool_config.encrypted) - - def session(self, **config): - """Create a session, see :ref:`session-construction-ref` - - :param config: session configuration key-word arguments, see :ref:`session-configuration-ref` for available key-word arguments. - - :returns: new :class:`neo4j.Session` object - """ - raise NotImplementedError - - @experimental("The pipeline API is experimental and may be removed or changed in a future release") - def pipeline(self, **config): - """ Create a pipeline. - """ - raise NotImplementedError - - def close(self): - """ Shut down, closing any open connections in the pool. - """ - self._pool.close() - - @experimental("The configuration may change in the future.") - def verify_connectivity(self, **config): - """ This verifies if the driver can connect to a remote server or a cluster - by establishing a network connection with the remote and possibly exchanging - a few data before closing the connection. It throws exception if fails to connect. - - Use the exception to further understand the cause of the connectivity problem. - - Note: Even if this method throws an exception, the driver still need to be closed via close() to free up all resources. - """ - raise NotImplementedError - - @experimental("Feature support query, based on Bolt Protocol Version and Neo4j Server Version will change in the future.") - def supports_multi_db(self): - """ Check if the server or cluster supports multi-databases. - - :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false. - :rtype: bool - """ - with self.session() as session: - session._connect(READ_ACCESS) - return session._connection.supports_multiple_databases - - -class BoltDriver(Direct, Driver): - """ A :class:`.BoltDriver` is created from a ``bolt`` URI and addresses - a single database machine. This may be a standalone server or could be a - specific member of a cluster. - - Connections established by a :class:`.BoltDriver` are always made to the - exact host and port detailed in the URI. - """ - - @classmethod - def open(cls, target, *, auth=None, **config): - """ - :param target: - :param auth: - :param config: The values that can be specified are found in :class: `neo4j.PoolConfig` and :class: `neo4j.WorkspaceConfig` - - :return: - :rtype: :class: `neo4j.BoltDriver` - """ - from neo4j.io import BoltPool - address = cls.parse_target(target) - pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = BoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config) - return cls(pool, default_workspace_config) - - def __init__(self, pool, default_workspace_config): - Direct.__init__(self, pool.address) - Driver.__init__(self, pool) - self._default_workspace_config = default_workspace_config - - def session(self, **config): - """ - :param config: The values that can be specified are found in :class: `neo4j.SessionConfig` - - :return: - :rtype: :class: `neo4j.Session` - """ - from neo4j.work.simple import Session - session_config = SessionConfig(self._default_workspace_config, config) - SessionConfig.consume(config) # Consume the config - return Session(self._pool, session_config) - - def pipeline(self, **config): - from neo4j.work.pipelining import ( - Pipeline, - PipelineConfig, - ) - pipeline_config = PipelineConfig(self._default_workspace_config, config) - PipelineConfig.consume(config) # Consume the config - return Pipeline(self._pool, pipeline_config) - - @experimental("The configuration may change in the future.") - def verify_connectivity(self, **config): - server_agent = None - config["fetch_size"] = -1 - with self.session(**config) as session: - result = session.run("RETURN 1 AS x") - value = result.single().value() - summary = result.consume() - server_agent = summary.server.agent - return server_agent - - -class Neo4jDriver(Routing, Driver): - """ A :class:`.Neo4jDriver` is created from a ``neo4j`` URI. The - routing behaviour works in tandem with Neo4j's `Causal Clustering - `_ - feature by directing read and write behaviour to appropriate - cluster members. - """ - - @classmethod - def open(cls, *targets, auth=None, routing_context=None, **config): - from neo4j.io import Neo4jPool - addresses = cls.parse_targets(*targets) - pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = Neo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) - return cls(pool, default_workspace_config) - - def __init__(self, pool, default_workspace_config): - Routing.__init__(self, pool.get_default_database_initial_router_addresses()) - Driver.__init__(self, pool) - self._default_workspace_config = default_workspace_config - - def session(self, **config): - session_config = SessionConfig(self._default_workspace_config, config) - SessionConfig.consume(config) # Consume the config - return Session(self._pool, session_config) - - def pipeline(self, **config): - from neo4j.work.pipelining import ( - Pipeline, - PipelineConfig, - ) - pipeline_config = PipelineConfig(self._default_workspace_config, config) - PipelineConfig.consume(config) # Consume the config - return Pipeline(self._pool, pipeline_config) - - @experimental("The configuration may change in the future.") - def verify_connectivity(self, **config): - """ - :raise ServiceUnavailable: raised if the server does not support routing or if routing support is broken. - """ - # TODO: Improve and update Stub Test Server to be able to test. - return self._verify_routing_connectivity() - - def _verify_routing_connectivity(self): - from neo4j.exceptions import ( - Neo4jError, - ServiceUnavailable, - SessionExpired, - ) - - table = self._pool.get_routing_table_for_default_database() - routing_info = {} - for ix in list(table.routers): - try: - routing_info[ix] = self._pool.fetch_routing_info( - address=table.routers[0], - database=self._default_workspace_config.database, - imp_user=self._default_workspace_config.impersonated_user, - bookmarks=None, - timeout=self._default_workspace_config - .connection_acquisition_timeout - ) - except (ServiceUnavailable, SessionExpired, Neo4jError): - routing_info[ix] = None - for key, val in routing_info.items(): - if val is not None: - return routing_info - raise ServiceUnavailable("Could not connect to any routing servers.") diff --git a/neo4j/io/README.rst b/neo4j/io/README.rst deleted file mode 100644 index dfe8742a7..000000000 --- a/neo4j/io/README.rst +++ /dev/null @@ -1 +0,0 @@ -Regular (non-async) I/O for Neo4j. \ No newline at end of file diff --git a/neo4j/work/pipelining.py b/neo4j/work/pipelining.py deleted file mode 100644 index ccca84202..000000000 --- a/neo4j/work/pipelining.py +++ /dev/null @@ -1,136 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from collections import deque -from threading import Thread, Lock -from time import sleep - -from neo4j.work import Workspace -from neo4j.conf import WorkspaceConfig -from neo4j.api import ( - WRITE_ACCESS, -) - -class PipelineConfig(WorkspaceConfig): - - #: - flush_every = 8192 # bytes - - -class Pipeline(Workspace): - - def __init__(self, pool, config): - assert isinstance(config, PipelineConfig) - super(Pipeline, self).__init__(pool, config) - self._connect(WRITE_ACCESS) - self._flush_every = config.flush_every - self._data = deque() - self._pull_lock = Lock() - - def push(self, statement, parameters=None): - self._connection.run(statement, parameters) - self._connection.pull(on_records=self._data.extend) - output_buffer_size = len(self._connection.outbox.view()) - if output_buffer_size >= self._flush_every: - self._connection.send_all() - - def _results_generator(self): - results_returned_count = 0 - try: - summary = 0 - while summary == 0: - _, summary = self._connection.fetch_message() - summary = 0 - while summary == 0: - detail, summary = self._connection.fetch_message() - for n in range(detail): - response = self._data.popleft() - results_returned_count += 1 - yield response - finally: - self._pull_lock.release() - - def pull(self): - """Returns a generator containing the results of the next query in the pipeline""" - # n.b. pull is now somewhat misleadingly named because it doesn't do anything - # the connection isn't touched until you try and iterate the generator we return - lock_acquired = self._pull_lock.acquire(blocking=False) - if not lock_acquired: - raise PullOrderException() - return self._results_generator() - - -class PullOrderException(Exception): - """Raise when calling pull if a previous pull result has not been fully consumed""" - - -class Pusher(Thread): - - def __init__(self, pipeline): - super(Pusher, self).__init__() - self.pipeline = pipeline - self.running = True - self.count = 0 - - def run(self): - while self.running: - self.pipeline.push("RETURN $x", {"x": self.count}) - self.count += 1 - - -class Puller(Thread): - - def __init__(self, pipeline): - super(Puller, self).__init__() - self.pipeline = pipeline - self.running = True - self.count = 0 - - def run(self): - while self.running: - for _ in self.pipeline.pull(): - pass # consume and discard records - self.count += 1 - - -def main(): - from neo4j import Driver - # from neo4j.bolt.diagnostics import watch - # watch("neobolt") - with Driver("bolt://", auth=("neo4j", "password")) as dx: - p = dx.pipeline(flush_every=1024) - pusher = Pusher(p) - puller = Puller(p) - try: - pusher.start() - puller.start() - while True: - print("sent %d, received %d, backlog %d" % (pusher.count, puller.count, pusher.count - puller.count)) - sleep(1) - except KeyboardInterrupt: - pusher.running = False - pusher.join() - puller.running = False - puller.join() - - -if __name__ == "__main__": - main() diff --git a/neo4j/work/__init__.py b/neo4j/work/query.py similarity index 100% rename from neo4j/work/__init__.py rename to neo4j/work/query.py diff --git a/testkitbackend/backend.py b/testkitbackend/_async/backend.py similarity index 100% rename from testkitbackend/backend.py rename to testkitbackend/_async/backend.py diff --git a/testkitbackend/requests.py b/testkitbackend/_async/requests.py similarity index 100% rename from testkitbackend/requests.py rename to testkitbackend/_async/requests.py diff --git a/tests/unit/data/__init__.py b/tests/unit/async_/io/__init__.py similarity index 100% rename from tests/unit/data/__init__.py rename to tests/unit/async_/io/__init__.py diff --git a/tests/unit/io/conftest.py b/tests/unit/async_/io/conftest.py similarity index 100% rename from tests/unit/io/conftest.py rename to tests/unit/async_/io/conftest.py diff --git a/tests/unit/io/test__common.py b/tests/unit/async_/io/test__common.py similarity index 100% rename from tests/unit/io/test__common.py rename to tests/unit/async_/io/test__common.py diff --git a/tests/unit/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py similarity index 100% rename from tests/unit/io/test_class_bolt.py rename to tests/unit/async_/io/test_class_bolt.py diff --git a/tests/unit/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py similarity index 100% rename from tests/unit/io/test_class_bolt3.py rename to tests/unit/async_/io/test_class_bolt3.py diff --git a/tests/unit/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py similarity index 100% rename from tests/unit/io/test_class_bolt4x0.py rename to tests/unit/async_/io/test_class_bolt4x0.py diff --git a/tests/unit/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py similarity index 100% rename from tests/unit/io/test_class_bolt4x1.py rename to tests/unit/async_/io/test_class_bolt4x1.py diff --git a/tests/unit/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py similarity index 100% rename from tests/unit/io/test_class_bolt4x2.py rename to tests/unit/async_/io/test_class_bolt4x2.py diff --git a/tests/unit/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py similarity index 100% rename from tests/unit/io/test_class_bolt4x3.py rename to tests/unit/async_/io/test_class_bolt4x3.py diff --git a/tests/unit/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py similarity index 100% rename from tests/unit/io/test_class_bolt4x4.py rename to tests/unit/async_/io/test_class_bolt4x4.py diff --git a/tests/unit/io/test_direct.py b/tests/unit/async_/io/test_direct.py similarity index 100% rename from tests/unit/io/test_direct.py rename to tests/unit/async_/io/test_direct.py diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py similarity index 100% rename from tests/unit/io/test_neo4j_pool.py rename to tests/unit/async_/io/test_neo4j_pool.py diff --git a/tests/unit/work/__init__.py b/tests/unit/async_/work/__init__.py similarity index 100% rename from tests/unit/work/__init__.py rename to tests/unit/async_/work/__init__.py diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/async_/work/_fake_connection.py similarity index 100% rename from tests/unit/work/_fake_connection.py rename to tests/unit/async_/work/_fake_connection.py diff --git a/tests/unit/work/test_result.py b/tests/unit/async_/work/test_result.py similarity index 100% rename from tests/unit/work/test_result.py rename to tests/unit/async_/work/test_result.py diff --git a/tests/unit/work/test_session.py b/tests/unit/async_/work/test_session.py similarity index 100% rename from tests/unit/work/test_session.py rename to tests/unit/async_/work/test_session.py diff --git a/tests/unit/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py similarity index 100% rename from tests/unit/work/test_transaction.py rename to tests/unit/async_/work/test_transaction.py diff --git a/tests/unit/io/__init__.py b/tests/unit/common/data/__init__.py similarity index 100% rename from tests/unit/io/__init__.py rename to tests/unit/common/data/__init__.py diff --git a/tests/unit/data/test_packing.py b/tests/unit/common/data/test_packing.py similarity index 100% rename from tests/unit/data/test_packing.py rename to tests/unit/common/data/test_packing.py diff --git a/tests/unit/io/test_routing.py b/tests/unit/common/io/test_routing.py similarity index 100% rename from tests/unit/io/test_routing.py rename to tests/unit/common/io/test_routing.py diff --git a/tests/unit/spatial/__init__.py b/tests/unit/common/spatial/__init__.py similarity index 100% rename from tests/unit/spatial/__init__.py rename to tests/unit/common/spatial/__init__.py diff --git a/tests/unit/spatial/test_cartesian_point.py b/tests/unit/common/spatial/test_cartesian_point.py similarity index 100% rename from tests/unit/spatial/test_cartesian_point.py rename to tests/unit/common/spatial/test_cartesian_point.py diff --git a/tests/unit/spatial/test_point.py b/tests/unit/common/spatial/test_point.py similarity index 100% rename from tests/unit/spatial/test_point.py rename to tests/unit/common/spatial/test_point.py diff --git a/tests/unit/spatial/test_wgs84_point.py b/tests/unit/common/spatial/test_wgs84_point.py similarity index 100% rename from tests/unit/spatial/test_wgs84_point.py rename to tests/unit/common/spatial/test_wgs84_point.py diff --git a/tests/unit/test_addressing.py b/tests/unit/common/test_addressing.py similarity index 100% rename from tests/unit/test_addressing.py rename to tests/unit/common/test_addressing.py diff --git a/tests/unit/test_api.py b/tests/unit/common/test_api.py similarity index 100% rename from tests/unit/test_api.py rename to tests/unit/common/test_api.py diff --git a/tests/unit/test_conf.py b/tests/unit/common/test_conf.py similarity index 100% rename from tests/unit/test_conf.py rename to tests/unit/common/test_conf.py diff --git a/tests/unit/test_data.py b/tests/unit/common/test_data.py similarity index 100% rename from tests/unit/test_data.py rename to tests/unit/common/test_data.py diff --git a/tests/unit/test_driver.py b/tests/unit/common/test_driver.py similarity index 100% rename from tests/unit/test_driver.py rename to tests/unit/common/test_driver.py diff --git a/tests/unit/test_exceptions.py b/tests/unit/common/test_exceptions.py similarity index 100% rename from tests/unit/test_exceptions.py rename to tests/unit/common/test_exceptions.py diff --git a/tests/unit/test_import_neo4j.py b/tests/unit/common/test_import_neo4j.py similarity index 100% rename from tests/unit/test_import_neo4j.py rename to tests/unit/common/test_import_neo4j.py diff --git a/tests/unit/test_record.py b/tests/unit/common/test_record.py similarity index 100% rename from tests/unit/test_record.py rename to tests/unit/common/test_record.py diff --git a/tests/unit/test_security.py b/tests/unit/common/test_security.py similarity index 100% rename from tests/unit/test_security.py rename to tests/unit/common/test_security.py diff --git a/tests/unit/test_types.py b/tests/unit/common/test_types.py similarity index 100% rename from tests/unit/test_types.py rename to tests/unit/common/test_types.py diff --git a/tests/unit/time/__init__.py b/tests/unit/common/time/__init__.py similarity index 100% rename from tests/unit/time/__init__.py rename to tests/unit/common/time/__init__.py diff --git a/tests/unit/time/test_clock.py b/tests/unit/common/time/test_clock.py similarity index 100% rename from tests/unit/time/test_clock.py rename to tests/unit/common/time/test_clock.py diff --git a/tests/unit/time/test_clocktime.py b/tests/unit/common/time/test_clocktime.py similarity index 100% rename from tests/unit/time/test_clocktime.py rename to tests/unit/common/time/test_clocktime.py diff --git a/tests/unit/time/test_date.py b/tests/unit/common/time/test_date.py similarity index 100% rename from tests/unit/time/test_date.py rename to tests/unit/common/time/test_date.py diff --git a/tests/unit/time/test_datetime.py b/tests/unit/common/time/test_datetime.py similarity index 100% rename from tests/unit/time/test_datetime.py rename to tests/unit/common/time/test_datetime.py diff --git a/tests/unit/time/test_duration.py b/tests/unit/common/time/test_duration.py similarity index 100% rename from tests/unit/time/test_duration.py rename to tests/unit/common/time/test_duration.py diff --git a/tests/unit/time/test_hydration.py b/tests/unit/common/time/test_hydration.py similarity index 100% rename from tests/unit/time/test_hydration.py rename to tests/unit/common/time/test_hydration.py diff --git a/tests/unit/time/test_time.py b/tests/unit/common/time/test_time.py similarity index 100% rename from tests/unit/time/test_time.py rename to tests/unit/common/time/test_time.py From 255b321cd60a796e92a07c8e9ceed075adaba2fc Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Tue, 14 Dec 2021 20:54:26 +0100 Subject: [PATCH 3/5] Docs for async --- bin/make-unasync | 78 ++++- docs/source/api.rst | 39 ++- docs/source/async_api.rst | 497 +++++++++++++++++++++++++++++++ docs/source/index.rst | 3 + neo4j/_async/driver.py | 38 ++- neo4j/_async/work/__init__.py | 1 + neo4j/_async/work/result.py | 27 +- neo4j/_async/work/session.py | 83 +++--- neo4j/_async/work/transaction.py | 16 +- neo4j/api.py | 12 +- 10 files changed, 689 insertions(+), 105 deletions(-) create mode 100644 docs/source/async_api.rst diff --git a/bin/make-unasync b/bin/make-unasync index d7e053d39..c217a539e 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -20,9 +20,9 @@ import collections import errno -from functools import reduce import os from pathlib import Path +import re import sys import tokenize as std_tokenize @@ -103,20 +103,86 @@ class CustomRule(unasync.Rule): super(CustomRule, self).__init__(*args, **kwargs) self.out_files = [] - def _unasync_name(self, name): - # copy from unasync to customize renaming rules + def _unasync_tokens(self, tokens): + # copy from unasync to hook into string handling # https://github.com/python-trio/unasync # License: MIT and Apache2 - if name in self.token_replacements: - return self.token_replacements[name] + # TODO __await__, ...? + used_space = None + for space, toknum, tokval in tokens: + if tokval in ["async", "await"]: + # When removing async or await, we want to use the whitespace + # that was before async/await before the next token so that + # `print(await stuff)` becomes `print(stuff)` and not + # `print( stuff)` + used_space = space + else: + if toknum == std_tokenize.NAME: + tokval = self._unasync_name(tokval) + elif toknum == std_tokenize.STRING: + if tokval[0] == tokval[1] and len(tokval) > 2: + # multiline string (`"""..."""` or `'''...'''`) + left_quote, name, right_quote = \ + tokval[:3], tokval[3:-3], tokval[-3:] + else: + # simple string (`"..."` or `'...'`) + left_quote, name, right_quote = \ + tokval[:1], tokval[1:-1], tokval[-1:] + tokval = \ + left_quote + self._unasync_string(name) + right_quote + if used_space is None: + used_space = space + yield (used_space, tokval) + used_space = None + + def _unasync_string(self, name): + start = 0 + end = 1 + out = "" + while end < len(name): + sub_name = name[start:end] + if sub_name.isidentifier(): + end += 1 + else: + if end == start + 1: + out += sub_name + start += 1 + end += 1 + else: + out += self._unasync_prefix(name[start:(end - 1)]) + start = end - 1 + + sub_name = name[start:] + if sub_name.isidentifier(): + out += self._unasync_prefix(name[start:]) + else: + out += sub_name + + # very boiled down unasync version that removes "async" and "await" + # substrings. + out = re.subn(r"(^|\s+|(?<=\W))(?:async|await)\s+", r"\1", out, + flags=re.MULTILINE)[0] + # Convert doc-reference names from 'async-xyz' to 'xyz' + out = re.subn(r":ref:`async-", ":ref:`", out)[0] + return out + + def _unasync_prefix(self, name): # Convert class names from 'AsyncXyz' to 'Xyz' - elif len(name) > 5 and name.startswith("Async") and name[5].isupper(): + if len(name) > 5 and name.startswith("Async") and name[5].isupper(): return name[5:] # Convert variable/method/function names from 'async_xyz' to 'xyz' elif len(name) > 6 and name.startswith("async_"): return name[6:] return name + def _unasync_name(self, name): + # copy from unasync to customize renaming rules + # https://github.com/python-trio/unasync + # License: MIT and Apache2 + if name in self.token_replacements: + return self.token_replacements[name] + return self._unasync_prefix(name) + def _unasync_file(self, filepath): # copy from unasync to append file suffix to out path # https://github.com/python-trio/unasync diff --git a/docs/source/api.rst b/docs/source/api.rst index 3481480fe..d9187feaa 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -11,7 +11,7 @@ GraphDatabase Driver Construction =================== -The :class:`neo4j.Driver` construction is via a `classmethod` on the :class:`neo4j.GraphDatabase` class. +The :class:`neo4j.Driver` construction is done via a `classmethod` on the :class:`neo4j.GraphDatabase` class. .. autoclass:: neo4j.GraphDatabase :members: driver @@ -24,18 +24,19 @@ Example, driver creation: from neo4j import GraphDatabase uri = "neo4j://example.com:7687" - driver = GraphDatabase.driver(uri, auth=("neo4j", "password"), max_connection_lifetime=1000) + driver = GraphDatabase.driver(uri, auth=("neo4j", "password")) driver.close() # close the driver object -For basic auth, this can be a simple tuple, for example: +For basic authentication, `auth` can be a simple tuple, for example: .. code-block:: python auth = ("neo4j", "password") -This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"`` +This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``. +Other authentication methods are described under :ref:`auth-ref`. Example, with block context: @@ -345,11 +346,9 @@ BoltDriver URI schemes: ``bolt``, ``bolt+ssc``, ``bolt+s`` -Driver subclass: - :class:`neo4j.BoltDriver` +Will result in: -.. - .. autoclass:: neo4j.BoltDriver +.. autoclass:: neo4j.BoltDriver .. _neo4j-driver-ref: @@ -360,11 +359,9 @@ Neo4jDriver URI schemes: ``neo4j``, ``neo4j+ssc``, ``neo4j+s`` -Driver subclass: - :class:`neo4j.Neo4jDriver` +Will result in: -.. - .. autoclass:: neo4j.Neo4jDriver +.. autoclass:: neo4j.Neo4jDriver *********************** @@ -604,7 +601,8 @@ Example: def create_person(driver, name): with driver.session(default_access_mode=neo4j.WRITE_ACCESS) as session: - result = session.run("CREATE (a:Person { name: $name }) RETURN id(a) AS node_id", name=name) + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" + result = session.run(query, name=name) record = result.single() return record["node_id"] @@ -665,13 +663,15 @@ Example: tx.close() def create_person_node(tx): + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" name = "default_name" - result = tx.run("CREATE (a:Person { name: $name }) RETURN id(a) AS node_id", name=name) + result = tx.run(query, name=name) record = result.single() return record["node_id"] def set_person_name(tx, node_id, name): - result = tx.run("MATCH (a:Person) WHERE id(a) = $id SET a.name = $name", id=node_id, name=name) + query = "MATCH (a:Person) WHERE id(a) = $id SET a.name = $name" + result = tx.run(query, id=node_id, name=name) info = result.consume() # use the info for logging etc. @@ -698,7 +698,8 @@ Example: node_id = session.write_transaction(create_person_tx, name) def create_person_tx(tx, name): - result = tx.run("CREATE (a:Person { name: $name }) RETURN id(a) AS node_id", name=name) + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" + result = tx.run(query, name=name) record = result.single() return record["node_id"] @@ -708,12 +709,6 @@ To exert more control over how a transaction function is carried out, the :func: - - - - - - ****** Result ****** diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst new file mode 100644 index 000000000..c22ccbda2 --- /dev/null +++ b/docs/source/async_api.rst @@ -0,0 +1,497 @@ +.. _async-api-documentation: + +####################### +Async API Documentation +####################### + +.. warning:: + The whole async API is currently in experimental phase. + + This means everything documented on this page might be removed or change + its API at any time (including in patch releases). + +****************** +AsyncGraphDatabase +****************** + +Async Driver Construction +========================= + +The :class:`neo4j.AsyncDriver` construction is done via a `classmethod` on the :class:`neo4j.AsyncGraphDatabase` class. + +.. autoclass:: neo4j.AsyncGraphDatabase + :members: driver + + +Example, driver creation: + +.. code-block:: python + + import asyncio + + from neo4j import AsyncGraphDatabase + + async def main(): + uri = "neo4j://example.com:7687" + driver = AsyncGraphDatabase.driver(uri, auth=("neo4j", "password")) + + await driver.close() # close the driver object + + asyncio.run(main()) + + +For basic authentication, ``auth`` can be a simple tuple, for example: + +.. code-block:: python + + auth = ("neo4j", "password") + +This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``. +Other authentication methods are described under :ref:`auth-ref`. + +Example, with block context: + +.. code-block:: python + + import asyncio + + from neo4j import AsyncGraphDatabase + + async def main(): + uri = "neo4j://example.com:7687" + auth = ("neo4j", "password") + async with AsyncGraphDatabase.driver(uri, auth=auth) as driver: + # use the driver + ... + + asyncio.run(main()) + + +.. _async-uri-ref: + +URI +=== + +On construction, the `scheme` of the URI determines the type of :class:`neo4j.AsyncDriver` object created. + +Available valid URIs: + ++ ``bolt://host[:port]`` ++ ``bolt+ssc://host[:port]`` ++ ``bolt+s://host[:port]`` ++ ``neo4j://host[:port][?routing_context]`` ++ ``neo4j+ssc://host[:port][?routing_context]`` ++ ``neo4j+s://host[:port][?routing_context]`` + +.. code-block:: python + + uri = "bolt://example.com:7687" + +.. code-block:: python + + uri = "neo4j://example.com:7687" + +Each supported scheme maps to a particular :class:`neo4j.AsyncDriver` subclass that implements a specific behaviour. + ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| URI Scheme | Driver Object and Setting | ++========================+=============================================================================================================================================+ +| bolt | :ref:`async-bolt-driver-ref` with no encryption. | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| bolt+ssc | :ref:`async-bolt-driver-ref` with encryption (accepts self signed certificates). | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| bolt+s | :ref:`async-bolt-driver-ref` with encryption (accepts only certificates signed by a certificate authority), full certificate checks. | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| neo4j | :ref:`async-neo4j-driver-ref` with no encryption. | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| neo4j+ssc | :ref:`async-neo4j-driver-ref` with encryption (accepts self signed certificates). | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| neo4j+s | :ref:`async-neo4j-driver-ref` with encryption (accepts only certificates signed by a certificate authority), full certificate checks. | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ + +.. note:: + + See https://neo4j.com/docs/operations-manual/current/configuration/ports/ for Neo4j ports. + + + +*********** +AsyncDriver +*********** + +Every Neo4j-backed application will require a :class:`neo4j.AsyncDriver` object. + +This object holds the details required to establish connections with a Neo4j database, including server URIs, credentials and other configuration. +:class:`neo4j.AsyncDriver` objects hold a connection pool from which :class:`neo4j.AsyncSession` objects can borrow connections. +Closing a driver will immediately shut down all connections in the pool. + +.. autoclass:: neo4j.AsyncDriver() + :members: session, close + + +.. _async-driver-configuration-ref: + +Async Driver Configuration +========================== + +:class:`neo4j.AsyncDriver` is configured exactly like :class:`neo4j.Driver` +(see :ref:`driver-configuration-ref`). The only difference is that the async +driver accepts an async custom resolver function: + +.. _async-resolver-ref: + +``resolver`` +------------ +A custom resolver function to resolve host and port values ahead of DNS resolution. +This function is called with a 2-tuple of (host, port) and should return an iterable of 2-tuples (host, port). + +If no custom resolver function is supplied, the internal resolver moves straight to regular DNS resolution. + +The custom resolver function can but does not have to be a coroutine. + +For example: + +.. code-block:: python + + from neo4j import AsyncGraphDatabase + + async def custom_resolver(socket_address): + if socket_address == ("example.com", 9999): + yield "::1", 7687 + yield "127.0.0.1", 7687 + else: + from socket import gaierror + raise gaierror("Unexpected socket address %r" % socket_address) + + # alternatively + def custom_resolver(socket_address): + ... + + driver = AsyncGraphDatabase.driver("neo4j://example.com:9999", + auth=("neo4j", "password"), + resolver=custom_resolver) + + +:Default: ``None`` + + + +Driver Object Lifetime +====================== + +For general applications, it is recommended to create one top-level :class:`neo4j.AsyncDriver` object that lives for the lifetime of the application. + +For example: + +.. code-block:: python + + from neo4j import AsyncGraphDatabase + + class Application: + + def __init__(self, uri, user, password) + self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) + + async def close(self): + await self.driver.close() + +Connection details held by the :class:`neo4j.AsyncDriver` are immutable. +Therefore if, for example, a password is changed, a replacement :class:`neo4j.AsyncDriver` object must be created. +More than one :class:`.AsyncDriver` may be required if connections to multiple databases, or connections as multiple users, are required. + +:class:`neo4j.AsyncDriver` objects are safe to be used in concurrent coroutines. +They are not thread-safe. + + +.. _async-bolt-driver-ref: + +AsyncBoltDriver +=============== + +URI schemes: + ``bolt``, ``bolt+ssc``, ``bolt+s`` + +Will result in: + +.. autoclass:: neo4j.AsyncBoltDriver + + +.. _async-neo4j-driver-ref: + +AsyncNeo4jDriver +================ + +URI schemes: + ``neo4j``, ``neo4j+ssc``, ``neo4j+s`` + +Will result in: + +.. autoclass:: neo4j.AsyncNeo4jDriver + + +********************************* +AsyncSessions & AsyncTransactions +********************************* +All database activity is co-ordinated through two mechanisms: the :class:`neo4j.AsyncSession` and the :class:`neo4j.AsyncTransaction`. + +A :class:`neo4j.AsyncSession` is a logical container for any number of causally-related transactional units of work. +Sessions automatically provide guarantees of causal consistency within a clustered environment but multiple sessions can also be causally chained if required. +Sessions provide the top-level of containment for database activity. +Session creation is a lightweight operation and *sessions cannot be shared between coroutines*. + +Connections are drawn from the :class:`neo4j.AsyncDriver` connection pool as required. + +A :class:`neo4j.AsyncTransaction` is a unit of work that is either committed in its entirety or is rolled back on failure. + + +.. _async-session-construction-ref: + +************************* +AsyncSession Construction +************************* + +To construct a :class:`neo4j.AsyncSession` use the :meth:`neo4j.AsyncDriver.session` method. + +.. code-block:: python + + import asyncio + + from neo4j import AsyncGraphDatabase + + async def main(): + driver = AsyncGraphDatabase(uri, auth=(user, password)) + session = driver.session() + result = await session.run("MATCH (a:Person) RETURN a.name AS name") + names = [record["name"] async for record in result] + await session.close() + await driver.close() + + asyncio.run(main()) + + +Sessions will often be created and destroyed using a *with block context*. + +.. code-block:: python + + async with driver.session() as session: + result = await session.run("MATCH (a:Person) RETURN a.name AS name") + # do something with the result... + + +Sessions will often be created with some configuration settings, see :ref:`async-session-configuration-ref`. + +.. code-block:: python + + async with driver.session(database="example_database", + fetch_size=100) as session: + result = await session.run("MATCH (a:Person) RETURN a.name AS name") + # do something with the result... + + +************ +AsyncSession +************ + +.. autoclass:: neo4j.AsyncSession() + + .. automethod:: close + + .. automethod:: run + + .. automethod:: last_bookmark + + .. automethod:: begin_transaction + + .. automethod:: read_transaction + + .. automethod:: write_transaction + + + +.. _async-session-configuration-ref: + +Session Configuration +===================== + +:class:`neo4j.AsyncSession` is configured exactly like :class:`neo4j.Session` +(see :ref:`session-configuration-ref`). + + +**************** +AsyncTransaction +**************** + +Neo4j supports three kinds of async transaction: + ++ :ref:`async-auto-commit-transactions-ref` ++ :ref:`async-explicit-transactions-ref` ++ :ref:`async-managed-transactions-ref` + +Each has pros and cons but if in doubt, use a managed transaction with a `transaction function`. + + +.. _async-auto-commit-transactions-ref: + +Async Auto-commit Transactions +============================== +Auto-commit transactions are the simplest form of transaction, available via :py:meth:`neo4j.AsyncSession.run`. + +These are easy to use but support only one statement per transaction and are not automatically retried on failure. +Auto-commit transactions are also the only way to run ``PERIODIC COMMIT`` statements, since this Cypher clause manages its own transactions internally. + +Example: + +.. code-block:: python + + import neo4j + + async def create_person(driver, name): + async with driver.session( + default_access_mode=neo4j.WRITE_ACCESS + ) as session: + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" + result = await session.run(query, name=name) + record = await result.single() + return record["node_id"] + +Example: + +.. code-block:: python + + import neo4j + + async def get_numbers(driver): + numbers = [] + async with driver.session( + default_access_mode=neo4j.READ_ACCESS + ) as session: + result = await session.run("UNWIND [1, 2, 3] AS x RETURN x") + async for record in result: + numbers.append(record["x"]) + return numbers + + +.. _async-explicit-transactions-ref: + +Explicit Async Transactions +=========================== +Explicit transactions support multiple statements and must be created with an explicit :py:meth:`neo4j.AsyncSession.begin_transaction` call. + +This creates a new :class:`neo4j.AsyncTransaction` object that can be used to run Cypher. + +It also gives applications the ability to directly control `commit` and `rollback` activity. + +.. autoclass:: neo4j.AsyncTransaction() + + .. automethod:: run + + .. automethod:: close + + .. automethod:: closed + + .. automethod:: commit + + .. automethod:: rollback + +Closing an explicit transaction can either happen automatically at the end of a ``async with`` block, +or can be explicitly controlled through the :py:meth:`neo4j.AsyncTransaction.commit`, :py:meth:`neo4j.AsyncTransaction.rollback` or :py:meth:`neo4j.AsyncTransaction.close` methods. + +Explicit transactions are most useful for applications that need to distribute Cypher execution across multiple functions for the same transaction. + +Example: + +.. code-block:: python + + import neo4j + + async def create_person(driver, name): + async with driver.session( + default_access_mode=neo4j.WRITE_ACCESS + ) as session: + tx = await session.begin_transaction() + node_id = await create_person_node(tx) + await set_person_name(tx, node_id, name) + await tx.commit() + await tx.close() + + async def create_person_node(tx): + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" + name = "default_name" + result = await tx.run(query, name=name) + record = await result.single() + return record["node_id"] + + async def set_person_name(tx, node_id, name): + query = "MATCH (a:Person) WHERE id(a) = $id SET a.name = $name" + result = await tx.run(query, id=node_id, name=name) + info = await result.consume() + # use the info for logging etc. + +.. _async-managed-transactions-ref: + + +Managed Async Transactions (`transaction functions`) +==================================================== +Transaction functions are the most powerful form of transaction, providing access mode override and retry capabilities. + ++ :py:meth:`neo4j.AsyncSession.write_transaction` ++ :py:meth:`neo4j.AsyncSession.read_transaction` + +These allow a function object representing the transactional unit of work to be passed as a parameter. +This function is called one or more times, within a configurable time limit, until it succeeds. +Results should be fully consumed within the function and only aggregate or status values should be returned. +Returning a live result object would prevent the driver from correctly managing connections and would break retry guarantees. + +Example: + +.. code-block:: python + + async def create_person(driver, name) + async with driver.session() as session: + node_id = await session.write_transaction(create_person_tx, name) + + async def create_person_tx(tx, name): + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" + result = await tx.run(query, name=name) + record = await result.single() + return record["node_id"] + +To exert more control over how a transaction function is carried out, the :func:`neo4j.unit_of_work` decorator can be used. + + + +*********** +AsyncResult +*********** + +Every time a query is executed, a :class:`neo4j.AsyncResult` is returned. + +This provides a handle to the result of the query, giving access to the records within it as well as the result metadata. + +Results also contain a buffer that automatically stores unconsumed records when results are consumed out of order. + +A :class:`neo4j.AsyncResult` is attached to an active connection, through a :class:`neo4j.AsyncSession`, until all its content has been buffered or consumed. + +.. autoclass:: neo4j.AsyncResult() + + .. describe:: iter(result) + + .. automethod:: keys + + .. automethod:: consume + + .. automethod:: single + + .. automethod:: peek + + .. automethod:: graph + + **This is experimental.** (See :ref:`filter-warnings-ref`) + + .. automethod:: value + + .. automethod:: values + + .. automethod:: data + +See https://neo4j.com/docs/driver-manual/current/cypher-workflow/#driver-type-mapping for more about type mapping. diff --git a/docs/source/index.rst b/docs/source/index.rst index ac4bda164..24ef23865 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,6 +23,8 @@ Topics + :ref:`api-documentation` ++ :ref:`async-api-documentation` (experimental) + + :ref:`temporal-data-types` + :ref:`breaking-changes` @@ -32,6 +34,7 @@ Topics :hidden: api.rst + async_api.rst temporal_types.rst breaking_changes.rst diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index e8dc634b9..a5746e7bd 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -41,16 +41,16 @@ class AsyncGraphDatabase: @classmethod @AsyncUtil.experimental_async( "neo4j async is in experimental phase. It might be removed or change " - "it's API at any time (including patch releases)." + "its API at any time (including patch releases)." ) def driver(cls, uri, *, auth=None, **config): """Create a driver. - :param uri: the connection URI for the driver, see :ref:`uri-ref` for available URIs. + :param uri: the connection URI for the driver, see :ref:`async-uri-ref` for available URIs. :param auth: the authentication details, see :ref:`auth-ref` for available authentication details. - :param config: driver configuration key-word arguments, see :ref:`driver-configuration-ref` for available key-word arguments. + :param config: driver configuration key-word arguments, see :ref:`async-driver-configuration-ref` for available key-word arguments. - :return: :ref:`neo4j-driver-ref` or :ref:`bolt-driver-ref` + :rtype: AsyncNeo4jDriver or AsyncBoltDriver """ from ..api import ( @@ -193,8 +193,8 @@ def parse_targets(cls, *targets): class AsyncDriver: - """ Base class for all types of :class:`neo4j.Driver`, instances of which are - used as the primary access point to Neo4j. + """ Base class for all types of :class:`neo4j.AsyncDriver`, instances of + which are used as the primary access point to Neo4j. """ #: Connection pool @@ -219,9 +219,11 @@ def encrypted(self): return bool(self._pool.pool_config.encrypted) def session(self, **config): - """Create a session, see :ref:`session-construction-ref` + """Create a session, see :ref:`async-session-construction-ref` - :param config: session configuration key-word arguments, see :ref:`session-configuration-ref` for available key-word arguments. + :param config: session configuration key-word arguments, + see :ref:`async-session-configuration-ref` for available key-word + arguments. :returns: new :class:`neo4j.AsyncSession` object """ @@ -257,12 +259,15 @@ async def supports_multi_db(self): class AsyncBoltDriver(_Direct, AsyncDriver): - """ A :class:`.BoltDriver` is created from a ``bolt`` URI and addresses - a single database machine. This may be a standalone server or could be a - specific member of a cluster. + """:class:`.AsyncBoltDriver` is instantiated for ``bolt`` URIs and + addresses a single database machine. This may be a standalone server or + could be a specific member of a cluster. - Connections established by a :class:`.BoltDriver` are always made to the - exact host and port detailed in the URI. + Connections established by a :class:`.AsyncBoltDriver` are always made to + the exact host and port detailed in the URI. + + This class is not supposed to be instantiated externally. Use + :meth:`AsyncGraphDatabase.driver` instead. """ @classmethod @@ -288,7 +293,7 @@ def __init__(self, pool, default_workspace_config): def session(self, **config): """ - :param config: The values that can be specified are found in :class: `neo4j.AsyncSessionConfig` + :param config: The values that can be specified are found in :class: `neo4j.SessionConfig` :return: :rtype: :class: `neo4j.AsyncSession` @@ -311,11 +316,14 @@ async def verify_connectivity(self, **config): class AsyncNeo4jDriver(_Routing, AsyncDriver): - """ A :class:`.Neo4jDriver` is created from a ``neo4j`` URI. The + """:class:`.AsyncNeo4jDriver` is instantiated for ``neo4j`` URIs. The routing behaviour works in tandem with Neo4j's `Causal Clustering `_ feature by directing read and write behaviour to appropriate cluster members. + + This class is not supposed to be instantiated externally. Use + :meth:`AsyncGraphDatabase.driver` instead. """ @classmethod diff --git a/neo4j/_async/work/__init__.py b/neo4j/_async/work/__init__.py index 8c6ea430a..e48e1c212 100644 --- a/neo4j/_async/work/__init__.py +++ b/neo4j/_async/work/__init__.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from .session import ( AsyncResult, AsyncSession, diff --git a/neo4j/_async/work/result.py b/neo4j/_async/work/result.py index 08fe16822..6184d9634 100644 --- a/neo4j/_async/work/result.py +++ b/neo4j/_async/work/result.py @@ -28,7 +28,7 @@ class AsyncResult: """A handler for the result of Cypher query execution. Instances of this class are typically constructed and returned by - :meth:`.Session.run` and :meth:`.Transaction.run`. + :meth:`.AyncSession.run` and :meth:`.AsyncTransaction.run`. """ def __init__(self, connection, hydrant, fetch_size, on_closed, @@ -204,7 +204,7 @@ async def _attach(self): await self._connection.fetch_message() async def _buffer(self, n=None): - """Try to fill `self_record_buffer` with n records. + """Try to fill `self._record_buffer` with n records. Might end up with more records in the buffer if the fetch size makes it overshoot. @@ -260,25 +260,28 @@ async def consume(self): Example:: def create_node_tx(tx, name): - result = tx.run("CREATE (n:ExampleNode { name: $name }) RETURN n", name=name) - record = result.single() + result = await tx.run( + "CREATE (n:ExampleNode { name: $name }) RETURN n", name=name + ) + record = await result.single() value = record.value() - info = result.consume() + info = await result.consume() return value, info - with driver.session() as session: - node_id, info = session.write_transaction(create_node_tx, "example") + async with driver.session() as session: + node_id, info = await session.write_transaction(create_node_tx, "example") Example:: - def get_two_tx(tx): - result = tx.run("UNWIND [1,2,3,4] AS x RETURN x") + async def get_two_tx(tx): + result = await tx.run("UNWIND [1,2,3,4] AS x RETURN x") values = [] - for ix, record in enumerate(result): - if x > 1: + async for record in result: + if len(values) >= 2: break values.append(record.values()) - info = result.consume() # discard the remaining records if there are any + # discard the remaining records if there are any + info = await result.consume() # use the info for logging etc. return values, info diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py index 6a1be2d43..bedf2915a 100644 --- a/neo4j/_async/work/session.py +++ b/neo4j/_async/work/session.py @@ -47,19 +47,20 @@ class AsyncSession(AsyncWorkspace): - """A :class:`.Session` is a logical context for transactional units - of work. Connections are drawn from the :class:`.Driver` connection + """A :class:`.AsyncSession` is a logical context for transactional units + of work. Connections are drawn from the :class:`.AsyncDriver` connection pool as required. - Session creation is a lightweight operation and sessions are not thread - safe. Therefore a session should generally be short-lived, and not - span multiple threads. + Session creation is a lightweight operation and sessions are not safe to + be used in concurrent contexts (multiple threads/coroutines). + Therefore, a session should generally be short-lived, and must not + span multiple threads/coroutines. In general, sessions will be created and destroyed within a `with` context. For example:: - with driver.session() as session: - result = session.run("MATCH (n:Person) RETURN n.name AS name") + async with driver.session() as session: + result = await session.run("MATCH (n:Person) RETURN n.name AS name") # do something with the result... :param pool: connection pool instance @@ -171,23 +172,23 @@ async def run(self, query, parameters=None, **kwargs): fetched lazily as consumed by the client application. If a query is executed before a previous - :class:`neo4j.Result` in the same :class:`.Session` has + :class:`neo4j.AsyncResult` in the same :class:`.AsyncSession` has been fully consumed, the first result will be fully fetched and buffered. Note therefore that the generally recommended pattern of usage is to fully consume one result before executing a subsequent query. If two results need to be - consumed in parallel, multiple :class:`.Session` objects + consumed in parallel, multiple :class:`.AsyncSession` objects can be used as an alternative to result buffering. - For more usage details, see :meth:`.Transaction.run`. + For more usage details, see :meth:`.AsyncTransaction.run`. :param query: cypher query :type query: str, neo4j.Query :param parameters: dictionary of parameters :type parameters: dict :param kwargs: additional keyword parameters - :returns: a new :class:`neo4j.Result` object - :rtype: :class:`neo4j.Result` + :returns: a new :class:`neo4j.AsyncResult` object + :rtype: AsyncResult """ if not query: raise ValueError("Cannot run an empty query") @@ -265,11 +266,11 @@ async def _open_transaction(self, *, access_mode, metadata=None, ) async def begin_transaction(self, metadata=None, timeout=None): - """ Begin a new unmanaged transaction. Creates a new :class:`.Transaction` within this session. + """ Begin a new unmanaged transaction. Creates a new :class:`.AsyncTransaction` within this session. At most one transaction may exist in a session at any point in time. To maintain multiple concurrent transactions, use multiple concurrent sessions. - Note: For auto-transaction (Session.run) this will trigger an consume for the current result. + Note: For auto-transaction (AsyncSession.run) this will trigger an consume for the current result. :param metadata: a dictionary with metadata. @@ -287,7 +288,7 @@ async def begin_transaction(self, metadata=None, timeout=None): :type timeout: int :returns: A new transaction instance. - :rtype: :class:`neo4j.Transaction` + :rtype: AsyncTransaction :raises TransactionError: :class:`neo4j.exceptions.TransactionError` if a transaction is already open. """ @@ -365,37 +366,40 @@ async def read_transaction(self, transaction_function, *args, **kwargs): This transaction will automatically be committed unless an exception is thrown during query execution or by the user code. Note, that this function perform retries and that the supplied `transaction_function` might get invoked more than once. - Managed transactions should not generally be explicitly committed (via tx.commit()). + Managed transactions should not generally be explicitly committed + (via ``await tx.commit()``). Example:: - def do_cypher_tx(tx, cypher): - result = tx.run(cypher) - values = [] - for record in result: - values.append(record.values()) + async def do_cypher_tx(tx, cypher): + result = await tx.run(cypher) + values = [record.values() async for record in result] return values - with driver.session() as session: - values = session.read_transaction(do_cypher_tx, "RETURN 1 AS x") + async with driver.session() as session: + values = await session.read_transaction(do_cypher_tx, "RETURN 1 AS x") Example:: - def get_two_tx(tx): - result = tx.run("UNWIND [1,2,3,4] AS x RETURN x") + async def get_two_tx(tx): + result = await tx.run("UNWIND [1,2,3,4] AS x RETURN x") values = [] - for ix, record in enumerate(result): - if x > 1: + async for record in result: + if len(values) >= 2: break values.append(record.values()) - info = result.consume() # discard the remaining records if there are any + # discard the remaining records if there are any + info = await result.consume() # use the info for logging etc. return values - with driver.session() as session: - values = session.read_transaction(get_two_tx) + async with driver.session() as session: + values = await session.read_transaction(get_two_tx) - :param transaction_function: a function that takes a transaction as an argument and does work with the transaction. `tx_function(tx, *args, **kwargs)` + :param transaction_function: a function that takes a transaction as an + argument and does work with the transaction. + `transaction_function(tx, *args, **kwargs)` where `tx` is a + :class:`.AsyncTransaction`. :param args: arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` :return: a result as returned by the given unit of work @@ -413,16 +417,19 @@ async def write_transaction(self, transaction_function, *args, **kwargs): Example:: - def create_node_tx(tx, name): - result = tx.run("CREATE (n:NodeExample { name: $name }) RETURN id(n) AS node_id", name=name) - record = result.single() + async def create_node_tx(tx, name): + query = "CREATE (n:NodeExample { name: $name }) RETURN id(n) AS node_id" + result = await tx.run(query, name=name) + record = await result.single() return record["node_id"] - with driver.session() as session: - node_id = session.write_transaction(create_node_tx, "example") - + async with driver.session() as session: + node_id = await session.write_transaction(create_node_tx, "example") - :param transaction_function: a function that takes a transaction as an argument and does work with the transaction. `tx_function(tx, *args, **kwargs)` + :param transaction_function: a function that takes a transaction as an + argument and does work with the transaction. + `transaction_function(tx, *args, **kwargs)` where `tx` is a + :class:`.AsyncTransaction`. :param args: key word arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` :return: a result as returned by the given unit of work diff --git a/neo4j/_async/work/transaction.py b/neo4j/_async/work/transaction.py index 1753685a0..a2ae32f56 100644 --- a/neo4j/_async/work/transaction.py +++ b/neo4j/_async/work/transaction.py @@ -25,13 +25,13 @@ class AsyncTransaction: - """ Container for multiple Cypher queries to be executed within - a single context. Transactions can be used within a :py:const:`with` + """ Container for multiple Cypher queries to be executed within a single + context. asynctransactions can be used within a :py:const:`async with` block where the transaction is committed or rolled back on based on - whether or not an exception is raised:: + whether an exception is raised:: - with session.begin_transaction() as tx: - pass + async with session.begin_transaction() as tx: + ... """ @@ -91,9 +91,9 @@ async def run(self, query, parameters=None, **kwparameters): queries below are all equivalent:: >>> query = "CREATE (a:Person { name: $name, age: $age })" - >>> result = tx.run(query, {"name": "Alice", "age": 33}) - >>> result = tx.run(query, {"name": "Alice"}, age=33) - >>> result = tx.run(query, name="Alice", age=33) + >>> result = await tx.run(query, {"name": "Alice", "age": 33}) + >>> result = await tx.run(query, {"name": "Alice"}, age=33) + >>> result = await tx.run(query, name="Alice", age=33) Parameter values can be of any type supported by the Neo4j type system. In Python, this includes :class:`bool`, :class:`int`, diff --git a/neo4j/api.py b/neo4j/api.py index b95150d59..36c05cbb2 100644 --- a/neo4j/api.py +++ b/neo4j/api.py @@ -106,7 +106,8 @@ def basic_auth(user, password, realm=None): :param realm: specifies the authentication provider :type realm: str or None - :return: auth token for use with :meth:`GraphDatabase.driver` + :return: auth token for use with :meth:`GraphDatabase.driver` or + :meth:`AsyncGraphDatabase.driver` :rtype: :class:`neo4j.Auth` """ return Auth("basic", user, password, realm) @@ -121,7 +122,8 @@ def kerberos_auth(base64_encoded_ticket): the credentials :type base64_encoded_ticket: str - :return: auth token for use with :meth:`GraphDatabase.driver` + :return: auth token for use with :meth:`GraphDatabase.driver` or + :meth:`AsyncGraphDatabase.driver` :rtype: :class:`neo4j.Auth` """ return Auth("kerberos", "", base64_encoded_ticket) @@ -136,7 +138,8 @@ def bearer_auth(base64_encoded_token): by a Single-Sign-On provider. :type base64_encoded_token: str - :return: auth token for use with :meth:`GraphDatabase.driver` + :return: auth token for use with :meth:`GraphDatabase.driver` or + :meth:`AsyncGraphDatabase.driver` :rtype: :class:`neo4j.Auth` """ return Auth("bearer", None, base64_encoded_token) @@ -157,7 +160,8 @@ def custom_auth(principal, credentials, realm, scheme, **parameters): authentication provider :type parameters: Dict[str, Any] - :return: auth token for use with :meth:`GraphDatabase.driver` + :return: auth token for use with :meth:`GraphDatabase.driver` or + :meth:`AsyncGraphDatabase.driver` :rtype: :class:`neo4j.Auth` """ return Auth(scheme, principal, credentials, realm, **parameters) From 9e8a3c20ce6bf2014e6e9daedf17d8b2cb245ff3 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Mon, 20 Dec 2021 17:56:22 +0100 Subject: [PATCH 4/5] Docs and other suggestions by Florent Co-authored-by: Florent Biville --- .editorconfig | 4 ++-- .pre-commit-config.yaml | 10 ---------- CHANGELOG.md | 2 +- CONTRIBUTING.md | 23 ++--------------------- docs/source/api.rst | 9 +++++---- docs/source/async_api.rst | 9 +++++---- docs/source/transactions.rst | 2 +- neo4j/_async/driver.py | 4 ++-- neo4j/_async/io/_bolt.py | 6 ++++++ neo4j/_async_compat/concurrency.py | 26 -------------------------- setup.cfg | 3 --- 11 files changed, 24 insertions(+), 74 deletions(-) diff --git a/.editorconfig b/.editorconfig index 708a87f88..823ad91d1 100644 --- a/.editorconfig +++ b/.editorconfig @@ -10,11 +10,11 @@ charset = utf-8 [*.{py,js,rst,txt,sh,bat}] trim_trailing_whitespace = true -[{Makefile,Drockerfile}] +[{Makefile,Dockerfile}] trim_trailing_whitespace = true [*.bat] -end_of_line =crlf +end_of_line = crlf [*.py] max_line_length = 79 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 182008619..c4eafcb56 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,16 +19,6 @@ repos: - batch - id: trailing-whitespace args: [ --markdown-linebreak-ext=md ] -# - repo: https://github.com/pycqa/flake8 -# rev: 4.0.1 -# hooks: -# - id: flake8 -# additional_dependencies: -# - flake8-bugbear -# - flake8-builtins -# - flake8-docstrings -# - flake8-quotes -# - pep8-naming - repo: https://github.com/pycqa/isort rev: 5.10.0 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index ca48aa200..4f6e20009 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ - Python 3.10 support added - Python 3.6 support has been dropped. -- `Result`, `Session`, and `Transaction`, can no longer be imported from +- `Result`, `Session`, and `Transaction` can no longer be imported from `neo4j.work`. They should've been imported from `neo4j` all along. - Experimental pipelines feature has been removed. - Experimental async driver has been added. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7cda62d18..d2656fd6c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -58,7 +58,7 @@ install the pre-commit hooks as described below insted. They will take care of updating the code if necessary. Setting up the development environment: - * Install Python 3.6+ + * Install Python 3.7+ * Install the requirements ```bash $ python3 -m pip install -U pip @@ -69,25 +69,6 @@ Setting up the development environment: ```bash $ pre-commit install ``` - Note that this is not an auto-formatter. It will alter some code, but - mostly it will just complain about non-compliant code. - You can disable a certain check for a single line of code if you think - your code-style if preferable. E.g. - ```python - assume_this_line(violates_rule_e123, and_w321) # noqa: E123,W321 - ``` - Or use just `# noqa` to disable all checks for this line. - If you use `# noqa` on its own line, it will disable *all* checks for the - whole file. Don't do that. - To disable certain rules for a whole file, check out - `setup.cfg`. - If you want to run the checks manually, you can do so: - ```bash - $ pre-commit run --all-files - # or - $ pre-commit run --file path/to/a/file - ``` - For more details see [flake8](https://flake8.pycqa.org/). ## Got an idea for a new project? @@ -95,7 +76,7 @@ Setting up the development environment: If you have an idea for a new tool or library, start by talking to other people in the community. Chances are that someone has a similar idea or may have already started working on it. The best software comes from getting like minds together to solve a problem. -And we'll do our best to help you promote and co-ordinate your Neo ecosystem projects. +And we'll do our best to help you promote and co-ordinate your Neo4j ecosystem projects. ## Further reading diff --git a/docs/source/api.rst b/docs/source/api.rst index d9187feaa..f10c71b91 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -17,7 +17,7 @@ The :class:`neo4j.Driver` construction is done via a `classmethod` on the :class :members: driver -Example, driver creation: +Driver creation example: .. code-block:: python @@ -39,7 +39,7 @@ This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``. Other authentication methods are described under :ref:`auth-ref`. -Example, with block context: +``with`` block context example: .. code-block:: python @@ -331,7 +331,8 @@ For example: Connection details held by the :class:`neo4j.Driver` are immutable. Therefore if, for example, a password is changed, a replacement :class:`neo4j.Driver` object must be created. -More than one :class:`.Driver` may be required if connections to multiple databases, or connections as multiple users, are required. +More than one :class:`.Driver` may be required if connections to multiple databases, or connections as multiple users, are required, +unless when using impersonation (:ref:`impersonated-user-ref`). :class:`neo4j.Driver` objects are thread-safe but cannot be shared across processes. Therefore, ``multithreading`` should generally be preferred over ``multiprocessing`` for parallel database access. @@ -371,7 +372,7 @@ All database activity is co-ordinated through two mechanisms: the :class:`neo4j. A :class:`neo4j.Session` is a logical container for any number of causally-related transactional units of work. Sessions automatically provide guarantees of causal consistency within a clustered environment but multiple sessions can also be causally chained if required. -Sessions provide the top-level of containment for database activity. +Sessions provide the top level of containment for database activity. Session creation is a lightweight operation and *sessions are not thread safe*. Connections are drawn from the :class:`neo4j.Driver` connection pool as required. diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index c22ccbda2..989ab200f 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -23,7 +23,7 @@ The :class:`neo4j.AsyncDriver` construction is done via a `classmethod` on the : :members: driver -Example, driver creation: +Driver creation example: .. code-block:: python @@ -49,7 +49,7 @@ For basic authentication, ``auth`` can be a simple tuple, for example: This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``. Other authentication methods are described under :ref:`auth-ref`. -Example, with block context: +``with`` block context example: .. code-block:: python @@ -197,7 +197,8 @@ For example: Connection details held by the :class:`neo4j.AsyncDriver` are immutable. Therefore if, for example, a password is changed, a replacement :class:`neo4j.AsyncDriver` object must be created. -More than one :class:`.AsyncDriver` may be required if connections to multiple databases, or connections as multiple users, are required. +More than one :class:`.AsyncDriver` may be required if connections to multiple databases, or connections as multiple users, are required, +unless when using impersonation (:ref:`impersonated-user-ref`). :class:`neo4j.AsyncDriver` objects are safe to be used in concurrent coroutines. They are not thread-safe. @@ -236,7 +237,7 @@ All database activity is co-ordinated through two mechanisms: the :class:`neo4j. A :class:`neo4j.AsyncSession` is a logical container for any number of causally-related transactional units of work. Sessions automatically provide guarantees of causal consistency within a clustered environment but multiple sessions can also be causally chained if required. -Sessions provide the top-level of containment for database activity. +Sessions provide the top level of containment for database activity. Session creation is a lightweight operation and *sessions cannot be shared between coroutines*. Connections are drawn from the :class:`neo4j.AsyncDriver` connection pool as required. diff --git a/docs/source/transactions.rst b/docs/source/transactions.rst index 93dc35106..1bb4db2ab 100644 --- a/docs/source/transactions.rst +++ b/docs/source/transactions.rst @@ -11,7 +11,7 @@ Sessions automatically provide guarantees of causal consistency within a cluster Sessions ======== -Sessions provide the top-level of containment for database activity. +Sessions provide the top level of containment for database activity. Session creation is a lightweight operation and sessions are `not` thread safe. Connections are drawn from the :class:`neo4j.Driver` connection pool as required; an idle session will not hold onto a connection. diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index a5746e7bd..009efc320 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -40,8 +40,8 @@ class AsyncGraphDatabase: @classmethod @AsyncUtil.experimental_async( - "neo4j async is in experimental phase. It might be removed or change " - "its API at any time (including patch releases)." + "neo4j async is in experimental phase. It might be removed or changed " + "at any time (including patch releases)." ) def driver(cls, uri, *, auth=None, **config): """Create a driver. diff --git a/neo4j/_async/io/_bolt.py b/neo4j/_async/io/_bolt.py index ba6c0eaf1..503ebeaf8 100644 --- a/neo4j/_async/io/_bolt.py +++ b/neo4j/_async/io/_bolt.py @@ -350,7 +350,9 @@ async def route(self, database=None, imp_user=None, bookmarks=None): sent to the network, and a response is fetched. :param database: database for which to fetch a routing table + Requires Bolt 4.0+. :param imp_user: the user to impersonate + Requires Bolt 4.4+. :param bookmarks: iterable of bookmark values after which this transaction should begin :return: dictionary of raw routing data @@ -369,7 +371,9 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, :param metadata: custom metadata dictionary to attach to the transaction :param timeout: timeout for transaction execution (seconds) :param db: name of the database against which to begin the transaction + Requires Bolt 4.0+. :param imp_user: the user to impersonate + Requires Bolt 4.4+. :param handlers: handler functions passed into the returned Response object :return: Response object """ @@ -407,7 +411,9 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, :param metadata: custom metadata dictionary to attach to the transaction :param timeout: timeout for transaction execution (seconds) :param db: name of the database against which to begin the transaction + Requires Bolt 4.0+. :param imp_user: the user to impersonate + Requires Bolt 4.4+ :param handlers: handler functions passed into the returned Response object :return: Response object """ diff --git a/neo4j/_async_compat/concurrency.py b/neo4j/_async_compat/concurrency.py index 5758424d8..868ad1bd3 100644 --- a/neo4j/_async_compat/concurrency.py +++ b/neo4j/_async_compat/concurrency.py @@ -229,29 +229,3 @@ def notify_all(self): Condition = threading.Condition RLock = threading.RLock - - -# async def main(): -# lock = AsyncRLock() -# -# async def say_after(delay, what): -# await asyncio.sleep(delay) -# print(what, repr(lock)) -# async with lock: -# print(what, repr(lock)) -# async with lock: -# print(what, repr(lock)) -# await asyncio.sleep(delay) -# print(what) -# -# task_1 = asyncio.create_task(say_after(0.5, "1")) -# task_2 = asyncio.create_task(say_after(0.5, "2")) -# task_3 = asyncio.create_task(say_after(0.5, "3")) -# -# await task_1 -# await task_2 -# await task_3 -# -# -# if __name__ == "__main__": -# asyncio.run(main()) diff --git a/setup.cfg b/setup.cfg index 5967a7fc4..c63faed9e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,6 +11,3 @@ multi_line_output=3 order_by_type=false remove_redundant_aliases=true use_parentheses=true - -[flake8] -inline-quotes=double From 24370b8ca686a47223727b3e091b63774d9396b3 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Mon, 20 Dec 2021 18:16:58 +0100 Subject: [PATCH 5/5] Track generated sync code --- neo4j/_sync/__init__.py | 16 + neo4j/_sync/driver.py | 380 ++++++++++++ neo4j/_sync/io/__init__.py | 43 ++ neo4j/_sync/io/_bolt.py | 571 ++++++++++++++++++ neo4j/_sync/io/_bolt3.py | 396 +++++++++++++ neo4j/_sync/io/_bolt4.py | 537 +++++++++++++++++ neo4j/_sync/io/_common.py | 280 +++++++++ neo4j/_sync/io/_pool.py | 701 +++++++++++++++++++++++ neo4j/_sync/work/__init__.py | 32 ++ neo4j/_sync/work/result.py | 379 ++++++++++++ neo4j/_sync/work/session.py | 447 +++++++++++++++ neo4j/_sync/work/transaction.py | 199 +++++++ neo4j/_sync/work/workspace.py | 102 ++++ testkitbackend/_sync/__init__.py | 16 + testkitbackend/_sync/backend.py | 145 +++++ testkitbackend/_sync/requests.py | 444 ++++++++++++++ tests/unit/sync/__init__.py | 16 + tests/unit/sync/io/__init__.py | 16 + tests/unit/sync/io/conftest.py | 156 +++++ tests/unit/sync/io/test__common.py | 50 ++ tests/unit/sync/io/test_class_bolt.py | 62 ++ tests/unit/sync/io/test_class_bolt3.py | 115 ++++ tests/unit/sync/io/test_class_bolt4x0.py | 209 +++++++ tests/unit/sync/io/test_class_bolt4x1.py | 227 ++++++++ tests/unit/sync/io/test_class_bolt4x2.py | 228 ++++++++ tests/unit/sync/io/test_class_bolt4x3.py | 255 +++++++++ tests/unit/sync/io/test_class_bolt4x4.py | 271 +++++++++ tests/unit/sync/io/test_direct.py | 231 ++++++++ tests/unit/sync/io/test_neo4j_pool.py | 259 +++++++++ tests/unit/sync/test_addressing.py | 125 ++++ tests/unit/sync/test_driver.py | 157 +++++ tests/unit/sync/work/__init__.py | 22 + tests/unit/sync/work/_fake_connection.py | 110 ++++ tests/unit/sync/work/test_result.py | 456 +++++++++++++++ tests/unit/sync/work/test_session.py | 285 +++++++++ tests/unit/sync/work/test_transaction.py | 185 ++++++ 36 files changed, 8123 insertions(+) create mode 100644 neo4j/_sync/__init__.py create mode 100644 neo4j/_sync/driver.py create mode 100644 neo4j/_sync/io/__init__.py create mode 100644 neo4j/_sync/io/_bolt.py create mode 100644 neo4j/_sync/io/_bolt3.py create mode 100644 neo4j/_sync/io/_bolt4.py create mode 100644 neo4j/_sync/io/_common.py create mode 100644 neo4j/_sync/io/_pool.py create mode 100644 neo4j/_sync/work/__init__.py create mode 100644 neo4j/_sync/work/result.py create mode 100644 neo4j/_sync/work/session.py create mode 100644 neo4j/_sync/work/transaction.py create mode 100644 neo4j/_sync/work/workspace.py create mode 100644 testkitbackend/_sync/__init__.py create mode 100644 testkitbackend/_sync/backend.py create mode 100644 testkitbackend/_sync/requests.py create mode 100644 tests/unit/sync/__init__.py create mode 100644 tests/unit/sync/io/__init__.py create mode 100644 tests/unit/sync/io/conftest.py create mode 100644 tests/unit/sync/io/test__common.py create mode 100644 tests/unit/sync/io/test_class_bolt.py create mode 100644 tests/unit/sync/io/test_class_bolt3.py create mode 100644 tests/unit/sync/io/test_class_bolt4x0.py create mode 100644 tests/unit/sync/io/test_class_bolt4x1.py create mode 100644 tests/unit/sync/io/test_class_bolt4x2.py create mode 100644 tests/unit/sync/io/test_class_bolt4x3.py create mode 100644 tests/unit/sync/io/test_class_bolt4x4.py create mode 100644 tests/unit/sync/io/test_direct.py create mode 100644 tests/unit/sync/io/test_neo4j_pool.py create mode 100644 tests/unit/sync/test_addressing.py create mode 100644 tests/unit/sync/test_driver.py create mode 100644 tests/unit/sync/work/__init__.py create mode 100644 tests/unit/sync/work/_fake_connection.py create mode 100644 tests/unit/sync/work/test_result.py create mode 100644 tests/unit/sync/work/test_session.py create mode 100644 tests/unit/sync/work/test_transaction.py diff --git a/neo4j/_sync/__init__.py b/neo4j/_sync/__init__.py new file mode 100644 index 000000000..b81a309da --- /dev/null +++ b/neo4j/_sync/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py new file mode 100644 index 000000000..711b99e77 --- /dev/null +++ b/neo4j/_sync/driver.py @@ -0,0 +1,380 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +from .._async_compat.util import Util +from ..addressing import Address +from ..api import ( + READ_ACCESS, + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, +) +from ..conf import ( + Config, + PoolConfig, + SessionConfig, + WorkspaceConfig, +) +from ..meta import experimental + + +class GraphDatabase: + """Accessor for :class:`neo4j.Driver` construction. + """ + + @classmethod + @Util.experimental_async( + "neo4j is in experimental phase. It might be removed or changed " + "at any time (including patch releases)." + ) + def driver(cls, uri, *, auth=None, **config): + """Create a driver. + + :param uri: the connection URI for the driver, see :ref:`uri-ref` for available URIs. + :param auth: the authentication details, see :ref:`auth-ref` for available authentication details. + :param config: driver configuration key-word arguments, see :ref:`driver-configuration-ref` for available key-word arguments. + + :rtype: Neo4jDriver or BoltDriver + """ + + from ..api import ( + DRIVER_BOLT, + DRIVER_NEO4j, + parse_neo4j_uri, + parse_routing_context, + SECURITY_TYPE_NOT_SECURE, + SECURITY_TYPE_SECURE, + SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_BOLT, + URI_SCHEME_BOLT_SECURE, + URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_NEO4J, + URI_SCHEME_NEO4J_SECURE, + URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + ) + + driver_type, security_type, parsed = parse_neo4j_uri(uri) + + if "trust" in config.keys(): + if config.get("trust") not in [TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES]: + from neo4j.exceptions import ConfigurationError + raise ConfigurationError("The config setting `trust` values are {!r}".format( + [ + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + ] + )) + + if security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] and ("encrypted" in config.keys() or "trust" in config.keys()): + from neo4j.exceptions import ConfigurationError + raise ConfigurationError("The config settings 'encrypted' and 'trust' can only be used with the URI schemes {!r}. Use the other URI schemes {!r} for setting encryption settings.".format( + [ + URI_SCHEME_BOLT, + URI_SCHEME_NEO4J, + ], + [ + URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_BOLT_SECURE, + URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_NEO4J_SECURE, + ] + )) + + if security_type == SECURITY_TYPE_SECURE: + config["encrypted"] = True + elif security_type == SECURITY_TYPE_SELF_SIGNED_CERTIFICATE: + config["encrypted"] = True + config["trust"] = TRUST_ALL_CERTIFICATES + + if driver_type == DRIVER_BOLT: + return cls.bolt_driver(parsed.netloc, auth=auth, **config) + elif driver_type == DRIVER_NEO4j: + routing_context = parse_routing_context(parsed.query) + return cls.neo4j_driver(parsed.netloc, auth=auth, routing_context=routing_context, **config) + + @classmethod + def bolt_driver(cls, target, *, auth=None, **config): + """ Create a driver for direct Bolt server access that uses + socket I/O and thread-based concurrency. + """ + from .._exceptions import ( + BoltHandshakeError, + BoltSecurityError, + ) + + try: + return BoltDriver.open(target, auth=auth, **config) + except (BoltHandshakeError, BoltSecurityError) as error: + from neo4j.exceptions import ServiceUnavailable + raise ServiceUnavailable(str(error)) from error + + @classmethod + def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): + """ Create a driver for routing-capable Neo4j service access + that uses socket I/O and thread-based concurrency. + """ + from neo4j._exceptions import ( + BoltHandshakeError, + BoltSecurityError, + ) + + try: + return Neo4jDriver.open(*targets, auth=auth, routing_context=routing_context, **config) + except (BoltHandshakeError, BoltSecurityError) as error: + from neo4j.exceptions import ServiceUnavailable + raise ServiceUnavailable(str(error)) from error + + +class _Direct: + + default_host = "localhost" + default_port = 7687 + + default_target = ":" + + def __init__(self, address): + self._address = address + + @property + def address(self): + return self._address + + @classmethod + def parse_target(cls, target): + """ Parse a target string to produce an address. + """ + if not target: + target = cls.default_target + address = Address.parse(target, default_host=cls.default_host, + default_port=cls.default_port) + return address + + +class _Routing: + + default_host = "localhost" + default_port = 7687 + + default_targets = ": :17601 :17687" + + def __init__(self, initial_addresses): + self._initial_addresses = initial_addresses + + @property + def initial_addresses(self): + return self._initial_addresses + + @classmethod + def parse_targets(cls, *targets): + """ Parse a sequence of target strings to produce an address + list. + """ + targets = " ".join(targets) + if not targets: + targets = cls.default_targets + addresses = Address.parse_list(targets, default_host=cls.default_host, default_port=cls.default_port) + return addresses + + +class Driver: + """ Base class for all types of :class:`neo4j.Driver`, instances of + which are used as the primary access point to Neo4j. + """ + + #: Connection pool + _pool = None + + def __init__(self, pool): + assert pool is not None + self._pool = pool + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + if not asyncio.iscoroutinefunction(self.close): + self.close() + + @property + def encrypted(self): + return bool(self._pool.pool_config.encrypted) + + def session(self, **config): + """Create a session, see :ref:`session-construction-ref` + + :param config: session configuration key-word arguments, + see :ref:`session-configuration-ref` for available key-word + arguments. + + :returns: new :class:`neo4j.Session` object + """ + raise NotImplementedError + + def close(self): + """ Shut down, closing any open connections in the pool. + """ + self._pool.close() + + @experimental("The configuration may change in the future.") + def verify_connectivity(self, **config): + """ This verifies if the driver can connect to a remote server or a cluster + by establishing a network connection with the remote and possibly exchanging + a few data before closing the connection. It throws exception if fails to connect. + + Use the exception to further understand the cause of the connectivity problem. + + Note: Even if this method throws an exception, the driver still need to be closed via close() to free up all resources. + """ + raise NotImplementedError + + @experimental("Feature support query, based on Bolt Protocol Version and Neo4j Server Version will change in the future.") + def supports_multi_db(self): + """ Check if the server or cluster supports multi-databases. + + :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false. + :rtype: bool + """ + with self.session() as session: + session._connect(READ_ACCESS) + return session._connection.supports_multiple_databases + + +class BoltDriver(_Direct, Driver): + """:class:`.BoltDriver` is instantiated for ``bolt`` URIs and + addresses a single database machine. This may be a standalone server or + could be a specific member of a cluster. + + Connections established by a :class:`.BoltDriver` are always made to + the exact host and port detailed in the URI. + + This class is not supposed to be instantiated externally. Use + :meth:`GraphDatabase.driver` instead. + """ + + @classmethod + def open(cls, target, *, auth=None, **config): + """ + :param target: + :param auth: + :param config: The values that can be specified are found in :class: `neo4j.PoolConfig` and :class: `neo4j.WorkspaceConfig` + + :return: + :rtype: :class: `neo4j.BoltDriver` + """ + from .io import BoltPool + address = cls.parse_target(target) + pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) + pool = BoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config) + return cls(pool, default_workspace_config) + + def __init__(self, pool, default_workspace_config): + _Direct.__init__(self, pool.address) + Driver.__init__(self, pool) + self._default_workspace_config = default_workspace_config + + def session(self, **config): + """ + :param config: The values that can be specified are found in :class: `neo4j.SessionConfig` + + :return: + :rtype: :class: `neo4j.Session` + """ + from .work import Session + session_config = SessionConfig(self._default_workspace_config, config) + SessionConfig.consume(config) # Consume the config + return Session(self._pool, session_config) + + @experimental("The configuration may change in the future.") + def verify_connectivity(self, **config): + server_agent = None + config["fetch_size"] = -1 + with self.session(**config) as session: + result = session.run("RETURN 1 AS x") + value = result.single().value() + summary = result.consume() + server_agent = summary.server.agent + return server_agent + + +class Neo4jDriver(_Routing, Driver): + """:class:`.Neo4jDriver` is instantiated for ``neo4j`` URIs. The + routing behaviour works in tandem with Neo4j's `Causal Clustering + `_ + feature by directing read and write behaviour to appropriate + cluster members. + + This class is not supposed to be instantiated externally. Use + :meth:`GraphDatabase.driver` instead. + """ + + @classmethod + def open(cls, *targets, auth=None, routing_context=None, **config): + from .io import Neo4jPool + addresses = cls.parse_targets(*targets) + pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) + pool = Neo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) + return cls(pool, default_workspace_config) + + def __init__(self, pool, default_workspace_config): + _Routing.__init__(self, pool.get_default_database_initial_router_addresses()) + Driver.__init__(self, pool) + self._default_workspace_config = default_workspace_config + + def session(self, **config): + from .work import Session + session_config = SessionConfig(self._default_workspace_config, config) + SessionConfig.consume(config) # Consume the config + return Session(self._pool, session_config) + + @experimental("The configuration may change in the future.") + def verify_connectivity(self, **config): + """ + :raise ServiceUnavailable: raised if the server does not support routing or if routing support is broken. + """ + # TODO: Improve and update Stub Test Server to be able to test. + return self._verify_routing_connectivity() + + def _verify_routing_connectivity(self): + from ..exceptions import ( + Neo4jError, + ServiceUnavailable, + SessionExpired, + ) + + table = self._pool.get_routing_table_for_default_database() + routing_info = {} + for ix in list(table.routers): + try: + routing_info[ix] = self._pool.fetch_routing_info( + address=table.routers[0], + database=self._default_workspace_config.database, + imp_user=self._default_workspace_config.impersonated_user, + bookmarks=None, + timeout=self._default_workspace_config + .connection_acquisition_timeout + ) + except (ServiceUnavailable, SessionExpired, Neo4jError): + routing_info[ix] = None + for key, val in routing_info.items(): + if val is not None: + return routing_info + raise ServiceUnavailable("Could not connect to any routing servers.") diff --git a/neo4j/_sync/io/__init__.py b/neo4j/_sync/io/__init__.py new file mode 100644 index 000000000..b598d07d3 --- /dev/null +++ b/neo4j/_sync/io/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This module contains the low-level functionality required for speaking +Bolt. It is not intended to be used directly by driver users. Instead, +the `session` module provides the main user-facing abstractions. +""" + + +__all__ = [ + "Bolt", + "BoltPool", + "Neo4jPool", + "check_supported_server_product", + "ConnectionErrorHandler", +] + + +from ._bolt import Bolt +from ._common import ( + check_supported_server_product, + ConnectionErrorHandler, +) +from ._pool import ( + BoltPool, + Neo4jPool, +) diff --git a/neo4j/_sync/io/_bolt.py b/neo4j/_sync/io/_bolt.py new file mode 100644 index 000000000..82ee8b628 --- /dev/null +++ b/neo4j/_sync/io/_bolt.py @@ -0,0 +1,571 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import asyncio +from collections import deque +from logging import getLogger +from time import perf_counter + +from ..._async_compat.network import BoltSocket +from ..._exceptions import BoltHandshakeError +from ...addressing import Address +from ...api import ( + ServerInfo, + Version, +) +from ...conf import PoolConfig +from ...exceptions import ( + AuthError, + IncompleteCommit, + ServiceUnavailable, + SessionExpired, +) +from ...meta import get_user_agent +from ...packstream import ( + Packer, + Unpacker, +) +from ._common import ( + CommitResponse, + Inbox, + Outbox, +) + + +# Set up logger +log = getLogger("neo4j") + + +class Bolt: + """ Server connection for Bolt protocol. + + A :class:`.Bolt` should be constructed following a + successful .open() + + Bolt handshake and takes the socket over which + the handshake was carried out. + """ + + MAGIC_PREAMBLE = b"\x60\x60\xB0\x17" + + PROTOCOL_VERSION = None + + # flag if connection needs RESET to go back to READY state + is_reset = False + + # The socket + in_use = False + + # The socket + _closed = False + + # The socket + _defunct = False + + #: The pool of which this connection is a member + pool = None + + # Store the id of the most recent ran query to be able to reduce sent bits by + # using the default (-1) to refer to the most recent query when pulling + # results for it. + most_recent_qid = None + + def __init__(self, unresolved_address, sock, max_connection_lifetime, *, + auth=None, user_agent=None, routing_context=None): + self.unresolved_address = unresolved_address + self.socket = sock + self.server_info = ServerInfo(Address(sock.getpeername()), + self.PROTOCOL_VERSION) + # so far `connection.recv_timeout_seconds` is the only available + # configuration hint that exists. Therefore, all hints can be stored at + # connection level. This might change in the future. + self.configuration_hints = {} + self.outbox = Outbox() + self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) + self.packer = Packer(self.outbox) + self.unpacker = Unpacker(self.inbox) + self.responses = deque() + self._max_connection_lifetime = max_connection_lifetime + self._creation_timestamp = perf_counter() + self.routing_context = routing_context + + # Determine the user agent + if user_agent: + self.user_agent = user_agent + else: + self.user_agent = get_user_agent() + + # Determine auth details + if not auth: + self.auth_dict = {} + elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: + from neo4j import Auth + self.auth_dict = vars(Auth("basic", *auth)) + else: + try: + self.auth_dict = vars(auth) + except (KeyError, TypeError): + raise AuthError("Cannot determine auth details from %r" % auth) + + # Check for missing password + try: + credentials = self.auth_dict["credentials"] + except KeyError: + pass + else: + if credentials is None: + raise AuthError("Password cannot be None") + + def __del__(self): + if not asyncio.iscoroutinefunction(self.close): + self.close() + + @property + @abc.abstractmethod + def supports_multiple_results(self): + """ Boolean flag to indicate if the connection version supports multiple + queries to be buffered on the server side (True) or if all results need + to be eagerly pulled before sending the next RUN (False). + """ + pass + + @property + @abc.abstractmethod + def supports_multiple_databases(self): + """ Boolean flag to indicate if the connection version supports multiple + databases. + """ + pass + + @classmethod + def protocol_handlers(cls, protocol_version=None): + """ Return a dictionary of available Bolt protocol handlers, + keyed by version tuple. If an explicit protocol version is + provided, the dictionary will contain either zero or one items, + depending on whether that version is supported. If no protocol + version is provided, all available versions will be returned. + + :param protocol_version: tuple identifying a specific protocol + version (e.g. (3, 5)) or None + :return: dictionary of version tuple to handler class for all + relevant and supported protocol versions + :raise TypeError: if protocol version is not passed in a tuple + """ + + # Carry out Bolt subclass imports locally to avoid circular dependency issues. + from ._bolt3 import Bolt3 + from ._bolt4 import ( + Bolt4x0, + Bolt4x1, + Bolt4x2, + Bolt4x3, + Bolt4x4, + ) + + handlers = { + Bolt3.PROTOCOL_VERSION: Bolt3, + Bolt4x0.PROTOCOL_VERSION: Bolt4x0, + Bolt4x1.PROTOCOL_VERSION: Bolt4x1, + Bolt4x2.PROTOCOL_VERSION: Bolt4x2, + Bolt4x3.PROTOCOL_VERSION: Bolt4x3, + Bolt4x4.PROTOCOL_VERSION: Bolt4x4, + } + + if protocol_version is None: + return handlers + + if not isinstance(protocol_version, tuple): + raise TypeError("Protocol version must be specified as a tuple") + + if protocol_version in handlers: + return {protocol_version: handlers[protocol_version]} + + return {} + + @classmethod + def version_list(cls, versions, limit=4): + """ Return a list of supported protocol versions in order of + preference. The number of protocol versions (or ranges) + returned is limited to four. + """ + # In fact, 4.3 is the fist version to support ranges. However, the range + # support got backported to 4.2. But even if the server is too old to + # have the backport, negotiating BOLT 4.1 is no problem as it's + # equivalent to 4.2 + first_with_range_support = Version(4, 2) + result = [] + for version in versions: + if (result + and version >= first_with_range_support + and result[-1][0] == version[0] + and result[-1][1][1] == version[1] + 1): + # can use range to encompass this version + result[-1][1][1] = version[1] + continue + result.append(Version(version[0], [version[1], version[1]])) + if len(result) == 4: + break + return result + + @classmethod + def get_handshake(cls): + """ Return the supported Bolt versions as bytes. + The length is 16 bytes as specified in the Bolt version negotiation. + :return: bytes + """ + supported_versions = sorted(cls.protocol_handlers().keys(), reverse=True) + offered_versions = cls.version_list(supported_versions) + return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00") + + @classmethod + def ping(cls, address, *, timeout=None, **config): + """ Attempt to establish a Bolt connection, returning the + agreed Bolt protocol version if successful. + """ + config = PoolConfig.consume(config) + try: + s, protocol_version, handshake, data = \ + BoltSocket.connect( + address, + timeout=timeout, + custom_resolver=config.resolver, + ssl_context=config.get_ssl_context(), + keep_alive=config.keep_alive, + ) + except (ServiceUnavailable, SessionExpired, BoltHandshakeError): + return None + else: + BoltSocket.close_socket(s) + return protocol_version + + @classmethod + def open( + cls, address, *, auth=None, timeout=None, routing_context=None, **pool_config + ): + """ Open a new Bolt connection to a given server address. + + :param address: + :param auth: + :param timeout: the connection timeout in seconds + :param routing_context: dict containing routing context + :param pool_config: + :return: + :raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version. + :raise ServiceUnavailable: raised if there was a connection issue. + """ + pool_config = PoolConfig.consume(pool_config) + s, pool_config.protocol_version, handshake, data = \ + BoltSocket.connect( + address, + timeout=timeout, + custom_resolver=pool_config.resolver, + ssl_context=pool_config.get_ssl_context(), + keep_alive=pool_config.keep_alive, + ) + + # Carry out Bolt subclass imports locally to avoid circular dependency + # issues. + if pool_config.protocol_version == (3, 0): + from ._bolt3 import Bolt3 + bolt_cls = Bolt3 + elif pool_config.protocol_version == (4, 0): + from ._bolt4 import Bolt4x0 + bolt_cls = Bolt4x0 + elif pool_config.protocol_version == (4, 1): + from ._bolt4 import Bolt4x1 + bolt_cls = Bolt4x1 + elif pool_config.protocol_version == (4, 2): + from ._bolt4 import Bolt4x2 + bolt_cls = Bolt4x2 + elif pool_config.protocol_version == (4, 3): + from ._bolt4 import Bolt4x3 + bolt_cls = Bolt4x3 + elif pool_config.protocol_version == (4, 4): + from ._bolt4 import Bolt4x4 + bolt_cls = Bolt4x4 + else: + log.debug("[#%04X] S: ", s.getsockname()[1]) + BoltSocket.close_socket(s) + + supported_versions = cls.protocol_handlers().keys() + raise BoltHandshakeError("The Neo4J server does not support communication with this driver. This driver have support for Bolt Protocols {}".format(supported_versions), address=address, request_data=handshake, response_data=data) + + connection = bolt_cls( + address, s, pool_config.max_connection_lifetime, auth=auth, + user_agent=pool_config.user_agent, routing_context=routing_context + ) + + try: + connection.hello() + except Exception: + connection.close() + raise + + return connection + + @property + @abc.abstractmethod + def encrypted(self): + pass + + @property + @abc.abstractmethod + def der_encoded_server_certificate(self): + pass + + @property + @abc.abstractmethod + def local_port(self): + pass + + @abc.abstractmethod + def hello(self): + """ Appends a HELLO message to the outgoing queue, sends it and consumes + all remaining messages. + """ + pass + + @abc.abstractmethod + def route(self, database=None, imp_user=None, bookmarks=None): + """ Fetch a routing table from the server for the given + `database`. For Bolt 4.3 and above, this appends a ROUTE + message; for earlier versions, a procedure call is made via + the regular Cypher execution mechanism. In all cases, this is + sent to the network, and a response is fetched. + + :param database: database for which to fetch a routing table + Requires Bolt 4.0+. + :param imp_user: the user to impersonate + Requires Bolt 4.4+. + :param bookmarks: iterable of bookmark values after which this + transaction should begin + :return: dictionary of raw routing data + """ + pass + + @abc.abstractmethod + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + """ Appends a RUN message to the output queue. + + :param query: Cypher query string + :param parameters: dictionary of Cypher parameters + :param mode: access mode for routing - "READ" or "WRITE" (default) + :param bookmarks: iterable of bookmark values after which this transaction should begin + :param metadata: custom metadata dictionary to attach to the transaction + :param timeout: timeout for transaction execution (seconds) + :param db: name of the database against which to begin the transaction + Requires Bolt 4.0+. + :param imp_user: the user to impersonate + Requires Bolt 4.4+. + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def discard(self, n=-1, qid=-1, **handlers): + """ Appends a DISCARD message to the output queue. + + :param n: number of records to discard, default = -1 (ALL) + :param qid: query ID to discard for, default = -1 (last query) + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def pull(self, n=-1, qid=-1, **handlers): + """ Appends a PULL message to the output queue. + + :param n: number of records to pull, default = -1 (ALL) + :param qid: query ID to pull for, default = -1 (last query) + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): + """ Appends a BEGIN message to the output queue. + + :param mode: access mode for routing - "READ" or "WRITE" (default) + :param bookmarks: iterable of bookmark values after which this transaction should begin + :param metadata: custom metadata dictionary to attach to the transaction + :param timeout: timeout for transaction execution (seconds) + :param db: name of the database against which to begin the transaction + Requires Bolt 4.0+. + :param imp_user: the user to impersonate + Requires Bolt 4.4+ + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def commit(self, **handlers): + """ Appends a COMMIT message to the output queue.""" + pass + + @abc.abstractmethod + def rollback(self, **handlers): + """ Appends a ROLLBACK message to the output queue.""" + pass + + @abc.abstractmethod + def reset(self): + """ Appends a RESET message to the outgoing queue, sends it and consumes + all remaining messages. + """ + pass + + def _append(self, signature, fields=(), response=None): + """ Appends a message to the outgoing queue. + + :param signature: the signature of the message + :param fields: the fields of the message as a tuple + :param response: a response object to handle callbacks + """ + self.packer.pack_struct(signature, fields) + self.outbox.wrap_message() + self.responses.append(response) + + def _send_all(self): + data = self.outbox.view() + if data: + try: + self.socket.sendall(data) + except OSError as error: + self._set_defunct_write(error) + self.outbox.clear() + + def send_all(self): + """ Send all queued messages to the server. + """ + if self.closed(): + raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if self.defunct(): + raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + self._send_all() + + @abc.abstractmethod + def fetch_message(self): + """ Receive at most one message from the server, if available. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + pass + + def fetch_all(self): + """ Fetch all outstanding messages. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + detail_count = summary_count = 0 + while self.responses: + response = self.responses[0] + while not response.complete: + detail_delta, summary_delta = self.fetch_message() + detail_count += detail_delta + summary_count += summary_delta + return detail_count, summary_count + + def _set_defunct_read(self, error=None, silent=False): + message = "Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + self._set_defunct(message, error=error, silent=silent) + + def _set_defunct_write(self, error=None, silent=False): + message = "Failed to write data to connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + self._set_defunct(message, error=error, silent=silent) + + def _set_defunct(self, message, error=None, silent=False): + from ._pool import BoltPool + direct_driver = isinstance(self.pool, BoltPool) + + if error: + log.debug("[#%04X] %s", self.socket.getsockname()[1], error) + log.error(message) + # We were attempting to receive data but the connection + # has unexpectedly terminated. So, we need to close the + # connection from the client side, and remove the address + # from the connection pool. + self._defunct = True + self.close() + if self.pool: + self.pool.deactivate(address=self.unresolved_address) + # Iterate through the outstanding responses, and if any correspond + # to COMMIT requests then raise an error to signal that we are + # unable to confirm that the COMMIT completed successfully. + if silent: + return + for response in self.responses: + if isinstance(response, CommitResponse): + if error: + raise IncompleteCommit(message) from error + else: + raise IncompleteCommit(message) + + if direct_driver: + if error: + raise ServiceUnavailable(message) from error + else: + raise ServiceUnavailable(message) + else: + if error: + raise SessionExpired(message) from error + else: + raise SessionExpired(message) + + def stale(self): + return (self._stale + or (0 <= self._max_connection_lifetime + <= perf_counter() - self._creation_timestamp)) + + _stale = False + + def set_stale(self): + self._stale = True + + @abc.abstractmethod + def close(self): + """ Close the connection. + """ + pass + + @abc.abstractmethod + def closed(self): + pass + + @abc.abstractmethod + def defunct(self): + pass + + +BoltSocket.Bolt = Bolt diff --git a/neo4j/_sync/io/_bolt3.py b/neo4j/_sync/io/_bolt3.py new file mode 100644 index 000000000..e19bd4c27 --- /dev/null +++ b/neo4j/_sync/io/_bolt3.py @@ -0,0 +1,396 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from enum import Enum +from logging import getLogger +from ssl import SSLSocket + +from ..._async_compat.util import Util +from ..._exceptions import ( + BoltError, + BoltProtocolError, +) +from ...api import ( + READ_ACCESS, + Version, +) +from ...exceptions import ( + ConfigurationError, + DatabaseUnavailable, + DriverError, + ForbiddenOnReadOnlyDatabase, + Neo4jError, + NotALeader, + ServiceUnavailable, +) +from ._bolt import Bolt +from ._common import ( + check_supported_server_product, + CommitResponse, + InitResponse, + Response, +) + + +log = getLogger("neo4j") + + +class ServerStates(Enum): + CONNECTED = "CONNECTED" + READY = "READY" + STREAMING = "STREAMING" + TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING" + FAILED = "FAILED" + + +class ServerStateManager: + _STATE_TRANSITIONS = { + ServerStates.CONNECTED: { + "hello": ServerStates.READY, + }, + ServerStates.READY: { + "run": ServerStates.STREAMING, + "begin": ServerStates.TX_READY_OR_TX_STREAMING, + }, + ServerStates.STREAMING: { + "pull": ServerStates.READY, + "discard": ServerStates.READY, + "reset": ServerStates.READY, + }, + ServerStates.TX_READY_OR_TX_STREAMING: { + "commit": ServerStates.READY, + "rollback": ServerStates.READY, + "reset": ServerStates.READY, + }, + ServerStates.FAILED: { + "reset": ServerStates.READY, + } + } + + def __init__(self, init_state, on_change=None): + self.state = init_state + self._on_change = on_change + + def transition(self, message, metadata): + if metadata.get("has_more"): + return + state_before = self.state + self.state = self._STATE_TRANSITIONS \ + .get(self.state, {}) \ + .get(message, self.state) + if state_before != self.state and callable(self._on_change): + self._on_change(state_before, self.state) + + +class Bolt3(Bolt): + """ Protocol handler for Bolt 3. + + This is supported by Neo4j versions 3.5, 4.0, 4.1, 4.2, 4.3, and 4.4. + """ + + PROTOCOL_VERSION = Version(3, 0) + + supports_multiple_results = False + + supports_multiple_databases = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager( + ServerStates.CONNECTED, on_change=self._on_server_state_change + ) + + def _on_server_state_change(self, old_state, new_state): + log.debug("[#%04X] State: %s > %s", self.local_port, + old_state.name, new_state.name) + + @property + def is_reset(self): + # We can't be sure of the server's state if there are still pending + # responses. Unless the last message we sent was RESET. In that case + # the server state will always be READY when we're done. + if (self.responses and self.responses[-1] + and self.responses[-1].message == "reset"): + return True + return self._server_state_manager.state == ServerStates.READY + + @property + def encrypted(self): + return isinstance(self.socket, SSLSocket) + + @property + def der_encoded_server_certificate(self): + return self.socket.getpeercert(binary_form=True) + + @property + def local_port(self): + try: + return self.socket.getsockname()[1] + except OSError: + return 0 + + def get_base_headers(self): + return { + "user_agent": self.user_agent, + } + + def hello(self): + headers = self.get_base_headers() + headers.update(self.auth_dict) + logged_headers = dict(headers) + if "credentials" in logged_headers: + logged_headers["credentials"] = "*******" + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), + response=InitResponse(self, "hello", + on_success=self.server_info.update)) + self.send_all() + self.fetch_all() + check_supported_server_product(self.server_info.agent) + + def route(self, database=None, imp_user=None, bookmarks=None): + if database is not None: + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}. " + "Server Agent {!r}".format( + self.PROTOCOL_VERSION, database, self.server_info.agent + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + + metadata = {} + records = [] + + # Ignoring database and bookmarks because there is no multi-db support. + # The bookmarks are only relevant for making sure a previously created + # db exists before querying a routing table for it. + self.run( + "CALL dbms.cluster.routing.getRoutingTable($context)", # This is an internal procedure call. Only available if the Neo4j 3.5 is setup with clustering. + {"context": self.routing_context}, + mode="r", # Bolt Protocol Version(3, 0) supports mode="r" + on_success=metadata.update + ) + self.pull(on_success=metadata.update, on_records=records.extend) + self.send_all() + self.fetch_all() + routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] + return routing_info + + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + if db is not None: + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}.".format( + self.PROTOCOL_VERSION, db + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + if not parameters: + parameters = {} + extra = {} + if mode in (READ_ACCESS, "r"): + extra["mode"] = "r" # It will default to mode "w" if nothing is specified + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of seconds") + fields = (query, parameters, extra) + log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) + if query.upper() == u"COMMIT": + self._append(b"\x10", fields, CommitResponse(self, "run", + **handlers)) + else: + self._append(b"\x10", fields, Response(self, "run", **handlers)) + + def discard(self, n=-1, qid=-1, **handlers): + # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + log.debug("[#%04X] C: DISCARD_ALL", self.local_port) + self._append(b"\x2F", (), Response(self, "discard", **handlers)) + + def pull(self, n=-1, qid=-1, **handlers): + # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + log.debug("[#%04X] C: PULL_ALL", self.local_port) + self._append(b"\x3F", (), Response(self, "pull", **handlers)) + + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): + if db is not None: + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}.".format( + self.PROTOCOL_VERSION, db + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + extra = {} + if mode in (READ_ACCESS, "r"): + extra["mode"] = "r" # It will default to mode "w" if nothing is specified + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of seconds") + log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + + def commit(self, **handlers): + log.debug("[#%04X] C: COMMIT", self.local_port) + self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + + def rollback(self, **handlers): + log.debug("[#%04X] C: ROLLBACK", self.local_port) + self._append(b"\x13", (), Response(self, "rollback", **handlers)) + + def reset(self): + """ Add a RESET message to the outgoing queue, send + it and consume all remaining messages. + """ + + def fail(metadata): + raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) + + log.debug("[#%04X] C: RESET", self.local_port) + self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self.send_all() + self.fetch_all() + + def fetch_message(self): + """ Receive at most one message from the server, if available. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + if self._closed: + raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if self._defunct: + raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if not self.responses: + return 0, 0 + + # Receive exactly one message + details, summary_signature, summary_metadata = \ + Util.next(self.inbox) + + if details: + log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data + self.responses[0].on_records(details) + + if summary_signature is None: + return len(details), 0 + + response = self.responses.popleft() + response.complete = True + if summary_signature == b"\x70": + log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) + self._server_state_manager.transition(response.message, + summary_metadata) + response.on_success(summary_metadata or {}) + elif summary_signature == b"\x7E": + log.debug("[#%04X] S: IGNORED", self.local_port) + response.on_ignored(summary_metadata or {}) + elif summary_signature == b"\x7F": + log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) + self._server_state_manager.state = ServerStates.FAILED + try: + response.on_failure(summary_metadata or {}) + except (ServiceUnavailable, DatabaseUnavailable): + if self.pool: + self.pool.deactivate(address=self.unresolved_address) + raise + except (NotALeader, ForbiddenOnReadOnlyDatabase): + if self.pool: + self.pool.on_write_failure(address=self.unresolved_address) + raise + except Neo4jError as e: + if self.pool and e.invalidates_all_connections(): + self.pool.mark_all_stale() + raise + else: + raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address) + + return len(details), 1 + + def close(self): + """ Close the connection. + """ + if not self._closed: + if not self._defunct: + log.debug("[#%04X] C: GOODBYE", self.local_port) + self._append(b"\x02", ()) + try: + self._send_all() + except (OSError, BoltError, DriverError): + pass + log.debug("[#%04X] C: ", self.local_port) + try: + self.socket.close() + except OSError: + pass + finally: + self._closed = True + + def closed(self): + return self._closed + + def defunct(self): + return self._defunct diff --git a/neo4j/_sync/io/_bolt4.py b/neo4j/_sync/io/_bolt4.py new file mode 100644 index 000000000..332422162 --- /dev/null +++ b/neo4j/_sync/io/_bolt4.py @@ -0,0 +1,537 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from logging import getLogger +from ssl import SSLSocket + +from ..._async_compat.util import Util +from ..._exceptions import ( + BoltError, + BoltProtocolError, +) +from ...api import ( + READ_ACCESS, + SYSTEM_DATABASE, + Version, +) +from ...exceptions import ( + ConfigurationError, + DatabaseUnavailable, + DriverError, + ForbiddenOnReadOnlyDatabase, + Neo4jError, + NotALeader, + ServiceUnavailable, +) +from ._bolt3 import ( + ServerStateManager, + ServerStates, +) +from ._bolt import Bolt +from ._common import ( + check_supported_server_product, + CommitResponse, + InitResponse, + Response, +) + + +log = getLogger("neo4j") + + +class Bolt4x0(Bolt): + """ Protocol handler for Bolt 4.0. + + This is supported by Neo4j versions 4.0, 4.1 and 4.2. + """ + + PROTOCOL_VERSION = Version(4, 0) + + supports_multiple_results = True + + supports_multiple_databases = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager( + ServerStates.CONNECTED, on_change=self._on_server_state_change + ) + + def _on_server_state_change(self, old_state, new_state): + log.debug("[#%04X] State: %s > %s", self.local_port, + old_state.name, new_state.name) + + @property + def is_reset(self): + # We can't be sure of the server's state if there are still pending + # responses. Unless the last message we sent was RESET. In that case + # the server state will always be READY when we're done. + if (self.responses and self.responses[-1] + and self.responses[-1].message == "reset"): + return True + return self._server_state_manager.state == ServerStates.READY + + @property + def encrypted(self): + return isinstance(self.socket, SSLSocket) + + @property + def der_encoded_server_certificate(self): + return self.socket.getpeercert(binary_form=True) + + @property + def local_port(self): + try: + return self.socket.getsockname()[1] + except OSError: + return 0 + + def get_base_headers(self): + return { + "user_agent": self.user_agent, + } + + def hello(self): + headers = self.get_base_headers() + headers.update(self.auth_dict) + logged_headers = dict(headers) + if "credentials" in logged_headers: + logged_headers["credentials"] = "*******" + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), + response=InitResponse(self, "hello", + on_success=self.server_info.update)) + self.send_all() + self.fetch_all() + check_supported_server_product(self.server_info.agent) + + def route(self, database=None, imp_user=None, bookmarks=None): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + metadata = {} + records = [] + + if database is None: # default database + self.run( + "CALL dbms.routing.getRoutingTable($context)", + {"context": self.routing_context}, + mode="r", + bookmarks=bookmarks, + db=SYSTEM_DATABASE, + on_success=metadata.update + ) + else: + self.run( + "CALL dbms.routing.getRoutingTable($context, $database)", + {"context": self.routing_context, "database": database}, + mode="r", + bookmarks=bookmarks, + db=SYSTEM_DATABASE, + on_success=metadata.update + ) + self.pull(on_success=metadata.update, on_records=records.extend) + self.send_all() + self.fetch_all() + routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] + return routing_info + + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + if not parameters: + parameters = {} + extra = {} + if mode in (READ_ACCESS, "r"): + extra["mode"] = "r" # It will default to mode "w" if nothing is specified + if db: + extra["db"] = db + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of seconds") + fields = (query, parameters, extra) + log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) + if query.upper() == u"COMMIT": + self._append(b"\x10", fields, CommitResponse(self, "run", + **handlers)) + else: + self._append(b"\x10", fields, Response(self, "run", **handlers)) + + def discard(self, n=-1, qid=-1, **handlers): + extra = {"n": n} + if qid != -1: + extra["qid"] = qid + log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) + self._append(b"\x2F", (extra,), Response(self, "discard", **handlers)) + + def pull(self, n=-1, qid=-1, **handlers): + extra = {"n": n} + if qid != -1: + extra["qid"] = qid + log.debug("[#%04X] C: PULL %r", self.local_port, extra) + self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) + + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + extra = {} + if mode in (READ_ACCESS, "r"): + extra["mode"] = "r" # It will default to mode "w" if nothing is specified + if db: + extra["db"] = db + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of seconds") + log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + + def commit(self, **handlers): + log.debug("[#%04X] C: COMMIT", self.local_port) + self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + + def rollback(self, **handlers): + log.debug("[#%04X] C: ROLLBACK", self.local_port) + self._append(b"\x13", (), Response(self, "rollback", **handlers)) + + def reset(self): + """ Add a RESET message to the outgoing queue, send + it and consume all remaining messages. + """ + + def fail(metadata): + raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) + + log.debug("[#%04X] C: RESET", self.local_port) + self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self.send_all() + self.fetch_all() + + def fetch_message(self): + """ Receive at most one message from the server, if available. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + if self._closed: + raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if self._defunct: + raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if not self.responses: + return 0, 0 + + # Receive exactly one message + details, summary_signature, summary_metadata = \ + Util.next(self.inbox) + + if details: + log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data + self.responses[0].on_records(details) + + if summary_signature is None: + return len(details), 0 + + response = self.responses.popleft() + response.complete = True + if summary_signature == b"\x70": + log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) + self._server_state_manager.transition(response.message, + summary_metadata) + response.on_success(summary_metadata or {}) + elif summary_signature == b"\x7E": + log.debug("[#%04X] S: IGNORED", self.local_port) + response.on_ignored(summary_metadata or {}) + elif summary_signature == b"\x7F": + log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) + self._server_state_manager.state = ServerStates.FAILED + try: + response.on_failure(summary_metadata or {}) + except (ServiceUnavailable, DatabaseUnavailable): + if self.pool: + self.pool.deactivate(address=self.unresolved_address) + raise + except (NotALeader, ForbiddenOnReadOnlyDatabase): + if self.pool: + self.pool.on_write_failure(address=self.unresolved_address) + raise + except Neo4jError as e: + if self.pool and e.invalidates_all_connections(): + self.pool.mark_all_stale() + raise + else: + raise BoltProtocolError("Unexpected response message with signature " + "%02X" % ord(summary_signature), self.unresolved_address) + + return len(details), 1 + + def close(self): + """ Close the connection. + """ + if not self._closed: + if not self._defunct: + log.debug("[#%04X] C: GOODBYE", self.local_port) + self._append(b"\x02", ()) + try: + self._send_all() + except (OSError, BoltError, DriverError): + pass + log.debug("[#%04X] C: ", self.local_port) + try: + self.socket.close() + except OSError: + pass + finally: + self._closed = True + + def closed(self): + return self._closed + + def defunct(self): + return self._defunct + + +class Bolt4x1(Bolt4x0): + """ Protocol handler for Bolt 4.1. + + This is supported by Neo4j versions 4.1 - 4.4. + """ + + PROTOCOL_VERSION = Version(4, 1) + + def get_base_headers(self): + """ Bolt 4.1 passes the routing context, originally taken from + the URI, into the connection initialisation message. This + enables server-side routing to propagate the same behaviour + through its driver. + """ + headers = { + "user_agent": self.user_agent, + } + if self.routing_context is not None: + headers["routing"] = self.routing_context + return headers + + +class Bolt4x2(Bolt4x1): + """ Protocol handler for Bolt 4.2. + + This is supported by Neo4j version 4.2 - 4.4. + """ + + PROTOCOL_VERSION = Version(4, 2) + + +class Bolt4x3(Bolt4x2): + """ Protocol handler for Bolt 4.3. + + This is supported by Neo4j version 4.3 - 4.4. + """ + + PROTOCOL_VERSION = Version(4, 3) + + def route(self, database=None, imp_user=None, bookmarks=None): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + + routing_context = self.routing_context or {} + log.debug("[#%04X] C: ROUTE %r %r %r", self.local_port, + routing_context, bookmarks, database) + metadata = {} + if bookmarks is None: + bookmarks = [] + else: + bookmarks = list(bookmarks) + self._append(b"\x66", (routing_context, bookmarks, database), + response=Response(self, "route", + on_success=metadata.update)) + self.send_all() + self.fetch_all() + return [metadata.get("rt")] + + def hello(self): + def on_success(metadata): + self.configuration_hints.update(metadata.pop("hints", {})) + self.server_info.update(metadata) + if "connection.recv_timeout_seconds" in self.configuration_hints: + recv_timeout = self.configuration_hints[ + "connection.recv_timeout_seconds" + ] + if isinstance(recv_timeout, int) and recv_timeout > 0: + self.socket.settimeout(recv_timeout) + else: + log.info("[#%04X] Server supplied an invalid value for " + "connection.recv_timeout_seconds (%r). Make sure " + "the server and network is set up correctly.", + self.local_port, recv_timeout) + + headers = self.get_base_headers() + headers.update(self.auth_dict) + logged_headers = dict(headers) + if "credentials" in logged_headers: + logged_headers["credentials"] = "*******" + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), + response=InitResponse(self, "hello", + on_success=on_success)) + self.send_all() + self.fetch_all() + check_supported_server_product(self.server_info.agent) + + +class Bolt4x4(Bolt4x3): + """ Protocol handler for Bolt 4.4. + + This is supported by Neo4j version 4.4. + """ + + PROTOCOL_VERSION = Version(4, 4) + + def route(self, database=None, imp_user=None, bookmarks=None): + routing_context = self.routing_context or {} + db_context = {} + if database is not None: + db_context.update(db=database) + if imp_user is not None: + db_context.update(imp_user=imp_user) + log.debug("[#%04X] C: ROUTE %r %r %r", self.local_port, + routing_context, bookmarks, db_context) + metadata = {} + if bookmarks is None: + bookmarks = [] + else: + bookmarks = list(bookmarks) + self._append(b"\x66", (routing_context, bookmarks, db_context), + response=Response(self, "route", + on_success=metadata.update)) + self.send_all() + self.fetch_all() + return [metadata.get("rt")] + + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + if not parameters: + parameters = {} + extra = {} + if mode in (READ_ACCESS, "r"): + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + if imp_user: + extra["imp_user"] = imp_user + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of " + "seconds") + fields = (query, parameters, extra) + log.debug("[#%04X] C: RUN %s", self.local_port, + " ".join(map(repr, fields))) + if query.upper() == u"COMMIT": + self._append(b"\x10", fields, CommitResponse(self, "run", + **handlers)) + else: + self._append(b"\x10", fields, Response(self, "run", **handlers)) + + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): + extra = {} + if mode in (READ_ACCESS, "r"): + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + if imp_user: + extra["imp_user"] = imp_user + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of " + "seconds") + log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) diff --git a/neo4j/_sync/io/_common.py b/neo4j/_sync/io/_common.py new file mode 100644 index 000000000..408de0a1f --- /dev/null +++ b/neo4j/_sync/io/_common.py @@ -0,0 +1,280 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import logging +import socket +from struct import pack as struct_pack + +from ..._async_compat.util import Util +from ...exceptions import ( + Neo4jError, + ServiceUnavailable, + SessionExpired, + UnsupportedServerProduct, +) +from ...packstream import ( + UnpackableBuffer, + Unpacker, +) + + +log = logging.getLogger("neo4j") + + +class MessageInbox: + + def __init__(self, s, on_error): + self.on_error = on_error + self._messages = self._yield_messages(s) + + def _yield_messages(self, sock): + try: + buffer = UnpackableBuffer() + unpacker = Unpacker(buffer) + chunk_size = 0 + while True: + + while chunk_size == 0: + # Determine the chunk size and skip noop + receive_into_buffer(sock, buffer, 2) + chunk_size = buffer.pop_u16() + if chunk_size == 0: + log.debug("[#%04X] S: ", sock.getsockname()[1]) + + receive_into_buffer(sock, buffer, chunk_size + 2) + chunk_size = buffer.pop_u16() + + if chunk_size == 0: + # chunk_size was the end marker for the message + size, tag = unpacker.unpack_structure_header() + fields = [unpacker.unpack() for _ in range(size)] + yield tag, fields + # Reset for new message + unpacker.reset() + + except (OSError, socket.timeout) as error: + Util.callback(self.on_error, error) + + def pop(self): + return Util.next(self._messages) + + +class Inbox(MessageInbox): + + def __next__(self): + tag, fields = self.pop() + if tag == b"\x71": + return fields, None, None + elif fields: + return [], tag, fields[0] + else: + return [], tag, None + + +class Outbox: + + def __init__(self, max_chunk_size=16384): + self._max_chunk_size = max_chunk_size + self._chunked_data = bytearray() + self._raw_data = bytearray() + self.write = self._raw_data.extend + + def max_chunk_size(self): + return self._max_chunk_size + + def clear(self): + self._chunked_data = bytearray() + self._raw_data.clear() + + def _chunk_data(self): + data_len = len(self._raw_data) + num_full_chunks, chunk_rest = divmod( + data_len, self._max_chunk_size + ) + num_chunks = num_full_chunks + bool(chunk_rest) + + data_view = memoryview(self._raw_data) + header_start = len(self._chunked_data) + data_start = header_start + 2 + raw_data_start = 0 + for i in range(num_chunks): + chunk_size = min(data_len - raw_data_start, + self._max_chunk_size) + self._chunked_data[header_start:data_start] = struct_pack( + ">H", chunk_size + ) + self._chunked_data[data_start:(data_start + chunk_size)] = \ + data_view[raw_data_start:(raw_data_start + chunk_size)] + header_start += chunk_size + 2 + data_start = header_start + 2 + raw_data_start += chunk_size + del data_view + self._raw_data.clear() + + def wrap_message(self): + self._chunk_data() + self._chunked_data += b"\x00\x00" + + def view(self): + self._chunk_data() + return memoryview(self._chunked_data) + + +class ConnectionErrorHandler: + """ + Wrapper class for handling connection errors. + + The class will wrap each method to invoke a callback if the method raises + Neo4jError, SessionExpired, or ServiceUnavailable. + The error will be re-raised after the callback. + """ + + def __init__(self, connection, on_error): + """ + :param connection the connection object to warp + :type connection Bolt + :param on_error the function to be called when a method of + connection raises of of the caught errors. + :type on_error callable + """ + self.__connection = connection + self.__on_error = on_error + + def __getattr__(self, name): + connection_attr = getattr(self.__connection, name) + if not callable(connection_attr): + return connection_attr + + def outer(func): + def inner(*args, **kwargs): + try: + func(*args, **kwargs) + except (Neo4jError, ServiceUnavailable, SessionExpired) as exc: + assert not asyncio.iscoroutinefunction(self.__on_error) + self.__on_error(exc) + raise + return inner + + def outer_async(coroutine_func): + def inner(*args, **kwargs): + try: + coroutine_func(*args, **kwargs) + except (Neo4jError, ServiceUnavailable, SessionExpired) as exc: + Util.callback(self.__on_error, exc) + raise + return inner + + if asyncio.iscoroutinefunction(connection_attr): + return outer_async(connection_attr) + return outer(connection_attr) + + def __setattr__(self, name, value): + if name.startswith("_" + self.__class__.__name__ + "__"): + super().__setattr__(name, value) + else: + setattr(self.__connection, name, value) + + +class Response: + """ Subscriber object for a full response (zero or + more detail messages followed by one summary message). + """ + + def __init__(self, connection, message, **handlers): + self.connection = connection + self.handlers = handlers + self.message = message + self.complete = False + + def on_records(self, records): + """ Called when one or more RECORD messages have been received. + """ + handler = self.handlers.get("on_records") + Util.callback(handler, records) + + def on_success(self, metadata): + """ Called when a SUCCESS message has been received. + """ + handler = self.handlers.get("on_success") + Util.callback(handler, metadata) + + if not metadata.get("has_more"): + handler = self.handlers.get("on_summary") + Util.callback(handler) + + def on_failure(self, metadata): + """ Called when a FAILURE message has been received. + """ + try: + self.connection.reset() + except (SessionExpired, ServiceUnavailable): + pass + handler = self.handlers.get("on_failure") + Util.callback(handler, metadata) + handler = self.handlers.get("on_summary") + Util.callback(handler) + raise Neo4jError.hydrate(**metadata) + + def on_ignored(self, metadata=None): + """ Called when an IGNORED message has been received. + """ + handler = self.handlers.get("on_ignored") + Util.callback(handler, metadata) + handler = self.handlers.get("on_summary") + Util.callback(handler) + + +class InitResponse(Response): + + def on_failure(self, metadata): + code = metadata.get("code") + if code == "Neo.ClientError.Security.Unauthorized": + raise Neo4jError.hydrate(**metadata) + else: + raise ServiceUnavailable( + metadata.get("message", "Connection initialisation failed") + ) + + +class CommitResponse(Response): + + pass + + +def check_supported_server_product(agent): + """ Checks that a server product is supported by the driver by + looking at the server agent string. + + :param agent: server agent string to check for validity + :raises UnsupportedServerProduct: if the product is not supported + """ + if not agent.startswith("Neo4j/"): + raise UnsupportedServerProduct(agent) + + +def receive_into_buffer(sock, buffer, n_bytes): + end = buffer.used + n_bytes + if end > len(buffer.data): + buffer.data += bytearray(end - len(buffer.data)) + view = memoryview(buffer.data) + while buffer.used < end: + n = sock.recv_into(view[buffer.used:end], end - buffer.used) + if n == 0: + raise OSError("No data") + buffer.used += n diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py new file mode 100644 index 000000000..5969fd6da --- /dev/null +++ b/neo4j/_sync/io/_pool.py @@ -0,0 +1,701 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +from collections import ( + defaultdict, + deque, +) +import logging +from logging import getLogger +from random import choice +from time import perf_counter + +from ..._async_compat.concurrency import ( + Condition, + RLock, +) +from ..._async_compat.network import NetworkUtil +from ..._exceptions import BoltError +from ...api import ( + READ_ACCESS, + WRITE_ACCESS, +) +from ...conf import ( + PoolConfig, + WorkspaceConfig, +) +from ...exceptions import ( + ClientError, + ConfigurationError, + DriverError, + Neo4jError, + ReadServiceUnavailable, + ServiceUnavailable, + SessionExpired, + WriteServiceUnavailable, +) +from ...routing import RoutingTable +from ._bolt import Bolt + + +# Set up logger +log = getLogger("neo4j") + + +class IOPool(abc.ABC): + """ A collection of connections to one or more server addresses. + """ + + def __init__(self, opener, pool_config, workspace_config): + assert callable(opener) + assert isinstance(pool_config, PoolConfig) + assert isinstance(workspace_config, WorkspaceConfig) + + self.opener = opener + self.pool_config = pool_config + self.workspace_config = workspace_config + self.connections = defaultdict(deque) + self.lock = RLock() + self.cond = Condition(self.lock) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def _acquire(self, address, timeout): + """ Acquire a connection to a given address from the pool. + The address supplied should always be an IP address, not + a host name. + + This method is thread safe. + """ + t0 = perf_counter() + if timeout is None: + timeout = self.workspace_config.connection_acquisition_timeout + + with self.lock: + def time_remaining(): + t = timeout - (perf_counter() - t0) + return t if t > 0 else 0 + + while True: + # try to find a free connection in pool + for connection in list(self.connections.get(address, [])): + if (connection.closed() or connection.defunct() + or (connection.stale() and not connection.in_use)): + # `close` is a noop on already closed connections. + # This is to make sure that the connection is + # gracefully closed, e.g. if it's just marked as + # `stale` but still alive. + if log.isEnabledFor(logging.DEBUG): + log.debug( + "[#%04X] C: removing old connection " + "(closed=%s, defunct=%s, stale=%s, in_use=%s)", + connection.local_port, + connection.closed(), connection.defunct(), + connection.stale(), connection.in_use + ) + connection.close() + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass + continue + if not connection.in_use: + connection.in_use = True + return connection + # all connections in pool are in-use + connections = self.connections[address] + max_pool_size = self.pool_config.max_connection_pool_size + infinite_pool_size = (max_pool_size < 0 + or max_pool_size == float("inf")) + can_create_new_connection = ( + infinite_pool_size + or len(connections) < max_pool_size + ) + if can_create_new_connection: + timeout = min(self.pool_config.connection_timeout, + time_remaining()) + try: + connection = self.opener(address, timeout) + except ServiceUnavailable: + self.remove(address) + raise + else: + connection.pool = self + connection.in_use = True + connections.append(connection) + return connection + + # failed to obtain a connection from pool because the + # pool is full and no free connection in the pool + if time_remaining(): + self.cond.wait(time_remaining()) + # if timed out, then we throw error. This time + # computation is needed, as with python 2.7, we + # cannot tell if the condition is notified or + # timed out when we come to this line + if not time_remaining(): + raise ClientError("Failed to obtain a connection from pool " + "within {!r}s".format(timeout)) + else: + raise ClientError("Failed to obtain a connection from pool " + "within {!r}s".format(timeout)) + + @abc.abstractmethod + def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + """ Acquire a connection to a server that can satisfy a set of parameters. + + :param access_mode: + :param timeout: + :param database: + :param bookmarks: + """ + + def release(self, *connections): + """ Release a connection back into the pool. + This method is thread safe. + """ + with self.lock: + for connection in connections: + if not (connection.defunct() + or connection.closed() + or connection.is_reset): + try: + connection.reset() + except (Neo4jError, DriverError, BoltError) as e: + log.debug( + "Failed to reset connection on release: %s", e + ) + connection.in_use = False + self.cond.notify_all() + + def in_use_connection_count(self, address): + """ Count the number of connections currently in use to a given + address. + """ + try: + connections = self.connections[address] + except KeyError: + return 0 + else: + return sum(1 if connection.in_use else 0 for connection in connections) + + def mark_all_stale(self): + with self.lock: + for address in self.connections: + for connection in self.connections[address]: + connection.set_stale() + + def deactivate(self, address): + """ Deactivate an address from the connection pool, if present, closing + all idle connection to that address + """ + with self.lock: + try: + connections = self.connections[address] + except KeyError: # already removed from the connection pool + return + for conn in list(connections): + if not conn.in_use: + connections.remove(conn) + try: + conn.close() + except OSError: + pass + if not connections: + self.remove(address) + + def on_write_failure(self, address): + raise WriteServiceUnavailable( + "No write service available for pool {}".format(self) + ) + + def remove(self, address): + """ Remove an address from the connection pool, if present, closing + all connections to that address. + """ + with self.lock: + for connection in self.connections.pop(address, ()): + try: + connection.close() + except OSError: + pass + + def close(self): + """ Close all connections and empty the pool. + This method is thread safe. + """ + try: + with self.lock: + for address in list(self.connections): + self.remove(address) + except TypeError: + pass + + +class BoltPool(IOPool): + + @classmethod + def open(cls, address, *, auth, pool_config, workspace_config): + """Create a new BoltPool + + :param address: + :param auth: + :param pool_config: + :param workspace_config: + :return: BoltPool + """ + + def opener(addr, timeout): + return Bolt.open( + addr, auth=auth, timeout=timeout, routing_context=None, + **pool_config + ) + + pool = cls(opener, pool_config, workspace_config, address) + return pool + + def __init__(self, opener, pool_config, workspace_config, address): + super().__init__(opener, pool_config, workspace_config) + self.address = address + + def __repr__(self): + return "<{} address={!r}>".format(self.__class__.__name__, + self.address) + + def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + # The access_mode and database is not needed for a direct connection, + # it's just there for consistency. + return self._acquire(self.address, timeout) + + +class Neo4jPool(IOPool): + """ Connection pool with routing table. + """ + + @classmethod + def open(cls, *addresses, auth, pool_config, workspace_config, + routing_context=None): + """Create a new Neo4jPool + + :param addresses: one or more address as positional argument + :param auth: + :param pool_config: + :param workspace_config: + :param routing_context: + :return: Neo4jPool + """ + + address = addresses[0] + if routing_context is None: + routing_context = {} + elif "address" in routing_context: + raise ConfigurationError("The key 'address' is reserved for routing context.") + routing_context["address"] = str(address) + + def opener(addr, timeout): + return Bolt.open( + addr, auth=auth, timeout=timeout, + routing_context=routing_context, **pool_config + ) + + pool = cls(opener, pool_config, workspace_config, address) + return pool + + def __init__(self, opener, pool_config, workspace_config, address): + """ + + :param opener: + :param pool_config: + :param workspace_config: + :param address: + """ + super().__init__(opener, pool_config, workspace_config) + # Each database have a routing table, the default database is a special case. + log.debug("[#0000] C: routing address %r", address) + self.address = address + self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])} + self.refresh_lock = RLock() + + def __repr__(self): + """ The representation shows the initial routing addresses. + + :return: The representation + :rtype: str + """ + return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses()) + + @property + def first_initial_routing_address(self): + return self.get_default_database_initial_router_addresses()[0] + + def get_default_database_initial_router_addresses(self): + """ Get the initial router addresses for the default database. + + :return: + :rtype: OrderedSet + """ + return self.get_routing_table_for_default_database().initial_routers + + def get_default_database_router_addresses(self): + """ Get the router addresses for the default database. + + :return: + :rtype: OrderedSet + """ + return self.get_routing_table_for_default_database().routers + + def get_routing_table_for_default_database(self): + return self.routing_tables[self.workspace_config.database] + + def get_or_create_routing_table(self, database): + with self.refresh_lock: + if database not in self.routing_tables: + self.routing_tables[database] = RoutingTable( + database=database, + routers=self.get_default_database_initial_router_addresses() + ) + return self.routing_tables[database] + + def fetch_routing_info( + self, address, database, imp_user, bookmarks, timeout + ): + """ Fetch raw routing info from a given router address. + + :param address: router address + :param database: the database name to get routing table for + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None + :param bookmarks: iterable of bookmark values after which the routing + info should be fetched + :param timeout: connection acquisition timeout in seconds + + :return: list of routing records, or None if no connection + could be established or if no readers or writers are present + :raise ServiceUnavailable: if the server does not support + routing, or if routing support is broken or outdated + """ + cx = self._acquire(address, timeout) + try: + routing_table = cx.route( + database or self.workspace_config.database, + imp_user or self.workspace_config.impersonated_user, + bookmarks + ) + finally: + self.release(cx) + return routing_table + + def fetch_routing_table( + self, *, address, timeout, database, imp_user, bookmarks + ): + """ Fetch a routing table from a given router address. + + :param address: router address + :param timeout: seconds + :param database: the database name + :type: str + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None + :param bookmarks: bookmarks used when fetching routing table + + :return: a new RoutingTable instance or None if the given router is + currently unable to provide routing information + """ + new_routing_info = None + try: + new_routing_info = self.fetch_routing_info( + address, database, imp_user, bookmarks, timeout + ) + except Neo4jError as e: + # checks if the code is an error that is caused by the client. In + # this case there is no sense in trying to fetch a RT from another + # router. Hence, the driver should fail fast during discovery. + if e.is_fatal_during_discovery(): + raise + except (ServiceUnavailable, SessionExpired): + pass + if not new_routing_info: + log.debug("Failed to fetch routing info %s", address) + return None + else: + servers = new_routing_info[0]["servers"] + ttl = new_routing_info[0]["ttl"] + database = new_routing_info[0].get("db", database) + new_routing_table = RoutingTable.parse_routing_info( + database=database, servers=servers, ttl=ttl + ) + + # Parse routing info and count the number of each type of server + num_routers = len(new_routing_table.routers) + num_readers = len(new_routing_table.readers) + + # num_writers = len(new_routing_table.writers) + # If no writers are available. This likely indicates a temporary state, + # such as leader switching, so we should not signal an error. + + # No routers + if num_routers == 0: + log.debug("No routing servers returned from server %s", address) + return None + + # No readers + if num_readers == 0: + log.debug("No read servers returned from server %s", address) + return None + + # At least one of each is fine, so return this table + return new_routing_table + + def _update_routing_table_from( + self, *routers, database=None, imp_user=None, bookmarks=None, + database_callback=None + ): + """ Try to update routing tables with the given routers. + + :return: True if the routing table is successfully updated, + otherwise False + """ + log.debug("Attempting to update routing table from {}".format( + ", ".join(map(repr, routers))) + ) + for router in routers: + for address in NetworkUtil.resolve_address( + router, resolver=self.pool_config.resolver + ): + new_routing_table = self.fetch_routing_table( + address=address, + timeout=self.pool_config.connection_timeout, + database=database, imp_user=imp_user, bookmarks=bookmarks + ) + if new_routing_table is not None: + new_databse = new_routing_table.database + old_routing_table = self.get_or_create_routing_table( + new_databse + ) + old_routing_table.update(new_routing_table) + log.debug( + "[#0000] C: address=%r (%r)", + address, self.routing_tables[new_databse] + ) + if callable(database_callback): + database_callback(new_databse) + return True + self.deactivate(router) + return False + + def update_routing_table( + self, *, database, imp_user, bookmarks, database_callback=None + ): + """ Update the routing table from the first router able to provide + valid routing information. + + :param database: The database name + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None + :param bookmarks: bookmarks used when fetching routing table + :param database_callback: A callback function that will be called with + the database name as only argument when a new routing table has been + acquired. This database name might different from `database` if that + was None and the underlying protocol supports reporting back the + actual database. + + :raise neo4j.exceptions.ServiceUnavailable: + """ + with self.refresh_lock: + routing_table = self.get_or_create_routing_table(database) + # copied because it can be modified + existing_routers = set(routing_table.routers) + + prefer_initial_routing_address = \ + self.routing_tables[database].initialized_without_writers + + if prefer_initial_routing_address: + # TODO: Test this state + if self._update_routing_table_from( + self.first_initial_routing_address, database=database, + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + # Why is only the first initial routing address used? + return + if self._update_routing_table_from( + *(existing_routers - {self.first_initial_routing_address}), + database=database, imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + return + + if not prefer_initial_routing_address: + if self._update_routing_table_from( + self.first_initial_routing_address, database=database, + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + # Why is only the first initial routing address used? + return + + # None of the routers have been successful, so just fail + log.error("Unable to retrieve routing information") + raise ServiceUnavailable("Unable to retrieve routing information") + + def update_connection_pool(self, *, database): + routing_table = self.get_or_create_routing_table(database) + servers = routing_table.servers() + for address in list(self.connections): + if address.unresolved not in servers: + super(Neo4jPool, self).deactivate(address) + + def ensure_routing_table_is_fresh( + self, *, access_mode, database, imp_user, bookmarks, + database_callback=None + ): + """ Update the routing table if stale. + + This method performs two freshness checks, before and after acquiring + the refresh lock. If the routing table is already fresh on entry, the + method exits immediately; otherwise, the refresh lock is acquired and + the second freshness check that follows determines whether an update + is still required. + + This method is thread-safe. + + :return: `True` if an update was required, `False` otherwise. + """ + from neo4j.api import READ_ACCESS + with self.refresh_lock: + routing_table = self.get_or_create_routing_table(database) + if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)): + # Readers are fresh. + return False + + self.update_routing_table( + database=database, imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ) + self.update_connection_pool(database=database) + + for database in list(self.routing_tables.keys()): + # Remove unused databases in the routing table + # Remove the routing table after a timeout = TTL + 30s + log.debug("[#0000] C: database=%s", database) + if (self.routing_tables[database].should_be_purged_from_memory() + and database != self.workspace_config.database): + del self.routing_tables[database] + + return True + + def _select_address(self, *, access_mode, database): + from ...api import READ_ACCESS + """ Selects the address with the fewest in-use connections. + """ + with self.refresh_lock: + if access_mode == READ_ACCESS: + addresses = self.routing_tables[database].readers + else: + addresses = self.routing_tables[database].writers + addresses_by_usage = {} + for address in addresses: + addresses_by_usage.setdefault( + self.in_use_connection_count(address), [] + ).append(address) + if not addresses_by_usage: + if access_mode == READ_ACCESS: + raise ReadServiceUnavailable( + "No read service currently available" + ) + else: + raise WriteServiceUnavailable( + "No write service currently available" + ) + return choice(addresses_by_usage[min(addresses_by_usage)]) + + def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + if access_mode not in (WRITE_ACCESS, READ_ACCESS): + raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) + if not timeout: + raise ClientError("'timeout' must be a float larger than 0; {}" + .format(timeout)) + + from neo4j.api import check_access_mode + access_mode = check_access_mode(access_mode) + with self.refresh_lock: + log.debug("[#0000] C: %r", + self.routing_tables) + self.ensure_routing_table_is_fresh( + access_mode=access_mode, database=database, imp_user=None, + bookmarks=bookmarks + ) + + while True: + try: + # Get an address for a connection that have the fewest in-use + # connections. + address = self._select_address( + access_mode=access_mode, database=database + ) + except (ReadServiceUnavailable, WriteServiceUnavailable) as err: + raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err + try: + log.debug("[#0000] C: database=%r address=%r", database, address) + # should always be a resolved address + connection = self._acquire(address, timeout=timeout) + except ServiceUnavailable: + self.deactivate(address=address) + else: + return connection + + def deactivate(self, address): + """ Deactivate an address from the connection pool, + if present, remove from the routing table and also closing + all idle connections to that address. + """ + log.debug("[#0000] C: Deactivating address %r", address) + # We use `discard` instead of `remove` here since the former + # will not fail if the address has already been removed. + for database in self.routing_tables.keys(): + self.routing_tables[database].routers.discard(address) + self.routing_tables[database].readers.discard(address) + self.routing_tables[database].writers.discard(address) + log.debug("[#0000] C: table=%r", self.routing_tables) + super(Neo4jPool, self).deactivate(address) + + def on_write_failure(self, address): + """ Remove a writer address from the routing table, if present. + """ + log.debug("[#0000] C: Removing writer %r", address) + for database in self.routing_tables.keys(): + self.routing_tables[database].writers.discard(address) + log.debug("[#0000] C: table=%r", self.routing_tables) diff --git a/neo4j/_sync/work/__init__.py b/neo4j/_sync/work/__init__.py new file mode 100644 index 000000000..3ceebdb13 --- /dev/null +++ b/neo4j/_sync/work/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .session import ( + Result, + Session, + Transaction, + Workspace, +) + + +__all__ = [ + "Result", + "Session", + "Transaction", + "Workspace", +] diff --git a/neo4j/_sync/work/result.py b/neo4j/_sync/work/result.py new file mode 100644 index 000000000..8d6342fdc --- /dev/null +++ b/neo4j/_sync/work/result.py @@ -0,0 +1,379 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import deque +from warnings import warn + +from ..._async_compat.util import Util +from ...data import DataDehydrator +from ...work import ResultSummary +from ..io import ConnectionErrorHandler + + +class Result: + """A handler for the result of Cypher query execution. Instances + of this class are typically constructed and returned by + :meth:`.AyncSession.run` and :meth:`.Transaction.run`. + """ + + def __init__(self, connection, hydrant, fetch_size, on_closed, + on_error): + self._connection = ConnectionErrorHandler(connection, on_error) + self._hydrant = hydrant + self._on_closed = on_closed + self._metadata = None + self._record_buffer = deque() + self._summary = None + self._bookmark = None + self._raw_qid = -1 + self._fetch_size = fetch_size + + # states + self._discarding = False # discard the remainder of records + self._attached = False # attached to a connection + # there are still more response messages we wait for + self._streaming = False + # there ar more records available to pull from the server + self._has_more = False + # the result has been fully iterated or consumed + self._closed = False + + @property + def _qid(self): + if self._raw_qid == self._connection.most_recent_qid: + return -1 + else: + return self._raw_qid + + def _tx_ready_run(self, query, parameters, **kwargs): + # BEGIN+RUN does not carry any extra on the RUN message. + # BEGIN {extra} + # RUN "query" {parameters} {extra} + self._run( + query, parameters, None, None, None, None, **kwargs + ) + + def _run( + self, query, parameters, db, imp_user, access_mode, bookmarks, + **kwargs + ): + query_text = str(query) # Query or string object + query_metadata = getattr(query, "metadata", None) + query_timeout = getattr(query, "timeout", None) + + parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwargs)) + + self._metadata = { + "query": query_text, + "parameters": parameters, + "server": self._connection.server_info, + } + + def on_attached(metadata): + self._metadata.update(metadata) + # For auto-commit there is no qid and Bolt 3 does not support qid + self._raw_qid = metadata.get("qid", -1) + if self._raw_qid != -1: + self._connection.most_recent_qid = self._raw_qid + self._keys = metadata.get("fields") + self._attached = True + + def on_failed_attach(metadata): + self._metadata.update(metadata) + self._attached = False + Util.callback(self._on_closed) + + self._connection.run( + query_text, + parameters=parameters, + mode=access_mode, + bookmarks=bookmarks, + metadata=query_metadata, + timeout=query_timeout, + db=db, + imp_user=imp_user, + on_success=on_attached, + on_failure=on_failed_attach, + ) + self._pull() + self._connection.send_all() + self._attach() + + def _pull(self): + def on_records(records): + if not self._discarding: + self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records)) + + def on_summary(): + self._attached = False + Util.callback(self._on_closed) + + def on_failure(metadata): + self._attached = False + Util.callback(self._on_closed) + + def on_success(summary_metadata): + self._streaming = False + has_more = summary_metadata.get("has_more") + self._has_more = bool(has_more) + if has_more: + return + self._metadata.update(summary_metadata) + self._bookmark = summary_metadata.get("bookmark") + + self._connection.pull( + n=self._fetch_size, + qid=self._qid, + on_records=on_records, + on_success=on_success, + on_failure=on_failure, + on_summary=on_summary, + ) + self._streaming = True + + def _discard(self): + def on_summary(): + self._attached = False + Util.callback(self._on_closed) + + def on_failure(metadata): + self._metadata.update(metadata) + self._attached = False + Util.callback(self._on_closed) + + def on_success(summary_metadata): + self._streaming = False + has_more = summary_metadata.get("has_more") + self._has_more = bool(has_more) + if has_more: + return + self._discarding = False + self._metadata.update(summary_metadata) + self._bookmark = summary_metadata.get("bookmark") + + # This was the last page received, discard the rest + self._connection.discard( + n=-1, + qid=self._qid, + on_success=on_success, + on_failure=on_failure, + on_summary=on_summary, + ) + self._streaming = True + + def __iter__(self): + """Iterator returning Records. + :returns: Record, it is an immutable ordered collection of key-value pairs. + :rtype: :class:`neo4j.Record` + """ + while self._record_buffer or self._attached: + if self._record_buffer: + yield self._record_buffer.popleft() + elif self._streaming: + self._connection.fetch_message() + elif self._discarding: + self._discard() + self._connection.send_all() + elif self._has_more: + self._pull() + self._connection.send_all() + + self._closed = True + + def _attach(self): + """Sets the Result object in an attached state by fetching messages from + the connection to the buffer. + """ + if self._closed is False: + while self._attached is False: + self._connection.fetch_message() + + def _buffer(self, n=None): + """Try to fill `self._record_buffer` with n records. + + Might end up with more records in the buffer if the fetch size makes it + overshoot. + Might ent up with fewer records in the buffer if there are not enough + records available. + """ + record_buffer = deque() + for record in self: + record_buffer.append(record) + if n is not None and len(record_buffer) >= n: + break + self._closed = False + if n is None: + self._record_buffer = record_buffer + else: + self._record_buffer.extend(record_buffer) + + def _buffer_all(self): + """Sets the Result object in an detached state by fetching all records + from the connection to the buffer. + """ + self._buffer() + + def _obtain_summary(self): + """Obtain the summary of this result, buffering any remaining records. + + :returns: The :class:`neo4j.ResultSummary` for this result + """ + if self._summary is None: + if self._metadata: + self._summary = ResultSummary( + self._connection.unresolved_address, **self._metadata + ) + elif self._connection: + self._summary = ResultSummary( + self._connection.unresolved_address, + server=self._connection.server_info + ) + + return self._summary + + def keys(self): + """The keys for the records in this result. + + :returns: tuple of key names + :rtype: tuple + """ + return self._keys + + def consume(self): + """Consume the remainder of this result and return a :class:`neo4j.ResultSummary`. + + Example:: + + def create_node_tx(tx, name): + result = tx.run( + "CREATE (n:ExampleNode { name: $name }) RETURN n", name=name + ) + record = result.single() + value = record.value() + info = result.consume() + return value, info + + with driver.session() as session: + node_id, info = session.write_transaction(create_node_tx, "example") + + Example:: + + def get_two_tx(tx): + result = tx.run("UNWIND [1,2,3,4] AS x RETURN x") + values = [] + for record in result: + if len(values) >= 2: + break + values.append(record.values()) + # discard the remaining records if there are any + info = result.consume() + # use the info for logging etc. + return values, info + + with driver.session() as session: + values, info = session.read_transaction(get_two_tx) + + :returns: The :class:`neo4j.ResultSummary` for this result + """ + if self._closed is False: + self._discarding = True + for _ in self: + pass + + return self._obtain_summary() + + def single(self): + """Obtain the next and only remaining record from this result if available else return None. + Calling this method always exhausts the result. + + A warning is generated if more than one record is available but + the first of these is still returned. + + :returns: the next :class:`neo4j.Record` or :const:`None` if none remain + :warns: if more than one record is available + """ + # TODO in 5.0 replace with this code that raises an error if there's not + # exactly one record in the left result stream. + # self._buffer(2). + # if len(self._record_buffer) != 1: + # raise SomeError("Expected exactly 1 record, found %i" + # % len(self._record_buffer)) + # return self._record_buffer.popleft() + # TODO: exhausts the result with self.consume if there are more records. + records = Util.list(self) + size = len(records) + if size == 0: + return None + if size != 1: + warn("Expected a result with a single record, but this result contains %d" % size) + return records[0] + + def peek(self): + """Obtain the next record from this result without consuming it. + This leaves the record in the buffer for further processing. + + :returns: the next :class:`.Record` or :const:`None` if none remain + """ + self._buffer(1) + if self._record_buffer: + return self._record_buffer[0] + + def graph(self): + """Return a :class:`neo4j.graph.Graph` instance containing all the graph objects + in the result. After calling this method, the result becomes + detached, buffering all remaining records. + + :returns: a result graph + :rtype: :class:`neo4j.graph.Graph` + """ + self._buffer_all() + return self._hydrant.graph + + def value(self, key=0, default=None): + """Helper function that return the remainder of the result as a list of values. + + See :class:`neo4j.Record.value` + + :param key: field to return for each remaining record. Obtain a single value from the record by index or key. + :param default: default value, used if the index of key is unavailable + :returns: list of individual values + :rtype: list + """ + return [record.value(key, default) for record in self] + + def values(self, *keys): + """Helper function that return the remainder of the result as a list of values lists. + + See :class:`neo4j.Record.values` + + :param keys: fields to return for each remaining record. Optionally filtering to include only certain values by index or key. + :returns: list of values lists + :rtype: list + """ + return [record.values(*keys) for record in self] + + def data(self, *keys): + """Helper function that return the remainder of the result as a list of dictionaries. + + See :class:`neo4j.Record.data` + + :param keys: fields to return for each remaining record. Optionally filtering to include only certain values by index or key. + :returns: list of dictionaries + :rtype: list + """ + return [record.data(*keys) for record in self] diff --git a/neo4j/_sync/work/session.py b/neo4j/_sync/work/session.py new file mode 100644 index 000000000..dbdf94828 --- /dev/null +++ b/neo4j/_sync/work/session.py @@ -0,0 +1,447 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from logging import getLogger +from random import random +from time import perf_counter + +from ..._async_compat import sleep +from ...api import ( + READ_ACCESS, + WRITE_ACCESS, +) +from ...conf import SessionConfig +from ...data import DataHydrator +from ...exceptions import ( + ClientError, + IncompleteCommit, + Neo4jError, + ServiceUnavailable, + SessionExpired, + TransactionError, + TransientError, +) +from ...work import Query +from .result import Result +from .transaction import Transaction +from .workspace import Workspace + + +log = getLogger("neo4j") + + +class Session(Workspace): + """A :class:`.Session` is a logical context for transactional units + of work. Connections are drawn from the :class:`.Driver` connection + pool as required. + + Session creation is a lightweight operation and sessions are not safe to + be used in concurrent contexts (multiple threads/coroutines). + Therefore, a session should generally be short-lived, and must not + span multiple threads/coroutines. + + In general, sessions will be created and destroyed within a `with` + context. For example:: + + with driver.session() as session: + result = session.run("MATCH (n:Person) RETURN n.name AS name") + # do something with the result... + + :param pool: connection pool instance + :param config: session config instance + """ + + # The current connection. + _connection = None + + # The current :class:`.Transaction` instance, if any. + _transaction = None + + # The current auto-transaction result, if any. + _auto_result = None + + # The state this session is in. + _state_failed = False + + # Session have been properly closed. + _closed = False + + def __init__(self, pool, session_config): + super().__init__(pool, session_config) + assert isinstance(session_config, SessionConfig) + self._bookmarks = tuple(session_config.bookmarks) + + def __del__(self): + if asyncio.iscoroutinefunction(self.close): + return + try: + self.close() + except (OSError, ServiceUnavailable, SessionExpired): + pass + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + if exception_type: + self._state_failed = True + self.close() + + def _connect(self, access_mode): + if access_mode is None: + access_mode = self._config.default_access_mode + super()._connect(access_mode) + + def _collect_bookmark(self, bookmark): + if bookmark: + self._bookmarks = [bookmark] + + def _result_closed(self): + if self._auto_result: + self._collect_bookmark(self._auto_result._bookmark) + self._auto_result = None + self._disconnect() + + def _result_error(self, _): + if self._auto_result: + self._auto_result = None + self._disconnect() + + def close(self): + """Close the session. + + This will release any borrowed resources, such as connections, and will + roll back any outstanding transactions. + """ + if self._connection: + if self._auto_result: + if self._state_failed is False: + try: + self._auto_result.consume() + self._collect_bookmark(self._auto_result._bookmark) + except Exception as error: + # TODO: Investigate potential non graceful close states + self._auto_result = None + self._state_failed = True + + if self._transaction: + if self._transaction.closed() is False: + self._transaction.rollback() # roll back the transaction if it is not closed + self._transaction = None + + try: + if self._connection: + self._connection.send_all() + self._connection.fetch_all() + # TODO: Investigate potential non graceful close states + except Neo4jError: + pass + except TransactionError: + pass + except ServiceUnavailable: + pass + except SessionExpired: + pass + finally: + self._disconnect() + + self._state_failed = False + self._closed = True + + def run(self, query, parameters=None, **kwargs): + """Run a Cypher query within an auto-commit transaction. + + The query is sent and the result header received + immediately but the :class:`neo4j.Result` content is + fetched lazily as consumed by the client application. + + If a query is executed before a previous + :class:`neo4j.Result` in the same :class:`.Session` has + been fully consumed, the first result will be fully fetched + and buffered. Note therefore that the generally recommended + pattern of usage is to fully consume one result before + executing a subsequent query. If two results need to be + consumed in parallel, multiple :class:`.Session` objects + can be used as an alternative to result buffering. + + For more usage details, see :meth:`.Transaction.run`. + + :param query: cypher query + :type query: str, neo4j.Query + :param parameters: dictionary of parameters + :type parameters: dict + :param kwargs: additional keyword parameters + :returns: a new :class:`neo4j.Result` object + :rtype: Result + """ + if not query: + raise ValueError("Cannot run an empty query") + if not isinstance(query, (str, Query)): + raise TypeError("query must be a string or a Query instance") + + if self._transaction: + raise ClientError("Explicit Transaction must be handled explicitly") + + if self._auto_result: + # This will buffer upp all records for the previous auto-transaction + self._auto_result._buffer_all() + + if not self._connection: + self._connect(self._config.default_access_mode) + cx = self._connection + protocol_version = cx.PROTOCOL_VERSION + server_info = cx.server_info + + hydrant = DataHydrator() + + self._auto_result = Result( + cx, hydrant, self._config.fetch_size, self._result_closed, + self._result_error + ) + self._auto_result._run( + query, parameters, self._config.database, + self._config.impersonated_user, self._config.default_access_mode, + self._bookmarks, **kwargs + ) + + return self._auto_result + + def last_bookmark(self): + """Return the bookmark received following the last completed transaction. + Note: For auto-transaction (Session.run) this will trigger an consume for the current result. + + :returns: :class:`neo4j.Bookmark` object + """ + # The set of bookmarks to be passed into the next transaction. + + if self._auto_result: + self._auto_result.consume() + + if self._transaction and self._transaction._closed: + self._collect_bookmark(self._transaction._bookmark) + self._transaction = None + + if len(self._bookmarks): + return self._bookmarks[len(self._bookmarks)-1] + return None + + def _transaction_closed_handler(self): + if self._transaction: + self._collect_bookmark(self._transaction._bookmark) + self._transaction = None + self._disconnect() + + def _transaction_error_handler(self, _): + if self._transaction: + self._transaction = None + self._disconnect() + + def _open_transaction(self, *, access_mode, metadata=None, + timeout=None): + self._connect(access_mode=access_mode) + self._transaction = Transaction( + self._connection, self._config.fetch_size, + self._transaction_closed_handler, + self._transaction_error_handler + ) + self._transaction._begin( + self._config.database, self._config.impersonated_user, + self._bookmarks, access_mode, metadata, timeout + ) + + def begin_transaction(self, metadata=None, timeout=None): + """ Begin a new unmanaged transaction. Creates a new :class:`.Transaction` within this session. + At most one transaction may exist in a session at any point in time. + To maintain multiple concurrent transactions, use multiple concurrent sessions. + + Note: For auto-transaction (Session.run) this will trigger an consume for the current result. + + :param metadata: + a dictionary with metadata. + Specified metadata will be attached to the executing transaction and visible in the output of ``dbms.listQueries`` and ``dbms.listTransactions`` procedures. + It will also get logged to the ``query.log``. + This functionality makes it easier to tag transactions and is equivalent to ``dbms.setTXMetaData`` procedure, see https://neo4j.com/docs/operations-manual/current/reference/procedures/ for procedure reference. + :type metadata: dict + + :param timeout: + the transaction timeout in seconds. + Transactions that execute longer than the configured timeout will be terminated by the database. + This functionality allows to limit query/transaction execution time. + Specified timeout overrides the default timeout configured in the database using ``dbms.transaction.timeout`` setting. + Value should not represent a duration of zero or negative duration. + :type timeout: int + + :returns: A new transaction instance. + :rtype: Transaction + + :raises TransactionError: :class:`neo4j.exceptions.TransactionError` if a transaction is already open. + """ + # TODO: Implement TransactionConfig consumption + + if self._auto_result: + self._auto_result.consume() + + if self._transaction: + raise TransactionError("Explicit transaction already open") + + self._open_transaction( + access_mode=self._config.default_access_mode, metadata=metadata, + timeout=timeout + ) + + return self._transaction + + def _run_transaction( + self, access_mode, transaction_function, *args, **kwargs + ): + if not callable(transaction_function): + raise TypeError("Unit of work is not callable") + + metadata = getattr(transaction_function, "metadata", None) + timeout = getattr(transaction_function, "timeout", None) + + retry_delay = retry_delay_generator(self._config.initial_retry_delay, self._config.retry_delay_multiplier, self._config.retry_delay_jitter_factor) + + errors = [] + + t0 = -1 # Timer + + while True: + try: + self._open_transaction( + access_mode=access_mode, metadata=metadata, + timeout=timeout + ) + tx = self._transaction + try: + result = transaction_function(tx, *args, **kwargs) + except Exception: + tx.close() + raise + else: + tx.commit() + except IncompleteCommit: + raise + except (ServiceUnavailable, SessionExpired) as error: + errors.append(error) + self._disconnect() + except TransientError as transient_error: + if not transient_error.is_retriable(): + raise + errors.append(transient_error) + else: + return result + if t0 == -1: + t0 = perf_counter() # The timer should be started after the first attempt + t1 = perf_counter() + if t1 - t0 > self._config.max_transaction_retry_time: + break + delay = next(retry_delay) + log.warning("Transaction failed and will be retried in {}s ({})".format(delay, "; ".join(errors[-1].args))) + sleep(delay) + + if errors: + raise errors[-1] + else: + raise ServiceUnavailable("Transaction failed") + + def read_transaction(self, transaction_function, *args, **kwargs): + """Execute a unit of work in a managed read transaction. + This transaction will automatically be committed unless an exception is thrown during query execution or by the user code. + Note, that this function perform retries and that the supplied `transaction_function` might get invoked more than once. + + Managed transactions should not generally be explicitly committed + (via ``tx.commit()``). + + Example:: + + def do_cypher_tx(tx, cypher): + result = tx.run(cypher) + values = [record.values() for record in result] + return values + + with driver.session() as session: + values = session.read_transaction(do_cypher_tx, "RETURN 1 AS x") + + Example:: + + def get_two_tx(tx): + result = tx.run("UNWIND [1,2,3,4] AS x RETURN x") + values = [] + for record in result: + if len(values) >= 2: + break + values.append(record.values()) + # discard the remaining records if there are any + info = result.consume() + # use the info for logging etc. + return values + + with driver.session() as session: + values = session.read_transaction(get_two_tx) + + :param transaction_function: a function that takes a transaction as an + argument and does work with the transaction. + `transaction_function(tx, *args, **kwargs)` where `tx` is a + :class:`.Transaction`. + :param args: arguments for the `transaction_function` + :param kwargs: key word arguments for the `transaction_function` + :return: a result as returned by the given unit of work + """ + return self._run_transaction( + READ_ACCESS, transaction_function, *args, **kwargs + ) + + def write_transaction(self, transaction_function, *args, **kwargs): + """Execute a unit of work in a managed write transaction. + This transaction will automatically be committed unless an exception is thrown during query execution or by the user code. + Note, that this function perform retries and that the supplied `transaction_function` might get invoked more than once. + + Managed transactions should not generally be explicitly committed (via tx.commit()). + + Example:: + + def create_node_tx(tx, name): + query = "CREATE (n:NodeExample { name: $name }) RETURN id(n) AS node_id" + result = tx.run(query, name=name) + record = result.single() + return record["node_id"] + + with driver.session() as session: + node_id = session.write_transaction(create_node_tx, "example") + + :param transaction_function: a function that takes a transaction as an + argument and does work with the transaction. + `transaction_function(tx, *args, **kwargs)` where `tx` is a + :class:`.Transaction`. + :param args: key word arguments for the `transaction_function` + :param kwargs: key word arguments for the `transaction_function` + :return: a result as returned by the given unit of work + """ + return self._run_transaction( + WRITE_ACCESS, transaction_function, *args, **kwargs + ) + + +def retry_delay_generator(initial_delay, multiplier, jitter_factor): + delay = initial_delay + while True: + jitter = jitter_factor * delay + yield delay - jitter + (2 * jitter * random()) + delay *= multiplier diff --git a/neo4j/_sync/work/transaction.py b/neo4j/_sync/work/transaction.py new file mode 100644 index 000000000..73d082388 --- /dev/null +++ b/neo4j/_sync/work/transaction.py @@ -0,0 +1,199 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..._async_compat.util import Util +from ...data import DataHydrator +from ...exceptions import TransactionError +from ...work import Query +from ..io import ConnectionErrorHandler +from .result import Result + + +class Transaction: + """ Container for multiple Cypher queries to be executed within a single + context. asynctransactions can be used within a :py:const:`with` + block where the transaction is committed or rolled back on based on + whether an exception is raised:: + + with session.begin_transaction() as tx: + ... + + """ + + def __init__(self, connection, fetch_size, on_closed, on_error): + self._connection = connection + self._error_handling_connection = ConnectionErrorHandler( + connection, self._error_handler + ) + self._bookmark = None + self._results = [] + self._closed = False + self._last_error = None + self._fetch_size = fetch_size + self._on_closed = on_closed + self._on_error = on_error + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + if self._closed: + return + success = not bool(exception_type) + if success: + self.commit() + self.close() + + def _begin( + self, database, imp_user, bookmarks, access_mode, metadata, timeout + ): + self._connection.begin( + bookmarks=bookmarks, metadata=metadata, timeout=timeout, + mode=access_mode, db=database, imp_user=imp_user + ) + self._error_handling_connection.send_all() + self._error_handling_connection.fetch_all() + + def _result_on_closed_handler(self): + pass + + def _error_handler(self, exc): + self._last_error = exc + Util.callback(self._on_error, exc) + + def _consume_results(self): + for result in self._results: + result.consume() + self._results = [] + + def run(self, query, parameters=None, **kwparameters): + """ Run a Cypher query within the context of this transaction. + + Cypher is typically expressed as a query template plus a + set of named parameters. In Python, parameters may be expressed + through a dictionary of parameters, through individual parameter + arguments, or as a mixture of both. For example, the `run` + queries below are all equivalent:: + + >>> query = "CREATE (a:Person { name: $name, age: $age })" + >>> result = tx.run(query, {"name": "Alice", "age": 33}) + >>> result = tx.run(query, {"name": "Alice"}, age=33) + >>> result = tx.run(query, name="Alice", age=33) + + Parameter values can be of any type supported by the Neo4j type + system. In Python, this includes :class:`bool`, :class:`int`, + :class:`str`, :class:`list` and :class:`dict`. Note however that + :class:`list` properties must be homogenous. + + :param query: cypher query + :type query: str + :param parameters: dictionary of parameters + :type parameters: dict + :param kwparameters: additional keyword parameters + :returns: a new :class:`neo4j.Result` object + :rtype: :class:`neo4j.Result` + :raise TransactionError: if the transaction is already closed + """ + if isinstance(query, Query): + raise ValueError("Query object is only supported for session.run") + + if self._closed: + raise TransactionError(self, "Transaction closed") + if self._last_error: + raise TransactionError(self, + "Transaction failed") from self._last_error + + if (self._results + and self._connection.supports_multiple_results is False): + # Bolt 3 Support + # Buffer up all records for the previous Result because it does not + # have any qid to fetch in batches. + self._results[-1]._buffer_all() + + result = Result( + self._connection, DataHydrator(), self._fetch_size, + self._result_on_closed_handler, + self._error_handler + ) + self._results.append(result) + + result._tx_ready_run(query, parameters, **kwparameters) + + return result + + def commit(self): + """Mark this transaction as successful and close in order to trigger a COMMIT. + + :raise TransactionError: if the transaction is already closed + """ + if self._closed: + raise TransactionError(self, "Transaction closed") + if self._last_error: + raise TransactionError(self, + "Transaction failed") from self._last_error + + metadata = {} + try: + # DISCARD pending records then do a commit. + self._consume_results() + self._connection.commit(on_success=metadata.update) + self._connection.send_all() + self._connection.fetch_all() + self._bookmark = metadata.get("bookmark") + finally: + self._closed = True + Util.callback(self._on_closed) + + return self._bookmark + + def rollback(self): + """Mark this transaction as unsuccessful and close in order to trigger a ROLLBACK. + + :raise TransactionError: if the transaction is already closed + """ + if self._closed: + raise TransactionError(self, "Transaction closed") + + metadata = {} + try: + if not (self._connection.defunct() + or self._connection.closed() + or self._connection.is_reset): + # DISCARD pending records then do a rollback. + self._consume_results() + self._connection.rollback(on_success=metadata.update) + self._connection.send_all() + self._connection.fetch_all() + finally: + self._closed = True + Util.callback(self._on_closed) + + def close(self): + """Close this transaction, triggering a ROLLBACK if not closed. + """ + if self._closed: + return + self.rollback() + + def closed(self): + """Indicator to show whether the transaction has been closed. + + :return: :const:`True` if closed, :const:`False` otherwise. + :rtype: bool + """ + return self._closed diff --git a/neo4j/_sync/work/workspace.py b/neo4j/_sync/work/workspace.py new file mode 100644 index 000000000..3ed50ad26 --- /dev/null +++ b/neo4j/_sync/work/workspace.py @@ -0,0 +1,102 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +from ...conf import WorkspaceConfig +from ...exceptions import ServiceUnavailable +from ..io import Neo4jPool + + +class Workspace: + + def __init__(self, pool, config): + assert isinstance(config, WorkspaceConfig) + self._pool = pool + self._config = config + self._connection = None + self._connection_access_mode = None + # Sessions are supposed to cache the database on which to operate. + self._cached_database = False + self._bookmarks = None + + def __del__(self): + if asyncio.iscoroutinefunction(self.close): + return + try: + self.close() + except OSError: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def _set_cached_database(self, database): + self._cached_database = True + self._config.database = database + + def _connect(self, access_mode): + if self._connection: + # TODO: Investigate this + # log.warning("FIXME: should always disconnect before connect") + self._connection.send_all() + self._connection.fetch_all() + self._disconnect() + if not self._cached_database: + if (self._config.database is not None + or not isinstance(self._pool, Neo4jPool)): + self._set_cached_database(self._config.database) + else: + # This is the first time we open a connection to a server in a + # cluster environment for this session without explicitly + # configured database. Hence, we request a routing table update + # to try to fetch the home database. If provided by the server, + # we shall use this database explicitly for all subsequent + # actions within this session. + self._pool.update_routing_table( + database=self._config.database, + imp_user=self._config.impersonated_user, + bookmarks=self._bookmarks, + database_callback=self._set_cached_database + ) + self._connection = self._pool.acquire( + access_mode=access_mode, + timeout=self._config.connection_acquisition_timeout, + database=self._config.database, + bookmarks=self._bookmarks + ) + self._connection_access_mode = access_mode + + def _disconnect(self, sync=False): + if self._connection: + if sync: + try: + self._connection.send_all() + self._connection.fetch_all() + except ServiceUnavailable: + pass + if self._connection: + self._pool.release(self._connection) + self._connection = None + self._connection_access_mode = None + + def close(self): + self._disconnect(sync=True) diff --git a/testkitbackend/_sync/__init__.py b/testkitbackend/_sync/__init__.py new file mode 100644 index 000000000..b81a309da --- /dev/null +++ b/testkitbackend/_sync/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py new file mode 100644 index 000000000..9f1bae94c --- /dev/null +++ b/testkitbackend/_sync/backend.py @@ -0,0 +1,145 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from inspect import ( + getmembers, + isfunction, +) +from json import ( + dumps, + loads, +) +import traceback + +from neo4j._exceptions import BoltError +from neo4j.exceptions import ( + DriverError, + Neo4jError, + UnsupportedServerProduct, +) + +from . import requests +from .._driver_logger import ( + buffer_handler, + log, +) +from ..backend import Request + + +class Backend: + def __init__(self, rd, wr): + self._rd = rd + self._wr = wr + self.drivers = {} + self.custom_resolutions = {} + self.dns_resolutions = {} + self.sessions = {} + self.results = {} + self.errors = {} + self.transactions = {} + self.errors = {} + self.key = 0 + # Collect all request handlers + self._requestHandlers = dict( + [m for m in getmembers(requests, isfunction)]) + + def next_key(self): + self.key = self.key + 1 + return self.key + + def process_request(self): + """ Reads next request from the stream and processes it. + """ + in_request = False + request = "" + for line in self._rd: + # Remove trailing newline + line = line.decode('UTF-8').rstrip() + if line == "#request begin": + in_request = True + elif line == "#request end": + self._process(request) + return True + else: + if in_request: + request = request + line + return False + + def _process(self, request): + """ Process a received request by retrieving handler that + corresponds to the request name. + """ + try: + request = loads(request, object_pairs_hook=Request) + if not isinstance(request, Request): + raise Exception("Request is not an object") + name = request.get('name', 'invalid') + handler = self._requestHandlers.get(name) + if not handler: + raise Exception("No request handler for " + name) + data = request["data"] + log.info("<<< " + name + dumps(data)) + handler(self, data) + unsused_keys = request.unseen_keys + if unsused_keys: + raise NotImplementedError( + "Backend does not support some properties of the " + name + + " request: " + ", ".join(unsused_keys) + ) + except (Neo4jError, DriverError, UnsupportedServerProduct, + BoltError) as e: + log.debug(traceback.format_exc()) + if isinstance(e, Neo4jError): + msg = "" if e.message is None else str(e.message) + else: + msg = str(e.args[0]) if e.args else "" + + key = self.next_key() + self.errors[key] = e + payload = {"id": key, "errorType": str(type(e)), "msg": msg} + if isinstance(e, Neo4jError): + payload["code"] = e.code + self.send_response("DriverError", payload) + except requests.FrontendError as e: + self.send_response("FrontendError", {"msg": str(e)}) + except Exception: + tb = traceback.format_exc() + log.error(tb) + self.send_response("BackendError", {"msg": tb}) + + def send_response(self, name, data): + """ Sends a response to backend. + """ + with buffer_handler.lock: + log_output = buffer_handler.stream.getvalue() + buffer_handler.stream.truncate(0) + buffer_handler.stream.seek(0) + if not log_output.endswith("\n"): + log_output += "\n" + self._wr.write(log_output.encode("utf-8")) + response = {"name": name, "data": data} + response = dumps(response) + log.info(">>> " + name + dumps(data)) + self._wr.write(b"#response begin\n") + self._wr.write(bytes(response+"\n", "utf-8")) + self._wr.write(b"#response end\n") + if isinstance(self._wr, asyncio.StreamWriter): + self._wr.drain() + else: + self._wr.flush() diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py new file mode 100644 index 000000000..744d83568 --- /dev/null +++ b/testkitbackend/_sync/requests.py @@ -0,0 +1,444 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from os import path + +import neo4j +from neo4j._async_compat.util import Util + +from .. import ( + fromtestkit, + totestkit, +) + + +class FrontendError(Exception): + pass + + +def load_config(): + config_path = path.join(path.dirname(__file__), "..", "test_config.json") + with open(config_path, "r") as fd: + config = json.load(fd) + skips = config["skips"] + features = [k for k, v in config["features"].items() if v is True] + import ssl + if ssl.HAS_TLSv1_3: + features += ["Feature:TLS:1.3"] + return skips, features + + +SKIPPED_TESTS, FEATURES = load_config() + + +def StartTest(backend, data): + if data["testName"] in SKIPPED_TESTS: + backend.send_response("SkipTest", { + "reason": SKIPPED_TESTS[data["testName"]] + }) + else: + backend.send_response("RunTest", {}) + + +def GetFeatures(backend, data): + backend.send_response("FeatureList", {"features": FEATURES}) + + +def NewDriver(backend, data): + auth_token = data["authorizationToken"]["data"] + data["authorizationToken"].mark_item_as_read_if_equals( + "name", "AuthorizationToken" + ) + scheme = auth_token["scheme"] + if scheme == "basic": + auth = neo4j.basic_auth( + auth_token["principal"], auth_token["credentials"], + realm=auth_token.get("realm", None) + ) + elif scheme == "kerberos": + auth = neo4j.kerberos_auth(auth_token["credentials"]) + elif scheme == "bearer": + auth = neo4j.bearer_auth(auth_token["credentials"]) + else: + auth = neo4j.custom_auth( + auth_token["principal"], auth_token["credentials"], + auth_token["realm"], auth_token["scheme"], + **auth_token.get("parameters", {}) + ) + auth_token.mark_item_as_read("parameters", recursive=True) + resolver = None + if data["resolverRegistered"] or data["domainNameResolverRegistered"]: + resolver = resolution_func(backend, data["resolverRegistered"], + data["domainNameResolverRegistered"]) + connection_timeout = data.get("connectionTimeoutMs") + if connection_timeout is not None: + connection_timeout /= 1000 + max_transaction_retry_time = data.get("maxTxRetryTimeMs") + if max_transaction_retry_time is not None: + max_transaction_retry_time /= 1000 + data.mark_item_as_read("domainNameResolverRegistered") + driver = neo4j.GraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], + resolver=resolver, connection_timeout=connection_timeout, + fetch_size=data.get("fetchSize"), + max_transaction_retry_time=max_transaction_retry_time, + ) + key = backend.next_key() + backend.drivers[key] = driver + backend.send_response("Driver", {"id": key}) + + +def VerifyConnectivity(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + driver.verify_connectivity() + backend.send_response("Driver", {"id": driver_id}) + + +def CheckMultiDBSupport(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + backend.send_response("MultiDBSupport", { + "id": backend.next_key(), "available": driver.supports_multi_db() + }) + + +def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False): + # This solution (putting custom resolution together with DNS resolution + # into one function only works because the Python driver calls the custom + # resolver function for every connection, which is not true for all + # drivers. Properly exposing a way to change the DNS lookup behavior is not + # possible without changing the driver's code. + assert custom_resolver or custom_dns_resolver + + def resolve(address): + addresses = [":".join(map(str, address))] + if custom_resolver: + key = backend.next_key() + backend.send_response("ResolverResolutionRequired", { + "id": key, + "address": addresses[0] + }) + if not backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.custom_resolutions: + raise RuntimeError( + "Backend did not receive expected " + "ResolverResolutionCompleted message for id %s" % key + ) + addresses = backend.custom_resolutions.pop(key) + if custom_dns_resolver: + dns_resolved_addresses = [] + for address in addresses: + key = backend.next_key() + address = address.rsplit(":", 1) + backend.send_response("DomainNameResolutionRequired", { + "id": key, + "name": address[0] + }) + if not backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.dns_resolutions: + raise RuntimeError( + "Backend did not receive expected " + "DomainNameResolutionCompleted message for id %s" % key + ) + dns_resolved_addresses += list(map( + lambda a: ":".join((a, *address[1:])), + backend.dns_resolutions.pop(key) + )) + + addresses = dns_resolved_addresses + + return list(map(neo4j.Address.parse, addresses)) + + return resolve + + +def ResolverResolutionCompleted(backend, data): + backend.custom_resolutions[data["requestId"]] = data["addresses"] + + +def DomainNameResolutionCompleted(backend, data): + backend.dns_resolutions[data["requestId"]] = data["addresses"] + + +def DriverClose(backend, data): + key = data["driverId"] + driver = backend.drivers[key] + driver.close() + backend.send_response("Driver", {"id": key}) + + +class SessionTracker: + """ Keeps some extra state about the tracked session + """ + + def __init__(self, session): + self.session = session + self.state = "" + self.error_id = "" + + +def NewSession(backend, data): + driver = backend.drivers[data["driverId"]] + access_mode = data["accessMode"] + if access_mode == "r": + access_mode = neo4j.READ_ACCESS + elif access_mode == "w": + access_mode = neo4j.WRITE_ACCESS + else: + raise ValueError("Unknown access mode:" + access_mode) + config = { + "default_access_mode": access_mode, + "bookmarks": data["bookmarks"], + "database": data["database"], + "fetch_size": data.get("fetchSize", None), + "impersonated_user": data.get("impersonatedUser", None), + + } + session = driver.session(**config) + key = backend.next_key() + backend.sessions[key] = SessionTracker(session) + backend.send_response("Session", {"id": key}) + + +def SessionRun(backend, data): + session = backend.sessions[data["sessionId"]].session + query, params = fromtestkit.to_query_and_params(data) + result = session.run(query, parameters=params) + key = backend.next_key() + backend.results[key] = result + backend.send_response("Result", {"id": key, "keys": result.keys()}) + + +def SessionClose(backend, data): + key = data["sessionId"] + session = backend.sessions[key].session + session.close() + del backend.sessions[key] + backend.send_response("Session", {"id": key}) + + +def SessionBeginTransaction(backend, data): + key = data["sessionId"] + session = backend.sessions[key].session + metadata, timeout = fromtestkit.to_meta_and_timeout(data) + tx = session.begin_transaction(metadata=metadata, timeout=timeout) + key = backend.next_key() + backend.transactions[key] = tx + backend.send_response("Transaction", {"id": key}) + + +def SessionReadTransaction(backend, data): + transactionFunc(backend, data, True) + + +def SessionWriteTransaction(backend, data): + transactionFunc(backend, data, False) + + +def transactionFunc(backend, data, is_read): + key = data["sessionId"] + session_tracker = backend.sessions[key] + session = session_tracker.session + metadata, timeout = fromtestkit.to_meta_and_timeout(data) + + @neo4j.unit_of_work(metadata=metadata, timeout=timeout) + def func(tx): + txkey = backend.next_key() + backend.transactions[txkey] = tx + session_tracker.state = '' + backend.send_response("RetryableTry", {"id": txkey}) + + cont = True + while cont: + cont = backend.process_request() + if session_tracker.state == '+': + cont = False + elif session_tracker.state == '-': + if session_tracker.error_id: + raise backend.errors[session_tracker.error_id] + else: + raise FrontendError("Client said no") + + if is_read: + session.read_transaction(func) + else: + session.write_transaction(func) + backend.send_response("RetryableDone", {}) + + +def RetryablePositive(backend, data): + key = data["sessionId"] + session_tracker = backend.sessions[key] + session_tracker.state = '+' + + +def RetryableNegative(backend, data): + key = data["sessionId"] + session_tracker = backend.sessions[key] + session_tracker.state = '-' + session_tracker.error_id = data.get('errorId', '') + + +def SessionLastBookmarks(backend, data): + key = data["sessionId"] + session = backend.sessions[key].session + bookmark = session.last_bookmark() + bookmarks = [] + if bookmark: + bookmarks.append(bookmark) + backend.send_response("Bookmarks", {"bookmarks": bookmarks}) + + +def TransactionRun(backend, data): + key = data["txId"] + tx = backend.transactions[key] + cypher, params = fromtestkit.to_cypher_and_params(data) + result = tx.run(cypher, parameters=params) + key = backend.next_key() + backend.results[key] = result + backend.send_response("Result", {"id": key, "keys": result.keys()}) + + +def TransactionCommit(backend, data): + key = data["txId"] + tx = backend.transactions[key] + tx.commit() + backend.send_response("Transaction", {"id": key}) + + +def TransactionRollback(backend, data): + key = data["txId"] + tx = backend.transactions[key] + tx.rollback() + backend.send_response("Transaction", {"id": key}) + + +def TransactionClose(backend, data): + key = data["txId"] + tx = backend.transactions[key] + tx.close() + backend.send_response("Transaction", {"id": key}) + + +def ResultNext(backend, data): + result = backend.results[data["resultId"]] + + try: + record = Util.next(Util.iter(result)) + except StopIteration: + backend.send_response("NullRecord", {}) + return + backend.send_response("Record", totestkit.record(record)) + + +def ResultSingle(backend, data): + result = backend.results[data["resultId"]] + backend.send_response("Record", totestkit.record(result.single())) + + +def ResultPeek(backend, data): + result = backend.results[data["resultId"]] + record = result.peek() + if record is not None: + backend.send_response("Record", totestkit.record(record)) + else: + backend.send_response("NullRecord", {}) + + +def ResultList(backend, data): + result = backend.results[data["resultId"]] + records = Util.list(result) + backend.send_response("RecordList", { + "records": [totestkit.record(r) for r in records] + }) + + +def ResultConsume(backend, data): + result = backend.results[data["resultId"]] + summary = result.consume() + from neo4j import ResultSummary + assert isinstance(summary, ResultSummary) + backend.send_response("Summary", { + "serverInfo": { + "address": ":".join(map(str, summary.server.address)), + "agent": summary.server.agent, + "protocolVersion": + ".".join(map(str, summary.server.protocol_version)), + }, + "counters": None if not summary.counters else { + "constraintsAdded": summary.counters.constraints_added, + "constraintsRemoved": summary.counters.constraints_removed, + "containsSystemUpdates": summary.counters.contains_system_updates, + "containsUpdates": summary.counters.contains_updates, + "indexesAdded": summary.counters.indexes_added, + "indexesRemoved": summary.counters.indexes_removed, + "labelsAdded": summary.counters.labels_added, + "labelsRemoved": summary.counters.labels_removed, + "nodesCreated": summary.counters.nodes_created, + "nodesDeleted": summary.counters.nodes_deleted, + "propertiesSet": summary.counters.properties_set, + "relationshipsCreated": summary.counters.relationships_created, + "relationshipsDeleted": summary.counters.relationships_deleted, + "systemUpdates": summary.counters.system_updates, + }, + "database": summary.database, + "notifications": summary.notifications, + "plan": summary.plan, + "profile": summary.profile, + "query": { + "text": summary.query, + "parameters": {k: totestkit.field(v) + for k, v in summary.parameters.items()}, + }, + "queryType": summary.query_type, + "resultAvailableAfter": summary.result_available_after, + "resultConsumedAfter": summary.result_consumed_after, + }) + + +def ForcedRoutingTableUpdate(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + database = data["database"] + bookmarks = data["bookmarks"] + with driver._pool.refresh_lock: + driver._pool.update_routing_table( + database=database, imp_user=None, bookmarks=bookmarks + ) + backend.send_response("Driver", {"id": driver_id}) + + +def GetRoutingTable(backend, data): + driver_id = data["driverId"] + database = data["database"] + driver = backend.drivers[driver_id] + routing_table = driver._pool.routing_tables[database] + response_data = { + "database": routing_table.database, + "ttl": routing_table.ttl, + } + for role in ("routers", "readers", "writers"): + addresses = routing_table.__getattribute__(role) + response_data[role] = list(map(str, addresses)) + backend.send_response("RoutingTable", response_data) diff --git a/tests/unit/sync/__init__.py b/tests/unit/sync/__init__.py new file mode 100644 index 000000000..b81a309da --- /dev/null +++ b/tests/unit/sync/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/sync/io/__init__.py b/tests/unit/sync/io/__init__.py new file mode 100644 index 000000000..b81a309da --- /dev/null +++ b/tests/unit/sync/io/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/sync/io/conftest.py b/tests/unit/sync/io/conftest.py new file mode 100644 index 000000000..33309fc9b --- /dev/null +++ b/tests/unit/sync/io/conftest.py @@ -0,0 +1,156 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from io import BytesIO +from struct import ( + pack as struct_pack, + unpack as struct_unpack, +) + +import pytest + +from neo4j._sync.io._common import MessageInbox +from neo4j.packstream import ( + Packer, + UnpackableBuffer, + Unpacker, +) + + +class FakeSocket: + + def __init__(self, address): + self.address = address + self.captured = b"" + self.messages = MessageInbox(self, on_error=print) + + def getsockname(self): + return "127.0.0.1", 0xFFFF + + def getpeername(self): + return self.address + + def recv_into(self, buffer, nbytes): + data = self.captured[:nbytes] + actual = len(data) + buffer[:actual] = data + self.captured = self.captured[actual:] + return actual + + def sendall(self, data): + self.captured += data + + def close(self): + return + + def pop_message(self): + return self.messages.pop() + + +class FakeSocket2: + + def __init__(self, address=None, on_send=None): + self.address = address + self.recv_buffer = bytearray() + self._messages = MessageInbox(self, on_error=print) + self.on_send = on_send + + def getsockname(self): + return "127.0.0.1", 0xFFFF + + def getpeername(self): + return self.address + + def recv_into(self, buffer, nbytes): + data = self.recv_buffer[:nbytes] + actual = len(data) + buffer[:actual] = data + self.recv_buffer = self.recv_buffer[actual:] + return actual + + def sendall(self, data): + if callable(self.on_send): + self.on_send(data) + + def close(self): + return + + def inject(self, data): + self.recv_buffer += data + + def _pop_chunk(self): + chunk_size, = struct_unpack(">H", self.recv_buffer[:2]) + print("CHUNK SIZE %r" % chunk_size) + end = 2 + chunk_size + chunk_data, self.recv_buffer = self.recv_buffer[2:end], self.recv_buffer[end:] + return chunk_data + + def pop_message(self): + data = bytearray() + while True: + chunk = self._pop_chunk() + print("CHUNK %r" % chunk) + if chunk: + data.extend(chunk) + elif data: + break # end of message + else: + continue # NOOP + header = data[0] + n_fields = header % 0x10 + tag = data[1] + buffer = UnpackableBuffer(data[2:]) + unpacker = Unpacker(buffer) + fields = [unpacker.unpack() for _ in range(n_fields)] + return tag, fields + + def send_message(self, tag, *fields): + data = self.encode_message(tag, *fields) + self.sendall(struct_pack(">H", len(data)) + data + b"\x00\x00") + + @classmethod + def encode_message(cls, tag, *fields): + b = BytesIO() + packer = Packer(b) + for field in fields: + packer.pack(field) + return bytearray([0xB0 + len(fields), tag]) + b.getvalue() + + +class FakeSocketPair: + + def __init__(self, address): + self.client = FakeSocket2(address) + self.server = FakeSocket2() + self.client.on_send = self.server.inject + self.server.on_send = self.client.inject + + +@pytest.fixture +def fake_socket(): + return FakeSocket + + +@pytest.fixture +def fake_socket_2(): + return FakeSocket2 + + +@pytest.fixture +def fake_socket_pair(): + return FakeSocketPair diff --git a/tests/unit/sync/io/test__common.py b/tests/unit/sync/io/test__common.py new file mode 100644 index 000000000..5106a2da4 --- /dev/null +++ b/tests/unit/sync/io/test__common.py @@ -0,0 +1,50 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._sync.io._common import Outbox + + +@pytest.mark.parametrize(("chunk_size", "data", "result"), ( + ( + 2, + (bytes(range(10, 15)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) + ), + ( + 2, + (bytes(range(10, 14)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13)) + ), + ( + 2, + (bytes((5, 6, 7)), bytes((8, 9))), + bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) + ), +)) +def test_async_outbox_chunking(chunk_size, data, result): + outbox = Outbox(max_chunk_size=chunk_size) + assert bytes(outbox.view()) == b"" + for d in data: + outbox.write(d) + assert bytes(outbox.view()) == result + # make sure this works multiple times + assert bytes(outbox.view()) == result + outbox.clear() + assert bytes(outbox.view()) == b"" diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py new file mode 100644 index 000000000..b7d1e6c55 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt.py @@ -0,0 +1,62 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._sync.io import Bolt + + +# python -m pytest tests/unit/io/test_class_bolt.py -s -v + + +def test_class_method_protocol_handlers(): + # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers + protocol_handlers = Bolt.protocol_handlers() + assert len(protocol_handlers) == 6 + + +@pytest.mark.parametrize( + "test_input, expected", + [ + ((0, 0), 0), + ((4, 0), 1), + ] +) +def test_class_method_protocol_handlers_with_protocol_version(test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers_with_protocol_version + protocol_handlers = Bolt.protocol_handlers(protocol_version=test_input) + assert len(protocol_handlers) == expected + + +def test_class_method_protocol_handlers_with_invalid_protocol_version(): + # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers_with_invalid_protocol_version + with pytest.raises(TypeError): + Bolt.protocol_handlers(protocol_version=2) + + +def test_class_method_get_handshake(): + # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_get_handshake + handshake = Bolt.get_handshake() + assert handshake == b"\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x04\x00\x00\x00\x03" + + +def test_magic_preamble(): + # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_magic_preamble + preamble = 0x6060B017 + preamble_bytes = preamble.to_bytes(4, byteorder="big") + assert Bolt.MAGIC_PREAMBLE == preamble_bytes diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py new file mode 100644 index 000000000..b42512d0f --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -0,0 +1,115 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._sync.io._bolt3 import Bolt3 +from neo4j.conf import PoolConfig +from neo4j.exceptions import ConfigurationError + +from ..._async_compat import ( + MagicMock, + mark_sync_test, +) + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = Bolt3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = Bolt3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = Bolt3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +def test_db_extra_not_supported_in_begin(fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError): + connection.begin(db="something") + + +def test_db_extra_not_supported_in_run(fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError): + connection.run("", db="something") + + +@mark_sync_test +def test_simple_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt3(address, socket, PoolConfig.max_connection_lifetime) + connection.discard() + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 0 + + +@mark_sync_test +def test_simple_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt3(address, socket, PoolConfig.max_connection_lifetime) + connection.pull() + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 0 + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +@mark_sync_test +def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + sockets.server.send_message(0x70, { + "server": "Neo4j/3.5.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = Bolt3( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py new file mode 100644 index 000000000..5f94d5c0f --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -0,0 +1,209 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock + +import pytest + +from neo4j._sync.io._bolt4 import Bolt4x0 +from neo4j.conf import PoolConfig + +from ..._async_compat import mark_sync_test + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@mark_sync_test +def test_db_extra_in_begin(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(db="something") + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x11" + assert len(fields) == 1 + assert fields[0] == {"db": "something"} + + +@mark_sync_test +def test_db_extra_in_run(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.run("", {}, db="something") + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x10" + assert len(fields) == 3 + assert fields[0] == "" + assert fields[1] == {} + assert fields[2] == {"db": "something"} + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +@mark_sync_test +def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + sockets.server.send_message(0x70, { + "server": "Neo4j/4.0.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = Bolt4x0( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py new file mode 100644 index 000000000..2d69b9de2 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -0,0 +1,227 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._sync.io._bolt4 import Bolt4x1 +from neo4j.conf import PoolConfig + +from ..._async_compat import ( + MagicMock, + mark_sync_test, +) + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@mark_sync_test +def test_db_extra_in_begin(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(db="something") + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x11" + assert len(fields) == 1 + assert fields[0] == {"db": "something"} + + +@mark_sync_test +def test_db_extra_in_run(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.run("", {}, db="something") + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x10" + assert len(fields) == 3 + assert fields[0] == "" + assert fields[1] == {} + assert fields[2] == {"db": "something"} + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.server.send_message(0x70, {"server": "Neo4j/4.1.0"}) + connection = Bolt4x1( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == 0x01 + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +@mark_sync_test +def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + sockets.server.send_message(0x70, { + "server": "Neo4j/4.1.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = Bolt4x1(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py new file mode 100644 index 000000000..036057960 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -0,0 +1,228 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._sync.io._bolt4 import Bolt4x2 +from neo4j.conf import PoolConfig + +from ..._async_compat import ( + MagicMock, + mark_sync_test, +) + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@mark_sync_test +def test_db_extra_in_begin(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(db="something") + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x11" + assert len(fields) == 1 + assert fields[0] == {"db": "something"} + + +@mark_sync_test +def test_db_extra_in_run(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.run("", {}, db="something") + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x10" + assert len(fields) == 3 + assert fields[0] == "" + assert fields[1] == {} + assert fields[2] == {"db": "something"} + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.server.send_message(0x70, {"server": "Neo4j/4.2.0"}) + connection = Bolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == 0x01 + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +@mark_sync_test +def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + sockets.server.send_message(0x70, { + "server": "Neo4j/4.2.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = Bolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py new file mode 100644 index 000000000..43469fc91 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -0,0 +1,255 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import pytest + +from neo4j._sync.io._bolt4 import Bolt4x3 +from neo4j.conf import PoolConfig + +from ..._async_compat import ( + MagicMock, + mark_sync_test, +) + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@mark_sync_test +def test_db_extra_in_begin(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(db="something") + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x11" + assert len(fields) == 1 + assert fields[0] == {"db": "something"} + + +@mark_sync_test +def test_db_extra_in_run(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.run("", {}, db="something") + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x10" + assert len(fields) == 3 + assert fields[0] == "" + assert fields[1] == {} + assert fields[2] == {"db": "something"} + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0"}) + connection = Bolt4x3( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == 0x01 + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + sockets.server.send_message( + 0x70, {"server": "Neo4j/4.3.0", "hints": hints} + ) + connection = Bolt4x3(address, sockets.client, + PoolConfig.max_connection_lifetime) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py new file mode 100644 index 000000000..b2523b1ca --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -0,0 +1,271 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from unittest.mock import MagicMock + +import pytest + +from neo4j._sync.io._bolt4 import Bolt4x4 +from neo4j.conf import PoolConfig + +from ..._async_compat import ( + MagicMock, + mark_sync_test, +) + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = Bolt4x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = Bolt4x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = Bolt4x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.server.send_message(0x70, {"server": "Neo4j/4.4.0"}) + connection = Bolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == 0x01 + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + sockets.server.send_message( + 0x70, {"server": "Neo4j/4.3.4", "hints": hints} + ) + connection = Bolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py new file mode 100644 index 000000000..d5ff16cb3 --- /dev/null +++ b/tests/unit/sync/io/test_direct.py @@ -0,0 +1,231 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j import ( + Config, + PoolConfig, + WorkspaceConfig, +) +from neo4j._sync.io import Bolt +from neo4j._sync.io._pool import IOPool +from neo4j.exceptions import ( + ClientError, + ServiceUnavailable, +) + +from ..._async_compat import ( + mark_sync_test, + Mock, + mock, +) + + +class FakeSocket: + def __init__(self, address): + self.address = address + + def getpeername(self): + return self.address + + def sendall(self, data): + return + + def close(self): + return + + +class QuickConnection: + def __init__(self, socket): + self.socket = socket + self.address = socket.getpeername() + + @property + def is_reset(self): + return True + + def stale(self): + return False + + def reset(self): + pass + + def close(self): + self.socket.close() + + def closed(self): + return False + + def defunct(self): + return False + + def timedout(self): + return False + + +class FakeBoltPool(IOPool): + + def __init__(self, address, *, auth=None, **config): + self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) + if config: + raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) + + def opener(addr, timeout): + return QuickConnection(FakeSocket(addr)) + + super().__init__(opener, self.pool_config, self.workspace_config) + self.address = address + + def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + return self._acquire(self.address, timeout) + + +@mark_sync_test +def test_bolt_connection_open(): + with pytest.raises(ServiceUnavailable): + Bolt.open(("localhost", 9999), auth=("test", "test")) + + +@mark_sync_test +def test_bolt_connection_open_timeout(): + with pytest.raises(ServiceUnavailable): + Bolt.open(("localhost", 9999), auth=("test", "test"), + timeout=1) + + +@mark_sync_test +def test_bolt_connection_ping(): + protocol_version = Bolt.ping(("localhost", 9999)) + assert protocol_version is None + + +@mark_sync_test +def test_bolt_connection_ping_timeout(): + protocol_version = Bolt.ping(("localhost", 9999), timeout=1) + assert protocol_version is None + + +@pytest.fixture +def pool(): + with FakeBoltPool(("127.0.0.1", 7687)) as pool: + yield pool + + +def assert_pool_size( address, expected_active, expected_inactive, pool): + try: + connections = pool.connections[address] + except KeyError: + assert 0 == expected_active + assert 0 == expected_inactive + else: + assert expected_active == len([cx for cx in connections if cx.in_use]) + assert (expected_inactive + == len([cx for cx in connections if not cx.in_use])) + + +@mark_sync_test +def test_pool_can_acquire(pool): + address = ("127.0.0.1", 7687) + connection = pool._acquire(address, timeout=3) + assert connection.address == address + assert_pool_size(address, 1, 0, pool) + + +@mark_sync_test +def test_pool_can_acquire_twice(pool): + address = ("127.0.0.1", 7687) + connection_1 = pool._acquire(address, timeout=3) + connection_2 = pool._acquire(address, timeout=3) + assert connection_1.address == address + assert connection_2.address == address + assert connection_1 is not connection_2 + assert_pool_size(address, 2, 0, pool) + + +@mark_sync_test +def test_pool_can_acquire_two_addresses(pool): + address_1 = ("127.0.0.1", 7687) + address_2 = ("127.0.0.1", 7474) + connection_1 = pool._acquire(address_1, timeout=3) + connection_2 = pool._acquire(address_2, timeout=3) + assert connection_1.address == address_1 + assert connection_2.address == address_2 + assert_pool_size(address_1, 1, 0, pool) + assert_pool_size(address_2, 1, 0, pool) + + +@mark_sync_test +def test_pool_can_acquire_and_release(pool): + address = ("127.0.0.1", 7687) + connection = pool._acquire(address, timeout=3) + assert_pool_size(address, 1, 0, pool) + pool.release(connection) + assert_pool_size(address, 0, 1, pool) + + +@mark_sync_test +def test_pool_releasing_twice(pool): + address = ("127.0.0.1", 7687) + connection = pool._acquire(address, timeout=3) + pool.release(connection) + assert_pool_size(address, 0, 1, pool) + pool.release(connection) + assert_pool_size(address, 0, 1, pool) + + +@mark_sync_test +def test_pool_in_use_count(pool): + address = ("127.0.0.1", 7687) + assert pool.in_use_connection_count(address) == 0 + connection = pool._acquire(address, timeout=3) + assert pool.in_use_connection_count(address) == 1 + pool.release(connection) + assert pool.in_use_connection_count(address) == 0 + + +@mark_sync_test +def test_pool_max_conn_pool_size(pool): + with FakeBoltPool((), max_connection_pool_size=1) as pool: + address = ("127.0.0.1", 7687) + pool._acquire(address, timeout=0) + assert pool.in_use_connection_count(address) == 1 + with pytest.raises(ClientError): + pool._acquire(address, timeout=0) + assert pool.in_use_connection_count(address) == 1 + + +@pytest.mark.parametrize("is_reset", (True, False)) +@mark_sync_test +def test_pool_reset_when_released(is_reset, pool): + address = ("127.0.0.1", 7687) + quick_connection_name = QuickConnection.__name__ + with mock.patch(f"{__name__}.{quick_connection_name}.is_reset", + new_callable=mock.PropertyMock) as is_reset_mock: + with mock.patch(f"{__name__}.{quick_connection_name}.reset", + new_callable=Mock) as reset_mock: + is_reset_mock.return_value = is_reset + connection = pool._acquire(address, timeout=3) + assert isinstance(connection, QuickConnection) + assert is_reset_mock.call_count == 0 + assert reset_mock.call_count == 0 + pool.release(connection) + assert is_reset_mock.call_count == 1 + assert reset_mock.call_count == int(not is_reset) diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py new file mode 100644 index 000000000..6fb57b985 --- /dev/null +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -0,0 +1,259 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import Mock + +import pytest + +from neo4j import ( + READ_ACCESS, + WRITE_ACCESS, +) +from neo4j._sync.io import Neo4jPool +from neo4j.addressing import ResolvedAddress +from neo4j.conf import ( + PoolConfig, + RoutingConfig, + WorkspaceConfig, +) + +from ..._async_compat import ( + mark_sync_test, + Mock, +) +from ..work import FakeConnection + + +ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") +READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") +WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") + + +@pytest.fixture() +def opener(): + def open_(addr, timeout): + connection = FakeConnection() + connection.addr = addr + connection.timeout = timeout + route_mock = Mock() + route_mock.return_value = [{ + "ttl": 1000, + "servers": [ + {"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"}, + {"addresses": [str(READER_ADDRESS)], "role": "READ"}, + {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, + ], + }] + connection.attach_mock(route_mock, "route") + opener_.connections.append(connection) + return connection + + opener_ = Mock() + opener_.connections = [] + opener_.side_effect = open_ + return opener_ + + +@mark_sync_test +def test_acquires_new_routing_table_if_deleted(opener): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx) + assert pool.routing_tables.get("test_db") + + del pool.routing_tables["test_db"] + + cx = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx) + assert pool.routing_tables.get("test_db") + + +@mark_sync_test +def test_acquires_new_routing_table_if_stale(opener): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx) + assert pool.routing_tables.get("test_db") + + old_value = pool.routing_tables["test_db"].last_updated_time + pool.routing_tables["test_db"].ttl = 0 + + cx = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx) + assert pool.routing_tables["test_db"].last_updated_time > old_value + + +@mark_sync_test +def test_removes_old_routing_table(opener): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx = pool.acquire(READ_ACCESS, 30, "test_db1", None) + pool.release(cx) + assert pool.routing_tables.get("test_db1") + cx = pool.acquire(READ_ACCESS, 30, "test_db2", None) + pool.release(cx) + assert pool.routing_tables.get("test_db2") + + old_value = pool.routing_tables["test_db1"].last_updated_time + pool.routing_tables["test_db1"].ttl = 0 + pool.routing_tables["test_db2"].ttl = \ + -RoutingConfig.routing_table_purge_delay + + cx = pool.acquire(READ_ACCESS, 30, "test_db1", None) + pool.release(cx) + assert pool.routing_tables["test_db1"].last_updated_time > old_value + assert "test_db2" not in pool.routing_tables + + +@pytest.mark.parametrize("type_", ("r", "w")) +@mark_sync_test +def test_chooses_right_connection_type(opener, type_): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS, + 30, "test_db", None) + pool.release(cx1) + if type_ == "r": + assert cx1.addr == READER_ADDRESS + else: + assert cx1.addr == WRITER_ADDRESS + + +@mark_sync_test +def test_reuses_connection(opener): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + assert cx1 is cx2 + + +@pytest.mark.parametrize("break_on_close", (True, False)) +@mark_sync_test +def test_closes_stale_connections(opener, break_on_close): + def break_connection(): + pool.deactivate(cx1.addr) + + if cx_close_mock_side_effect: + cx_close_mock_side_effect() + + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + assert cx1 in pool.connections[cx1.addr] + # simulate connection going stale (e.g. exceeding) and then breaking when + # the pool tries to close the connection + cx1.stale.return_value = True + cx_close_mock = cx1.close + if break_on_close: + cx_close_mock_side_effect = cx_close_mock.side_effect + cx_close_mock.side_effect = break_connection + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx2) + if break_on_close: + cx1.close.assert_called() + else: + cx1.close.assert_called_once() + assert cx2 is not cx1 + assert cx2.addr == cx1.addr + assert cx1 not in pool.connections[cx1.addr] + assert cx2 in pool.connections[cx2.addr] + + +@mark_sync_test +def test_does_not_close_stale_connections_in_use(opener): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + assert cx1 in pool.connections[cx1.addr] + # simulate connection going stale (e.g. exceeding) while being in use + cx1.stale.return_value = True + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx2) + cx1.close.assert_not_called() + assert cx2 is not cx1 + assert cx2.addr == cx1.addr + assert cx1 in pool.connections[cx1.addr] + assert cx2 in pool.connections[cx2.addr] + + pool.release(cx1) + # now that cx1 is back in the pool and still stale, + # it should be closed when trying to acquire the next connection + cx1.close.assert_not_called() + + cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx3) + cx1.close.assert_called_once() + assert cx2 is cx3 + assert cx3.addr == cx1.addr + assert cx1 not in pool.connections[cx1.addr] + assert cx3 in pool.connections[cx2.addr] + + +@mark_sync_test +def test_release_resets_connections(opener): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1.is_reset_mock.return_value = False + cx1.is_reset_mock.reset_mock() + pool.release(cx1) + cx1.is_reset_mock.assert_called_once() + cx1.reset.assert_called_once() + + +@mark_sync_test +def test_release_does_not_resets_closed_connections(opener): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1.closed.return_value = True + cx1.closed.reset_mock() + cx1.is_reset_mock.reset_mock() + pool.release(cx1) + cx1.closed.assert_called_once() + cx1.is_reset_mock.asset_not_called() + cx1.reset.asset_not_called() + + +@mark_sync_test +def test_release_does_not_resets_defunct_connections(opener): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1.defunct.return_value = True + cx1.defunct.reset_mock() + cx1.is_reset_mock.reset_mock() + pool.release(cx1) + cx1.defunct.assert_called_once() + cx1.is_reset_mock.asset_not_called() + cx1.reset.asset_not_called() diff --git a/tests/unit/sync/test_addressing.py b/tests/unit/sync/test_addressing.py new file mode 100644 index 000000000..4fd814b25 --- /dev/null +++ b/tests/unit/sync/test_addressing.py @@ -0,0 +1,125 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from socket import ( + AF_INET, + AF_INET6, +) +import unittest.mock as mock + +import pytest + +from neo4j import ( + Address, + IPv4Address, +) +from neo4j._async_compat.network import NetworkUtil +from neo4j._async_compat.util import Util + +from .._async_compat import mark_sync_test + + +mock_socket_ipv4 = mock.Mock() +mock_socket_ipv4.getpeername = lambda: ("127.0.0.1", 7687) # (address, port) + +mock_socket_ipv6 = mock.Mock() +mock_socket_ipv6.getpeername = lambda: ("[::1]", 7687, 0, 0) # (address, port, flow info, scope id) + + +@mark_sync_test +def test_address_resolve(): + address = Address(("127.0.0.1", 7687)) + resolved = NetworkUtil.resolve_address(address) + resolved = Util.list(resolved) + assert isinstance(resolved, Address) is False + assert isinstance(resolved, list) is True + assert len(resolved) == 1 + assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) + + +@mark_sync_test +def test_address_resolve_with_custom_resolver_none(): + address = Address(("127.0.0.1", 7687)) + resolved = NetworkUtil.resolve_address(address, resolver=None) + resolved = Util.list(resolved) + assert isinstance(resolved, Address) is False + assert isinstance(resolved, list) is True + assert len(resolved) == 1 + assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (Address(("127.0.0.1", "abcd")), ValueError), + (Address((None, None)), ValueError), + ] + +) +@mark_sync_test +def test_address_resolve_with_unresolvable_address(test_input, expected): + with pytest.raises(expected): + Util.list( + NetworkUtil.resolve_address(test_input, resolver=None) + ) + + +@mark_sync_test +@pytest.mark.parametrize("resolver_type", ("sync", "async")) +def test_address_resolve_with_custom_resolver(resolver_type): + def custom_resolver_sync(_): + return [("127.0.0.1", 7687), ("localhost", 1234)] + + def custom_resolver_async(_): + return [("127.0.0.1", 7687), ("localhost", 1234)] + + if resolver_type == "sync": + custom_resolver = custom_resolver_sync + else: + custom_resolver = custom_resolver_async + + address = Address(("127.0.0.1", 7687)) + resolved = NetworkUtil.resolve_address( + address, family=AF_INET, resolver=custom_resolver + ) + resolved = Util.list(resolved) + assert isinstance(resolved, Address) is False + assert isinstance(resolved, list) is True + assert len(resolved) == 2 # IPv4 only + assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) + assert resolved[1] == IPv4Address(('127.0.0.1', 1234)) + + +@mark_sync_test +def test_address_unresolve(): + custom_resolved = [("127.0.0.1", 7687), ("localhost", 4321)] + custom_resolver = lambda _: custom_resolved + + address = Address(("foobar", 1234)) + unresolved = address.unresolved + assert address.__class__ == unresolved.__class__ + assert address == unresolved + resolved = NetworkUtil.resolve_address( + address, family=AF_INET, resolver=custom_resolver + ) + resolved = Util.list(resolved) + custom_resolved = sorted(Address(a) for a in custom_resolved) + unresolved = sorted(a.unresolved for a in resolved) + assert custom_resolved == unresolved + assert (list(map(lambda a: a.__class__, custom_resolved)) + == list(map(lambda a: a.__class__, unresolved))) diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py new file mode 100644 index 000000000..93579b2e6 --- /dev/null +++ b/tests/unit/sync/test_driver.py @@ -0,0 +1,157 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j import ( + BoltDriver, + GraphDatabase, + Neo4jDriver, + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, +) +from neo4j.api import WRITE_ACCESS +from neo4j.exceptions import ConfigurationError + +from .._async_compat import ( + mark_sync_test, + mock, +) + + +@pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) +@pytest.mark.parametrize("host", ("localhost", "127.0.0.1", + "[::1]", "[0:0:0:0:0:0:0:1]")) +@pytest.mark.parametrize("port", (":1234", "", ":7687")) +@pytest.mark.parametrize("auth_token", (("test", "test"), None)) +def test_direct_driver_constructor(protocol, host, port, auth_token): + uri = protocol + host + port + driver = GraphDatabase.driver(uri, auth=auth_token) + assert isinstance(driver, BoltDriver) + + +@pytest.mark.parametrize("protocol", ("neo4j://", "neo4j+s://", "neo4j+ssc://")) +@pytest.mark.parametrize("host", ("localhost", "127.0.0.1", + "[::1]", "[0:0:0:0:0:0:0:1]")) +@pytest.mark.parametrize("port", (":1234", "", ":7687")) +@pytest.mark.parametrize("auth_token", (("test", "test"), None)) +def test_routing_driver_constructor(protocol, host, port, auth_token): + uri = protocol + host + port + driver = GraphDatabase.driver(uri, auth=auth_token) + assert isinstance(driver, Neo4jDriver) + + +@pytest.mark.parametrize("test_uri", ( + "bolt+ssc://127.0.0.1:9001", + "bolt+s://127.0.0.1:9001", + "bolt://127.0.0.1:9001", + "neo4j+ssc://127.0.0.1:9001", + "neo4j+s://127.0.0.1:9001", + "neo4j://127.0.0.1:9001", +)) +@pytest.mark.parametrize( + ("test_config", "expected_failure", "expected_failure_message"), + ( + ({"encrypted": False}, ConfigurationError, "The config settings"), + ({"encrypted": True}, ConfigurationError, "The config settings"), + ( + {"encrypted": True, "trust": TRUST_ALL_CERTIFICATES}, + ConfigurationError, "The config settings" + ), + ( + {"trust": TRUST_ALL_CERTIFICATES}, + ConfigurationError, "The config settings" + ), + ( + {"trust": TRUST_SYSTEM_CA_SIGNED_CERTIFICATES}, + ConfigurationError, "The config settings" + ), + ) +) +def test_driver_config_error( + test_uri, test_config, expected_failure, expected_failure_message +): + if "+" in test_uri: + # `+s` and `+ssc` are short hand syntax for not having to configure the + # encryption behavior of the driver. Specifying both is invalid. + with pytest.raises(expected_failure, match=expected_failure_message): + GraphDatabase.driver(test_uri, **test_config) + else: + GraphDatabase.driver(test_uri, **test_config) + + +@pytest.mark.parametrize("test_uri", ( + "http://localhost:9001", + "ftp://localhost:9001", + "x://localhost:9001", +)) +def test_invalid_protocol(test_uri): + with pytest.raises(ConfigurationError, match="scheme"): + GraphDatabase.driver(test_uri) + + +@pytest.mark.parametrize( + ("test_config", "expected_failure", "expected_failure_message"), + ( + ({"trust": 1}, ConfigurationError, "The config setting `trust`"), + ({"trust": True}, ConfigurationError, "The config setting `trust`"), + ({"trust": None}, ConfigurationError, "The config setting `trust`"), + ) +) +def test_driver_trust_config_error( + test_config, expected_failure, expected_failure_message +): + with pytest.raises(expected_failure, match=expected_failure_message): + GraphDatabase.driver("bolt://127.0.0.1:9001", **test_config) + + +@pytest.mark.parametrize("uri", ( + "bolt://127.0.0.1:9000", + "neo4j://127.0.0.1:9000", +)) +@mark_sync_test +def test_driver_opens_write_session_by_default(uri, mocker): + driver = GraphDatabase.driver(uri) + from neo4j import Transaction + + # we set a specific db, because else the driver would try to fetch a RT + # to get hold of the actual home database (which won't work in this + # unittest) + with driver.session(database="foobar") as session: + with mock.patch.object( + session._pool, "acquire", autospec=True + ) as acquire_mock: + with mock.patch.object( + Transaction, "_begin", autospec=True + ) as tx_begin_mock: + tx = session.begin_transaction() + acquire_mock.assert_called_once_with( + access_mode=WRITE_ACCESS, + timeout=mocker.ANY, + database=mocker.ANY, + bookmarks=mocker.ANY + ) + tx_begin_mock.assert_called_once_with( + tx, + mocker.ANY, + mocker.ANY, + mocker.ANY, + WRITE_ACCESS, + mocker.ANY, + mocker.ANY + ) diff --git a/tests/unit/sync/work/__init__.py b/tests/unit/sync/work/__init__.py new file mode 100644 index 000000000..2613b53d2 --- /dev/null +++ b/tests/unit/sync/work/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._fake_connection import ( + fake_connection, + FakeConnection, +) diff --git a/tests/unit/sync/work/_fake_connection.py b/tests/unit/sync/work/_fake_connection.py new file mode 100644 index 000000000..1748ea61a --- /dev/null +++ b/tests/unit/sync/work/_fake_connection.py @@ -0,0 +1,110 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect + +import pytest + +from neo4j import ServerInfo +from neo4j._sync.io import Bolt + +from ..._async_compat import ( + Mock, + mock, +) + + +class FakeConnection(mock.NonCallableMagicMock): + callbacks = [] + server_info = ServerInfo("127.0.0.1", (4, 3)) + + def __init__(self, *args, **kwargs): + kwargs["spec"] = Bolt + super().__init__(*args, **kwargs) + self.attach_mock(Mock(return_value=True), "is_reset_mock") + self.attach_mock(Mock(return_value=False), "defunct") + self.attach_mock(Mock(return_value=False), "stale") + self.attach_mock(Mock(return_value=False), "closed") + self.attach_mock(Mock(), "unresolved_address") + + def close_side_effect(): + self.closed.return_value = True + + self.attach_mock(Mock(side_effect=close_side_effect), + "close") + + @property + def is_reset(self): + if self.closed.return_value or self.defunct.return_value: + raise AssertionError( + "is_reset should not be called on a closed or defunct " + "connection." + ) + return self.is_reset_mock() + + def fetch_message(self, *args, **kwargs): + if self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_message")(*args, **kwargs) + + def fetch_all(self, *args, **kwargs): + while self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_all")(*args, **kwargs) + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + def func(*args, **kwargs): + def callback(): + for cb_name, param_count in ( + ("on_success", 1), + ("on_summary", 0) + ): + cb = kwargs.get(cb_name, None) + if callable(cb): + try: + param_count = \ + len(inspect.signature(cb).parameters) + except ValueError: + # e.g. built-in method as cb + pass + if param_count == 1: + res = cb({}) + else: + res = cb() + try: + res # maybe the callback is async + except TypeError: + pass # or maybe it wasn't ;) + self.callbacks.append(callback) + + return func + + method_mock = parent.__getattr__(name) + if name in ("run", "commit", "pull", "rollback", "discard"): + method_mock.side_effect = build_message_handler(name) + return method_mock + + +@pytest.fixture +def fake_connection(): + return FakeConnection() diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py new file mode 100644 index 000000000..863df8388 --- /dev/null +++ b/tests/unit/sync/work/test_result.py @@ -0,0 +1,456 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest import mock + +import pytest + +from neo4j import ( + Address, + Record, + Result, + ResultSummary, + ServerInfo, + SummaryCounters, + Version, +) +from neo4j._async_compat.util import Util +from neo4j.data import DataHydrator + +from ..._async_compat import mark_sync_test + + +class Records: + def __init__(self, fields, records): + assert all(len(fields) == len(r) for r in records) + self.fields = fields + # self.records = [{"record_values": r} for r in records] + self.records = records + + def __len__(self): + return self.records.__len__() + + def __iter__(self): + return self.records.__iter__() + + def __getitem__(self, item): + return self.records.__getitem__(item) + + +class ConnectionStub: + class Message: + def __init__(self, message, *args, **kwargs): + self.message = message + self.args = args + self.kwargs = kwargs + + def _cb(self, cb_name, *args, **kwargs): + # print(self.message, cb_name.upper(), args, kwargs) + cb = self.kwargs.get(cb_name) + Util.callback(cb, *args, **kwargs) + + def on_success(self, metadata): + self._cb("on_success", metadata) + + def on_summary(self): + self._cb("on_summary") + + def on_records(self, records): + self._cb("on_records", records) + + def __eq__(self, other): + return self.message == other + + def __repr__(self): + return "Message(%s)" % self.message + + def __init__(self, records=None, run_meta=None, summary_meta=None, + force_qid=False): + self._multi_result = isinstance(records, (list, tuple)) + if self._multi_result: + self._records = records + self._use_qid = True + else: + self._records = records, + self._use_qid = force_qid + self.fetch_idx = 0 + self._qid = -1 + self.most_recent_qid = None + self.record_idxs = [0] * len(self._records) + self.to_pull = [None] * len(self._records) + self._exhausted = [False] * len(self._records) + self.queued = [] + self.sent = [] + self.run_meta = run_meta + self.summary_meta = summary_meta + ConnectionStub.server_info.update({"server": "Neo4j/4.3.0"}) + self.unresolved_address = None + + def send_all(self): + self.sent += self.queued + self.queued = [] + + def fetch_message(self): + if self.fetch_idx >= len(self.sent): + pytest.fail("Waits for reply to never sent message") + msg = self.sent[self.fetch_idx] + if msg == "RUN": + self.fetch_idx += 1 + self._qid += 1 + meta = {"fields": self._records[self._qid].fields, + **(self.run_meta or {})} + if self._use_qid: + meta.update(qid=self._qid) + msg.on_success(meta) + elif msg == "DISCARD": + self.fetch_idx += 1 + qid = msg.kwargs.get("qid", -1) + if qid < 0: + qid = self._qid + self.record_idxs[qid] = len(self._records[qid]) + msg.on_success(self.summary_meta or {}) + msg.on_summary() + elif msg == "PULL": + qid = msg.kwargs.get("qid", -1) + if qid < 0: + qid = self._qid + if self._exhausted[qid]: + pytest.fail("PULLing exhausted result") + if self.to_pull[qid] is None: + n = msg.kwargs.get("n", -1) + if n < 0: + n = len(self._records[qid]) + self.to_pull[qid] = \ + min(n, len(self._records[qid]) - self.record_idxs[qid]) + # if to == len(self._records): + # self.fetch_idx += 1 + if self.to_pull[qid] > 0: + record = self._records[qid][self.record_idxs[qid]] + self.record_idxs[qid] += 1 + self.to_pull[qid] -= 1 + msg.on_records([record]) + elif self.to_pull[qid] == 0: + self.to_pull[qid] = None + self.fetch_idx += 1 + if self.record_idxs[qid] < len(self._records[qid]): + msg.on_success({"has_more": True}) + else: + msg.on_success( + {"bookmark": "foo", **(self.summary_meta or {})} + ) + self._exhausted[qid] = True + msg.on_summary() + + def fetch_all(self): + while self.fetch_idx < len(self.sent): + self.fetch_message() + + def run(self, *args, **kwargs): + self.queued.append(ConnectionStub.Message("RUN", *args, **kwargs)) + + def discard(self, *args, **kwargs): + self.queued.append(ConnectionStub.Message("DISCARD", *args, **kwargs)) + + def pull(self, *args, **kwargs): + self.queued.append(ConnectionStub.Message("PULL", *args, **kwargs)) + + server_info = ServerInfo(Address(("bolt://localhost", 7687)), Version(4, 3)) + + def defunct(self): + return False + + +class HydratorStub(DataHydrator): + def hydrate(self, values): + return values + + +def noop(*_, **__): + pass + + +def fetch_and_compare_all_records( + result, key, expected_records, method, limit=None +): + received_records = [] + if method == "for loop": + for record in result: + assert isinstance(record, Record) + received_records.append([record.data().get(key, None)]) + if limit is not None and len(received_records) == limit: + break + if limit is None: + assert result._closed + elif method == "next": + iter_ = Util.iter(result) + n = len(expected_records) if limit is None else limit + for _ in range(n): + record = Util.next(iter_) + received_records.append([record.get(key, None)]) + if limit is None: + with pytest.raises(StopIteration): + Util.next(iter_) + assert result._closed + elif method == "new iter": + n = len(expected_records) if limit is None else limit + for _ in range(n): + iter_ = Util.iter(result) + record = Util.next(iter_) + received_records.append([record.get(key, None)]) + if limit is None: + iter_ = Util.iter(result) + with pytest.raises(StopIteration): + Util.next(iter_) + assert result._closed + else: + raise ValueError() + assert received_records == expected_records + + +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +@pytest.mark.parametrize("records", ( + [], + [[42]], + [[1], [2], [3], [4], [5]], +)) +@mark_sync_test +def test_result_iteration(method, records): + connection = ConnectionStub(records=Records(["x"], records)) + result = Result(connection, HydratorStub(), 2, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + fetch_and_compare_all_records(result, "x", records, method) + + +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +@pytest.mark.parametrize("invert_fetch", (True, False)) +@mark_sync_test +def test_parallel_result_iteration(method, invert_fetch): + records1 = [[i] for i in range(1, 6)] + records2 = [[i] for i in range(6, 11)] + connection = ConnectionStub( + records=(Records(["x"], records1), Records(["x"], records2)) + ) + result1 = Result(connection, HydratorStub(), 2, noop, noop) + result1._run("CYPHER1", {}, None, None, "r", None) + result2 = Result(connection, HydratorStub(), 2, noop, noop) + result2._run("CYPHER2", {}, None, None, "r", None) + if invert_fetch: + fetch_and_compare_all_records( + result2, "x", records2, method + ) + fetch_and_compare_all_records( + result1, "x", records1, method + ) + else: + fetch_and_compare_all_records( + result1, "x", records1, method + ) + fetch_and_compare_all_records( + result2, "x", records2, method + ) + + +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +@pytest.mark.parametrize("invert_fetch", (True, False)) +@mark_sync_test +def test_interwoven_result_iteration(method, invert_fetch): + records1 = [[i] for i in range(1, 10)] + records2 = [[i] for i in range(11, 20)] + connection = ConnectionStub( + records=(Records(["x"], records1), Records(["y"], records2)) + ) + result1 = Result(connection, HydratorStub(), 2, noop, noop) + result1._run("CYPHER1", {}, None, None, "r", None) + result2 = Result(connection, HydratorStub(), 2, noop, noop) + result2._run("CYPHER2", {}, None, None, "r", None) + start = 0 + for n in (1, 2, 3, 1, None): + end = n if n is None else start + n + if invert_fetch: + fetch_and_compare_all_records( + result2, "y", records2[start:end], method, n + ) + fetch_and_compare_all_records( + result1, "x", records1[start:end], method, n + ) + else: + fetch_and_compare_all_records( + result1, "x", records1[start:end], method, n + ) + fetch_and_compare_all_records( + result2, "y", records2[start:end], method, n + ) + start = end + + +@pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) +@pytest.mark.parametrize("fetch_size", (1, 2)) +@mark_sync_test +def test_result_peek(records, fetch_size): + connection = ConnectionStub(records=Records(["x"], records)) + result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + for i in range(len(records) + 1): + record = result.peek() + if i == len(records): + assert record is None + else: + assert isinstance(record, Record) + assert record.get("x") == records[i][0] + iter_ = Util.iter(result) + Util.next(iter_) # consume the record + + +@pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) +@pytest.mark.parametrize("fetch_size", (1, 2)) +@mark_sync_test +def test_result_single(records, fetch_size): + connection = ConnectionStub(records=Records(["x"], records)) + result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + with pytest.warns(None) as warning_record: + record = result.single() + if not records: + assert not warning_record + assert record is None + else: + if len(records) > 1: + assert len(warning_record) == 1 + else: + assert not warning_record + assert isinstance(record, Record) + assert record.get("x") == records[0][0] + + +@mark_sync_test +def test_keys_are_available_before_and_after_stream(): + connection = ConnectionStub(records=Records(["x"], [[1], [2]])) + result = Result(connection, HydratorStub(), 1, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + assert list(result.keys()) == ["x"] + Util.list(result) + assert list(result.keys()) == ["x"] + + +@pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) +@pytest.mark.parametrize("consume_one", (True, False)) +@pytest.mark.parametrize("summary_meta", (None, {"database": "foobar"})) +@mark_sync_test +def test_consume(records, consume_one, summary_meta): + connection = ConnectionStub( + records=Records(["x"], records), summary_meta=summary_meta + ) + result = Result(connection, HydratorStub(), 1, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + if consume_one: + try: + Util.next(Util.iter(result)) + except StopIteration: + pass + summary = result.consume() + assert isinstance(summary, ResultSummary) + if summary_meta and "db" in summary_meta: + assert summary.database == summary_meta["db"] + else: + assert summary.database is None + server_info = summary.server + assert isinstance(server_info, ServerInfo) + assert server_info.version_info() == Version(4, 3) + assert server_info.protocol_version == Version(4, 3) + assert isinstance(summary.counters, SummaryCounters) + + +@pytest.mark.parametrize("t_first", (None, 0, 1, 123456789)) +@pytest.mark.parametrize("t_last", (None, 0, 1, 123456789)) +@mark_sync_test +def test_time_in_summary(t_first, t_last): + run_meta = None + if t_first is not None: + run_meta = {"t_first": t_first} + summary_meta = None + if t_last is not None: + summary_meta = {"t_last": t_last} + connection = ConnectionStub( + records=Records(["n"], [[i] for i in range(100)]), run_meta=run_meta, + summary_meta=summary_meta + ) + + result = Result(connection, HydratorStub(), 1, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + summary = result.consume() + + if t_first is not None: + assert isinstance(summary.result_available_after, int) + assert summary.result_available_after == t_first + else: + assert summary.result_available_after is None + if t_last is not None: + assert isinstance(summary.result_consumed_after, int) + assert summary.result_consumed_after == t_last + else: + assert summary.result_consumed_after is None + assert not hasattr(summary, "t_first") + assert not hasattr(summary, "t_last") + + +@mark_sync_test +def test_counts_in_summary(): + connection = ConnectionStub(records=Records(["n"], [[1], [2]])) + + result = Result(connection, HydratorStub(), 1, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + summary = result.consume() + + assert isinstance(summary.counters, SummaryCounters) + + +@pytest.mark.parametrize("query_type", ("r", "w", "rw", "s")) +@mark_sync_test +def test_query_type(query_type): + connection = ConnectionStub( + records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type} + ) + + result = Result(connection, HydratorStub(), 1, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + summary = result.consume() + + assert isinstance(summary.query_type, str) + assert summary.query_type == query_type + + +@pytest.mark.parametrize("num_records", range(0, 5)) +@mark_sync_test +def test_data(num_records): + connection = ConnectionStub( + records=Records(["n"], [[i + 1] for i in range(num_records)]) + ) + + result = Result(connection, HydratorStub(), 1, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + result._buffer_all() + records = result._record_buffer.copy() + assert len(records) == num_records + expected_data = [] + for i, record in enumerate(records): + record.data = mock.Mock() + expected_data.append("magic_return_%s" % i) + record.data.return_value = expected_data[-1] + assert result.data("hello", "world") == expected_data + for record in records: + assert record.data.called_once_with("hello", "world") diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py new file mode 100644 index 000000000..4a12f695d --- /dev/null +++ b/tests/unit/sync/work/test_session.py @@ -0,0 +1,285 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from contextlib import contextmanager + +import pytest + +from neo4j import ( + Session, + SessionConfig, + Transaction, + unit_of_work, +) +from neo4j._sync.io._pool import IOPool + +from ..._async_compat import ( + mark_sync_test, + Mock, + mock, +) +from ._fake_connection import FakeConnection + + +@pytest.fixture() +def pool(): + pool = Mock(spec=IOPool) + pool.acquire.side_effect = iter(FakeConnection, 0) + return pool + + +@mark_sync_test +def test_session_context_calls_close(): + s = Session(None, SessionConfig()) + with mock.patch.object(s, 'close', autospec=True) as mock_close: + with s: + pass + mock_close.assert_called_once_with() + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@pytest.mark.parametrize(("repetitions", "consume"), ( + (1, False), (2, False), (2, True) +)) +@mark_sync_test +def test_opens_connection_on_run( + pool, test_run_args, repetitions, consume +): + with Session(pool, SessionConfig()) as session: + assert session._connection is None + result = session.run(*test_run_args) + assert session._connection is not None + if consume: + result.consume() + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@pytest.mark.parametrize("repetitions", range(1, 3)) +@mark_sync_test +def test_closes_connection_after_consume( + pool, test_run_args, repetitions +): + with Session(pool, SessionConfig()) as session: + result = session.run(*test_run_args) + result.consume() + assert session._connection is None + assert session._connection is None + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@mark_sync_test +def test_keeps_connection_until_last_result_consumed( + pool, test_run_args +): + with Session(pool, SessionConfig()) as session: + result1 = session.run(*test_run_args) + result2 = session.run(*test_run_args) + assert session._connection is not None + result1.consume() + assert session._connection is not None + result2.consume() + assert session._connection is None + + +@mark_sync_test +def test_opens_connection_on_tx_begin(pool): + with Session(pool, SessionConfig()) as session: + assert session._connection is None + with session.begin_transaction() as _: + assert session._connection is not None + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@pytest.mark.parametrize("repetitions", range(1, 3)) +@mark_sync_test +def test_keeps_connection_on_tx_run(pool, test_run_args, repetitions): + with Session(pool, SessionConfig()) as session: + with session.begin_transaction() as tx: + for _ in range(repetitions): + tx.run(*test_run_args) + assert session._connection is not None + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@pytest.mark.parametrize("repetitions", range(1, 3)) +@mark_sync_test +def test_keeps_connection_on_tx_consume( + pool, test_run_args, repetitions +): + with Session(pool, SessionConfig()) as session: + with session.begin_transaction() as tx: + for _ in range(repetitions): + result = tx.run(*test_run_args) + result.consume() + assert session._connection is not None + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@mark_sync_test +def test_closes_connection_after_tx_close(pool, test_run_args): + with Session(pool, SessionConfig()) as session: + with session.begin_transaction() as tx: + for _ in range(2): + result = tx.run(*test_run_args) + result.consume() + tx.close() + assert session._connection is None + assert session._connection is None + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@mark_sync_test +def test_closes_connection_after_tx_commit(pool, test_run_args): + with Session(pool, SessionConfig()) as session: + with session.begin_transaction() as tx: + for _ in range(2): + result = tx.run(*test_run_args) + result.consume() + tx.commit() + assert session._connection is None + assert session._connection is None + + +@pytest.mark.parametrize("bookmarks", (None, [], ["abc"], ["foo", "bar"])) +@mark_sync_test +def test_session_returns_bookmark_directly(pool, bookmarks): + with Session( + pool, SessionConfig(bookmarks=bookmarks) + ) as session: + if bookmarks: + assert session.last_bookmark() == bookmarks[-1] + else: + assert session.last_bookmark() is None + + +@pytest.mark.parametrize(("query", "error_type"), ( + (None, ValueError), + (1234, TypeError), + ({"how about": "no?"}, TypeError), + (["I don't", "think so"], TypeError), +)) +@mark_sync_test +def test_session_run_wrong_types(pool, query, error_type): + with Session(pool, SessionConfig()) as session: + with pytest.raises(error_type): + session.run(query) + + +@pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction")) +@mark_sync_test +def test_tx_function_argument_type(pool, tx_type): + def work(tx): + assert isinstance(tx, Transaction) + + with Session(pool, SessionConfig()) as session: + getattr(session, tx_type)(work) + + +@pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction")) +@pytest.mark.parametrize("decorator_kwargs", ( + {}, + {"timeout": 5}, + {"metadata": {"foo": "bar"}}, + {"timeout": 5, "metadata": {"foo": "bar"}}, + +)) +@mark_sync_test +def test_decorated_tx_function_argument_type(pool, tx_type, decorator_kwargs): + @unit_of_work(**decorator_kwargs) + def work(tx): + assert isinstance(tx, Transaction) + + with Session(pool, SessionConfig()) as session: + getattr(session, tx_type)(work) + + +@mark_sync_test +def test_session_tx_type(pool): + with Session(pool, SessionConfig()) as session: + tx = session.begin_transaction() + assert isinstance(tx, Transaction) + + +@pytest.mark.parametrize(("parameters", "error_type"), ( + ({"x": None}, None), + ({"x": True}, None), + ({"x": False}, None), + ({"x": 123456789}, None), + ({"x": 3.1415926}, None), + ({"x": float("nan")}, None), + ({"x": float("inf")}, None), + ({"x": float("-inf")}, None), + ({"x": "foo"}, None), + ({"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])}, None), + ({"x": b"\x00\x33\x66\x99\xcc\xff"}, None), + ({"x": [1, 2, 3]}, None), + ({"x": ["a", "b", "c"]}, None), + ({"x": ["a", 2, 1.234]}, None), + ({"x": ["a", 2, ["c"]]}, None), + ({"x": {"one": "eins", "two": "zwei", "three": "drei"}}, None), + ({"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}}, None), + + # maps must have string keys + ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), + ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), +)) +@pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed")) +@mark_sync_test +def test_session_run_with_parameters( + pool, parameters, error_type, run_type +): + @contextmanager + def raises(): + if error_type is not None: + with pytest.raises(error_type) as exc: + yield exc + else: + yield None + + with Session(pool, SessionConfig()) as session: + if run_type == "auto": + with raises(): + session.run("RETURN $x", **parameters) + elif run_type == "unmanaged": + tx = session.begin_transaction() + with raises(): + tx.run("RETURN $x", **parameters) + elif run_type == "managed": + def work(tx): + with raises() as exc: + tx.run("RETURN $x", **parameters) + if exc is not None: + raise exc + with raises(): + session.write_transaction(work) + else: + raise ValueError(run_type) diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py new file mode 100644 index 000000000..b5b40283c --- /dev/null +++ b/tests/unit/sync/work/test_transaction.py @@ -0,0 +1,185 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from neo4j import ( + Query, + Transaction, +) + +from ._fake_connection import fake_connection + + +@pytest.mark.parametrize(("explicit_commit", "close"), ( + (False, False), + (True, False), + (True, True), +)) +def test_transaction_context_when_committing(mocker, fake_connection, + explicit_commit, close): + on_closed = MagicMock() + on_error = MagicMock() + tx = Transaction(fake_connection, 2, on_closed, on_error) + mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) + mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + with tx as tx_: + assert mock_commit.call_count == 0 + assert mock_rollback.call_count == 0 + assert tx is tx_ + if explicit_commit: + tx_.commit() + mock_commit.assert_called_once_with() + assert tx.closed() + if close: + tx_.close() + assert tx_.closed() + mock_commit.assert_called_once_with() + assert mock_rollback.call_count == 0 + assert tx_.closed() + + +@pytest.mark.parametrize(("rollback", "close"), ( + (True, False), + (False, True), + (True, True), +)) +def test_transaction_context_with_explicit_rollback(mocker, fake_connection, + rollback, close): + on_closed = MagicMock() + on_error = MagicMock() + tx = Transaction(fake_connection, 2, on_closed, on_error) + mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) + mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + with tx as tx_: + assert mock_commit.call_count == 0 + assert mock_rollback.call_count == 0 + assert tx is tx_ + if rollback: + tx_.rollback() + mock_rollback.assert_called_once_with() + assert tx_.closed() + if close: + tx_.close() + mock_rollback.assert_called_once_with() + assert tx_.closed() + assert mock_commit.call_count == 0 + mock_rollback.assert_called_once_with() + assert tx_.closed() + + +def test_transaction_context_calls_rollback_on_error(mocker, fake_connection): + class OopsError(RuntimeError): + pass + + on_closed = MagicMock() + on_error = MagicMock() + tx = Transaction(fake_connection, 2, on_closed, on_error) + mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) + mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + with pytest.raises(OopsError): + with tx as tx_: + assert mock_commit.call_count == 0 + assert mock_rollback.call_count == 0 + assert tx is tx_ + raise OopsError + assert mock_commit.call_count == 0 + mock_rollback.assert_called_once_with() + assert tx_.closed() + + +@pytest.mark.parametrize(("parameters", "error_type"), ( + # maps must have string keys + ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), + ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), + ({"x": uuid4()}, TypeError), +)) +def test_transaction_run_with_invalid_parameters(fake_connection, parameters, + error_type): + on_closed = MagicMock() + on_error = MagicMock() + tx = Transaction(fake_connection, 2, on_closed, on_error) + with pytest.raises(error_type): + tx.run("RETURN $x", **parameters) + + +def test_transaction_run_takes_no_query_object(fake_connection): + on_closed = MagicMock() + on_error = MagicMock() + tx = Transaction(fake_connection, 2, on_closed, on_error) + with pytest.raises(ValueError): + tx.run(Query("RETURN 1")) + + +def test_transaction_rollbacks_on_open_connections(fake_connection): + tx = Transaction(fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + fake_connection.is_reset_mock.return_value = False + fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + fake_connection.is_reset_mock.assert_called_once() + fake_connection.reset.assert_not_called() + fake_connection.rollback.assert_called_once() + + +def test_transaction_no_rollback_on_reset_connections(fake_connection): + tx = Transaction(fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + fake_connection.is_reset_mock.return_value = True + fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + fake_connection.is_reset_mock.assert_called_once() + fake_connection.reset.asset_not_called() + fake_connection.rollback.asset_not_called() + + +def test_transaction_no_rollback_on_closed_connections(fake_connection): + tx = Transaction(fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + fake_connection.closed.return_value = True + fake_connection.closed.reset_mock() + fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + fake_connection.closed.assert_called_once() + fake_connection.is_reset_mock.asset_not_called() + fake_connection.reset.asset_not_called() + fake_connection.rollback.asset_not_called() + + +def test_transaction_no_rollback_on_defunct_connections(fake_connection): + tx = Transaction(fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + fake_connection.defunct.return_value = True + fake_connection.defunct.reset_mock() + fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + fake_connection.defunct.assert_called_once() + fake_connection.is_reset_mock.asset_not_called() + fake_connection.reset.asset_not_called() + fake_connection.rollback.asset_not_called()