Skip to content

Refactor websockets transports #536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
57 changes: 2 additions & 55 deletions gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import functools
import io
import json
import logging
Expand Down Expand Up @@ -28,6 +27,7 @@
from ..utils import extract_files
from .appsync_auth import AppSyncAuthentication
from .async_transport import AsyncTransport
from .common.aiohttp_closed_event import create_aiohttp_closed_event
from .exceptions import (
TransportAlreadyConnected,
TransportClosed,
Expand Down Expand Up @@ -147,59 +147,6 @@ async def connect(self) -> None:
else:
raise TransportAlreadyConnected("Transport is already connected")

@staticmethod
def create_aiohttp_closed_event(session) -> asyncio.Event:
"""Work around aiohttp issue that doesn't properly close transports on exit.

See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209

Returns:
An event that will be set once all transports have been properly closed.
"""

ssl_transports = 0
all_is_lost = asyncio.Event()

def connection_lost(exc, orig_lost):
nonlocal ssl_transports

try:
orig_lost(exc)
finally:
ssl_transports -= 1
if ssl_transports == 0:
all_is_lost.set()

def eof_received(orig_eof_received):
try: # pragma: no cover
orig_eof_received()
except AttributeError: # pragma: no cover
# It may happen that eof_received() is called after
# _app_protocol and _transport are set to None.
pass

for conn in session.connector._conns.values():
for handler, _ in conn:
proto = getattr(handler.transport, "_ssl_protocol", None)
if proto is None:
continue

ssl_transports += 1
orig_lost = proto.connection_lost
orig_eof_received = proto.eof_received

proto.connection_lost = functools.partial(
connection_lost, orig_lost=orig_lost
)
proto.eof_received = functools.partial(
eof_received, orig_eof_received=orig_eof_received
)

if ssl_transports == 0:
all_is_lost.set()

return all_is_lost

async def close(self) -> None:
"""Coroutine which will close the aiohttp session.

Expand All @@ -219,7 +166,7 @@ async def close(self) -> None:
log.debug("connector_owner is False -> not closing connector")

else:
closed_event = self.create_aiohttp_closed_event(self.session)
closed_event = create_aiohttp_closed_event(self.session)
await self.session.close()
try:
await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout)
Expand Down
Loading
Loading