Skip to content

gh-109051: fix start_tls() on paused-for-writing transport #109603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,10 +1261,22 @@ async def start_tls(self, transport, protocol, sslcontext, *,
raise TypeError(
f'transport {transport!r} is not supported by start_tls()')

# gh-109051: SSLProtocol needs to preserve "writing paused" state
if isinstance(transport, transports._FlowControlMixin):
writing_paused = transport._protocol_paused
else:
# Don't break compatibility with transports that don't implement
# the private _FlowControlMixin (e.g. wrapper transports) as much
# as possible.
_, high_water = transport.get_write_buffer_limits()
buffer_size = transport.get_write_buffer_size()
writing_paused = buffer_size > high_water
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it safe to only use this code path? (remove fast-path for _FlowControlMixin)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, and I'm not sure. I saw the pattern used by _SendfileFallbackProtocol which checks for the mixin. It's the only other outstanding use case for set_protocol() in the stdlib.

Pedantically speaking, the documentation doesn't make an unambiguous claim that pause_writing() is called instantly when the buffer is filled above the high water mark, or that a background thread (or a proactor callback? I don't truly understand that part) doesn't drain the buffer in a way that is accidentally observable. But I don't think any transport in the stdlib is susceptible to this.

A notable case happens if the high water mark is increased while paused, and before start_tls. In the current implementation, the limit returned by get_write_buffer_limits() would be the increased one, but resume_writing() would not be called yet until the next opportunity to drain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In short, the answer is no, I don't think it would be safe. But I'd defer to an asyncio expert.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of doubt, keep the specific code path for _FlowControlMixin.


waiter = self.create_future()
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
writing_paused=writing_paused,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False)
Expand Down
5 changes: 3 additions & 2 deletions Lib/asyncio/sslproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class SSLProtocol(protocols.BufferedProtocol):

def __init__(self, loop, app_protocol, sslcontext, waiter,
server_side=False, server_hostname=None,
writing_paused=False,
call_connection_made=True,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
Expand Down Expand Up @@ -331,7 +332,7 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,

# Flow Control

self._ssl_writing_paused = False
self._ssl_writing_paused = writing_paused

self._app_reading_paused = False

Expand All @@ -341,7 +342,7 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
self._set_read_buffer_limits()
self._eof_received = False

self._app_writing_paused = False
self._app_writing_paused = writing_paused
self._outgoing_high_water = 0
self._outgoing_low_water = 0
self._set_write_buffer_limits()
Expand Down
125 changes: 125 additions & 0 deletions Lib/test/test_asyncio/test_sslproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import socket
import threading
import unittest
import weakref
from test import support
Expand Down Expand Up @@ -594,6 +595,130 @@ async def run_main():

self.loop.run_until_complete(run_main())

def test_start_tls_writing_paused(self):
# gh-109051: start_tls() should not break if called while transport has
# paused writing.

HELLO_MSG = b'1' * self.PAYLOAD_SIZE

server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
has_paused_writing_event = threading.Event()
hello_count = 0

def client(sock, addr):
sock.settimeout(self.TIMEOUT)
sock.connect(addr)

# Wait until we know that the server transport has paused writing
has_paused_writing_event.wait(timeout=support.SHORT_TIMEOUT)
has_paused_writing_event.clear()

for _ in range(hello_count):
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

sock.start_tls(client_context)
self.assertFalse(has_paused_writing_event.is_set())
sock.sendall(HELLO_MSG)

# Wait once again
has_paused_writing_event.wait(timeout=support.SHORT_TIMEOUT)
for _ in range(hello_count):
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))

sock.close()

class ServerProto(asyncio.Protocol):
def __init__(self, on_con, on_con_lost, on_got_hello):
self.on_con = on_con
self.on_con_lost = on_con_lost
self.on_got_hello = on_got_hello
self.data = b''
self.transport = None
self.is_writing_paused = False

def connection_made(self, tr):
self.transport = tr
self.on_con.set_result(tr)

def replace_transport(self, tr):
self.transport = tr

def data_received(self, data):
self.data += data
if len(self.data) >= len(HELLO_MSG):
self.on_got_hello.set_result(None)

def pause_writing(self) -> None:
self.is_writing_paused = True

def resume_writing(self) -> None:
self.is_writing_paused = False

def connection_lost(self, exc):
self.transport = None
if exc is None:
self.on_con_lost.set_result(None)
else:
self.on_con_lost.set_exception(exc)

async def main(proto, on_con, on_con_lost, on_got_hello):
nonlocal hello_count
tr = await on_con

while not proto.is_writing_paused:
tr.write(HELLO_MSG)
hello_count += 1
has_paused_writing_event.set()

self.assertEqual(proto.data, b'')

new_tr = await self.loop.start_tls(
tr, proto, server_context,
server_side=True,
ssl_handshake_timeout=self.TIMEOUT)
proto.replace_transport(new_tr)

await on_got_hello

# Check pause/resume_writing are still being called after the
# protocol is switched.
self.assertFalse(proto.is_writing_paused)
self.assertFalse(has_paused_writing_event.is_set())
hello_count = 0
while not proto.is_writing_paused:
new_tr.write(HELLO_MSG)
hello_count += 1
has_paused_writing_event.set()

await on_con_lost
self.assertFalse(proto.is_writing_paused)
self.assertEqual(proto.data, HELLO_MSG)
new_tr.close()

async def run_main():
on_con = self.loop.create_future()
on_con_lost = self.loop.create_future()
on_got_hello = self.loop.create_future()
proto = ServerProto(on_con, on_con_lost, on_got_hello)

server = await self.loop.create_server(
lambda: proto, '127.0.0.1', 0)
addr = server.sockets[0].getsockname()

with self.tcp_client(lambda sock: client(sock, addr),
timeout=self.TIMEOUT):
await asyncio.wait_for(
main(proto, on_con, on_con_lost, on_got_hello),
timeout=self.TIMEOUT)

server.close()
await server.wait_closed()

self.loop.run_until_complete(run_main())

def test_start_tls_wrong_args(self):
async def main():
with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix an assertion failure after :meth:`asyncio.loop.start_tls` is called with a
transport that had called :meth:`~asyncio.BaseProtocol.pause_writing` but not
resumed yet.