Skip to content

Commit 9a2a20a

Browse files
authored
Enable Rust Extension for Faster PackStream (#979)
* Making the driver pick up optional rust extension * Enable rust extension for packing as well * Fix not using rust packer * Minor clean-ups in TestKit glue * TestKit backend: make error classification more robust * Optimization: check only once for rust availability
1 parent d7dec04 commit 9a2a20a

File tree

16 files changed

+223
-104
lines changed

16 files changed

+223
-104
lines changed

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,5 +117,9 @@ asyncio_mode = "auto"
117117
[tool.mypy]
118118

119119
[[tool.mypy.overrides]]
120-
module = "pandas.*"
120+
module = [
121+
"pandas.*",
122+
"neo4j._codec.packstream._rust",
123+
"neo4j._codec.packstream._rust.*",
124+
]
121125
ignore_missing_imports = true

src/neo4j/_codec/packstream/_common.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,7 @@
1616
# limitations under the License.
1717

1818

19-
class Structure:
20-
21-
def __init__(self, tag, *fields):
22-
self.tag = tag
23-
self.fields = list(fields)
24-
25-
def __repr__(self):
26-
return "Structure[0x%02X](%s)" % (ord(self.tag), ", ".join(map(repr, self.fields)))
27-
28-
def __eq__(self, other):
29-
try:
30-
return self.tag == other.tag and self.fields == other.fields
31-
except AttributeError:
32-
return False
33-
34-
def __ne__(self, other):
35-
return not self.__eq__(other)
36-
37-
def __len__(self):
38-
return len(self.fields)
39-
40-
def __getitem__(self, key):
41-
return self.fields[key]
42-
43-
def __setitem__(self, key, value):
44-
self.fields[key] = value
19+
try:
20+
from ._rust import Structure
21+
except ImportError:
22+
from ._python import Structure
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# This file is part of Neo4j.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# https://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
from ._common import Structure
20+
21+
22+
__all__ = [
23+
"Structure",
24+
]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# This file is part of Neo4j.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# https://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
class Structure:
20+
21+
def __init__(self, tag, *fields):
22+
self.tag = tag
23+
self.fields = list(fields)
24+
25+
def __repr__(self):
26+
return "Structure[0x%02X](%s)" % (
27+
ord(self.tag), ", ".join(map(repr, self.fields))
28+
)
29+
30+
def __eq__(self, other):
31+
try:
32+
return self.tag == other.tag and self.fields == other.fields
33+
except AttributeError:
34+
return False
35+
36+
def __ne__(self, other):
37+
return not self.__eq__(other)
38+
39+
def __len__(self):
40+
return len(self.fields)
41+
42+
def __getitem__(self, key):
43+
return self.fields[key]
44+
45+
def __setitem__(self, key, value):
46+
self.fields[key] = value

src/neo4j/_codec/packstream/v1/__init__.py

Lines changed: 56 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,46 +16,26 @@
1616
# limitations under the License.
1717

1818

19-
import typing as t
2019
from codecs import decode
2120
from contextlib import contextmanager
2221
from struct import (
2322
pack as struct_pack,
2423
unpack as struct_unpack,
2524
)
2625

27-
from ...._optional_deps import (
28-
np,
29-
pd,
30-
)
3126
from ...hydration import DehydrationHooks
3227
from .._common import Structure
28+
from .types import *
3329

3430

35-
NONE_VALUES: t.Tuple = (None,)
36-
TRUE_VALUES: t.Tuple = (True,)
37-
FALSE_VALUES: t.Tuple = (False,)
38-
INT_TYPES: t.Tuple[t.Type, ...] = (int,)
39-
FLOAT_TYPES: t.Tuple[t.Type, ...] = (float,)
40-
# we can't put tuple here because spatial types subclass tuple,
41-
# and we don't want to treat them as sequences
42-
SEQUENCE_TYPES: t.Tuple[t.Type, ...] = (list,)
43-
MAPPING_TYPES: t.Tuple[t.Type, ...] = (dict,)
44-
BYTES_TYPES: t.Tuple[t.Type, ...] = (bytes, bytearray)
45-
46-
47-
if np is not None:
48-
TRUE_VALUES = (*TRUE_VALUES, np.bool_(True))
49-
FALSE_VALUES = (*FALSE_VALUES, np.bool_(False))
50-
INT_TYPES = (*INT_TYPES, np.integer)
51-
FLOAT_TYPES = (*FLOAT_TYPES, np.floating)
52-
SEQUENCE_TYPES = (*SEQUENCE_TYPES, np.ndarray)
53-
54-
if pd is not None:
55-
NONE_VALUES = (*NONE_VALUES, pd.NA)
56-
SEQUENCE_TYPES = (*SEQUENCE_TYPES, pd.Series, pd.Categorical,
57-
pd.core.arrays.ExtensionArray)
58-
MAPPING_TYPES = (*MAPPING_TYPES, pd.DataFrame)
31+
try:
32+
from .._rust.v1 import (
33+
pack as _rust_pack,
34+
unpack as _rust_unpack,
35+
)
36+
except ImportError:
37+
_rust_pack = None
38+
_rust_unpack = None
5939

6040

6141
PACKED_UINT_8 = [struct_pack(">B", value) for value in range(0x100)]
@@ -74,12 +54,17 @@ def __init__(self, stream):
7454
self.stream = stream
7555
self._write = self.stream.write
7656

77-
def _pack_raw(self, data):
78-
self._write(data)
79-
8057
def pack(self, data, dehydration_hooks=None):
81-
self._pack(data,
82-
dehydration_hooks=self._inject_hooks(dehydration_hooks))
58+
dehydration_hooks = self._inject_hooks(dehydration_hooks)
59+
self._pack(data, dehydration_hooks=dehydration_hooks)
60+
61+
if _rust_pack:
62+
def _pack(self, data, dehydration_hooks=None):
63+
data = _rust_pack(data, dehydration_hooks)
64+
self._write(data)
65+
else:
66+
def _pack(self, data, dehydration_hooks=None):
67+
self._py_pack(data, dehydration_hooks)
8368

8469
@classmethod
8570
def _inject_hooks(cls, dehydration_hooks=None):
@@ -93,8 +78,7 @@ def _inject_hooks(cls, dehydration_hooks=None):
9378
subtypes={}
9479
)
9580

96-
97-
def _pack(self, value, dehydration_hooks=None):
81+
def _py_pack(self, value, dehydration_hooks=None):
9882
write = self._write
9983

10084
# None
@@ -136,18 +120,18 @@ def _pack(self, value, dehydration_hooks=None):
136120
elif isinstance(value, str):
137121
encoded = value.encode("utf-8")
138122
self._pack_string_header(len(encoded))
139-
self._pack_raw(encoded)
123+
self._write(encoded)
140124

141125
# Bytes
142126
elif isinstance(value, BYTES_TYPES):
143127
self._pack_bytes_header(len(value))
144-
self._pack_raw(value)
128+
self._write(value)
145129

146130
# List
147131
elif isinstance(value, SEQUENCE_TYPES):
148132
self._pack_list_header(len(value))
149133
for item in value:
150-
self._pack(item, dehydration_hooks)
134+
self._py_pack(item, dehydration_hooks)
151135

152136
# Map
153137
elif isinstance(value, MAPPING_TYPES):
@@ -157,8 +141,8 @@ def _pack(self, value, dehydration_hooks=None):
157141
raise TypeError(
158142
"Map keys must be strings, not {}".format(type(key))
159143
)
160-
self._pack(key, dehydration_hooks)
161-
self._pack(item, dehydration_hooks)
144+
self._py_pack(key, dehydration_hooks)
145+
self._py_pack(item, dehydration_hooks)
162146

163147
# Structure
164148
elif isinstance(value, Structure):
@@ -169,7 +153,7 @@ def _pack(self, value, dehydration_hooks=None):
169153
if dehydration_hooks:
170154
transformer = dehydration_hooks.get_transformer(value)
171155
if transformer is not None:
172-
self._pack(transformer(value), dehydration_hooks)
156+
self._py_pack(transformer(value), dehydration_hooks)
173157
return
174158

175159
raise ValueError("Values of type %s are not supported" % type(value))
@@ -298,11 +282,16 @@ def read(self, n=1):
298282
def read_u8(self):
299283
return self.unpackable.read_u8()
300284

301-
def unpack(self, hydration_hooks=None):
302-
value = self._unpack(hydration_hooks=hydration_hooks)
303-
if hydration_hooks and type(value) in hydration_hooks:
304-
return hydration_hooks[type(value)](value)
305-
return value
285+
if _rust_unpack:
286+
def unpack(self, hydration_hooks=None):
287+
value, i = _rust_unpack(
288+
self.unpackable.data, self.unpackable.p, hydration_hooks
289+
)
290+
self.unpackable.p = i
291+
return value
292+
else:
293+
def unpack(self, hydration_hooks=None):
294+
return self._unpack(hydration_hooks=hydration_hooks)
306295

307296
def _unpack(self, hydration_hooks=None):
308297
marker = self.read_u8()
@@ -384,8 +373,13 @@ def _unpack(self, hydration_hooks=None):
384373
size, tag = self._unpack_structure_header(marker)
385374
value = Structure(tag, *([None] * size))
386375
for i in range(len(value)):
387-
value[i] = self.unpack(hydration_hooks=hydration_hooks)
388-
return value
376+
value[i] = self._unpack(hydration_hooks=hydration_hooks)
377+
if not hydration_hooks:
378+
return value
379+
hydration_hook = hydration_hooks.get(type(value))
380+
if not hydration_hook:
381+
return value
382+
return hydration_hook(value)
389383

390384
else:
391385
raise ValueError("Unknown PackStream marker %02X" % marker)
@@ -397,22 +391,22 @@ def _unpack_list_items(self, marker, hydration_hooks=None):
397391
if size == 0:
398392
return
399393
elif size == 1:
400-
yield self.unpack(hydration_hooks=hydration_hooks)
394+
yield self._unpack(hydration_hooks=hydration_hooks)
401395
else:
402396
for _ in range(size):
403-
yield self.unpack(hydration_hooks=hydration_hooks)
397+
yield self._unpack(hydration_hooks=hydration_hooks)
404398
elif marker == 0xD4: # LIST_8:
405399
size, = struct_unpack(">B", self.read(1))
406400
for _ in range(size):
407-
yield self.unpack(hydration_hooks=hydration_hooks)
401+
yield self._unpack(hydration_hooks=hydration_hooks)
408402
elif marker == 0xD5: # LIST_16:
409403
size, = struct_unpack(">H", self.read(2))
410404
for _ in range(size):
411-
yield self.unpack(hydration_hooks=hydration_hooks)
405+
yield self._unpack(hydration_hooks=hydration_hooks)
412406
elif marker == 0xD6: # LIST_32:
413407
size, = struct_unpack(">I", self.read(4))
414408
for _ in range(size):
415-
yield self.unpack(hydration_hooks=hydration_hooks)
409+
yield self._unpack(hydration_hooks=hydration_hooks)
416410
else:
417411
return
418412

@@ -426,29 +420,29 @@ def _unpack_map(self, marker, hydration_hooks=None):
426420
size = marker & 0x0F
427421
value = {}
428422
for _ in range(size):
429-
key = self.unpack(hydration_hooks=hydration_hooks)
430-
value[key] = self.unpack(hydration_hooks=hydration_hooks)
423+
key = self._unpack(hydration_hooks=hydration_hooks)
424+
value[key] = self._unpack(hydration_hooks=hydration_hooks)
431425
return value
432426
elif marker == 0xD8: # MAP_8:
433427
size, = struct_unpack(">B", self.read(1))
434428
value = {}
435429
for _ in range(size):
436-
key = self.unpack(hydration_hooks=hydration_hooks)
437-
value[key] = self.unpack(hydration_hooks=hydration_hooks)
430+
key = self._unpack(hydration_hooks=hydration_hooks)
431+
value[key] = self._unpack(hydration_hooks=hydration_hooks)
438432
return value
439433
elif marker == 0xD9: # MAP_16:
440434
size, = struct_unpack(">H", self.read(2))
441435
value = {}
442436
for _ in range(size):
443-
key = self.unpack(hydration_hooks=hydration_hooks)
444-
value[key] = self.unpack(hydration_hooks=hydration_hooks)
437+
key = self._unpack(hydration_hooks=hydration_hooks)
438+
value[key] = self._unpack(hydration_hooks=hydration_hooks)
445439
return value
446440
elif marker == 0xDA: # MAP_32:
447441
size, = struct_unpack(">I", self.read(4))
448442
value = {}
449443
for _ in range(size):
450-
key = self.unpack(hydration_hooks=hydration_hooks)
451-
value[key] = self.unpack(hydration_hooks=hydration_hooks)
444+
key = self._unpack(hydration_hooks=hydration_hooks)
445+
value[key] = self._unpack(hydration_hooks=hydration_hooks)
452446
return value
453447
else:
454448
return None

0 commit comments

Comments
 (0)