Skip to content

Commit 4f8b091

Browse files
bigmontzrobsdedudethelonelyvulpesfbivilleinjectives
authored
Add tests for Bolt fix for broken DateTime encoding (#470) (#476)
The structures with signature `0x46` and `0x66` are being replaced by `0x49` and `0x69`. This new structures changes the meaning of seconds and nano seconds from `adjusted Unix epoch` to `UTC`. This changes have with goal of avoiding un-existing or ambiguous ZonedDateTime to be received or sent over Bolt. Bolt v4.3 and v4.4 were patched to support this feature if the server supports the patch. Co-authored-by: Rouven Bauer <[email protected]> Co-authored-by: grant lodge <[email protected]> Co-authored-by: Florent Biville <[email protected]> Co-authored-by: Dmitriy Tverdiakov <[email protected]>
1 parent 4df66b2 commit 4f8b091

File tree

133 files changed

+4270
-1011
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

133 files changed

+4270
-1011
lines changed

boltstub/bolt_protocol.py

Lines changed: 138 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
ServerExit,
66
)
77
from .packstream import Structure
8-
from .simple_jolt import dumps_simple as jolt_dumps
8+
from .simple_jolt import v1 as jolt_v1
9+
from .simple_jolt import v2 as jolt_v2
910
from .util import (
1011
hex_repr,
1112
recursive_subclasses,
1213
)
1314

15+
jolt_package = {
16+
1: jolt_v1,
17+
2: jolt_v2,
18+
}
19+
1420

1521
def get_bolt_protocol(version):
1622
if version is None:
@@ -27,33 +33,21 @@ def get_bolt_protocol(version):
2733
def verify_script_messages(script):
2834
protocol = get_bolt_protocol(script.context.bolt_version)
2935
for line in script.client_lines:
30-
if line.parsed[0] not in protocol.messages["C"].values():
31-
raise BoltUnknownMessageError(
32-
"Unsupported client message {} for BOLT version {}. "
33-
"Must be one of {}".format(
34-
line.parsed[0], script.context.bolt_version,
35-
list(protocol.messages["C"].values())
36-
),
37-
line
38-
)
36+
# will raise an exception if the message is unknown or the body
37+
# cannot be decoded
38+
protocol.translate_client_line(line)
3939
for line in script.server_lines:
40-
if line.parsed[0] is None:
40+
if line.is_command:
4141
continue # this server line contains a command, not a message
42-
if line.parsed[0] not in protocol.messages["S"].values():
43-
raise BoltUnknownMessageError(
44-
"Unsupported server message {} for BOLT version {}. "
45-
"Must be one of {}".format(
46-
line.parsed[0], script.context.bolt_version,
47-
list(protocol.messages["S"].values())
48-
),
49-
line
50-
)
42+
# will raise an exception if the message is unknown
43+
protocol.translate_server_line(line)
5144

5245

5346
class TranslatedStructure(Structure):
54-
def __init__(self, name, tag, *fields):
47+
def __init__(self, name, tag, *fields, packstream_version):
5548
# Verified is false as this class in only used for message structs
56-
super().__init__(tag, *fields, verified=False)
49+
super().__init__(tag, *fields, packstream_version=packstream_version,
50+
verified=False)
5751
self.name = name
5852

5953
def __repr__(self):
@@ -62,7 +56,8 @@ def __repr__(self):
6256

6357
def __str__(self):
6458
return self.name + " {}".format(" ".join(
65-
map(jolt_dumps, self.fields_to_jolt_types())
59+
map(jolt_package[self.packstream_version].dumps_simple,
60+
self.fields_to_jolt_types())
6661
))
6762

6863
def __eq__(self, other):
@@ -80,6 +75,8 @@ class BoltProtocol:
8075
# allow the server to negotiate other bolt versions
8176
equivalent_versions = set()
8277

78+
packstream_version = None
79+
8380
messages = {
8481
"C": {},
8582
"S": {},
@@ -94,25 +91,55 @@ def decode_versions(cls, b):
9491
for minor in range(spec_minor, spec_minor - range_ - 1, -1):
9592
yield major, minor
9693

94+
@classmethod
95+
def translate_client_line(cls, client_line):
96+
if not client_line.jolt_parsed:
97+
client_line.parse_jolt(cls.get_jolt_package())
98+
name, fields = client_line.jolt_parsed
99+
try:
100+
tag = next(tag_ for tag_, name_ in cls.messages["C"].items()
101+
if name == name_)
102+
except StopIteration:
103+
raise BoltUnknownMessageError(
104+
"Unsupported client message {} for BOLT version {}. "
105+
"Must be one of {}".format(
106+
name, cls.protocol_version,
107+
list(cls.messages["C"].values())
108+
),
109+
client_line
110+
)
111+
return TranslatedStructure(
112+
name, tag, *fields, packstream_version=cls.packstream_version
113+
)
114+
97115
@classmethod
98116
def translate_server_line(cls, server_line):
99-
name, fields = server_line.parsed
117+
if not server_line.jolt_parsed:
118+
server_line.parse_jolt(cls.get_jolt_package())
119+
name, fields = server_line.jolt_parsed
100120
try:
101121
tag = next(tag_ for tag_, name_ in cls.messages["S"].items()
102122
if name == name_)
103-
except StopIteration as e:
104-
raise ValueError(
105-
"Unknown response message type {} in Bolt version {}".format(
106-
name, ".".join(map(str, cls.protocol_version))
107-
)
108-
) from e
109-
return TranslatedStructure(name, tag, *fields)
123+
except StopIteration:
124+
raise BoltUnknownMessageError(
125+
"Unsupported server message {} for BOLT version {}. "
126+
"Must be one of {}".format(
127+
name, cls.protocol_version,
128+
list(cls.messages["S"].values())
129+
),
130+
server_line
131+
)
132+
return TranslatedStructure(
133+
name, tag, *fields, packstream_version=cls.packstream_version
134+
)
110135

111136
@classmethod
112137
def translate_structure(cls, structure: Structure):
113138
try:
114-
return TranslatedStructure(cls.messages["C"][structure.tag],
115-
structure.tag, *structure.fields)
139+
return TranslatedStructure(
140+
cls.messages["C"][structure.tag], structure.tag,
141+
*structure.fields, packstream_version=cls.packstream_version
142+
)
116143
except KeyError:
117144
raise ServerExit(
118145
"Unknown response message type {} in Bolt version {}".format(
@@ -121,6 +148,10 @@ def translate_structure(cls, structure: Structure):
121148
)
122149
)
123150

151+
@classmethod
152+
def get_jolt_package(cls):
153+
return jolt_package[cls.packstream_version]
154+
124155

125156
class Bolt1Protocol(BoltProtocol):
126157

@@ -129,6 +160,8 @@ class Bolt1Protocol(BoltProtocol):
129160
# allow the server to negotiate other bolt versions
130161
equivalent_versions = set()
131162

163+
packstream_version = 1
164+
132165
messages = {
133166
"C": {
134167
b"\x01": "INIT",
@@ -159,11 +192,15 @@ def decode_versions(cls, b):
159192
@classmethod
160193
def get_auto_response(cls, request: TranslatedStructure):
161194
if request.tag == b"\x01":
162-
return TranslatedStructure("SUCCESS", b"\x70", {
163-
"server": cls.server_agent,
164-
})
195+
return TranslatedStructure(
196+
"SUCCESS", b"\x70", {"server": cls.server_agent},
197+
packstream_version=cls.packstream_version
198+
)
165199
else:
166-
return TranslatedStructure("SUCCESS", b"\x70", {})
200+
return TranslatedStructure(
201+
"SUCCESS", b"\x70", {},
202+
packstream_version=cls.packstream_version
203+
)
167204

168205

169206
class Bolt2Protocol(Bolt1Protocol):
@@ -173,16 +210,22 @@ class Bolt2Protocol(Bolt1Protocol):
173210
# allow the server to negotiate other bolt versions
174211
equivalent_versions = set()
175212

213+
packstream_version = 1
214+
176215
server_agent = "Neo4j/3.4.0"
177216

178217
@classmethod
179218
def get_auto_response(cls, request: TranslatedStructure):
180219
if request.tag == b"\x01":
181-
return TranslatedStructure("SUCCESS", b"\x70", {
182-
"server": cls.server_agent,
183-
})
220+
return TranslatedStructure(
221+
"SUCCESS", b"\x70", {"server": cls.server_agent},
222+
packstream_version=cls.packstream_version
223+
)
184224
else:
185-
return TranslatedStructure("SUCCESS", b"\x70", {})
225+
return TranslatedStructure(
226+
"SUCCESS", b"\x70", {},
227+
packstream_version=cls.packstream_version
228+
)
186229

187230

188231
class Bolt3Protocol(Bolt2Protocol):
@@ -192,6 +235,8 @@ class Bolt3Protocol(Bolt2Protocol):
192235
# allow the server to negotiate other bolt versions
193236
equivalent_versions = set()
194237

238+
packstream_version = 1
239+
195240
messages = {
196241
"C": {
197242
b"\x01": "HELLO",
@@ -217,12 +262,16 @@ class Bolt3Protocol(Bolt2Protocol):
217262
@classmethod
218263
def get_auto_response(cls, request: TranslatedStructure):
219264
if request.tag == b"\x01":
220-
return TranslatedStructure("SUCCESS", b"\x70", {
221-
"connection_id": "bolt-0",
222-
"server": cls.server_agent,
223-
})
265+
return TranslatedStructure(
266+
"SUCCESS", b"\x70",
267+
{"connection_id": "bolt-0", "server": cls.server_agent},
268+
packstream_version=cls.packstream_version
269+
)
224270
else:
225-
return TranslatedStructure("SUCCESS", b"\x70", {})
271+
return TranslatedStructure(
272+
"SUCCESS", b"\x70", {},
273+
packstream_version=cls.packstream_version
274+
)
226275

227276

228277
class Bolt4x0Protocol(Bolt3Protocol):
@@ -232,6 +281,8 @@ class Bolt4x0Protocol(Bolt3Protocol):
232281
# allow the server to negotiate other bolt versions
233282
equivalent_versions = set()
234283

284+
packstream_version = 1
285+
235286
messages = {
236287
"C": {
237288
b"\x01": "HELLO",
@@ -265,12 +316,16 @@ def decode_versions(cls, b):
265316
@classmethod
266317
def get_auto_response(cls, request: TranslatedStructure):
267318
if request.tag == b"\x01":
268-
return TranslatedStructure("SUCCESS", b"\x70", {
269-
"connection_id": "bolt-0",
270-
"server": cls.server_agent,
271-
})
319+
return TranslatedStructure(
320+
"SUCCESS", b"\x70",
321+
{"connection_id": "bolt-0", "server": cls.server_agent},
322+
packstream_version=cls.packstream_version
323+
)
272324
else:
273-
return TranslatedStructure("SUCCESS", b"\x70", {})
325+
return TranslatedStructure(
326+
"SUCCESS", b"\x70", {},
327+
packstream_version=cls.packstream_version
328+
)
274329

275330

276331
class Bolt4x1Protocol(Bolt4x0Protocol):
@@ -280,6 +335,8 @@ class Bolt4x1Protocol(Bolt4x0Protocol):
280335
# allow the server to negotiate other bolt versions
281336
equivalent_versions = set()
282337

338+
packstream_version = 1
339+
283340
messages = {
284341
"C": {
285342
b"\x01": "HELLO",
@@ -305,13 +362,19 @@ class Bolt4x1Protocol(Bolt4x0Protocol):
305362
@classmethod
306363
def get_auto_response(cls, request: TranslatedStructure):
307364
if request.tag == b"\x01":
308-
return TranslatedStructure("SUCCESS", b"\x70", {
309-
"connection_id": "bolt-0",
310-
"server": cls.server_agent,
311-
"routing": None,
312-
})
365+
return TranslatedStructure(
366+
"SUCCESS", b"\x70",
367+
{
368+
"connection_id": "bolt-0", "server": cls.server_agent,
369+
"routing": None,
370+
},
371+
packstream_version=cls.packstream_version
372+
)
313373
else:
314-
return TranslatedStructure("SUCCESS", b"\x70", {})
374+
return TranslatedStructure(
375+
"SUCCESS", b"\x70", {},
376+
packstream_version=cls.packstream_version
377+
)
315378

316379

317380
class Bolt4x2Protocol(Bolt4x1Protocol):
@@ -321,6 +384,8 @@ class Bolt4x2Protocol(Bolt4x1Protocol):
321384
# allow the server to negotiate other bolt versions
322385
equivalent_versions = {(4, 1)}
323386

387+
packstream_version = 1
388+
324389
server_agent = "Neo4j/4.2.0"
325390

326391

@@ -331,6 +396,8 @@ class Bolt4x3Protocol(Bolt4x2Protocol):
331396
# allow the server to negotiate other bolt versions
332397
equivalent_versions = set()
333398

399+
packstream_version = 1
400+
334401
messages = {
335402
"C": {
336403
b"\x01": "HELLO",
@@ -370,4 +437,18 @@ class Bolt4x4Protocol(Bolt4x3Protocol):
370437
# allow the server to negotiate other bolt versions
371438
equivalent_versions = set()
372439

440+
packstream_version = 1
441+
373442
server_agent = "Neo4j/4.4.0"
443+
444+
445+
class Bolt5x0Protocol(Bolt4x4Protocol):
446+
447+
protocol_version = (5, 0)
448+
version_aliases = set()
449+
# allow the server to negotiate other bolt versions
450+
equivalent_versions = set()
451+
452+
packstream_version = 2
453+
454+
server_agent = "Neo4j/5.0.0"

boltstub/channel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88

99

1010
class Channel:
11+
# This class is the glue between a stub script, the socket, and the bolt
12+
# protocol.
13+
1114
def __init__(self, wire, bolt_version, log_cb=None, handshake_data=None):
1215
self.wire = wire
13-
self.stream = PackStream(wire)
1416
self.bolt_protocol = get_bolt_protocol(bolt_version)
17+
self.stream = PackStream(wire, self.bolt_protocol.packstream_version)
1518
self.log = log_cb
1619
self.handshake_data = handshake_data
1720
self._buffered_msg = None
@@ -69,6 +72,9 @@ def version_handshake(self):
6972
self.wire.send()
7073
self._log("S: <HANDSHAKE> %s", hex_repr(response))
7174

75+
def match_client_line(self, client_line, msg):
76+
return client_line.match_message(msg.name, msg.fields)
77+
7278
def send_raw(self, b):
7379
self.log("%s", hex_repr(b))
7480
self.wire.write(b)

0 commit comments

Comments
 (0)