diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 21f71b180..7664a10e3 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -47,6 +47,7 @@ CommitResponse, InitResponse, Response, + tx_timeout_as_ms, ) @@ -225,11 +226,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") - if timeout: - try: - extra["tx_timeout"] = int(1000 * timeout) - except TypeError: - raise TypeError("Timeout must be specified as a number of seconds") + if timeout or (isinstance(timeout, (float, int)) and timeout == 0): + extra["tx_timeout"] = tx_timeout_as_ms(timeout) fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) if query.upper() == u"COMMIT": @@ -277,12 +275,8 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") - if timeout: - try: - extra["tx_timeout"] = int(1000 * timeout) - except TypeError: - raise TypeError("Timeout must be specified as a number of seconds") - log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + if timeout or (isinstance(timeout, (float, int)) and timeout == 0): + extra["tx_timeout"] = tx_timeout_as_ms(timeout) self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) def commit(self, **handlers): diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 2feeea14b..370a9ccff 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -34,7 +34,6 @@ from neo4j.exceptions import ( ConfigurationError, DatabaseUnavailable, - DriverError, ForbiddenOnReadOnlyDatabase, Neo4jError, NotALeader, @@ -48,6 +47,7 @@ CommitResponse, InitResponse, Response, + tx_timeout_as_ms, ) from neo4j.io._bolt3 import ( ServerStateManager, @@ -178,11 +178,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") - if timeout: - try: - extra["tx_timeout"] = int(1000 * timeout) - except TypeError: - raise TypeError("Timeout must be specified as a number of seconds") + if timeout or (isinstance(timeout, (float, int)) and timeout == 0): + extra["tx_timeout"] = tx_timeout_as_ms(timeout) fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) if query.upper() == u"COMMIT": @@ -229,11 +226,8 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") - if timeout: - try: - extra["tx_timeout"] = int(1000 * timeout) - except TypeError: - raise TypeError("Timeout must be specified as a number of seconds") + if timeout or (isinstance(timeout, (float, int)) and timeout == 0): + extra["tx_timeout"] = tx_timeout_as_ms(timeout) log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) @@ -490,12 +484,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") - if timeout: - try: - extra["tx_timeout"] = int(1000 * timeout) - except TypeError: - raise TypeError("Timeout must be specified as a number of " - "seconds") + if timeout or (isinstance(timeout, (float, int)) and timeout == 0): + extra["tx_timeout"] = tx_timeout_as_ms(timeout) fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) @@ -525,11 +515,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") - if timeout: - try: - extra["tx_timeout"] = int(1000 * timeout) - except TypeError: - raise TypeError("Timeout must be specified as a number of " - "seconds") + if timeout or (isinstance(timeout, (float, int)) and timeout == 0): + extra["tx_timeout"] = tx_timeout_as_ms(timeout) log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) diff --git a/neo4j/io/_common.py b/neo4j/io/_common.py index f11ddd8e0..fbf8fe126 100644 --- a/neo4j/io/_common.py +++ b/neo4j/io/_common.py @@ -270,3 +270,31 @@ def on_failure(self, metadata): class CommitResponse(Response): pass + + +def tx_timeout_as_ms(timeout: float) -> int: + """ + Round transaction timeout to milliseconds. + + Values in (0, 1], else values are rounded using the built-in round() + function (round n.5 values to nearest even). + + :param timeout: timeout in seconds + + :returns: timeout in milliseconds (rounded) + + :raise ValueError: if timeout is negative + """ + try: + timeout = float(timeout) + except (TypeError, ValueError) as e: + err_type = type(e) + msg = "Timeout must be specified as a number of seconds" + raise err_type(msg) from e + ms = int(round(1000 * timeout)) + if ms == 0 and timeout > 0: + # Special case for 0 < timeout < 0.5 ms. + # This would be rounded to 0 ms, but the server interprets this as + # infinite timeout. So we round to the smallest possible timeout: 1 ms. + ms = 1 + return ms diff --git a/neo4j/work/simple.py b/neo4j/work/simple.py index d07aaba8d..6d9b89523 100644 --- a/neo4j/work/simple.py +++ b/neo4j/work/simple.py @@ -281,10 +281,19 @@ def begin_transaction(self, metadata=None, timeout=None): :param timeout: the transaction timeout in seconds. - Transactions that execute longer than the configured timeout will be terminated by the database. - This functionality allows to limit query/transaction execution time. - Specified timeout overrides the default timeout configured in the database using ``dbms.transaction.timeout`` setting. - Value should not represent a duration of zero or negative duration. + Transactions that execute longer than the configured timeout will + be terminated by the database. + This functionality allows user code to limit query/transaction + execution time. + The specified timeout overrides the default timeout configured in + the database using the ``db.transaction.timeout`` setting + (``dbms.transaction.timeout`` before Neo4j 5.0). + Values higher than ``db.transaction.timeout`` will be ignored and + will fall back to the default for server versions between 4.2 and + 5.2 (inclusive). + The value should not represent a negative duration. + A ``0`` duration will make the transaction execute indefinitely. + :data:`None` will use the default timeout configured on the server. :type timeout: int :returns: A new transaction instance. @@ -441,7 +450,21 @@ class Query: :type text: str :param metadata: metadata attached to the query. :type metadata: dict - :param timeout: seconds. + :param timeout: + the transaction timeout in seconds. + Transactions that execute longer than the configured timeout will + be terminated by the database. + This functionality allows user code to limit query/transaction + execution time. + The specified timeout overrides the default timeout configured in + the database using the ``db.transaction.timeout`` setting + (``dbms.transaction.timeout`` before Neo4j 5.0). + Values higher than ``db.transaction.timeout`` will be ignored and + will fall back to the default for server versions between 4.2 and + 5.2 (inclusive). + The value should not represent a negative duration. + A ``0`` duration will make the transaction execute indefinitely. + :data:`None` will use the default timeout configured on the server. :type timeout: float or :const:`None` """ def __init__(self, text, metadata=None, timeout=None): @@ -476,12 +499,19 @@ def count_people_tx(tx): :param timeout: the transaction timeout in seconds. - Transactions that execute longer than the configured timeout will be terminated by the database. - This functionality allows to limit query/transaction execution time. - Specified timeout overrides the default timeout configured in the database using ``dbms.transaction.timeout`` setting. - Values higher than ``dbms.transaction.timeout`` will be ignored and - will fall back to default (unless using Neo4j < 4.2). - Value should not represent a duration of zero or negative duration. + Transactions that execute longer than the configured timeout will + be terminated by the database. + This functionality allows user code to limit query/transaction + execution time. + The specified timeout overrides the default timeout configured in + the database using the ``db.transaction.timeout`` setting + (``dbms.transaction.timeout`` before Neo4j 5.0). + Values higher than ``db.transaction.timeout`` will be ignored and + will fall back to the default for server versions between 4.2 and + 5.2 (inclusive). + The value should not represent a negative duration. + A ``0`` duration will make the transaction execute indefinitely. + :data:`None` will use the default timeout configured on the server. :type timeout: float or :const:`None` """ diff --git a/tests/unit/io/test_class_bolt3.py b/tests/unit/io/test_class_bolt3.py index f7d63e855..0a5d49cdb 100644 --- a/tests/unit/io/test_class_bolt3.py +++ b/tests/unit/io/test_class_bolt3.py @@ -110,3 +110,56 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): PoolConfig.max_connection_lifetime) connection.hello() sockets.client.settimeout.assert_not_called() + + +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.server.send_message(0x70, {}) + connection = Bolt3(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res diff --git a/tests/unit/io/test_class_bolt4x0.py b/tests/unit/io/test_class_bolt4x0.py index 3879acb0c..f797266d0 100644 --- a/tests/unit/io/test_class_bolt4x0.py +++ b/tests/unit/io/test_class_bolt4x0.py @@ -197,3 +197,56 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): PoolConfig.max_connection_lifetime) connection.hello() sockets.client.settimeout.assert_not_called() + + +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.server.send_message(0x70, {}) + connection = Bolt4x0(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res diff --git a/tests/unit/io/test_class_bolt4x1.py b/tests/unit/io/test_class_bolt4x1.py index 663d3cbe1..34748750b 100644 --- a/tests/unit/io/test_class_bolt4x1.py +++ b/tests/unit/io/test_class_bolt4x1.py @@ -210,3 +210,56 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): PoolConfig.max_connection_lifetime) connection.hello() sockets.client.settimeout.assert_not_called() + + +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.server.send_message(0x70, {}) + connection = Bolt4x1(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res diff --git a/tests/unit/io/test_class_bolt4x2.py b/tests/unit/io/test_class_bolt4x2.py index 470adf5c7..80d53a4c5 100644 --- a/tests/unit/io/test_class_bolt4x2.py +++ b/tests/unit/io/test_class_bolt4x2.py @@ -211,3 +211,56 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): PoolConfig.max_connection_lifetime) connection.hello() sockets.client.settimeout.assert_not_called() + + +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.server.send_message(0x70, {}) + connection = Bolt4x2(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res diff --git a/tests/unit/io/test_class_bolt4x3.py b/tests/unit/io/test_class_bolt4x3.py index fc08f5b92..18273f569 100644 --- a/tests/unit/io/test_class_bolt4x3.py +++ b/tests/unit/io/test_class_bolt4x3.py @@ -237,3 +237,56 @@ def test_hint_recv_timeout_seconds(fake_socket_pair, hints, valid, and "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages) + + +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.server.send_message(0x70, {}) + connection = Bolt4x3(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res diff --git a/tests/unit/io/test_class_bolt4x4.py b/tests/unit/io/test_class_bolt4x4.py index 19378a1cd..dfa19ad2a 100644 --- a/tests/unit/io/test_class_bolt4x4.py +++ b/tests/unit/io/test_class_bolt4x4.py @@ -249,3 +249,56 @@ def test_hint_recv_timeout_seconds(fake_socket_pair, hints, valid, and "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages) + + +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.server.send_message(0x70, {}) + connection = Bolt4x4(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res