@@ -390,24 +390,24 @@ def wrap_socket(self, sock, server_side=False,
390
390
server_hostname = None , session = None ):
391
391
# SSLSocket class handles server_hostname encoding before it calls
392
392
# ctx._wrap_socket()
393
- return self .sslsocket_class (
393
+ return self .sslsocket_class . _create (
394
394
sock = sock ,
395
395
server_side = server_side ,
396
396
do_handshake_on_connect = do_handshake_on_connect ,
397
397
suppress_ragged_eofs = suppress_ragged_eofs ,
398
398
server_hostname = server_hostname ,
399
- _context = self ,
400
- _session = session
399
+ context = self ,
400
+ session = session
401
401
)
402
402
403
403
def wrap_bio (self , incoming , outgoing , server_side = False ,
404
404
server_hostname = None , session = None ):
405
405
# Need to encode server_hostname here because _wrap_bio() can only
406
406
# handle ASCII str.
407
- return self .sslobject_class (
407
+ return self .sslobject_class . _create (
408
408
incoming , outgoing , server_side = server_side ,
409
409
server_hostname = self ._encode_hostname (server_hostname ),
410
- session = session , _context = self ,
410
+ session = session , context = self ,
411
411
)
412
412
413
413
def set_npn_protocols (self , npn_protocols ):
@@ -612,14 +612,23 @@ class SSLObject:
612
612
* Any form of network IO incluging methods such as ``recv`` and ``send``.
613
613
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
614
614
"""
615
+ def __init__ (self , * args , ** kwargs ):
616
+ raise TypeError (
617
+ f"{ self .__class__ .__name__ } does not have a public "
618
+ f"constructor. Instances are returned by SSLContext.wrap_bio()."
619
+ )
615
620
616
- def __init__ (self , incoming , outgoing , server_side = False ,
617
- server_hostname = None , session = None , _context = None ):
618
- self ._sslobj = _context ._wrap_bio (
621
+ @classmethod
622
+ def _create (cls , incoming , outgoing , server_side = False ,
623
+ server_hostname = None , session = None , context = None ):
624
+ self = cls .__new__ (cls )
625
+ sslobj = context ._wrap_bio (
619
626
incoming , outgoing , server_side = server_side ,
620
627
server_hostname = server_hostname ,
621
628
owner = self , session = session
622
629
)
630
+ self ._sslobj = sslobj
631
+ return self
623
632
624
633
@property
625
634
def context (self ):
@@ -741,72 +750,48 @@ def version(self):
741
750
class SSLSocket (socket ):
742
751
"""This class implements a subtype of socket.socket that wraps
743
752
the underlying OS socket in an SSL context when necessary, and
744
- provides read and write methods over that channel."""
745
-
746
- def __init__ (self , sock = None , keyfile = None , certfile = None ,
747
- server_side = False , cert_reqs = CERT_NONE ,
748
- ssl_version = PROTOCOL_TLS , ca_certs = None ,
749
- do_handshake_on_connect = True ,
750
- family = AF_INET , type = SOCK_STREAM , proto = 0 , fileno = None ,
751
- suppress_ragged_eofs = True , npn_protocols = None , ciphers = None ,
752
- server_hostname = None ,
753
- _context = None , _session = None ):
754
-
755
- if _context :
756
- self ._context = _context
757
- else :
758
- if server_side and not certfile :
759
- raise ValueError ("certfile must be specified for server-side "
760
- "operations" )
761
- if keyfile and not certfile :
762
- raise ValueError ("certfile must be specified" )
763
- if certfile and not keyfile :
764
- keyfile = certfile
765
- self ._context = SSLContext (ssl_version )
766
- self ._context .verify_mode = cert_reqs
767
- if ca_certs :
768
- self ._context .load_verify_locations (ca_certs )
769
- if certfile :
770
- self ._context .load_cert_chain (certfile , keyfile )
771
- if npn_protocols :
772
- self ._context .set_npn_protocols (npn_protocols )
773
- if ciphers :
774
- self ._context .set_ciphers (ciphers )
775
- self .keyfile = keyfile
776
- self .certfile = certfile
777
- self .cert_reqs = cert_reqs
778
- self .ssl_version = ssl_version
779
- self .ca_certs = ca_certs
780
- self .ciphers = ciphers
781
- # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get
782
- # mixed in.
753
+ provides read and write methods over that channel. """
754
+
755
+ def __init__ (self , * args , ** kwargs ):
756
+ raise TypeError (
757
+ f"{ self .__class__ .__name__ } does not have a public "
758
+ f"constructor. Instances are returned by "
759
+ f"SSLContext.wrap_socket()."
760
+ )
761
+
762
+ @classmethod
763
+ def _create (cls , sock , server_side = False , do_handshake_on_connect = True ,
764
+ suppress_ragged_eofs = True , server_hostname = None ,
765
+ context = None , session = None ):
783
766
if sock .getsockopt (SOL_SOCKET , SO_TYPE ) != SOCK_STREAM :
784
767
raise NotImplementedError ("only stream sockets are supported" )
785
768
if server_side :
786
769
if server_hostname :
787
770
raise ValueError ("server_hostname can only be specified "
788
771
"in client mode" )
789
- if _session is not None :
772
+ if session is not None :
790
773
raise ValueError ("session can only be specified in "
791
774
"client mode" )
792
- if self . _context .check_hostname and not server_hostname :
775
+ if context .check_hostname and not server_hostname :
793
776
raise ValueError ("check_hostname requires server_hostname" )
794
- self ._session = _session
777
+
778
+ kwargs = dict (
779
+ family = sock .family , type = sock .type , proto = sock .proto ,
780
+ fileno = sock .fileno ()
781
+ )
782
+ self = cls .__new__ (cls , ** kwargs )
783
+ super (SSLSocket , self ).__init__ (** kwargs )
784
+ self .settimeout (sock .gettimeout ())
785
+ sock .detach ()
786
+
787
+ self ._context = context
788
+ self ._session = session
789
+ self ._closed = False
790
+ self ._sslobj = None
795
791
self .server_side = server_side
796
- self .server_hostname = self . _context ._encode_hostname (server_hostname )
792
+ self .server_hostname = context ._encode_hostname (server_hostname )
797
793
self .do_handshake_on_connect = do_handshake_on_connect
798
794
self .suppress_ragged_eofs = suppress_ragged_eofs
799
- if sock is not None :
800
- super ().__init__ (family = sock .family ,
801
- type = sock .type ,
802
- proto = sock .proto ,
803
- fileno = sock .fileno ())
804
- self .settimeout (sock .gettimeout ())
805
- sock .detach ()
806
- elif fileno is not None :
807
- super ().__init__ (fileno = fileno )
808
- else :
809
- super ().__init__ (family = family , type = type , proto = proto )
810
795
811
796
# See if we are connected
812
797
try :
@@ -818,8 +803,6 @@ def __init__(self, sock=None, keyfile=None, certfile=None,
818
803
else :
819
804
connected = True
820
805
821
- self ._closed = False
822
- self ._sslobj = None
823
806
self ._connected = connected
824
807
if connected :
825
808
# create the SSL object
@@ -834,10 +817,10 @@ def __init__(self, sock=None, keyfile=None, certfile=None,
834
817
# non-blocking
835
818
raise ValueError ("do_handshake_on_connect should not be specified for non-blocking sockets" )
836
819
self .do_handshake ()
837
-
838
820
except (OSError , ValueError ):
839
821
self .close ()
840
822
raise
823
+ return self
841
824
842
825
@property
843
826
def context (self ):
@@ -1184,12 +1167,25 @@ def wrap_socket(sock, keyfile=None, certfile=None,
1184
1167
do_handshake_on_connect = True ,
1185
1168
suppress_ragged_eofs = True ,
1186
1169
ciphers = None ):
1187
- return SSLSocket (sock = sock , keyfile = keyfile , certfile = certfile ,
1188
- server_side = server_side , cert_reqs = cert_reqs ,
1189
- ssl_version = ssl_version , ca_certs = ca_certs ,
1190
- do_handshake_on_connect = do_handshake_on_connect ,
1191
- suppress_ragged_eofs = suppress_ragged_eofs ,
1192
- ciphers = ciphers )
1170
+
1171
+ if server_side and not certfile :
1172
+ raise ValueError ("certfile must be specified for server-side "
1173
+ "operations" )
1174
+ if keyfile and not certfile :
1175
+ raise ValueError ("certfile must be specified" )
1176
+ context = SSLContext (ssl_version )
1177
+ context .verify_mode = cert_reqs
1178
+ if ca_certs :
1179
+ context .load_verify_locations (ca_certs )
1180
+ if certfile :
1181
+ context .load_cert_chain (certfile , keyfile )
1182
+ if ciphers :
1183
+ context .set_ciphers (ciphers )
1184
+ return context .wrap_socket (
1185
+ sock = sock , server_side = server_side ,
1186
+ do_handshake_on_connect = do_handshake_on_connect ,
1187
+ suppress_ragged_eofs = suppress_ragged_eofs
1188
+ )
1193
1189
1194
1190
# some utility functions
1195
1191
0 commit comments