Skip to content

Commit dde0488

Browse files
authored
[3.10] Raise TypeError if SSLSocket is passed to asyncio transport-based methods (GH-31442). (GH-31443)
(cherry picked from commit 1f9d4c9) Co-authored-by: Andrew Svetlov <[email protected]>
1 parent ea3e042 commit dde0488

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

Lib/asyncio/base_events.py

+15
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ def _set_nodelay(sock):
202202
pass
203203

204204

205+
def _check_ssl_socket(sock):
206+
if ssl is not None and isinstance(sock, ssl.SSLSocket):
207+
raise TypeError("Socket cannot be of type SSLSocket")
208+
209+
205210
class _SendfileFallbackProtocol(protocols.Protocol):
206211
def __init__(self, transp):
207212
if not isinstance(transp, transports._FlowControlMixin):
@@ -863,6 +868,7 @@ async def sock_sendfile(self, sock, file, offset=0, count=None,
863868
*, fallback=True):
864869
if self._debug and sock.gettimeout() != 0:
865870
raise ValueError("the socket must be non-blocking")
871+
_check_ssl_socket(sock)
866872
self._check_sendfile_params(sock, file, offset, count)
867873
try:
868874
return await self._sock_sendfile_native(sock, file,
@@ -1004,6 +1010,9 @@ async def create_connection(
10041010
raise ValueError(
10051011
'ssl_handshake_timeout is only meaningful with ssl')
10061012

1013+
if sock is not None:
1014+
_check_ssl_socket(sock)
1015+
10071016
if happy_eyeballs_delay is not None and interleave is None:
10081017
# If using happy eyeballs, default to interleave addresses by family
10091018
interleave = 1
@@ -1437,6 +1446,9 @@ async def create_server(
14371446
raise ValueError(
14381447
'ssl_handshake_timeout is only meaningful with ssl')
14391448

1449+
if sock is not None:
1450+
_check_ssl_socket(sock)
1451+
14401452
if host is not None or port is not None:
14411453
if sock is not None:
14421454
raise ValueError(
@@ -1531,6 +1543,9 @@ async def connect_accepted_socket(
15311543
raise ValueError(
15321544
'ssl_handshake_timeout is only meaningful with ssl')
15331545

1546+
if sock is not None:
1547+
_check_ssl_socket(sock)
1548+
15341549
transport, protocol = await self._create_connection_transport(
15351550
sock, protocol_factory, ssl, '', server_side=True,
15361551
ssl_handshake_timeout=ssl_handshake_timeout)

Lib/asyncio/selector_events.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@ def _test_selector_event(selector, fd, event):
4040
return bool(key.events & event)
4141

4242

43-
def _check_ssl_socket(sock):
44-
if ssl is not None and isinstance(sock, ssl.SSLSocket):
45-
raise TypeError("Socket cannot be of type SSLSocket")
46-
47-
4843
class BaseSelectorEventLoop(base_events.BaseEventLoop):
4944
"""Selector event loop.
5045
@@ -357,7 +352,7 @@ async def sock_recv(self, sock, n):
357352
The maximum amount of data to be received at once is specified by
358353
nbytes.
359354
"""
360-
_check_ssl_socket(sock)
355+
base_events._check_ssl_socket(sock)
361356
if self._debug and sock.gettimeout() != 0:
362357
raise ValueError("the socket must be non-blocking")
363358
try:
@@ -398,7 +393,7 @@ async def sock_recv_into(self, sock, buf):
398393
The received data is written into *buf* (a writable buffer).
399394
The return value is the number of bytes written.
400395
"""
401-
_check_ssl_socket(sock)
396+
base_events._check_ssl_socket(sock)
402397
if self._debug and sock.gettimeout() != 0:
403398
raise ValueError("the socket must be non-blocking")
404399
try:
@@ -439,7 +434,7 @@ async def sock_sendall(self, sock, data):
439434
raised, and there is no way to determine how much data, if any, was
440435
successfully processed by the receiving end of the connection.
441436
"""
442-
_check_ssl_socket(sock)
437+
base_events._check_ssl_socket(sock)
443438
if self._debug and sock.gettimeout() != 0:
444439
raise ValueError("the socket must be non-blocking")
445440
try:
@@ -488,7 +483,7 @@ async def sock_connect(self, sock, address):
488483
489484
This method is a coroutine.
490485
"""
491-
_check_ssl_socket(sock)
486+
base_events._check_ssl_socket(sock)
492487
if self._debug and sock.gettimeout() != 0:
493488
raise ValueError("the socket must be non-blocking")
494489

@@ -553,7 +548,7 @@ async def sock_accept(self, sock):
553548
object usable to send and receive data on the connection, and address
554549
is the address bound to the socket on the other end of the connection.
555550
"""
556-
_check_ssl_socket(sock)
551+
base_events._check_ssl_socket(sock)
557552
if self._debug and sock.gettimeout() != 0:
558553
raise ValueError("the socket must be non-blocking")
559554
fut = self.create_future()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Raise :exc:`TypeError` if :class:`ssl.SSLSocket` is passed to
2+
transport-based APIs.

0 commit comments

Comments
 (0)