Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,47 @@ class _TestSSL(tb.SSLTestCase):
PAYLOAD_SIZE = 1024 * 100
TIMEOUT = 60

def test_start_tls_buffer_transfer(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

HELLO_MSG = b'1' * self.PAYLOAD_SIZE
BUFFERED_MSG = b'buffered data before TLS'

server_context = self._create_server_ssl_context(
self.ONLYCERT, self.ONLYKEY)
client_context = self._create_client_ssl_context()

async def handle_client(reader, writer):
# Send data before TLS upgrade
writer.write(BUFFERED_MSG)
await writer.drain()
await asyncio.sleep(0.2)

# Read pre-TLS data
data = await reader.readexactly(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

# Upgrade to TLS (server side)
try:
# We need the wait_for because the broken version hangs here
await asyncio.wait_for(
writer.start_tls(server_context),
timeout=2)
self.assertIsNotNone(writer.get_extra_info('sslcontext'))
except asyncio.TimeoutError:
self.assertIsNotNone(writer.get_extra_info('sslcontext'))

# Send/receive over TLS
writer.write(b'OK')
await writer.drain()

data = await reader.readexactly(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

writer.close()
await self.wait_closed(writer)

def test_create_server_ssl_1(self):
CNT = 0 # number of clients that were successful
TOTAL_CNT = 25 # total number of clients that test will create
Expand Down
10 changes: 10 additions & 0 deletions winloop/loop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,16 @@ cdef class Loop:
ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False)

stream_buff = None
if hasattr(protocol, '_stream_reader'):
stream_reader = protocol._stream_reader
if stream_reader is not None:
stream_buff = getattr(stream_reader, '_buffer', None)

if stream_buff is not None:
ssl_protocol._incoming.write(stream_buff)
stream_buff.clear()

# Pause early so that "ssl_protocol.data_received()" doesn't
# have a chance to get called before "ssl_protocol.connection_made()".
transport.pause_reading()
Expand Down
Loading