From 364c001b3eba7f3d437fa98c4e4995c5af2389cc Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Thu, 16 Jun 2022 14:34:09 +0200 Subject: [PATCH 1/3] UTC encoding patch for Bolt 4.4 and 4.3 This is to be considered quite literally a patch. A proper implementation of this functionality would require a lot of refactoring that potentially involves moving modules and packages around that are undocumented, hence they shouldn't be used, but they are not marked as internal by a leading underscore. Therefore, I don't want to introduce such a refactoring in a patch release of the driver. --- neo4j/data.py | 34 ++++++++--- neo4j/io/__init__.py | 2 + neo4j/io/_bolt4.py | 17 ++++++ neo4j/time/hydration.py | 81 +++++++++++++++++++++++++- neo4j/work/result.py | 5 +- neo4j/work/simple.py | 2 +- neo4j/work/transaction.py | 6 +- testkitbackend/test_config.json | 1 + tests/unit/time/test_dehydration.py | 88 +++++++++++++++++++++++++++-- tests/unit/time/test_hydration.py | 75 +++++++++++++++++++++++- tests/unit/work/_fake_connection.py | 1 + tests/unit/work/test_result.py | 1 + 12 files changed, 295 insertions(+), 18 deletions(-) diff --git a/neo4j/data.py b/neo4j/data.py index e1a3a959c..58d8d8ad4 100644 --- a/neo4j/data.py +++ b/neo4j/data.py @@ -34,6 +34,7 @@ hydrate_date, dehydrate_date, hydrate_time, dehydrate_time, hydrate_datetime, dehydrate_datetime, + hydrate_datetime_v2, dehydrate_datetime_v2, hydrate_duration, dehydrate_duration, dehydrate_timedelta, ) @@ -268,7 +269,7 @@ def transform(self, x): class DataHydrator: # TODO: extend DataTransformer - def __init__(self): + def __init__(self, patch_utc=False): super(DataHydrator, self).__init__() self.graph = Graph() self.graph_hydrator = Graph.Hydrator(self.graph) @@ -282,11 +283,19 @@ def __init__(self): b"D": hydrate_date, b"T": hydrate_time, # time zone offset b"t": hydrate_time, # no time zone - b"F": hydrate_datetime, # time zone offset - b"f": hydrate_datetime, # time zone name b"d": hydrate_datetime, # no time zone b"E": hydrate_duration, } + if not patch_utc: + self.hydration_functions.update({ + b"F": hydrate_datetime, # time zone offset + b"f": hydrate_datetime, # time zone name + }) + else: + self.hydration_functions.update({ + b"I": hydrate_datetime_v2, # time zone offset + b"i": hydrate_datetime_v2, # time zone name + }) def hydrate(self, values): """ Convert PackStream values into native values. @@ -320,10 +329,10 @@ class DataDehydrator: # TODO: extend DataTransformer @classmethod - def fix_parameters(cls, parameters): + def fix_parameters(cls, parameters, patch_utc=False): if not parameters: return {} - dehydrator = cls() + dehydrator = cls(patch_utc=patch_utc) try: dehydrated, = dehydrator.dehydrate([parameters]) except TypeError as error: @@ -332,7 +341,7 @@ def fix_parameters(cls, parameters): else: return dehydrated - def __init__(self): + def __init__(self, patch_utc=False): self.dehydration_functions = {} self.dehydration_functions.update({ Point: dehydrate_point, @@ -340,11 +349,20 @@ def __init__(self): date: dehydrate_date, Time: dehydrate_time, time: dehydrate_time, - DateTime: dehydrate_datetime, - datetime: dehydrate_datetime, Duration: dehydrate_duration, timedelta: dehydrate_timedelta, }) + if not patch_utc: + self.dehydration_functions.update({ + DateTime: dehydrate_datetime, + datetime: dehydrate_datetime, + }) + else: + self.dehydration_functions.update({ + DateTime: dehydrate_datetime_v2, + datetime: dehydrate_datetime_v2, + }) + # Allow dehydration from any direct Point subclass self.dehydration_functions.update({cls: dehydrate_point for cls in Point.__subclasses__()}) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 40d12b575..0867c0996 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -147,6 +147,8 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No # configuration hint that exists. Therefore, all hints can be stored at # connection level. This might change in the future. self.configuration_hints = {} + # back ported protocol patches negotiated with the server + self.bolt_patches = set() self.outbox = Outbox() self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) self.packer = Packer(self.outbox) diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index c29c6aebb..2feeea14b 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -368,6 +368,20 @@ class Bolt4x3(Bolt4x2): PROTOCOL_VERSION = Version(4, 3) + def get_base_headers(self): + """ Bolt 4.1 passes the routing context, originally taken from + the URI, into the connection initialisation message. This + enables server-side routing to propagate the same behaviour + through its driver. + """ + headers = { + "user_agent": self.user_agent, + "patch_bolt": ["utc"] + } + if self.routing_context is not None: + headers["routing"] = self.routing_context + return headers + def route(self, database=None, imp_user=None, bookmarks=None): if imp_user is not None: raise ConfigurationError( @@ -394,6 +408,7 @@ def route(self, database=None, imp_user=None, bookmarks=None): def hello(self): def on_success(metadata): + # configuration hints self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) if "connection.recv_timeout_seconds" in self.configuration_hints: @@ -407,6 +422,8 @@ def on_success(metadata): "connection.recv_timeout_seconds (%r). Make sure " "the server and network is set up correctly.", self.local_port, recv_timeout) + # bolt patch handshake + self.bolt_patches.update(set(metadata.pop("patch_bolt", ()))) headers = self.get_base_headers() headers.update(self.auth_dict) diff --git a/neo4j/time/hydration.py b/neo4j/time/hydration.py index e061004a8..b45ec96f5 100644 --- a/neo4j/time/hydration.py +++ b/neo4j/time/hydration.py @@ -132,6 +132,34 @@ def hydrate_datetime(seconds, nanoseconds, tz=None): return zone.localize(t) +def hydrate_datetime_v2(seconds, nanoseconds, tz=None): + """ Hydrator for `DateTime` and `LocalDateTime` values. + + :param seconds: + :param nanoseconds: + :param tz: + :return: datetime + """ + import pytz + + minutes, seconds = map(int, divmod(seconds, 60)) + hours, minutes = map(int, divmod(minutes, 60)) + days, hours = map(int, divmod(hours, 24)) + t = DateTime.combine( + Date.from_ordinal(get_date_unix_epoch_ordinal() + days), + Time(hours, minutes, seconds, nanoseconds) + ) + if tz is None: + return t + if isinstance(tz, int): + tz_offset_minutes, tz_offset_seconds = divmod(tz, 60) + zone = pytz.FixedOffset(tz_offset_minutes) + else: + zone = pytz.timezone(tz) + t = t.replace(tzinfo=pytz.UTC) + return t.as_timezone(zone) + + def dehydrate_datetime(value): """ Dehydrator for `datetime` values. @@ -167,8 +195,57 @@ def seconds_and_nanoseconds(dt): else: # with time offset seconds, nanoseconds = seconds_and_nanoseconds(value) - return Structure(b"F", seconds, nanoseconds, - int(tz.utcoffset(value).total_seconds())) + offset = tz.utcoffset(value) + if offset.microseconds: + raise ValueError("Bolt protocol does not support sub-second " + "UTC offsets.") + offset_seconds = offset.days * 86400 + offset.seconds + return Structure(b"F", seconds, nanoseconds, offset_seconds) + + +def dehydrate_datetime_v2(value): + """ Dehydrator for `datetime` values. + + :param value: + :type value: datetime + :return: + """ + + import pytz + + def seconds_and_nanoseconds(dt): + if isinstance(dt, datetime): + dt = DateTime.from_native(dt) + dt = dt.astimezone(pytz.UTC) + utc_epoch = DateTime(1970, 1, 1, tzinfo=pytz.UTC) + dt_clock_time = dt.to_clock_time() + utc_epoch_clock_time = utc_epoch.to_clock_time() + t = dt_clock_time - utc_epoch_clock_time + return t.seconds, t.nanoseconds + + tz = value.tzinfo + if tz is None: + # without time zone + value = pytz.UTC.localize(value) + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"d", seconds, nanoseconds) + elif hasattr(tz, "zone") and tz.zone and isinstance(tz.zone, str): + # with named pytz time zone + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"i", seconds, nanoseconds, tz.zone) + elif hasattr(tz, "key") and tz.key and isinstance(tz.key, str): + # with named zoneinfo (Python 3.9+) time zone + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"i", seconds, nanoseconds, tz.key) + else: + # with time offset + seconds, nanoseconds = seconds_and_nanoseconds(value) + offset = tz.utcoffset(value) + if offset.microseconds: + raise ValueError("Bolt protocol does not support sub-second " + "UTC offsets.") + offset_seconds = offset.days * 86400 + offset.seconds + return Structure(b"I", seconds, nanoseconds, offset_seconds) def hydrate_duration(months, days, seconds, nanoseconds): diff --git a/neo4j/work/result.py b/neo4j/work/result.py index b01cd30c6..89086d46e 100644 --- a/neo4j/work/result.py +++ b/neo4j/work/result.py @@ -75,7 +75,10 @@ def _run(self, query, parameters, db, imp_user, access_mode, bookmarks, query_metadata = getattr(query, "metadata", None) query_timeout = getattr(query, "timeout", None) - parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwparameters)) + parameters = DataDehydrator.fix_parameters( + dict(parameters or {}, **kwparameters), + patch_utc="utc" in self._connection.bolt_patches + ) self._metadata = { "query": query_text, diff --git a/neo4j/work/simple.py b/neo4j/work/simple.py index cc6b94c30..835eddb0c 100644 --- a/neo4j/work/simple.py +++ b/neo4j/work/simple.py @@ -206,7 +206,7 @@ def run(self, query, parameters=None, **kwparameters): protocol_version = cx.PROTOCOL_VERSION server_info = cx.server_info - hydrant = DataHydrator() + hydrant = DataHydrator(patch_utc="utc" in cx.bolt_patches) self._autoResult = Result( cx, hydrant, self._config.fetch_size, self._result_closed, diff --git a/neo4j/work/transaction.py b/neo4j/work/transaction.py index 8b7739877..c0c756d5e 100644 --- a/neo4j/work/transaction.py +++ b/neo4j/work/transaction.py @@ -127,8 +127,12 @@ def run(self, query, parameters=None, **kwparameters): # have any qid to fetch in batches. self._results[-1]._buffer_all() + hydrant = DataHydrator( + patch_utc="utc" in self._connection.bolt_patches + ) + result = Result( - self._connection, DataHydrator(), self._fetch_size, + self._connection, hydrant, self._fetch_size, self._result_on_closed_handler, self._error_handler ) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 075349cc7..db15c9193 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -42,6 +42,7 @@ "Feature:Bolt:4.2": true, "Feature:Bolt:4.3": true, "Feature:Bolt:4.4": true, + "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", "Feature:TLS:1.2": true, diff --git a/tests/unit/time/test_dehydration.py b/tests/unit/time/test_dehydration.py index 7fbd4b23f..f05486d7c 100644 --- a/tests/unit/time/test_dehydration.py +++ b/tests/unit/time/test_dehydration.py @@ -91,15 +91,15 @@ def test_native_date_time_negative_offset(self): assert struct == Structure(b"F", 1539344261, 474716000, -3600) def test_date_time_zone_id(self): - dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, - pytz.timezone("Europe/Stockholm")) + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862) + dt = pytz.timezone("Europe/Stockholm").localize(dt) struct, = self.dehydrator.dehydrate((dt,)) assert struct == Structure(b"f", 1539344261, 474716862, "Europe/Stockholm") def test_native_date_time_zone_id(self): - dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, - pytz.timezone("Europe/Stockholm")) + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716) + dt = pytz.timezone("Europe/Stockholm").localize(dt) struct, = self.dehydrator.dehydrate((dt,)) assert struct == Structure(b"f", 1539344261, 474716000, "Europe/Stockholm") @@ -133,3 +133,83 @@ def test_native_duration_mixed_sign(self): duration = datetime.timedelta(days=-1, seconds=2, microseconds=3) struct, = self.dehydrator.dehydrate((duration,)) assert struct == Structure(b"E", 0, -1, 2, 3000) + + +class TestPatchedTemporalDehydration(TestTemporalDehydration): + + def setUp(self): + self.dehydrator = DataDehydrator(patch_utc=True) + + def test_date(self): + super().test_date() + + def test_native_date(self): + super().test_native_date() + + def test_time(self): + super().test_time() + + def test_native_time(self): + super().test_native_time() + + def test_local_time(self): + super().test_local_time() + + def test_local_native_time(self): + super().test_local_native_time() + + def test_date_time(self): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.FixedOffset(60)) + struct, = self.dehydrator.dehydrate((dt,)) + assert struct == Structure(b"I", 1539340661, 474716862, 3600) + + def test_native_date_time(self): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.FixedOffset(60)) + struct, = self.dehydrator.dehydrate((dt,)) + assert struct == Structure(b"I", 1539340661, 474716000, 3600) + + def test_date_time_negative_offset(self): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.FixedOffset(-60)) + struct, = self.dehydrator.dehydrate((dt,)) + assert struct == Structure(b"I", 1539347861, 474716862, -3600) + + def test_native_date_time_negative_offset(self): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.FixedOffset(-60)) + struct, = self.dehydrator.dehydrate((dt,)) + assert struct == Structure(b"I", 1539347861, 474716000, -3600) + + def test_date_time_zone_id(self): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862) + dt = pytz.timezone("Europe/Stockholm").localize(dt) + struct, = self.dehydrator.dehydrate((dt,)) + assert struct == Structure(b"i", 1539337061, 474716862, + "Europe/Stockholm") + + def test_native_date_time_zone_id(self): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716) + dt = pytz.timezone("Europe/Stockholm").localize(dt) + struct, = self.dehydrator.dehydrate((dt,)) + assert struct == Structure(b"i", 1539337061, 474716000, + "Europe/Stockholm") + + def test_local_date_time(self): + super().test_local_date_time() + + def test_native_local_date_time(self): + super().test_native_local_date_time() + + def test_duration(self): + super().test_duration() + + def test_native_duration(self): + super().test_native_duration() + + def test_duration_mixed_sign(self): + super().test_duration_mixed_sign() + + def test_native_duration_mixed_sign(self): + super().test_native_duration_mixed_sign() diff --git a/tests/unit/time/test_hydration.py b/tests/unit/time/test_hydration.py index d66f69494..30131799f 100644 --- a/tests/unit/time/test_hydration.py +++ b/tests/unit/time/test_hydration.py @@ -22,8 +22,11 @@ from decimal import Decimal from unittest import TestCase +import pytz + from neo4j.data import DataHydrator from neo4j.packstream import Structure +from neo4j.time import DateTime class TestTemporalHydration(TestCase): @@ -31,7 +34,7 @@ class TestTemporalHydration(TestCase): def setUp(self): self.hydrant = DataHydrator() - def test_can_hydrate_date_time_structure(self): + def test_local_date_time(self): struct = Structure(b'd', 1539344261, 474716862) dt, = self.hydrant.hydrate([struct]) self.assertEqual(dt.year, 2018) @@ -40,3 +43,73 @@ def test_can_hydrate_date_time_structure(self): self.assertEqual(dt.hour, 11) self.assertEqual(dt.minute, 37) self.assertEqual(dt.second, Decimal("41.474716862")) + + def test_date_time(self): + struct = Structure(b"F", 1539344261, 474716862, 3600) + dt, = self.hydrant.hydrate([struct]) + expected_dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862) + expected_dt = pytz.FixedOffset(60).localize(expected_dt) + assert dt == expected_dt + + def test_date_time_negative_offset(self): + struct = Structure(b"F", 1539344261, 474716862, -3600) + dt, = self.hydrant.hydrate([struct]) + expected_dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862) + expected_dt = pytz.FixedOffset(-60).localize(expected_dt) + assert dt == expected_dt + + def test_date_time_zone_id(self): + struct = Structure(b"f", 1539344261, 474716862, "Europe/Stockholm") + dt, = self.hydrant.hydrate([struct]) + expected_dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862) + expected_dt = pytz.timezone("Europe/Stockholm").localize(expected_dt) + assert dt == expected_dt + + def test_does_not_handle_patched_date_time(self): + struct = Structure(b"I", 123, 456, 3600) + dt, = self.hydrant.hydrate([struct]) + assert struct == dt # no hydration defined + + def test_does_not_handle_patched_date_time_zone_id(self): + struct = Structure(b"i", 123, 456, "Europe/Stockholm") + dt, = self.hydrant.hydrate([struct]) + assert struct == dt # no hydration defined + + +class TestPatchedTemporalHydration(TestCase): + + def setUp(self): + self.hydrant = DataHydrator(patch_utc=True) + + test_local_date_time = TestTemporalHydration.test_local_date_time + + def test_date_time(self): + struct = Structure(b"I", 1539340661, 474716862, 3600) + dt, = self.hydrant.hydrate([struct]) + expected_dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862) + expected_dt = pytz.FixedOffset(60).localize(expected_dt) + assert dt == expected_dt + + def test_date_time_negative_offset(self): + struct = Structure(b"I", 1539347861, 474716862, -3600) + dt, = self.hydrant.hydrate([struct]) + expected_dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862) + expected_dt = pytz.FixedOffset(-60).localize(expected_dt) + assert dt == expected_dt + + def test_date_time_zone_id(self): + struct = Structure(b"i", 1539337061, 474716862, "Europe/Stockholm") + dt, = self.hydrant.hydrate([struct]) + expected_dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862) + expected_dt = pytz.timezone("Europe/Stockholm").localize(expected_dt) + assert dt == expected_dt + + def test_does_not_handle_unpatched_date_time(self): + struct = Structure(b"F", 123, 456, 3600) + dt, = self.hydrant.hydrate([struct]) + assert struct == dt # no hydration defined + + def test_does_not_handle_unpatched_date_time_zone_id(self): + struct = Structure(b"f", 123, 456, "Europe/Stockholm") + dt, = self.hydrant.hydrate([struct]) + assert struct == dt # no hydration defined diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/work/_fake_connection.py index c94026ea2..9bc6815ad 100644 --- a/tests/unit/work/_fake_connection.py +++ b/tests/unit/work/_fake_connection.py @@ -32,6 +32,7 @@ class FakeConnection(mock.NonCallableMagicMock): callbacks = [] server_info = ServerInfo("127.0.0.1", (4, 3)) local_port = 1234 + bolt_patches = set() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/tests/unit/work/test_result.py b/tests/unit/work/test_result.py index c33331f6a..2cf5ddeb9 100644 --- a/tests/unit/work/test_result.py +++ b/tests/unit/work/test_result.py @@ -105,6 +105,7 @@ def __init__(self, records=None, run_meta=None, summary_meta=None, self.run_meta = run_meta self.summary_meta = summary_meta ConnectionStub.server_info.update({"server": "Neo4j/4.3.0"}) + self.bolt_patches = set() self.unresolved_address = None def send_all(self): From 33198656a78dbd0f139a27db6b5ca7f91de02ad8 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Thu, 30 Jun 2022 16:06:40 +0200 Subject: [PATCH 2/3] Enable TestKit tests for temporal types --- testkit/build.py | 12 +++- testkitbackend/_driver_logger.py | 41 +++++++++++ testkitbackend/backend.py | 113 +++++++++++++++++++------------ testkitbackend/exceptions.py | 25 +++++++ testkitbackend/fromtestkit.py | 81 +++++++++++++++++++--- testkitbackend/requirements.txt | 1 + testkitbackend/test_config.json | 1 + testkitbackend/totestkit.py | 58 ++++++++++++++++ 8 files changed, 275 insertions(+), 57 deletions(-) create mode 100644 testkitbackend/_driver_logger.py create mode 100644 testkitbackend/exceptions.py create mode 100644 testkitbackend/requirements.txt diff --git a/testkit/build.py b/testkit/build.py index 132452e56..2a2b1eeff 100644 --- a/testkit/build.py +++ b/testkit/build.py @@ -1,14 +1,20 @@ """ -Executed in Go driver container. +Executed in driver container. Responsible for building driver and test backend. """ + + import subprocess +import sys def run(args, env=None): - subprocess.run(args, universal_newlines=True, stderr=subprocess.STDOUT, - check=True, env=env) + subprocess.run(args, universal_newlines=True, stdout=sys.stdout, + stderr=sys.stderr, check=True, env=env) if __name__ == "__main__": run(["python", "setup.py", "build"]) + run(["python", "-m", "pip", "install", "-U", "pip"]) + run(["python", "-m", "pip", "install", "-Ur", + "testkitbackend/requirements.txt"]) diff --git a/testkitbackend/_driver_logger.py b/testkitbackend/_driver_logger.py new file mode 100644 index 000000000..ef09ec355 --- /dev/null +++ b/testkitbackend/_driver_logger.py @@ -0,0 +1,41 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import io +import logging +import sys + + +buffer_handler = logging.StreamHandler(io.StringIO()) +buffer_handler.setLevel(logging.DEBUG) + +handler = logging.StreamHandler(sys.stdout) +handler.setLevel(logging.DEBUG) +logging.getLogger("neo4j").addHandler(handler) +logging.getLogger("neo4j").addHandler(buffer_handler) +logging.getLogger("neo4j").setLevel(logging.DEBUG) + +log = logging.getLogger("testkitbackend") +log.addHandler(handler) +log.setLevel(logging.DEBUG) + + +__all__ = [ + "buffer_handler", + "log", +] diff --git a/testkitbackend/backend.py b/testkitbackend/backend.py index c488f8adb..497735517 100644 --- a/testkitbackend/backend.py +++ b/testkitbackend/backend.py @@ -14,39 +14,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +import asyncio +import traceback from inspect import ( getmembers, isfunction, ) -import io -from json import loads, dumps -import logging -import sys -import traceback - -from neo4j._exceptions import ( - BoltError +from json import ( + dumps, + loads, ) +from pathlib import Path + +from neo4j._exceptions import BoltError from neo4j.exceptions import ( DriverError, Neo4jError, UnsupportedServerProduct, ) -import testkitbackend.requests as requests - -buffer_handler = logging.StreamHandler(io.StringIO()) -buffer_handler.setLevel(logging.DEBUG) +from ._driver_logger import ( + buffer_handler, + log, +) +from .exceptions import MarkdAsDriverException +from . import requests -handler = logging.StreamHandler(sys.stdout) -handler.setLevel(logging.DEBUG) -logging.getLogger("neo4j").addHandler(handler) -logging.getLogger("neo4j").addHandler(buffer_handler) -logging.getLogger("neo4j").setLevel(logging.DEBUG) -log = logging.getLogger("testkitbackend") -log.addHandler(handler) -log.setLevel(logging.DEBUG) +TESTKIT_BACKEND_PATH = Path(__file__).absolute().resolve().parent +DRIVER_PATH = TESTKIT_BACKEND_PATH.parent / "neo4j" class Request(dict): @@ -134,6 +132,41 @@ def process_request(self): request = request + line return False + @staticmethod + def _exc_stems_from_driver(exc): + stack = traceback.extract_tb(exc.__traceback__) + for frame in stack[-1:1:-1]: + p = Path(frame.filename) + if TESTKIT_BACKEND_PATH in p.parents: + return False + if DRIVER_PATH in p.parents: + return True + + def write_driver_exc(self, exc): + log.debug(traceback.format_exc()) + + key = self.next_key() + self.errors[key] = exc + + payload = {"id": key, "msg": ""} + + if isinstance(exc, MarkdAsDriverException): + wrapped_exc = exc.wrapped_exc + payload["errorType"] = str(type(wrapped_exc)) + if wrapped_exc.args: + payload["msg"] = str(wrapped_exc.args[0]) + else: + payload["errorType"] = str(type(exc)) + if isinstance(exc, Neo4jError) and exc.message is not None: + payload["msg"] = str(exc.message) + elif exc.args: + payload["msg"] = str(exc.args[0]) + + if isinstance(exc, Neo4jError): + payload["code"] = exc.code + + self.send_response("DriverError", payload) + def _process(self, request): """ Process a received request by retrieving handler that corresponds to the request name. @@ -156,34 +189,25 @@ def _process(self, request): " request: " + ", ".join(unsused_keys) ) except (Neo4jError, DriverError, UnsupportedServerProduct, - BoltError) as e: - log.debug(traceback.format_exc()) - if isinstance(e, Neo4jError): - msg = "" if e.message is None else str(e.message) - else: - msg = str(e.args[0]) if e.args else "" - - key = self.next_key() - self.errors[key] = e - payload = {"id": key, "errorType": str(type(e)), "msg": msg} - if isinstance(e, Neo4jError): - payload["code"] = e.code - self.send_response("DriverError", payload) + BoltError, MarkdAsDriverException) as e: + self.write_driver_exc(e) except requests.FrontendError as e: self.send_response("FrontendError", {"msg": str(e)}) - except Exception: - tb = traceback.format_exc() - log.error(tb) - self.send_response("BackendError", {"msg": tb}) + except Exception as e: + if self._exc_stems_from_driver(e): + self.write_driver_exc(e) + else: + tb = traceback.format_exc() + log.error(tb) + self.send_response("BackendError", {"msg": tb}) def send_response(self, name, data): """ Sends a response to backend. """ - buffer_handler.acquire() - log_output = buffer_handler.stream.getvalue() - buffer_handler.stream.truncate(0) - buffer_handler.stream.seek(0) - buffer_handler.release() + with buffer_handler.lock: + log_output = buffer_handler.stream.getvalue() + buffer_handler.stream.truncate(0) + buffer_handler.stream.seek(0) if not log_output.endswith("\n"): log_output += "\n" self._wr.write(log_output.encode("utf-8")) @@ -193,4 +217,7 @@ def send_response(self, name, data): self._wr.write(b"#response begin\n") self._wr.write(bytes(response+"\n", "utf-8")) self._wr.write(b"#response end\n") - self._wr.flush() + if isinstance(self._wr, asyncio.StreamWriter): + self._wr.drain() + else: + self._wr.flush() diff --git a/testkitbackend/exceptions.py b/testkitbackend/exceptions.py new file mode 100644 index 000000000..b16625960 --- /dev/null +++ b/testkitbackend/exceptions.py @@ -0,0 +1,25 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class MarkdAsDriverException(Exception): + """ + Wrap any error as DriverException + """ + def __init__(self, wrapped_exc): + super().__init__() + self.wrapped_exc = wrapped_exc diff --git a/testkitbackend/fromtestkit.py b/testkitbackend/fromtestkit.py index 5c3c92d87..b85cacba2 100644 --- a/testkitbackend/fromtestkit.py +++ b/testkitbackend/fromtestkit.py @@ -15,7 +15,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j.work.simple import Query + +from datetime import timedelta + +import pytz + +from neo4j import Query +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) def to_cypher_and_params(data): @@ -50,24 +61,72 @@ def to_query_and_params(data): def to_param(m): """ Converts testkit parameter format to driver (python) parameter """ - value = m["data"]["value"] + data = m["data"] name = m["name"] if name == "CypherNull": + if data["value"] is not None: + raise ValueError("CypherNull should be None") return None if name == "CypherString": - return str(value) + return str(data["value"]) if name == "CypherBool": - return bool(value) + return bool(data["value"]) if name == "CypherInt": - return int(value) + return int(data["value"]) if name == "CypherFloat": - return float(value) + return float(data["value"]) if name == "CypherString": - return str(value) + return str(data["value"]) if name == "CypherBytes": - return bytearray([int(byte, 16) for byte in value.split()]) + return bytearray([int(byte, 16) for byte in data["value"].split()]) if name == "CypherList": - return [to_param(v) for v in value] + return [to_param(v) for v in data["value"]] if name == "CypherMap": - return {k: to_param(value[k]) for k in value} - raise Exception("Unknown param type " + name) + return {k: to_param(data["value"][k]) for k in data["value"]} + if name == "CypherDate": + return Date(data["year"], data["month"], data["day"]) + if name == "CypherTime": + tz = None + utc_offset_s = data.get("utc_offset_s") + if utc_offset_s is not None: + utc_offset_m = utc_offset_s // 60 + if utc_offset_m * 60 != utc_offset_s: + raise ValueError("the used timezone library only supports " + "UTC offsets by minutes") + tz = pytz.FixedOffset(utc_offset_m) + return Time(data["hour"], data["minute"], data["second"], + data["nanosecond"], tzinfo=tz) + if name == "CypherDateTime": + datetime = DateTime( + data["year"], data["month"], data["day"], + data["hour"], data["minute"], data["second"], data["nanosecond"] + ) + utc_offset_s = data["utc_offset_s"] + timezone_id = data["timezone_id"] + if timezone_id is not None: + utc_offset = timedelta(seconds=utc_offset_s) + tz = pytz.timezone(timezone_id) + localized_datetime = tz.localize(datetime, is_dst=False) + if localized_datetime.utcoffset() == utc_offset: + return localized_datetime + localized_datetime = tz.localize(datetime, is_dst=True) + if localized_datetime.utcoffset() == utc_offset: + return localized_datetime + raise ValueError( + "cannot localize datetime %s to timezone %s with UTC " + "offset %s" % (datetime, timezone_id, utc_offset) + ) + elif utc_offset_s is not None: + utc_offset_m = utc_offset_s // 60 + if utc_offset_m * 60 != utc_offset_s: + raise ValueError("the used timezone library only supports " + "UTC offsets by minutes") + tz = pytz.FixedOffset(utc_offset_m) + return tz.localize(datetime) + return datetime + if name == "CypherDuration": + return Duration( + months=data["months"], days=data["days"], + seconds=data["seconds"], nanoseconds=data["nanoseconds"] + ) + raise ValueError("Unknown param type " + name) diff --git a/testkitbackend/requirements.txt b/testkitbackend/requirements.txt new file mode 100644 index 000000000..3c8d7e782 --- /dev/null +++ b/testkitbackend/requirements.txt @@ -0,0 +1 @@ +-r ../requirements.txt diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index db15c9193..139898018 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -32,6 +32,7 @@ "Feature:API:Result.Peek": true, "Feature:API:Result.Single": "Does not raise error when not exactly one record is available. To be fixed in 5.0.", "Feature:API:SessionConnectionTimeout": true, + "Feature:API:Type.Temporal": true, "Feature:API:UpdateRoutingTableTimeout": true, "Feature:Auth:Bearer": true, "Feature:Auth:Custom": true, diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 3068017aa..acc642e5f 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -21,6 +21,12 @@ Path, Relationship, ) +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) def record(rec): @@ -82,5 +88,57 @@ def to(name, val): "relationships": field(list(v.relationships)), } return {"name": "Path", "data": path} + if isinstance(v, Date): + return { + "name": "CypherDate", + "data": { + "year": v.year, + "month": v.month, + "day": v.day + } + } + if isinstance(v, Time): + data = { + "hour": v.hour, + "minute": v.minute, + "second": int(v.second), + "nanosecond": v.nanosecond + } + if v.tzinfo is not None: + data["utc_offset_s"] = v.tzinfo.utcoffset(v).total_seconds() + return { + "name": "CypherTime", + "data": data + } + if isinstance(v, DateTime): + data = { + "year": v.year, + "month": v.month, + "day": v.day, + "hour": v.hour, + "minute": v.minute, + "second": int(v.second), + "nanosecond": v.nanosecond + } + if v.tzinfo is not None: + data["utc_offset_s"] = v.tzinfo.utcoffset(v).total_seconds() + for attr in ("zone", "key"): + timezone_id = getattr(v.tzinfo, attr, None) + if isinstance(timezone_id, str): + data["timezone_id"] = timezone_id + return { + "name": "CypherDateTime", + "data": data, + } + if isinstance(v, Duration): + return { + "name": "CypherDuration", + "data": { + "months": v.months, + "days": v.days, + "seconds": v.seconds, + "nanoseconds": v.nanoseconds + }, + } raise Exception("Unhandled type:" + str(type(v))) From 8d4778b3e55b19f7aaa392281c3c9c68a0b41f2d Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 6 Jul 2022 12:31:21 +0200 Subject: [PATCH 3/3] Add TestKit feature flag `Detail:ResultStreamWorksAfterBrokenRecord` --- testkitbackend/test_config.json | 1 + 1 file changed, 1 insertion(+) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 139898018..e34fb7ca9 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -55,6 +55,7 @@ "Optimization:MinimalResets": true, "Optimization:PullPipelining": true, "Optimization:ResultListFetchAll": "The idiomatic way to cast to list is indistinguishable from iterating over the result.", + "Detail:ResultStreamWorksAfterBrokenRecord": true, "ConfHint:connection.recv_timeout_seconds": true } }