Skip to content

Commit 19ad25b

Browse files
authored
Roll back data written to output buffer on packing failure (#640)
While packing data to packstream, several errors can occur (integers that are out of bounds, unknown data types, etc.). On packing failure, the driver should never send the half-finished packed data over the wire. This will most likely cause the server to close the connection as the data will be corrupt.
1 parent 0121447 commit 19ad25b

File tree

4 files changed

+44
-2
lines changed

4 files changed

+44
-2
lines changed

neo4j/_async/io/_bolt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,8 @@ def _append(self, signature, fields=(), response=None):
443443
:param fields: the fields of the message as a tuple
444444
:param response: a response object to handle callbacks
445445
"""
446-
self.packer.pack_struct(signature, fields)
446+
with self.outbox.tmp_buffer():
447+
self.packer.pack_struct(signature, fields)
447448
self.outbox.wrap_message()
448449
self.responses.append(response)
449450

neo4j/_async/io/_common.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
import asyncio
20+
from contextlib import contextmanager
2021
import logging
2122
import socket
2223
from struct import pack as struct_pack
@@ -94,11 +95,14 @@ def __init__(self, max_chunk_size=16384):
9495
self._chunked_data = bytearray()
9596
self._raw_data = bytearray()
9697
self.write = self._raw_data.extend
98+
self._tmp_buffering = 0
9799

98100
def max_chunk_size(self):
99101
return self._max_chunk_size
100102

101103
def clear(self):
104+
if self._tmp_buffering:
105+
raise RuntimeError("Cannot clear while buffering")
102106
self._chunked_data = bytearray()
103107
self._raw_data.clear()
104108

@@ -128,13 +132,29 @@ def _chunk_data(self):
128132
self._raw_data.clear()
129133

130134
def wrap_message(self):
135+
if self._tmp_buffering:
136+
raise RuntimeError("Cannot wrap message while buffering")
131137
self._chunk_data()
132138
self._chunked_data += b"\x00\x00"
133139

134140
def view(self):
141+
if self._tmp_buffering:
142+
raise RuntimeError("Cannot view while buffering")
135143
self._chunk_data()
136144
return memoryview(self._chunked_data)
137145

146+
@contextmanager
147+
def tmp_buffer(self):
148+
self._tmp_buffering += 1
149+
old_len = len(self._raw_data)
150+
try:
151+
yield
152+
except Exception:
153+
del self._raw_data[old_len:]
154+
raise
155+
finally:
156+
self._tmp_buffering -= 1
157+
138158

139159
class ConnectionErrorHandler:
140160
"""

neo4j/_sync/io/_bolt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,8 @@ def _append(self, signature, fields=(), response=None):
443443
:param fields: the fields of the message as a tuple
444444
:param response: a response object to handle callbacks
445445
"""
446-
self.packer.pack_struct(signature, fields)
446+
with self.outbox.tmp_buffer():
447+
self.packer.pack_struct(signature, fields)
447448
self.outbox.wrap_message()
448449
self.responses.append(response)
449450

neo4j/_sync/io/_common.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
import asyncio
20+
from contextlib import contextmanager
2021
import logging
2122
import socket
2223
from struct import pack as struct_pack
@@ -94,11 +95,14 @@ def __init__(self, max_chunk_size=16384):
9495
self._chunked_data = bytearray()
9596
self._raw_data = bytearray()
9697
self.write = self._raw_data.extend
98+
self._tmp_buffering = 0
9799

98100
def max_chunk_size(self):
99101
return self._max_chunk_size
100102

101103
def clear(self):
104+
if self._tmp_buffering:
105+
raise RuntimeError("Cannot clear while buffering")
102106
self._chunked_data = bytearray()
103107
self._raw_data.clear()
104108

@@ -128,13 +132,29 @@ def _chunk_data(self):
128132
self._raw_data.clear()
129133

130134
def wrap_message(self):
135+
if self._tmp_buffering:
136+
raise RuntimeError("Cannot wrap message while buffering")
131137
self._chunk_data()
132138
self._chunked_data += b"\x00\x00"
133139

134140
def view(self):
141+
if self._tmp_buffering:
142+
raise RuntimeError("Cannot view while buffering")
135143
self._chunk_data()
136144
return memoryview(self._chunked_data)
137145

146+
@contextmanager
147+
def tmp_buffer(self):
148+
self._tmp_buffering += 1
149+
old_len = len(self._raw_data)
150+
try:
151+
yield
152+
except Exception:
153+
del self._raw_data[old_len:]
154+
raise
155+
finally:
156+
self._tmp_buffering -= 1
157+
138158

139159
class ConnectionErrorHandler:
140160
"""

0 commit comments

Comments
 (0)