Skip to content

Commit a5d1eb8

Browse files
authored
bpo-34638: Store a weak reference to stream reader to break strong references loop (GH-9201)
Store a weak reference to stream readerfor breaking strong references It breaks the strong reference loop between reader and protocol and allows to detect and close the socket if the stream is deleted (garbage collected)
1 parent aca819f commit a5d1eb8

File tree

4 files changed

+160
-10
lines changed

4 files changed

+160
-10
lines changed

Lib/asyncio/streams.py

+81-10
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
'open_connection', 'start_server')
44

55
import socket
6+
import sys
7+
import weakref
68

79
if hasattr(socket, 'AF_UNIX'):
810
__all__ += ('open_unix_connection', 'start_unix_server')
911

1012
from . import coroutines
1113
from . import events
1214
from . import exceptions
15+
from . import format_helpers
1316
from . import protocols
1417
from .log import logger
1518
from .tasks import sleep
@@ -186,46 +189,106 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
186189
call inappropriate methods of the protocol.)
187190
"""
188191

192+
_source_traceback = None
193+
189194
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
190195
super().__init__(loop=loop)
191-
self._stream_reader = stream_reader
196+
if stream_reader is not None:
197+
self._stream_reader_wr = weakref.ref(stream_reader,
198+
self._on_reader_gc)
199+
self._source_traceback = stream_reader._source_traceback
200+
else:
201+
self._stream_reader_wr = None
202+
if client_connected_cb is not None:
203+
# This is a stream created by the `create_server()` function.
204+
# Keep a strong reference to the reader until a connection
205+
# is established.
206+
self._strong_reader = stream_reader
207+
self._reject_connection = False
192208
self._stream_writer = None
209+
self._transport = None
193210
self._client_connected_cb = client_connected_cb
194211
self._over_ssl = False
195212
self._closed = self._loop.create_future()
196213

214+
def _on_reader_gc(self, wr):
215+
transport = self._transport
216+
if transport is not None:
217+
# connection_made was called
218+
context = {
219+
'message': ('An open stream object is being garbage '
220+
'collected; call "stream.close()" explicitly.')
221+
}
222+
if self._source_traceback:
223+
context['source_traceback'] = self._source_traceback
224+
self._loop.call_exception_handler(context)
225+
transport.abort()
226+
else:
227+
self._reject_connection = True
228+
self._stream_reader_wr = None
229+
230+
def _untrack_reader(self):
231+
self._stream_reader_wr = None
232+
233+
@property
234+
def _stream_reader(self):
235+
if self._stream_reader_wr is None:
236+
return None
237+
return self._stream_reader_wr()
238+
197239
def connection_made(self, transport):
198-
self._stream_reader.set_transport(transport)
240+
if self._reject_connection:
241+
context = {
242+
'message': ('An open stream was garbage collected prior to '
243+
'establishing network connection; '
244+
'call "stream.close()" explicitly.')
245+
}
246+
if self._source_traceback:
247+
context['source_traceback'] = self._source_traceback
248+
self._loop.call_exception_handler(context)
249+
transport.abort()
250+
return
251+
self._transport = transport
252+
reader = self._stream_reader
253+
if reader is not None:
254+
reader.set_transport(transport)
199255
self._over_ssl = transport.get_extra_info('sslcontext') is not None
200256
if self._client_connected_cb is not None:
201257
self._stream_writer = StreamWriter(transport, self,
202-
self._stream_reader,
258+
reader,
203259
self._loop)
204-
res = self._client_connected_cb(self._stream_reader,
260+
res = self._client_connected_cb(reader,
205261
self._stream_writer)
206262
if coroutines.iscoroutine(res):
207263
self._loop.create_task(res)
264+
self._strong_reader = None
208265

209266
def connection_lost(self, exc):
210-
if self._stream_reader is not None:
267+
reader = self._stream_reader
268+
if reader is not None:
211269
if exc is None:
212-
self._stream_reader.feed_eof()
270+
reader.feed_eof()
213271
else:
214-
self._stream_reader.set_exception(exc)
272+
reader.set_exception(exc)
215273
if not self._closed.done():
216274
if exc is None:
217275
self._closed.set_result(None)
218276
else:
219277
self._closed.set_exception(exc)
220278
super().connection_lost(exc)
221-
self._stream_reader = None
279+
self._stream_reader_wr = None
222280
self._stream_writer = None
281+
self._transport = None
223282

224283
def data_received(self, data):
225-
self._stream_reader.feed_data(data)
284+
reader = self._stream_reader
285+
if reader is not None:
286+
reader.feed_data(data)
226287

227288
def eof_received(self):
228-
self._stream_reader.feed_eof()
289+
reader = self._stream_reader
290+
if reader is not None:
291+
reader.feed_eof()
229292
if self._over_ssl:
230293
# Prevent a warning in SSLProtocol.eof_received:
231294
# "returning true from eof_received()
@@ -282,6 +345,9 @@ def can_write_eof(self):
282345
return self._transport.can_write_eof()
283346

284347
def close(self):
348+
# a reader can be garbage collected
349+
# after connection closing
350+
self._protocol._untrack_reader()
285351
return self._transport.close()
286352

287353
def is_closing(self):
@@ -318,6 +384,8 @@ async def drain(self):
318384

319385
class StreamReader:
320386

387+
_source_traceback = None
388+
321389
def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
322390
# The line length limit is a security feature;
323391
# it also doubles as half the buffer limit.
@@ -336,6 +404,9 @@ def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
336404
self._exception = None
337405
self._transport = None
338406
self._paused = False
407+
if self._loop.get_debug():
408+
self._source_traceback = format_helpers.extract_stack(
409+
sys._getframe(1))
339410

340411
def __repr__(self):
341412
info = ['StreamReader']

Lib/asyncio/subprocess.py

+5
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ def __repr__(self):
3636
info.append(f'stderr={self.stderr!r}')
3737
return '<{}>'.format(' '.join(info))
3838

39+
def _untrack_reader(self):
40+
# StreamWriter.close() expects the protocol
41+
# to have this method defined.
42+
pass
43+
3944
def connection_made(self, transport):
4045
self._transport = transport
4146

Lib/test/test_asyncio/test_streams.py

+71
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def test_ctor_global_loop(self, m_events):
4646
self.assertIs(stream._loop, m_events.get_event_loop.return_value)
4747

4848
def _basetest_open_connection(self, open_connection_fut):
49+
messages = []
50+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
4951
reader, writer = self.loop.run_until_complete(open_connection_fut)
5052
writer.write(b'GET / HTTP/1.0\r\n\r\n')
5153
f = reader.readline()
@@ -55,6 +57,7 @@ def _basetest_open_connection(self, open_connection_fut):
5557
data = self.loop.run_until_complete(f)
5658
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
5759
writer.close()
60+
self.assertEqual(messages, [])
5861

5962
def test_open_connection(self):
6063
with test_utils.run_test_server() as httpd:
@@ -70,6 +73,8 @@ def test_open_unix_connection(self):
7073
self._basetest_open_connection(conn_fut)
7174

7275
def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
76+
messages = []
77+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
7378
try:
7479
reader, writer = self.loop.run_until_complete(open_connection_fut)
7580
finally:
@@ -80,6 +85,7 @@ def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
8085
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
8186

8287
writer.close()
88+
self.assertEqual(messages, [])
8389

8490
@unittest.skipIf(ssl is None, 'No ssl module')
8591
def test_open_connection_no_loop_ssl(self):
@@ -104,13 +110,16 @@ def test_open_unix_connection_no_loop_ssl(self):
104110
self._basetest_open_connection_no_loop_ssl(conn_fut)
105111

106112
def _basetest_open_connection_error(self, open_connection_fut):
113+
messages = []
114+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
107115
reader, writer = self.loop.run_until_complete(open_connection_fut)
108116
writer._protocol.connection_lost(ZeroDivisionError())
109117
f = reader.read()
110118
with self.assertRaises(ZeroDivisionError):
111119
self.loop.run_until_complete(f)
112120
writer.close()
113121
test_utils.run_briefly(self.loop)
122+
self.assertEqual(messages, [])
114123

115124
def test_open_connection_error(self):
116125
with test_utils.run_test_server() as httpd:
@@ -621,6 +630,9 @@ async def client(addr):
621630
writer.close()
622631
return msgback
623632

633+
messages = []
634+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
635+
624636
# test the server variant with a coroutine as client handler
625637
server = MyServer(self.loop)
626638
addr = server.start()
@@ -637,6 +649,8 @@ async def client(addr):
637649
server.stop()
638650
self.assertEqual(msg, b"hello world!\n")
639651

652+
self.assertEqual(messages, [])
653+
640654
@support.skip_unless_bind_unix_socket
641655
def test_start_unix_server(self):
642656

@@ -685,6 +699,9 @@ async def client(path):
685699
writer.close()
686700
return msgback
687701

702+
messages = []
703+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
704+
688705
# test the server variant with a coroutine as client handler
689706
with test_utils.unix_socket_path() as path:
690707
server = MyServer(self.loop, path)
@@ -703,6 +720,8 @@ async def client(path):
703720
server.stop()
704721
self.assertEqual(msg, b"hello world!\n")
705722

723+
self.assertEqual(messages, [])
724+
706725
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
707726
def test_read_all_from_pipe_reader(self):
708727
# See asyncio issue 168. This test is derived from the example
@@ -893,6 +912,58 @@ def test_wait_closed_on_close_with_unread_data(self):
893912
wr.close()
894913
self.loop.run_until_complete(wr.wait_closed())
895914

915+
def test_del_stream_before_sock_closing(self):
916+
messages = []
917+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
918+
919+
with test_utils.run_test_server() as httpd:
920+
rd, wr = self.loop.run_until_complete(
921+
asyncio.open_connection(*httpd.address, loop=self.loop))
922+
sock = wr.get_extra_info('socket')
923+
self.assertNotEqual(sock.fileno(), -1)
924+
925+
wr.write(b'GET / HTTP/1.0\r\n\r\n')
926+
f = rd.readline()
927+
data = self.loop.run_until_complete(f)
928+
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
929+
930+
# drop refs to reader/writer
931+
del rd
932+
del wr
933+
gc.collect()
934+
# make a chance to close the socket
935+
test_utils.run_briefly(self.loop)
936+
937+
self.assertEqual(1, len(messages))
938+
self.assertEqual(sock.fileno(), -1)
939+
940+
self.assertEqual(1, len(messages))
941+
self.assertEqual('An open stream object is being garbage '
942+
'collected; call "stream.close()" explicitly.',
943+
messages[0]['message'])
944+
945+
def test_del_stream_before_connection_made(self):
946+
messages = []
947+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
948+
949+
with test_utils.run_test_server() as httpd:
950+
rd = asyncio.StreamReader(loop=self.loop)
951+
pr = asyncio.StreamReaderProtocol(rd, loop=self.loop)
952+
del rd
953+
gc.collect()
954+
tr, _ = self.loop.run_until_complete(
955+
self.loop.create_connection(
956+
lambda: pr, *httpd.address))
957+
958+
sock = tr.get_extra_info('socket')
959+
self.assertEqual(sock.fileno(), -1)
960+
961+
self.assertEqual(1, len(messages))
962+
self.assertEqual('An open stream was garbage collected prior to '
963+
'establishing network connection; '
964+
'call "stream.close()" explicitly.',
965+
messages[0]['message'])
966+
896967

897968
if __name__ == '__main__':
898969
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Store a weak reference to stream reader to break strong references loop
2+
between reader and protocol. It allows to detect and close the socket if
3+
the stream is deleted (garbage collected) without ``close()`` call.

0 commit comments

Comments
 (0)