@@ -504,6 +504,95 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
504
504
return addrs , params , config
505
505
506
506
507
+ class TLSUpgradeProto (asyncio .Protocol ):
508
+ def __init__ (self , loop , host , port , ssl_context , ssl_is_advisory ):
509
+ self .on_data = _create_future (loop )
510
+ self .host = host
511
+ self .port = port
512
+ self .ssl_context = ssl_context
513
+ self .ssl_is_advisory = ssl_is_advisory
514
+
515
+ def data_received (self , data ):
516
+ if data == b'S' :
517
+ self .on_data .set_result (True )
518
+ elif (self .ssl_is_advisory and
519
+ self .ssl_context .verify_mode == ssl_module .CERT_NONE and
520
+ data == b'N' ):
521
+ # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522
+ # since the only way to get ssl_is_advisory is from
523
+ # sslmode=prefer (or sslmode=allow). But be extra sure to
524
+ # disallow insecure connections when the ssl context asks for
525
+ # real security.
526
+ self .on_data .set_result (False )
527
+ else :
528
+ self .on_data .set_exception (
529
+ ConnectionError (
530
+ 'PostgreSQL server at "{host}:{port}" '
531
+ 'rejected SSL upgrade' .format (
532
+ host = self .host , port = self .port )))
533
+
534
+ def connection_lost (self , exc ):
535
+ if not self .on_data .done ():
536
+ if exc is None :
537
+ exc = ConnectionError ('unexpected connection_lost() call' )
538
+ self .on_data .set_exception (exc )
539
+
540
+
541
+ async def _create_ssl_connection (protocol_factory , host , port , * ,
542
+ loop , ssl_context , ssl_is_advisory = False ):
543
+
544
+ if ssl_context is True :
545
+ ssl_context = ssl_module .create_default_context ()
546
+
547
+ tr , pr = await loop .create_connection (
548
+ lambda : TLSUpgradeProto (loop , host , port ,
549
+ ssl_context , ssl_is_advisory ),
550
+ host , port )
551
+
552
+ tr .write (struct .pack ('!ll' , 8 , 80877103 )) # SSLRequest message.
553
+
554
+ try :
555
+ do_ssl_upgrade = await pr .on_data
556
+ except (Exception , asyncio .CancelledError ):
557
+ tr .close ()
558
+ raise
559
+
560
+ if hasattr (loop , 'start_tls' ):
561
+ if do_ssl_upgrade :
562
+ try :
563
+ new_tr = await loop .start_tls (
564
+ tr , pr , ssl_context , server_hostname = host )
565
+ except (Exception , asyncio .CancelledError ):
566
+ tr .close ()
567
+ raise
568
+ else :
569
+ new_tr = tr
570
+
571
+ pg_proto = protocol_factory ()
572
+ pg_proto .connection_made (new_tr )
573
+ new_tr .set_protocol (pg_proto )
574
+
575
+ return new_tr , pg_proto
576
+ else :
577
+ conn_factory = functools .partial (
578
+ loop .create_connection , protocol_factory )
579
+
580
+ if do_ssl_upgrade :
581
+ conn_factory = functools .partial (
582
+ conn_factory , ssl = ssl_context , server_hostname = host )
583
+
584
+ sock = _get_socket (tr )
585
+ sock = sock .dup ()
586
+ _set_nodelay (sock )
587
+ tr .close ()
588
+
589
+ try :
590
+ return await conn_factory (sock = sock )
591
+ except (Exception , asyncio .CancelledError ):
592
+ sock .close ()
593
+ raise
594
+
595
+
507
596
async def _connect_addr (* , addr , loop , timeout , params , config ,
508
597
connection_class ):
509
598
assert loop is not None
@@ -526,8 +615,6 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
526
615
else :
527
616
connector = loop .create_connection (proto_factory , * addr )
528
617
529
- connector = asyncio .ensure_future (connector )
530
-
531
618
before = time .monotonic ()
532
619
try :
533
620
tr , pr = await asyncio .wait_for (
@@ -575,79 +662,41 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
575
662
raise last_error
576
663
577
664
578
- async def _negotiate_ssl_connection (host , port , conn_factory , * , loop , ssl ,
579
- server_hostname , ssl_is_advisory = False ):
580
- # Note: ssl_is_advisory only affects behavior when the server does not
581
- # accept SSLRequests. If the SSLRequest is accepted but either the SSL
582
- # negotiation fails or the PostgreSQL user isn't permitted to use SSL,
583
- # there's nothing that would attempt to reconnect with a non-SSL socket.
584
- reader , writer = await asyncio .open_connection (host , port )
585
-
586
- tr = writer .transport
587
- try :
588
- sock = _get_socket (tr )
589
- _set_nodelay (sock )
590
-
591
- writer .write (struct .pack ('!ll' , 8 , 80877103 )) # SSLRequest message.
592
- await writer .drain ()
593
- resp = await reader .readexactly (1 )
594
-
595
- if resp == b'S' :
596
- conn_factory = functools .partial (
597
- conn_factory , ssl = ssl , server_hostname = server_hostname )
598
- elif (ssl_is_advisory and
599
- ssl .verify_mode == ssl_module .CERT_NONE and
600
- resp == b'N' ):
601
- # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
602
- # since the only way to get ssl_is_advisory is from sslmode=prefer
603
- # (or sslmode=allow). But be extra sure to disallow insecure
604
- # connections when the ssl context asks for real security.
605
- pass
606
- else :
607
- raise ConnectionError (
608
- 'PostgreSQL server at "{}:{}" rejected SSL upgrade' .format (
609
- host , port ))
610
-
611
- sock = sock .dup () # Must come before tr.close()
612
- finally :
613
- writer .close ()
614
- await compat .wait_closed (writer )
615
-
616
- try :
617
- return await conn_factory (sock = sock ) # Must come after tr.close()
618
- except (Exception , asyncio .CancelledError ):
619
- sock .close ()
620
- raise
665
+ async def _cancel (* , loop , addr , params : _ConnectionParameters ,
666
+ backend_pid , backend_secret ):
621
667
668
+ class CancelProto (asyncio .Protocol ):
622
669
623
- async def _create_ssl_connection (protocol_factory , host , port , * ,
624
- loop , ssl_context , ssl_is_advisory = False ):
625
- return await _negotiate_ssl_connection (
626
- host , port ,
627
- functools .partial (loop .create_connection , protocol_factory ),
628
- loop = loop ,
629
- ssl = ssl_context ,
630
- server_hostname = host ,
631
- ssl_is_advisory = ssl_is_advisory )
670
+ def __init__ (self ):
671
+ self .on_disconnect = _create_future (loop )
632
672
673
+ def connection_lost (self , exc ):
674
+ if not self .on_disconnect .done ():
675
+ self .on_disconnect .set_result (True )
633
676
634
- async def _open_connection (* , loop , addr , params : _ConnectionParameters ):
635
677
if isinstance (addr , str ):
636
- r , w = await asyncio . open_unix_connection ( addr )
678
+ tr , pr = await loop . create_unix_connection ( CancelProto , addr )
637
679
else :
638
680
if params .ssl :
639
- r , w = await _negotiate_ssl_connection (
681
+ tr , pr = await _create_ssl_connection (
682
+ CancelProto ,
640
683
* addr ,
641
- asyncio .open_connection ,
642
684
loop = loop ,
643
- ssl = params .ssl ,
644
- server_hostname = addr [0 ],
685
+ ssl_context = params .ssl ,
645
686
ssl_is_advisory = params .ssl_is_advisory )
646
687
else :
647
- r , w = await asyncio .open_connection (* addr )
648
- _set_nodelay (_get_socket (w .transport ))
688
+ tr , pr = await loop .create_connection (
689
+ CancelProto , * addr )
690
+ _set_nodelay (_get_socket (tr ))
691
+
692
+ # Pack a CancelRequest message
693
+ msg = struct .pack ('!llll' , 16 , 80877102 , backend_pid , backend_secret )
649
694
650
- return r , w
695
+ try :
696
+ tr .write (msg )
697
+ await pr .on_disconnect
698
+ finally :
699
+ tr .close ()
651
700
652
701
653
702
def _get_socket (transport ):
0 commit comments