Skip to content

Commit bdba7ce

Browse files
committed
Use loop.start_tls() to upgrade connections to SSL
The old way of TLS upgrade (openining a connection, asking postgres to do TLS and then duping the underlying socket) seems not to work anymore on Windows with Python 3.8.
1 parent d655a39 commit bdba7ce

File tree

2 files changed

+117
-80
lines changed

2 files changed

+117
-80
lines changed

asyncpg/connect_utils.py

Lines changed: 112 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,95 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
504504
return addrs, params, config
505505

506506

507+
class TLSUpgradeProto(asyncio.Protocol):
508+
def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
509+
self.on_data = _create_future(loop)
510+
self.host = host
511+
self.port = port
512+
self.ssl_context = ssl_context
513+
self.ssl_is_advisory = ssl_is_advisory
514+
515+
def data_received(self, data):
516+
if data == b'S':
517+
self.on_data.set_result(True)
518+
elif (self.ssl_is_advisory and
519+
self.ssl_context.verify_mode == ssl_module.CERT_NONE and
520+
data == b'N'):
521+
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522+
# since the only way to get ssl_is_advisory is from
523+
# sslmode=prefer (or sslmode=allow). But be extra sure to
524+
# disallow insecure connections when the ssl context asks for
525+
# real security.
526+
self.on_data.set_result(False)
527+
else:
528+
self.on_data.set_exception(
529+
ConnectionError(
530+
'PostgreSQL server at "{host}:{port}" '
531+
'rejected SSL upgrade'.format(
532+
host=self.host, port=self.port)))
533+
534+
def connection_lost(self, exc):
535+
if not self.on_data.done():
536+
if exc is None:
537+
exc = ConnectionError('unexpected connection_lost() call')
538+
self.on_data.set_exception(exc)
539+
540+
541+
async def _create_ssl_connection(protocol_factory, host, port, *,
542+
loop, ssl_context, ssl_is_advisory=False):
543+
544+
if ssl_context is True:
545+
ssl_context = ssl_module.create_default_context()
546+
547+
tr, pr = await loop.create_connection(
548+
lambda: TLSUpgradeProto(loop, host, port,
549+
ssl_context, ssl_is_advisory),
550+
host, port)
551+
552+
tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
553+
554+
try:
555+
do_ssl_upgrade = await pr.on_data
556+
except (Exception, asyncio.CancelledError):
557+
tr.close()
558+
raise
559+
560+
if hasattr(loop, 'start_tls'):
561+
if do_ssl_upgrade:
562+
try:
563+
new_tr = await loop.start_tls(
564+
tr, pr, ssl_context, server_hostname=host)
565+
except (Exception, asyncio.CancelledError):
566+
tr.close()
567+
raise
568+
else:
569+
new_tr = tr
570+
571+
pg_proto = protocol_factory()
572+
pg_proto.connection_made(new_tr)
573+
new_tr.set_protocol(pg_proto)
574+
575+
return new_tr, pg_proto
576+
else:
577+
conn_factory = functools.partial(
578+
loop.create_connection, protocol_factory)
579+
580+
if do_ssl_upgrade:
581+
conn_factory = functools.partial(
582+
conn_factory, ssl=ssl_context, server_hostname=host)
583+
584+
sock = _get_socket(tr)
585+
sock = sock.dup()
586+
_set_nodelay(sock)
587+
tr.close()
588+
589+
try:
590+
return await conn_factory(sock=sock)
591+
except (Exception, asyncio.CancelledError):
592+
sock.close()
593+
raise
594+
595+
507596
async def _connect_addr(*, addr, loop, timeout, params, config,
508597
connection_class):
509598
assert loop is not None
@@ -526,8 +615,6 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
526615
else:
527616
connector = loop.create_connection(proto_factory, *addr)
528617

529-
connector = asyncio.ensure_future(connector)
530-
531618
before = time.monotonic()
532619
try:
533620
tr, pr = await asyncio.wait_for(
@@ -575,79 +662,41 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
575662
raise last_error
576663

577664

578-
async def _negotiate_ssl_connection(host, port, conn_factory, *, loop, ssl,
579-
server_hostname, ssl_is_advisory=False):
580-
# Note: ssl_is_advisory only affects behavior when the server does not
581-
# accept SSLRequests. If the SSLRequest is accepted but either the SSL
582-
# negotiation fails or the PostgreSQL user isn't permitted to use SSL,
583-
# there's nothing that would attempt to reconnect with a non-SSL socket.
584-
reader, writer = await asyncio.open_connection(host, port)
585-
586-
tr = writer.transport
587-
try:
588-
sock = _get_socket(tr)
589-
_set_nodelay(sock)
590-
591-
writer.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
592-
await writer.drain()
593-
resp = await reader.readexactly(1)
594-
595-
if resp == b'S':
596-
conn_factory = functools.partial(
597-
conn_factory, ssl=ssl, server_hostname=server_hostname)
598-
elif (ssl_is_advisory and
599-
ssl.verify_mode == ssl_module.CERT_NONE and
600-
resp == b'N'):
601-
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
602-
# since the only way to get ssl_is_advisory is from sslmode=prefer
603-
# (or sslmode=allow). But be extra sure to disallow insecure
604-
# connections when the ssl context asks for real security.
605-
pass
606-
else:
607-
raise ConnectionError(
608-
'PostgreSQL server at "{}:{}" rejected SSL upgrade'.format(
609-
host, port))
610-
611-
sock = sock.dup() # Must come before tr.close()
612-
finally:
613-
writer.close()
614-
await compat.wait_closed(writer)
615-
616-
try:
617-
return await conn_factory(sock=sock) # Must come after tr.close()
618-
except (Exception, asyncio.CancelledError):
619-
sock.close()
620-
raise
665+
async def _cancel(*, loop, addr, params: _ConnectionParameters,
666+
backend_pid, backend_secret):
621667

668+
class CancelProto(asyncio.Protocol):
622669

623-
async def _create_ssl_connection(protocol_factory, host, port, *,
624-
loop, ssl_context, ssl_is_advisory=False):
625-
return await _negotiate_ssl_connection(
626-
host, port,
627-
functools.partial(loop.create_connection, protocol_factory),
628-
loop=loop,
629-
ssl=ssl_context,
630-
server_hostname=host,
631-
ssl_is_advisory=ssl_is_advisory)
670+
def __init__(self):
671+
self.on_disconnect = _create_future(loop)
632672

673+
def connection_lost(self, exc):
674+
if not self.on_disconnect.done():
675+
self.on_disconnect.set_result(True)
633676

634-
async def _open_connection(*, loop, addr, params: _ConnectionParameters):
635677
if isinstance(addr, str):
636-
r, w = await asyncio.open_unix_connection(addr)
678+
tr, pr = await loop.create_unix_connection(CancelProto, addr)
637679
else:
638680
if params.ssl:
639-
r, w = await _negotiate_ssl_connection(
681+
tr, pr = await _create_ssl_connection(
682+
CancelProto,
640683
*addr,
641-
asyncio.open_connection,
642684
loop=loop,
643-
ssl=params.ssl,
644-
server_hostname=addr[0],
685+
ssl_context=params.ssl,
645686
ssl_is_advisory=params.ssl_is_advisory)
646687
else:
647-
r, w = await asyncio.open_connection(*addr)
648-
_set_nodelay(_get_socket(w.transport))
688+
tr, pr = await loop.create_connection(
689+
CancelProto, *addr)
690+
_set_nodelay(_get_socket(tr))
691+
692+
# Pack a CancelRequest message
693+
msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)
649694

650-
return r, w
695+
try:
696+
tr.write(msg)
697+
await pr.on_disconnect
698+
finally:
699+
tr.close()
651700

652701

653702
def _get_socket(transport):

asyncpg/connection.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import collections
1111
import collections.abc
1212
import itertools
13-
import struct
1413
import sys
1514
import time
1615
import traceback
@@ -1186,24 +1185,16 @@ async def _cleanup_stmts(self):
11861185
await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT)
11871186

11881187
async def _cancel(self, waiter):
1189-
r = w = None
1190-
11911188
try:
11921189
# Open new connection to the server
1193-
r, w = await connect_utils._open_connection(
1194-
loop=self._loop, addr=self._addr, params=self._params)
1195-
1196-
# Pack CancelRequest message
1197-
msg = struct.pack('!llll', 16, 80877102,
1198-
self._protocol.backend_pid,
1199-
self._protocol.backend_secret)
1200-
1201-
w.write(msg)
1202-
await r.read() # Wait until EOF
1190+
await connect_utils._cancel(
1191+
loop=self._loop, addr=self._addr, params=self._params,
1192+
backend_pid=self._protocol.backend_pid,
1193+
backend_secret=self._protocol.backend_secret)
12031194
except ConnectionResetError as ex:
12041195
# On some systems Postgres will reset the connection
12051196
# after processing the cancellation command.
1206-
if r is None and not waiter.done():
1197+
if not waiter.done():
12071198
waiter.set_exception(ex)
12081199
except asyncio.CancelledError:
12091200
# There are two scenarios in which the cancellation
@@ -1221,9 +1212,6 @@ async def _cancel(self, waiter):
12211212
compat.current_asyncio_task(self._loop))
12221213
if not waiter.done():
12231214
waiter.set_result(None)
1224-
if w is not None:
1225-
w.close()
1226-
await compat.wait_closed(w)
12271215

12281216
def _cancel_current_command(self, waiter):
12291217
self._cancellations.add(self._loop.create_task(self._cancel(waiter)))

0 commit comments

Comments
 (0)