Skip to content

Commit 9b85b6c

Browse files
authored
Fix aiohttp wait for closed ssl connections (#153)
1 parent 7f402c8 commit 9b85b6c

File tree

3 files changed

+136
-17
lines changed

3 files changed

+136
-17
lines changed

gql/transport/aiohttp.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import functools
13
import io
24
import json
35
import logging
@@ -44,6 +46,7 @@ def __init__(
4446
auth: Optional[BasicAuth] = None,
4547
ssl: Union[SSLContext, bool, Fingerprint] = False,
4648
timeout: Optional[int] = None,
49+
ssl_close_timeout: Optional[Union[int, float]] = 10,
4750
client_session_args: Optional[Dict[str, Any]] = None,
4851
) -> None:
4952
"""Initialize the transport with the given aiohttp parameters.
@@ -53,6 +56,8 @@ def __init__(
5356
:param cookies: Dict of HTTP cookies.
5457
:param auth: BasicAuth object to enable Basic HTTP auth if needed
5558
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
59+
:param ssl_close_timeout: Timeout in seconds to wait for the ssl connection
60+
to close properly
5661
:param client_session_args: Dict of extra args passed to
5762
`aiohttp.ClientSession`_
5863
@@ -65,6 +70,7 @@ def __init__(
6570
self.auth: Optional[BasicAuth] = auth
6671
self.ssl: Union[SSLContext, bool, Fingerprint] = ssl
6772
self.timeout: Optional[int] = timeout
73+
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
6874
self.client_session_args = client_session_args
6975
self.session: Optional[aiohttp.ClientSession] = None
7076

@@ -100,6 +106,59 @@ async def connect(self) -> None:
100106
else:
101107
raise TransportAlreadyConnected("Transport is already connected")
102108

109+
@staticmethod
110+
def create_aiohttp_closed_event(session) -> asyncio.Event:
111+
"""Work around aiohttp issue that doesn't properly close transports on exit.
112+
113+
See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209
114+
115+
Returns:
116+
An event that will be set once all transports have been properly closed.
117+
"""
118+
119+
ssl_transports = 0
120+
all_is_lost = asyncio.Event()
121+
122+
def connection_lost(exc, orig_lost):
123+
nonlocal ssl_transports
124+
125+
try:
126+
orig_lost(exc)
127+
finally:
128+
ssl_transports -= 1
129+
if ssl_transports == 0:
130+
all_is_lost.set()
131+
132+
def eof_received(orig_eof_received):
133+
try:
134+
orig_eof_received()
135+
except AttributeError: # pragma: no cover
136+
# It may happen that eof_received() is called after
137+
# _app_protocol and _transport are set to None.
138+
pass
139+
140+
for conn in session.connector._conns.values():
141+
for handler, _ in conn:
142+
proto = getattr(handler.transport, "_ssl_protocol", None)
143+
if proto is None:
144+
continue
145+
146+
ssl_transports += 1
147+
orig_lost = proto.connection_lost
148+
orig_eof_received = proto.eof_received
149+
150+
proto.connection_lost = functools.partial(
151+
connection_lost, orig_lost=orig_lost
152+
)
153+
proto.eof_received = functools.partial(
154+
eof_received, orig_eof_received=orig_eof_received
155+
)
156+
157+
if ssl_transports == 0:
158+
all_is_lost.set()
159+
160+
return all_is_lost
161+
103162
async def close(self) -> None:
104163
"""Coroutine which will close the aiohttp session.
105164
@@ -108,7 +167,12 @@ async def close(self) -> None:
108167
when you exit the async context manager.
109168
"""
110169
if self.session is not None:
170+
closed_event = self.create_aiohttp_closed_event(self.session)
111171
await self.session.close()
172+
try:
173+
await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout)
174+
except asyncio.TimeoutError:
175+
pass
112176
self.session = None
113177

114178
async def execute(

tests/conftest.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ def pytest_collection_modifyitems(config, items):
7777
item.add_marker(skip_transport)
7878

7979

80-
@pytest.fixture
81-
async def aiohttp_server():
80+
async def aiohttp_server_base(with_ssl=False):
8281
"""Factory to create a TestServer instance, given an app.
8382
8483
aiohttp_server(app, **kwargs)
@@ -89,7 +88,13 @@ async def aiohttp_server():
8988

9089
async def go(app, *, port=None, **kwargs): # type: ignore
9190
server = AIOHTTPTestServer(app, port=port)
92-
await server.start_server(**kwargs)
91+
92+
start_server_args = {**kwargs}
93+
if with_ssl:
94+
testcert, ssl_context = get_localhost_ssl_context()
95+
start_server_args["ssl"] = ssl_context
96+
97+
await server.start_server(**start_server_args)
9398
servers.append(server)
9499
return server
95100

@@ -99,6 +104,18 @@ async def go(app, *, port=None, **kwargs): # type: ignore
99104
await servers.pop().close()
100105

101106

107+
@pytest.fixture
108+
async def aiohttp_server():
109+
async for server in aiohttp_server_base():
110+
yield server
111+
112+
113+
@pytest.fixture
114+
async def ssl_aiohttp_server():
115+
async for server in aiohttp_server_base(with_ssl=True):
116+
yield server
117+
118+
102119
# Adding debug logs to websocket tests
103120
for name in [
104121
"websockets.legacy.server",
@@ -121,6 +138,22 @@ async def go(app, *, port=None, **kwargs): # type: ignore
121138
MS = 0.001 * int(os.environ.get("GQL_TESTS_TIMEOUT_FACTOR", 1))
122139

123140

141+
def get_localhost_ssl_context():
142+
# This is a copy of certificate from websockets tests folder
143+
#
144+
# Generate TLS certificate with:
145+
# $ openssl req -x509 -config test_localhost.cnf \
146+
# -days 15340 -newkey rsa:2048 \
147+
# -out test_localhost.crt -keyout test_localhost.key
148+
# $ cat test_localhost.key test_localhost.crt > test_localhost.pem
149+
# $ rm test_localhost.key test_localhost.crt
150+
testcert = bytes(pathlib.Path(__file__).with_name("test_localhost.pem"))
151+
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
152+
ssl_context.load_cert_chain(testcert)
153+
154+
return (testcert, ssl_context)
155+
156+
124157
class WebSocketServer:
125158
"""Websocket server on localhost on a free port.
126159
@@ -141,20 +174,7 @@ async def start(self, handler, extra_serve_args=None):
141174
extra_serve_args = {}
142175

143176
if self.with_ssl:
144-
# This is a copy of certificate from websockets tests folder
145-
#
146-
# Generate TLS certificate with:
147-
# $ openssl req -x509 -config test_localhost.cnf \
148-
# -days 15340 -newkey rsa:2048 \
149-
# -out test_localhost.crt -keyout test_localhost.key
150-
# $ cat test_localhost.key test_localhost.crt > test_localhost.pem
151-
# $ rm test_localhost.key test_localhost.crt
152-
self.testcert = bytes(
153-
pathlib.Path(__file__).with_name("test_localhost.pem")
154-
)
155-
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
156-
ssl_context.load_cert_chain(self.testcert)
157-
177+
self.testcert, ssl_context = get_localhost_ssl_context()
158178
extra_serve_args["ssl"] = ssl_context
159179

160180
# Start a server with a random open port

tests/test_aiohttp.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,3 +1073,38 @@ async def handler(request):
10731073
execution_result = await session.execute(query, get_execution_result=True)
10741074

10751075
assert execution_result.extensions["key1"] == "val1"
1076+
1077+
1078+
@pytest.mark.asyncio
1079+
@pytest.mark.parametrize("ssl_close_timeout", [0, 10])
1080+
async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server, ssl_close_timeout):
1081+
from aiohttp import web
1082+
from gql.transport.aiohttp import AIOHTTPTransport
1083+
1084+
async def handler(request):
1085+
return web.Response(text=query1_server_answer, content_type="application/json")
1086+
1087+
app = web.Application()
1088+
app.router.add_route("POST", "/", handler)
1089+
server = await ssl_aiohttp_server(app)
1090+
1091+
url = server.make_url("/")
1092+
1093+
assert str(url).startswith("https://")
1094+
1095+
sample_transport = AIOHTTPTransport(
1096+
url=url, timeout=10, ssl_close_timeout=ssl_close_timeout
1097+
)
1098+
1099+
async with Client(transport=sample_transport,) as session:
1100+
1101+
query = gql(query1_str)
1102+
1103+
# Execute query asynchronously
1104+
result = await session.execute(query)
1105+
1106+
continents = result["continents"]
1107+
1108+
africa = continents[0]
1109+
1110+
assert africa["code"] == "AF"

0 commit comments

Comments
 (0)