Skip to content

Commit 263b8a5

Browse files
committed
Fix handling of sub-ms transaction timeouts
Transaction timeouts are specified in seconds as float. However, the server expects it in milliseconds as int. This would lead to 1) rounding issues: previously, the driver would multiply by 1000 and then truncate to int. E.g., 256.4 seconds would be turned into 256399 ms because of float imprecision. Therefore, the built-in `round` is now used instead. 2) values below 1 ms (e.g., 0.0001) would be rounded down to 0. However, 0 is a special value that instructs the server to not apply any timeout. This is likely to surprise the user which specified a non-zero timeout. In this special case, the driver now rounds up to 1 ms. Back-port of: neo4j#940
1 parent 9346ce8 commit 263b8a5

9 files changed

+386
-34
lines changed

neo4j/io/_bolt3.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
CommitResponse,
4848
InitResponse,
4949
Response,
50+
tx_timeout_as_ms,
5051
)
5152

5253

@@ -225,11 +226,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
225226
extra["tx_metadata"] = dict(metadata)
226227
except TypeError:
227228
raise TypeError("Metadata must be coercible to a dict")
228-
if timeout:
229-
try:
230-
extra["tx_timeout"] = int(1000 * timeout)
231-
except TypeError:
232-
raise TypeError("Timeout must be specified as a number of seconds")
229+
if timeout or timeout == 0:
230+
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
233231
fields = (query, parameters, extra)
234232
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
235233
if query.upper() == u"COMMIT":
@@ -277,12 +275,8 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
277275
extra["tx_metadata"] = dict(metadata)
278276
except TypeError:
279277
raise TypeError("Metadata must be coercible to a dict")
280-
if timeout:
281-
try:
282-
extra["tx_timeout"] = int(1000 * timeout)
283-
except TypeError:
284-
raise TypeError("Timeout must be specified as a number of seconds")
285-
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
278+
if timeout or timeout == 0:
279+
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
286280
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
287281

288282
def commit(self, **handlers):

neo4j/io/_bolt4.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from neo4j.exceptions import (
3535
ConfigurationError,
3636
DatabaseUnavailable,
37-
DriverError,
3837
ForbiddenOnReadOnlyDatabase,
3938
Neo4jError,
4039
NotALeader,
@@ -48,6 +47,7 @@
4847
CommitResponse,
4948
InitResponse,
5049
Response,
50+
tx_timeout_as_ms,
5151
)
5252
from neo4j.io._bolt3 import (
5353
ServerStateManager,
@@ -178,11 +178,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
178178
extra["tx_metadata"] = dict(metadata)
179179
except TypeError:
180180
raise TypeError("Metadata must be coercible to a dict")
181-
if timeout:
182-
try:
183-
extra["tx_timeout"] = int(1000 * timeout)
184-
except TypeError:
185-
raise TypeError("Timeout must be specified as a number of seconds")
181+
if timeout or timeout == 0:
182+
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
186183
fields = (query, parameters, extra)
187184
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
188185
if query.upper() == u"COMMIT":
@@ -229,11 +226,8 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
229226
extra["tx_metadata"] = dict(metadata)
230227
except TypeError:
231228
raise TypeError("Metadata must be coercible to a dict")
232-
if timeout:
233-
try:
234-
extra["tx_timeout"] = int(1000 * timeout)
235-
except TypeError:
236-
raise TypeError("Timeout must be specified as a number of seconds")
229+
if timeout or timeout == 0:
230+
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
237231
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
238232
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
239233

@@ -490,12 +484,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
490484
extra["tx_metadata"] = dict(metadata)
491485
except TypeError:
492486
raise TypeError("Metadata must be coercible to a dict")
493-
if timeout:
494-
try:
495-
extra["tx_timeout"] = int(1000 * timeout)
496-
except TypeError:
497-
raise TypeError("Timeout must be specified as a number of "
498-
"seconds")
487+
if timeout or timeout == 0:
488+
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
499489
fields = (query, parameters, extra)
500490
log.debug("[#%04X] C: RUN %s", self.local_port,
501491
" ".join(map(repr, fields)))
@@ -525,11 +515,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
525515
extra["tx_metadata"] = dict(metadata)
526516
except TypeError:
527517
raise TypeError("Metadata must be coercible to a dict")
528-
if timeout:
529-
try:
530-
extra["tx_timeout"] = int(1000 * timeout)
531-
except TypeError:
532-
raise TypeError("Timeout must be specified as a number of "
533-
"seconds")
518+
if timeout or timeout == 0:
519+
extra["tx_timeout"] = tx_timeout_as_ms(timeout)
534520
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
535521
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))

neo4j/io/_common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,33 @@ def on_failure(self, metadata):
270270
class CommitResponse(Response):
271271

272272
pass
273+
274+
275+
def tx_timeout_as_ms(timeout: float) -> int:
276+
"""
277+
Round transaction timeout to milliseconds.
278+
279+
Values in (0, 1], else values are rounded using the built-in round()
280+
function (round n.5 values to nearest even).
281+
282+
:param timeout: timeout in seconds (must be >= 0)
283+
284+
:returns: timeout in milliseconds (rounded)
285+
286+
:raise ValueError: if timeout is negative
287+
"""
288+
try:
289+
timeout = float(timeout)
290+
except (TypeError, ValueError) as e:
291+
err_type = type(e)
292+
msg = "Timeout must be specified as a number of seconds"
293+
raise err_type(msg) from None
294+
if timeout < 0:
295+
raise ValueError("Timeout must be a positive number or 0.")
296+
ms = int(round(1000 * timeout))
297+
if ms == 0 and timeout > 0:
298+
# Special case for 0 < timeout < 0.5 ms.
299+
# This would be rounded to 0 ms, but the server interprets this as
300+
# infinite timeout. So we round to the smallest possible timeout: 1 ms.
301+
ms = 1
302+
return ms

tests/unit/io/test_class_bolt3.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,60 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
110110
PoolConfig.max_connection_lifetime)
111111
connection.hello()
112112
sockets.client.settimeout.assert_not_called()
113+
114+
115+
@pytest.mark.parametrize(
116+
("func", "args", "extra_idx"),
117+
(
118+
("run", ("RETURN 1",), 2),
119+
("begin", (), 0),
120+
)
121+
)
122+
@pytest.mark.parametrize(
123+
("timeout", "res"),
124+
(
125+
(None, None),
126+
(0, 0),
127+
(0.1, 100),
128+
(0.001, 1),
129+
(1e-15, 1),
130+
(0.0005, 1),
131+
(0.0001, 1),
132+
(1.0015, 1002),
133+
(1.000499, 1000),
134+
(1.0025, 1002),
135+
(3.0005, 3000),
136+
(3.456, 3456),
137+
(1, 1000),
138+
(
139+
-1e-15,
140+
ValueError("Timeout must be a positive number or 0")
141+
),
142+
(
143+
"foo",
144+
ValueError("Timeout must be specified as a number of seconds")
145+
),
146+
(
147+
[1, 2],
148+
TypeError("Timeout must be specified as a number of seconds")
149+
)
150+
)
151+
)
152+
def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res):
153+
address = ("127.0.0.1", 7687)
154+
sockets = fake_socket_pair(address)
155+
sockets.server.send_message(0x70, {})
156+
connection = Bolt3(address, sockets.client, 0)
157+
func = getattr(connection, func)
158+
if isinstance(res, Exception):
159+
with pytest.raises(type(res), match=str(res)):
160+
func(*args, timeout=timeout)
161+
else:
162+
func(*args, timeout=timeout)
163+
connection.send_all()
164+
tag, fields = sockets.server.pop_message()
165+
extra = fields[extra_idx]
166+
if timeout is None:
167+
assert "tx_timeout" not in extra
168+
else:
169+
assert extra["tx_timeout"] == res

tests/unit/io/test_class_bolt4x0.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,60 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
197197
PoolConfig.max_connection_lifetime)
198198
connection.hello()
199199
sockets.client.settimeout.assert_not_called()
200+
201+
202+
@pytest.mark.parametrize(
203+
("func", "args", "extra_idx"),
204+
(
205+
("run", ("RETURN 1",), 2),
206+
("begin", (), 0),
207+
)
208+
)
209+
@pytest.mark.parametrize(
210+
("timeout", "res"),
211+
(
212+
(None, None),
213+
(0, 0),
214+
(0.1, 100),
215+
(0.001, 1),
216+
(1e-15, 1),
217+
(0.0005, 1),
218+
(0.0001, 1),
219+
(1.0015, 1002),
220+
(1.000499, 1000),
221+
(1.0025, 1002),
222+
(3.0005, 3000),
223+
(3.456, 3456),
224+
(1, 1000),
225+
(
226+
-1e-15,
227+
ValueError("Timeout must be a positive number or 0")
228+
),
229+
(
230+
"foo",
231+
ValueError("Timeout must be specified as a number of seconds")
232+
),
233+
(
234+
[1, 2],
235+
TypeError("Timeout must be specified as a number of seconds")
236+
)
237+
)
238+
)
239+
def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res):
240+
address = ("127.0.0.1", 7687)
241+
sockets = fake_socket_pair(address)
242+
sockets.server.send_message(0x70, {})
243+
connection = Bolt4x0(address, sockets.client, 0)
244+
func = getattr(connection, func)
245+
if isinstance(res, Exception):
246+
with pytest.raises(type(res), match=str(res)):
247+
func(*args, timeout=timeout)
248+
else:
249+
func(*args, timeout=timeout)
250+
connection.send_all()
251+
tag, fields = sockets.server.pop_message()
252+
extra = fields[extra_idx]
253+
if timeout is None:
254+
assert "tx_timeout" not in extra
255+
else:
256+
assert extra["tx_timeout"] == res

tests/unit/io/test_class_bolt4x1.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,60 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
210210
PoolConfig.max_connection_lifetime)
211211
connection.hello()
212212
sockets.client.settimeout.assert_not_called()
213+
214+
215+
@pytest.mark.parametrize(
216+
("func", "args", "extra_idx"),
217+
(
218+
("run", ("RETURN 1",), 2),
219+
("begin", (), 0),
220+
)
221+
)
222+
@pytest.mark.parametrize(
223+
("timeout", "res"),
224+
(
225+
(None, None),
226+
(0, 0),
227+
(0.1, 100),
228+
(0.001, 1),
229+
(1e-15, 1),
230+
(0.0005, 1),
231+
(0.0001, 1),
232+
(1.0015, 1002),
233+
(1.000499, 1000),
234+
(1.0025, 1002),
235+
(3.0005, 3000),
236+
(3.456, 3456),
237+
(1, 1000),
238+
(
239+
-1e-15,
240+
ValueError("Timeout must be a positive number or 0")
241+
),
242+
(
243+
"foo",
244+
ValueError("Timeout must be specified as a number of seconds")
245+
),
246+
(
247+
[1, 2],
248+
TypeError("Timeout must be specified as a number of seconds")
249+
)
250+
)
251+
)
252+
def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res):
253+
address = ("127.0.0.1", 7687)
254+
sockets = fake_socket_pair(address)
255+
sockets.server.send_message(0x70, {})
256+
connection = Bolt4x1(address, sockets.client, 0)
257+
func = getattr(connection, func)
258+
if isinstance(res, Exception):
259+
with pytest.raises(type(res), match=str(res)):
260+
func(*args, timeout=timeout)
261+
else:
262+
func(*args, timeout=timeout)
263+
connection.send_all()
264+
tag, fields = sockets.server.pop_message()
265+
extra = fields[extra_idx]
266+
if timeout is None:
267+
assert "tx_timeout" not in extra
268+
else:
269+
assert extra["tx_timeout"] == res

tests/unit/io/test_class_bolt4x2.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,60 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
211211
PoolConfig.max_connection_lifetime)
212212
connection.hello()
213213
sockets.client.settimeout.assert_not_called()
214+
215+
216+
@pytest.mark.parametrize(
217+
("func", "args", "extra_idx"),
218+
(
219+
("run", ("RETURN 1",), 2),
220+
("begin", (), 0),
221+
)
222+
)
223+
@pytest.mark.parametrize(
224+
("timeout", "res"),
225+
(
226+
(None, None),
227+
(0, 0),
228+
(0.1, 100),
229+
(0.001, 1),
230+
(1e-15, 1),
231+
(0.0005, 1),
232+
(0.0001, 1),
233+
(1.0015, 1002),
234+
(1.000499, 1000),
235+
(1.0025, 1002),
236+
(3.0005, 3000),
237+
(3.456, 3456),
238+
(1, 1000),
239+
(
240+
-1e-15,
241+
ValueError("Timeout must be a positive number or 0")
242+
),
243+
(
244+
"foo",
245+
ValueError("Timeout must be specified as a number of seconds")
246+
),
247+
(
248+
[1, 2],
249+
TypeError("Timeout must be specified as a number of seconds")
250+
)
251+
)
252+
)
253+
def test_tx_timeout(fake_socket_pair, func, args, extra_idx, timeout, res):
254+
address = ("127.0.0.1", 7687)
255+
sockets = fake_socket_pair(address)
256+
sockets.server.send_message(0x70, {})
257+
connection = Bolt4x2(address, sockets.client, 0)
258+
func = getattr(connection, func)
259+
if isinstance(res, Exception):
260+
with pytest.raises(type(res), match=str(res)):
261+
func(*args, timeout=timeout)
262+
else:
263+
func(*args, timeout=timeout)
264+
connection.send_all()
265+
tag, fields = sockets.server.pop_message()
266+
extra = fields[extra_idx]
267+
if timeout is None:
268+
assert "tx_timeout" not in extra
269+
else:
270+
assert extra["tx_timeout"] == res

0 commit comments

Comments
 (0)