From a8dd60f038752da1887fd7cd0c66ce8ee2415631 Mon Sep 17 00:00:00 2001 From: Fediz01 Date: Fri, 19 Dec 2025 16:36:12 +0100 Subject: [PATCH] Fix start_tls() loses StreamReader buffer data and hangs when upgrading the connection --- tests/test_tcp.py | 41 +++++++++++++++++++++++++++++++++++++++++ winloop/loop.pyx | 10 ++++++++++ 2 files changed, 51 insertions(+) diff --git a/tests/test_tcp.py b/tests/test_tcp.py index f88c07b..7b0fe77 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -1262,6 +1262,47 @@ class _TestSSL(tb.SSLTestCase): PAYLOAD_SIZE = 1024 * 100 TIMEOUT = 60 + def test_start_tls_buffer_transfer(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest() + + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + BUFFERED_MSG = b'buffered data before TLS' + + server_context = self._create_server_ssl_context( + self.ONLYCERT, self.ONLYKEY) + client_context = self._create_client_ssl_context() + + async def handle_client(reader, writer): + # Send data before TLS upgrade + writer.write(BUFFERED_MSG) + await writer.drain() + await asyncio.sleep(0.2) + + # Read pre-TLS data + data = await reader.readexactly(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + # Upgrade to TLS (server side) + try: + # We need the wait_for because the broken version hangs here + await asyncio.wait_for( + writer.start_tls(server_context), + timeout=2) + self.assertIsNotNone(writer.get_extra_info('sslcontext')) + except asyncio.TimeoutError: + self.assertIsNotNone(writer.get_extra_info('sslcontext')) + + # Send/receive over TLS + writer.write(b'OK') + await writer.drain() + + data = await reader.readexactly(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + writer.close() + await self.wait_closed(writer) + def test_create_server_ssl_1(self): CNT = 0 # number of clients that were successful TOTAL_CNT = 25 # total number of clients that test will create diff --git a/winloop/loop.pyx b/winloop/loop.pyx index 301e780..0e6d9cf 100644 --- a/winloop/loop.pyx +++ b/winloop/loop.pyx @@ -1713,6 +1713,16 @@ cdef class Loop: ssl_shutdown_timeout=ssl_shutdown_timeout, call_connection_made=False) + stream_buff = None + if hasattr(protocol, '_stream_reader'): + stream_reader = protocol._stream_reader + if stream_reader is not None: + stream_buff = getattr(stream_reader, '_buffer', None) + + if stream_buff is not None: + ssl_protocol._incoming.write(stream_buff) + stream_buff.clear() + # Pause early so that "ssl_protocol.data_received()" doesn't # have a chance to get called before "ssl_protocol.connection_made()". transport.pause_reading()