diff --git a/neo4j/_async/io/_bolt.py b/neo4j/_async/io/_bolt.py index 503ebeaf8..83c03c9eb 100644 --- a/neo4j/_async/io/_bolt.py +++ b/neo4j/_async/io/_bolt.py @@ -443,7 +443,8 @@ def _append(self, signature, fields=(), response=None): :param fields: the fields of the message as a tuple :param response: a response object to handle callbacks """ - self.packer.pack_struct(signature, fields) + with self.outbox.tmp_buffer(): + self.packer.pack_struct(signature, fields) self.outbox.wrap_message() self.responses.append(response) diff --git a/neo4j/_async/io/_common.py b/neo4j/_async/io/_common.py index 486e33df2..aaf458f76 100644 --- a/neo4j/_async/io/_common.py +++ b/neo4j/_async/io/_common.py @@ -17,6 +17,7 @@ import asyncio +from contextlib import contextmanager import logging import socket from struct import pack as struct_pack @@ -94,11 +95,14 @@ def __init__(self, max_chunk_size=16384): self._chunked_data = bytearray() self._raw_data = bytearray() self.write = self._raw_data.extend + self._tmp_buffering = 0 def max_chunk_size(self): return self._max_chunk_size def clear(self): + if self._tmp_buffering: + raise RuntimeError("Cannot clear while buffering") self._chunked_data = bytearray() self._raw_data.clear() @@ -128,13 +132,29 @@ def _chunk_data(self): self._raw_data.clear() def wrap_message(self): + if self._tmp_buffering: + raise RuntimeError("Cannot wrap message while buffering") self._chunk_data() self._chunked_data += b"\x00\x00" def view(self): + if self._tmp_buffering: + raise RuntimeError("Cannot view while buffering") self._chunk_data() return memoryview(self._chunked_data) + @contextmanager + def tmp_buffer(self): + self._tmp_buffering += 1 + old_len = len(self._raw_data) + try: + yield + except Exception: + del self._raw_data[old_len:] + raise + finally: + self._tmp_buffering -= 1 + class ConnectionErrorHandler: """ diff --git a/neo4j/_sync/io/_bolt.py b/neo4j/_sync/io/_bolt.py index 82ee8b628..007bb6b48 100644 --- a/neo4j/_sync/io/_bolt.py +++ b/neo4j/_sync/io/_bolt.py @@ -443,7 +443,8 @@ def _append(self, signature, fields=(), response=None): :param fields: the fields of the message as a tuple :param response: a response object to handle callbacks """ - self.packer.pack_struct(signature, fields) + with self.outbox.tmp_buffer(): + self.packer.pack_struct(signature, fields) self.outbox.wrap_message() self.responses.append(response) diff --git a/neo4j/_sync/io/_common.py b/neo4j/_sync/io/_common.py index 408de0a1f..647da7eca 100644 --- a/neo4j/_sync/io/_common.py +++ b/neo4j/_sync/io/_common.py @@ -17,6 +17,7 @@ import asyncio +from contextlib import contextmanager import logging import socket from struct import pack as struct_pack @@ -94,11 +95,14 @@ def __init__(self, max_chunk_size=16384): self._chunked_data = bytearray() self._raw_data = bytearray() self.write = self._raw_data.extend + self._tmp_buffering = 0 def max_chunk_size(self): return self._max_chunk_size def clear(self): + if self._tmp_buffering: + raise RuntimeError("Cannot clear while buffering") self._chunked_data = bytearray() self._raw_data.clear() @@ -128,13 +132,29 @@ def _chunk_data(self): self._raw_data.clear() def wrap_message(self): + if self._tmp_buffering: + raise RuntimeError("Cannot wrap message while buffering") self._chunk_data() self._chunked_data += b"\x00\x00" def view(self): + if self._tmp_buffering: + raise RuntimeError("Cannot view while buffering") self._chunk_data() return memoryview(self._chunked_data) + @contextmanager + def tmp_buffer(self): + self._tmp_buffering += 1 + old_len = len(self._raw_data) + try: + yield + except Exception: + del self._raw_data[old_len:] + raise + finally: + self._tmp_buffering -= 1 + class ConnectionErrorHandler: """