diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index ddb9daca026936..ce03948afac803 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -180,14 +180,29 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None, buffer_size=65536): - self._pending_data_length = -1 self._paused = True + self._buffer_size = buffer_size + self._pending_data = None + self._pending_data_length = -1 + self._recv_fut_canceled = False + self._read_fut = None + self._protocol = None + super().__init__(loop, sock, protocol, waiter, extra, server) - self._data = bytearray(buffer_size) self._loop.call_soon(self._loop_reading) self._paused = False + def set_protocol(self, protocol): + super().set_protocol(protocol) + + if isinstance(protocol, protocols.BufferedProtocol): + self._data = protocol.get_buffer(self._buffer_size) + + if self._read_fut: + self._read_fut.cancel() + self._recv_fut_canceled = True + def is_reading(self): return not self._paused and not self._closing @@ -218,12 +233,20 @@ def resume_reading(self): if self._read_fut is None: self._loop.call_soon(self._loop_reading, None) - length = self._pending_data_length - self._pending_data_length = -1 - if length > -1: - # Call the protocol method after calling _loop_reading(), - # since the protocol can decide to pause reading again. - self._loop.call_soon(self._data_received, self._data[:length], length) + if isinstance(self._protocol, protocols.BufferedProtocol): + length = self._pending_data_length + self._pending_data_length = -1 + if length > -1: + # Call the protocol method after calling _loop_reading(), + # since the protocol can decide to pause reading again. + self._loop.call_soon(self._buffer_updated, length) + else: + data = self._pending_data + self._pending_data = None + if data is not None: + # Call the protocol method after calling _loop_reading(), + # since the protocol can decide to pause reading again. + self._loop.call_soon(self._data_received, data) if self._loop.get_debug(): logger.debug("%r resumes reading", self) @@ -244,7 +267,7 @@ def _eof_received(self): if not keep_open: self.close() - def _data_received(self, data, length): + def _buffer_updated(self, length): if self._paused: # Don't call any protocol method while reading is paused. # The protocol will be called on resume_reading(). @@ -256,35 +279,62 @@ def _data_received(self, data, length): self._eof_received() return + try: + self._protocol.buffer_updated(length) + except BaseException as exc: + self._fatal_error(exc, + 'Fatal error: protocol.buffer_updated() ' + 'call failed.') + + def _data_received(self, data): + if self._paused: + # Don't call any protocol method while reading is paused. + # The protocol will be called on resume_reading(). + assert self._pending_data is None + self._pending_data = data + return + + if not data: + self._eof_received() + return + + self._protocol.data_received(data) + + def _handle_recv_result(self, result): + """ + Handles the future result of recv / recv_into. + Returns if should continue reading or not, determined by EOF. + """ if isinstance(self._protocol, protocols.BufferedProtocol): - try: - protocols._feed_data_to_buffered_proto(self._protocol, data) - except (SystemExit, KeyboardInterrupt): - raise - except BaseException as exc: - self._fatal_error(exc, - 'Fatal error: protocol.buffer_updated() ' - 'call failed.') - return + length = result + if length > -1: + self._buffer_updated(length) + if length == 0: + return False else: - self._protocol.data_received(data) + data = result + self._data_received(data) + if not data: + return False + return True def _loop_reading(self, fut=None): - length = -1 - data = None try: if fut is not None: assert self._read_fut is fut or (self._read_fut is None and self._closing) self._read_fut = None if fut.done(): - # deliver data later in "finally" clause - length = fut.result() - if length == 0: - # we got end-of-file so no need to reschedule a new read - return - - data = self._data[:length] + try: + if not self._handle_recv_result(fut.result()): + # we got end-of-file so no need to reschedule a new read + return + except exceptions.CancelledError: + if self._recv_fut_canceled: + # a cancellation is expected on a change of protocol + self._recv_fut_canceled = True + else: + raise else: # the future will be replaced by next proactor.recv call fut.cancel() @@ -298,7 +348,12 @@ def _loop_reading(self, fut=None): if not self._paused: # reschedule a new read - self._read_fut = self._loop._proactor.recv_into(self._sock, self._data) + if isinstance(self._protocol, protocols.BufferedProtocol): + self._read_fut = self._loop._proactor.recv_into( + self._sock, self._data) + else: + self._read_fut = self._loop._proactor.recv( + self._sock, self._buffer_size) except ConnectionAbortedError as exc: if not self._closing: self._fatal_error(exc, 'Fatal read error on pipe transport') @@ -315,9 +370,6 @@ def _loop_reading(self, fut=None): else: if not self._paused: self._read_fut.add_done_callback(self._loop_reading) - finally: - if length > -1: - self._data_received(data, length) class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 9a169035de8865..48ccb0df76460a 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -112,7 +112,7 @@ def factory(): return await loop.create_unix_server(factory, path, **kwds) -class FlowControlMixin(protocols.Protocol): +class FlowControlMixin(protocols.BaseProtocol): """Reusable flow control logic for StreamWriter.drain(). This implements the protocol methods pause_writing(), @@ -180,7 +180,7 @@ def _get_close_waiter(self, stream): raise NotImplementedError -class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): +class BaseStreamReaderProtocol(FlowControlMixin): """Helper class to adapt between Protocol and StreamReader. (This is a helper class instead of making StreamReader itself a @@ -267,11 +267,6 @@ def connection_lost(self, exc): self._stream_writer = None self._transport = None - def data_received(self, data): - reader = self._stream_reader - if reader is not None: - reader.feed_data(data) - def eof_received(self): reader = self._stream_reader if reader is not None: @@ -298,6 +293,30 @@ def __del__(self): closed.exception() +class StreamReaderProtocol(BaseStreamReaderProtocol, protocols.Protocol): + def data_received(self, data): + reader = self._stream_reader + if reader is not None: + reader.feed_data(data) + + +class StreamReaderBufferedProtocol(BaseStreamReaderProtocol, protocols.BufferedProtocol): + def __init__(self, stream_reader, client_connected_cb=None, loop=None, + buffer_size=65536): + super().__init__(stream_reader, + client_connected_cb=client_connected_cb, + loop=loop) + self._buffer = memoryview(bytearray(buffer_size)) + + def get_buffer(self, sizehint): + return self._buffer + + def buffer_updated(self, nbytes): + reader = self._stream_reader + if reader is not None: + reader.feed_data(self._buffer[:nbytes]) + + class StreamWriter: """Wraps a Transport. diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index cf7683fee64621..a99c2c94a02737 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -20,6 +20,7 @@ from . import events from . import exceptions from . import futures +from . import protocols from . import selector_events from . import tasks from . import transports @@ -480,9 +481,13 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): os.set_blocking(self._fileno, False) self._loop.call_soon(self._protocol.connection_made, self) + # only start reading when connection_made() has been called - self._loop.call_soon(self._loop._add_reader, - self._fileno, self._read_ready) + if isinstance(protocol, protocols.BufferedProtocol): + self._read_ready = self._readinto_buffer_ready + else: + self._read_ready = self._read_buffer_ready + self._loop.call_soon(self._loop._add_reader, self._fileno, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called self._loop.call_soon(futures._set_result_unless_cancelled, @@ -509,7 +514,37 @@ def __repr__(self): info.append('closed') return '<{}>'.format(' '.join(info)) - def _read_ready(self): + def _readinto_buffer_ready(self): + try: + buf = self._protocol.get_buffer(-1) + if not len(buf): + raise RuntimeError('get_buffer() returned an empty buffer') + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal error: protocol.get_buffer() call failed.') + return + + nbytes = 0 + try: + nbytes = self._pipe.readinto(buf) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc, 'Fatal read error on pipe transport') + else: + if nbytes: + self._protocol.buffer_updated(nbytes) + else: + if self._loop.get_debug(): + logger.info("%r was closed by peer", self) + self._closing = True + self._loop._remove_reader(self._fileno) + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) + + def _read_buffer_ready(self): try: data = os.read(self._fileno, self.max_size) except (BlockingIOError, InterruptedError): diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py index 7fca0541ee75a2..7d91edfe2152ec 100644 --- a/Lib/test/test_asyncio/test_proactor_events.py +++ b/Lib/test/test_asyncio/test_proactor_events.py @@ -30,17 +30,20 @@ def close_transport(transport): transport._sock = None -class ProactorSocketTransportTests(test_utils.TestCase): +class ProactorSocketTransportTestsBase: - def setUp(self): - super().setUp() + def make_test_protocol(self): + raise NotImplementedError + + def setup_globals(self): self.loop = self.new_test_loop() self.addCleanup(self.loop.close) self.proactor = mock.Mock() self.loop._proactor = self.proactor - self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.protocol = None self.sock = mock.Mock(socket.socket) self.buffer_size = 65536 + self.make_test_protocol() def socket_transport(self, waiter=None): transport = _ProactorSocketTransport(self.loop, self.sock, @@ -48,35 +51,6 @@ def socket_transport(self, waiter=None): self.addCleanup(close_transport, transport) return transport - def test_ctor(self): - fut = self.loop.create_future() - tr = self.socket_transport(waiter=fut) - test_utils.run_briefly(self.loop) - self.assertIsNone(fut.result()) - self.protocol.connection_made(tr) - self.proactor.recv_into.assert_called_with(self.sock, bytearray(self.buffer_size)) - - def test_loop_reading(self): - tr = self.socket_transport() - tr._loop_reading() - self.loop._proactor.recv_into.assert_called_with(self.sock, bytearray(self.buffer_size)) - self.assertFalse(self.protocol.data_received.called) - self.assertFalse(self.protocol.eof_received.called) - - def test_loop_reading_data(self): - buf = b'data' - res = self.loop.create_future() - res.set_result(len(buf)) - - tr = self.socket_transport() - tr._read_fut = res - tr._data[:len(buf)] = buf - tr._loop_reading(res) - called_buf = bytearray(self.buffer_size) - called_buf[:len(buf)] = buf - self.loop._proactor.recv_into.assert_called_with(self.sock, called_buf) - self.protocol.data_received.assert_called_with(bytearray(buf)) - @unittest.skipIf(sys.flags.optimize, "Assertions are disabled in optimized mode") def test_loop_reading_no_data(self): res = self.loop.create_future() @@ -88,12 +62,12 @@ def test_loop_reading_no_data(self): tr.close = mock.Mock() tr._read_fut = res tr._loop_reading(res) - self.assertFalse(self.loop._proactor.recv_into.called) + self.assertFalse(self.loop._proactor.recv.called) self.assertTrue(self.protocol.eof_received.called) self.assertTrue(tr.close.called) def test_loop_reading_aborted(self): - err = self.loop._proactor.recv_into.side_effect = ConnectionAbortedError() + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() tr = self.socket_transport() tr._fatal_error = mock.Mock() @@ -103,7 +77,7 @@ def test_loop_reading_aborted(self): 'Fatal read error on pipe transport') def test_loop_reading_aborted_closing(self): - self.loop._proactor.recv_into.side_effect = ConnectionAbortedError() + self.loop._proactor.recv.side_effect = ConnectionAbortedError() tr = self.socket_transport() tr._closing = True @@ -112,7 +86,7 @@ def test_loop_reading_aborted_closing(self): self.assertFalse(tr._fatal_error.called) def test_loop_reading_aborted_is_fatal(self): - self.loop._proactor.recv_into.side_effect = ConnectionAbortedError() + self.loop._proactor.recv.side_effect = ConnectionAbortedError() tr = self.socket_transport() tr._closing = False tr._fatal_error = mock.Mock() @@ -120,7 +94,7 @@ def test_loop_reading_aborted_is_fatal(self): self.assertTrue(tr._fatal_error.called) def test_loop_reading_conn_reset_lost(self): - err = self.loop._proactor.recv_into.side_effect = ConnectionResetError() + err = self.loop._proactor.recv.side_effect = ConnectionResetError() tr = self.socket_transport() tr._closing = False @@ -131,7 +105,7 @@ def test_loop_reading_conn_reset_lost(self): tr._force_close.assert_called_with(err) def test_loop_reading_exception(self): - err = self.loop._proactor.recv_into.side_effect = (OSError()) + err = self.loop._proactor.recv.side_effect = (OSError()) tr = self.socket_transport() tr._fatal_error = mock.Mock() @@ -363,62 +337,6 @@ def test_write_eof_duplex_pipe(self): tr.write_eof() close_transport(tr) - def test_pause_resume_reading(self): - tr = self.socket_transport() - index = 0 - msgs = [b'data1', b'data2', b'data3', b'data4', b'data5', b''] - reversed_msgs = list(reversed(msgs)) - - def recv_into(sock, data): - f = self.loop.create_future() - msg = reversed_msgs.pop() - - result = f.result - def monkey(): - data[:len(msg)] = msg - return result() - f.result = monkey - - f.set_result(len(msg)) - return f - - self.loop._proactor.recv_into.side_effect = recv_into - self.loop._run_once() - self.assertFalse(tr._paused) - self.assertTrue(tr.is_reading()) - - for msg in msgs[:2]: - self.loop._run_once() - self.protocol.data_received.assert_called_with(bytearray(msg)) - - tr.pause_reading() - tr.pause_reading() - self.assertTrue(tr._paused) - self.assertFalse(tr.is_reading()) - for i in range(10): - self.loop._run_once() - self.protocol.data_received.assert_called_with(bytearray(msgs[1])) - - tr.resume_reading() - tr.resume_reading() - self.assertFalse(tr._paused) - self.assertTrue(tr.is_reading()) - - for msg in msgs[2:4]: - self.loop._run_once() - self.protocol.data_received.assert_called_with(bytearray(msg)) - - tr.pause_reading() - tr.resume_reading() - self.loop.call_exception_handler = mock.Mock() - self.loop._run_once() - self.loop.call_exception_handler.assert_not_called() - self.protocol.data_received.assert_called_with(bytearray(msgs[4])) - tr.close() - - self.assertFalse(tr.is_reading()) - - def pause_writing_transport(self, high): tr = self.socket_transport() tr.set_write_buffer_limits(high=high) @@ -498,6 +416,187 @@ def test_dont_pause_writing(self): self.assertFalse(self.protocol.pause_writing.called) +class ProactorSocketTransportWithProtocolTests(test_utils.TestCase, ProactorSocketTransportTestsBase): + + def make_test_protocol(self): + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + + def setUp(self): + super().setUp() + self.setup_globals() + + def test_ctor(self): + fut = self.loop.create_future() + tr = self.socket_transport(waiter=fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, self.buffer_size) + + def test_loop_reading(self): + tr = self.socket_transport() + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, self.buffer_size) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + buf = b'data' + res = self.loop.create_future() + res.set_result(buf) + + tr = self.socket_transport() + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, self.buffer_size) + self.protocol.data_received.assert_called_with(buf) + + def test_pause_resume_reading(self): + tr = self.socket_transport() + futures = [] + msgs = [b'data1', b'data2', b'data3', b'data4', b'data5', b''] + for msg in msgs: + f = self.loop.create_future() + f.set_result(msg) + futures.append(f) + + self.loop._proactor.recv.side_effect = futures + self.loop._run_once() + self.assertFalse(tr._paused) + self.assertTrue(tr.is_reading()) + + for msg in msgs[:2]: + self.loop._run_once() + self.protocol.data_received.assert_called_with(msg) + + tr.pause_reading() + self.assertTrue(tr._paused) + self.assertFalse(tr.is_reading()) + for _ in range(10): + self.loop._run_once() + self.protocol.data_received.assert_called_with(msgs[1]) + + tr.resume_reading() + self.assertFalse(tr._paused) + self.assertTrue(tr.is_reading()) + + for msg in msgs[2:4]: + self.loop._run_once() + self.protocol.data_received.assert_called_with(msg) + + tr.pause_reading() + tr.resume_reading() + self.loop.call_exception_handler = mock.Mock() + self.loop._run_once() + self.loop.call_exception_handler.assert_not_called() + self.protocol.data_received.assert_called_with(msgs[4]) + tr.close() + + self.assertFalse(tr.is_reading()) + + +class ProactorSocketTransportWithBufferedProtocolTests(test_utils.TestCase, ProactorSocketTransportTestsBase): + + def make_test_protocol(self): + self.protocol = test_utils.make_test_buffered_protocol( + asyncio.BufferedProtocol, self.buffer_size) + + # patch for some test in the base class to pass + # BufferedProtocol calls recv_into instead of recv and all the tests + # that assert if the function is called check if `recv` is called. + self.loop._proactor.recv = self.loop._proactor.recv_into + + def setUp(self): + super().setUp() + self.setup_globals() + + def assert_buffer_received(self, data): + self.assertTrue(self.protocol.buffer_updated.called) + self.assertEqual(self.protocol._last_called_buffer[:len(data)], + bytearray(data)) + + def test_ctor(self): + fut = self.loop.create_future() + tr = self.socket_transport(waiter=fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv_into.assert_called_with(self.sock, bytearray(self.buffer_size)) + + def test_loop_reading(self): + tr = self.socket_transport() + tr._loop_reading() + self.loop._proactor.recv_into.assert_called_with(self.sock, bytearray(self.buffer_size)) + self.assertFalse(self.protocol.buffer_updated.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + buf = b'data' + res = self.loop.create_future() + res.set_result(len(buf)) + + tr = self.socket_transport() + tr._read_fut = res + tr._data[:len(buf)] = buf + tr._loop_reading(res) + called_buf = bytearray(self.buffer_size) + called_buf[:len(buf)] = buf + self.loop._proactor.recv_into.assert_called_with(self.sock, called_buf) + self.assert_buffer_received(buf) + + def test_pause_resume_reading(self): + tr = self.socket_transport() + msgs = [b'data1', b'data2', b'data3', b'data4', b'data5', b''] + reversed_msgs = list(reversed(msgs)) + + def recv_into(sock, data): + msg = reversed_msgs.pop() + + f = self.loop.create_future() + result = f.result + def monkey(): + data[:len(msg)] = msg + return result() + f.result = monkey + + f.set_result(len(msg)) + return f + + self.loop._proactor.recv_into.side_effect = recv_into + self.loop._run_once() + self.assertFalse(tr._paused) + self.assertTrue(tr.is_reading()) + + for msg in msgs[:2]: + self.loop._run_once() + self.assert_buffer_received(msg) + + tr.pause_reading() + self.assertTrue(tr._paused) + self.assertFalse(tr.is_reading()) + for _ in range(10): + self.loop._run_once() + self.assert_buffer_received(msgs[1]) + + tr.resume_reading() + self.assertFalse(tr._paused) + self.assertTrue(tr.is_reading()) + + for msg in msgs[2:4]: + self.loop._run_once() + self.assert_buffer_received(msg) + + tr.pause_reading() + tr.resume_reading() + self.loop.call_exception_handler = mock.Mock() + self.loop._run_once() + self.loop.call_exception_handler.assert_not_called() + self.assert_buffer_received(msgs[4]) + tr.close() + + self.assertFalse(tr.is_reading()) + + class ProactorDatagramTransportTests(test_utils.TestCase): def setUp(self): diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py index 9918165909f7f8..9c3f9b7234a210 100644 --- a/Lib/test/test_asyncio/test_unix_events.py +++ b/Lib/test/test_asyncio/test_unix_events.py @@ -652,12 +652,14 @@ def test_sock_sendfile_exception(self): self.assertEqual(1000, self.file.tell()) -class UnixReadPipeTransportTests(test_utils.TestCase): +class UnixReadPipeTransportTestsBase: - def setUp(self): - super().setUp() + def make_test_protocol(self): + raise NotImplementedError + + def setup_globals(self): self.loop = self.new_test_loop() - self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.protocol = None self.pipe = mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 @@ -672,6 +674,8 @@ def setUp(self): m_fstat.return_value = st self.addCleanup(fstat_patcher.stop) + self.make_test_protocol() + def read_pipe_transport(self, waiter=None): transport = unix_events._UnixReadPipeTransport(self.loop, self.pipe, self.protocol, @@ -688,54 +692,6 @@ def test_ctor(self): self.loop.assert_reader(5, tr._read_ready) self.assertIsNone(waiter.result()) - @mock.patch('os.read') - def test__read_ready(self, m_read): - tr = self.read_pipe_transport() - m_read.return_value = b'data' - tr._read_ready() - - m_read.assert_called_with(5, tr.max_size) - self.protocol.data_received.assert_called_with(b'data') - - @mock.patch('os.read') - def test__read_ready_eof(self, m_read): - tr = self.read_pipe_transport() - m_read.return_value = b'' - tr._read_ready() - - m_read.assert_called_with(5, tr.max_size) - self.assertFalse(self.loop.readers) - test_utils.run_briefly(self.loop) - self.protocol.eof_received.assert_called_with() - self.protocol.connection_lost.assert_called_with(None) - - @mock.patch('os.read') - def test__read_ready_blocked(self, m_read): - tr = self.read_pipe_transport() - m_read.side_effect = BlockingIOError - tr._read_ready() - - m_read.assert_called_with(5, tr.max_size) - test_utils.run_briefly(self.loop) - self.assertFalse(self.protocol.data_received.called) - - @mock.patch('asyncio.log.logger.error') - @mock.patch('os.read') - def test__read_ready_error(self, m_read, m_logexc): - tr = self.read_pipe_transport() - err = OSError() - m_read.side_effect = err - tr._close = mock.Mock() - tr._read_ready() - - m_read.assert_called_with(5, tr.max_size) - tr._close.assert_called_with(err) - m_logexc.assert_called_with( - test_utils.MockPattern( - 'Fatal read error on pipe transport' - '\nprotocol:.*\ntransport:.*'), - exc_info=(OSError, MOCK_ANY, MOCK_ANY)) - @mock.patch('os.read') def test_pause_reading(self, m_read): tr = self.read_pipe_transport() @@ -829,6 +785,126 @@ def test_resume_reading_on_paused_pipe(self): tr.resume_reading() +class UnixReadPipeTransportWithProtocolTests(test_utils.TestCase, UnixReadPipeTransportTestsBase): + + def make_test_protocol(self): + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + + def setUp(self): + super().setUp() + self.setup_globals() + + @mock.patch('os.read') + def test__read_ready(self, m_read): + tr = self.read_pipe_transport() + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = self.read_pipe_transport() + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = self.read_pipe_transport() + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @mock.patch('asyncio.log.logger.error') + @mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = self.read_pipe_transport() + err = OSError() + m_read.side_effect = err + tr._close = mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with( + test_utils.MockPattern( + 'Fatal read error on pipe transport' + '\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) + + +class UnixReadPipeTransportWithBufferedProtocolTests(test_utils.TestCase, UnixReadPipeTransportTestsBase): + + def make_test_protocol(self): + self.protocol = test_utils.make_test_buffered_protocol( + asyncio.BufferedProtocol, 65536) + + def setUp(self): + super().setUp() + self.setup_globals() + + def set_next_buffered_read(self, data): + data_length = len(data) + buf = self.protocol.get_buffer(data_length) + buf[:data_length] = data + self.pipe.readinto.return_value = data_length + + def test__read_ready(self): + tr = self.read_pipe_transport() + data = b'data' + self.set_next_buffered_read(data) + tr._read_ready() + + self.pipe.readinto.assert_called_once() + self.assertTrue(self.protocol.buffer_updated.called) + self.assertEqual(self.protocol._last_called_buffer[:len(data)], bytearray(data)) + + def test__read_ready_eof(self): + tr = self.read_pipe_transport() + self.set_next_buffered_read(b'') + tr._read_ready() + + self.pipe.readinto.assert_called_once() + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + + def test__read_ready_blocked(self): + tr = self.read_pipe_transport() + self.pipe.readinto.side_effect = BlockingIOError + tr._read_ready() + + self.pipe.readinto.assert_called_once() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.buffer_updated.called) + + @mock.patch('asyncio.log.logger.error') + def test__read_ready_error(self, m_logexc): + tr = self.read_pipe_transport() + err = OSError() + self.pipe.readinto.side_effect = err + tr._close = mock.Mock() + tr._read_ready() + + self.pipe.readinto.assert_called_once() + tr._close.assert_called_with(err) + m_logexc.assert_called_with( + test_utils.MockPattern( + 'Fatal read error on pipe transport' + '\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) + + class UnixWritePipeTransportTests(test_utils.TestCase): def setUp(self): diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py index 96be5a1c3bcf77..159d2413999081 100644 --- a/Lib/test/test_asyncio/utils.py +++ b/Lib/test/test_asyncio/utils.py @@ -315,6 +315,25 @@ def make_test_protocol(base): return type('TestProtocol', (base,) + base.__bases__, dct)() +def make_test_buffered_protocol(base, buffer_size): + protocol = make_test_protocol(base) + protocol._buffer = bytearray(buffer_size) + protocol._last_called_buffer = None + + def get_buffer(*_, **__): + return protocol._buffer + + def buffer_updated(nbytes): + protocol.buffer_updated.called = True + protocol._last_called_buffer = protocol._buffer[:nbytes] + + protocol.get_buffer = get_buffer + + protocol.buffer_updated = buffer_updated + protocol.buffer_updated.called = False + return protocol + + class TestSelector(selectors.BaseSelector): def __init__(self): diff --git a/Misc/NEWS.d/next/Library/2020-07-11-20-15-29.bpo-41279.M4OEou.rst b/Misc/NEWS.d/next/Library/2020-07-11-20-15-29.bpo-41279.M4OEou.rst new file mode 100644 index 00000000000000..66a6d7e516a18b --- /dev/null +++ b/Misc/NEWS.d/next/Library/2020-07-11-20-15-29.bpo-41279.M4OEou.rst @@ -0,0 +1 @@ +Add ``StreamReaderBufferedProtocol``. diff --git a/Misc/NEWS.d/next/Library/2020-07-11-20-16-56.bpo-41279.PYW8U8.rst b/Misc/NEWS.d/next/Library/2020-07-11-20-16-56.bpo-41279.PYW8U8.rst new file mode 100644 index 00000000000000..9c207bfe7f102e --- /dev/null +++ b/Misc/NEWS.d/next/Library/2020-07-11-20-16-56.bpo-41279.PYW8U8.rst @@ -0,0 +1 @@ +Add ``BufferedProtocol`` support to ``_UnixReadPipeTransport``. diff --git a/Misc/NEWS.d/next/Library/2020-07-14-23-35-08.bpo-41279.Beuyjq.rst b/Misc/NEWS.d/next/Library/2020-07-14-23-35-08.bpo-41279.Beuyjq.rst new file mode 100644 index 00000000000000..cd76a1edbed5c8 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2020-07-14-23-35-08.bpo-41279.Beuyjq.rst @@ -0,0 +1,3 @@ +Call ``protocol.get_buffer`` on the protocol given to +``_ProactorReadPipeTransport`` if the protocol is of instance +``BufferedProtocol`` instead of creating a new buffer.