Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/OpenSSL/SSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,14 @@ def explode(*args, **kwargs): # type: ignore[no-untyped-def]
"Getting group name is not supported by the linked OpenSSL version",
)

_requires_client_hello_cb = _make_requires(
getattr(_lib, "Cryptography_HAS_CLIENT_HELLO_CB", 0),
(
"SSL client hello callback is not supported by the "
"linked cryptographic library"
),
)


class Session:
"""
Expand Down Expand Up @@ -905,6 +913,7 @@ def __init__(self, method: int) -> None:
self._info_callback = None
self._keylog_callback = None
self._tlsext_servername_callback = None
self._client_hello_callback = None
self._app_data = None
self._alpn_select_helper: _ALPNSelectHelper | None = None
self._alpn_select_callback: _ALPNSelectCallback | None = None
Expand Down Expand Up @@ -1762,6 +1771,33 @@ def wrapper(ssl, alert, arg): # type: ignore[no-untyped-def]
self._context, self._tlsext_servername_callback
)

@_requires_client_hello_cb
@_require_not_used
def set_ssl_ctx_client_hello_callback(
self, callback: Callable[[Connection], None]
) -> None:
"""
Specify a callback function to be called when the ClientHello
is received.

:param callback: The callback function. It will be invoked with one
argument, the Connection instance.

.. versionadded:: 0.13
"""

@wraps(callback)
def wrapper(ssl, alert, arg): # type: ignore[no-untyped-def]
callback(Connection._reverse_mapping[ssl])
return 1

self._client_hello_callback = _ffi.callback(
"int (*)(SSL *, int *, void *)", wrapper
)
_lib.SSL_CTX_set_client_hello_cb(
self._context, self._client_hello_callback, _ffi.NULL
)

@_require_not_used
def set_tlsext_use_srtp(self, profiles: bytes) -> None:
"""
Expand Down Expand Up @@ -3262,3 +3298,50 @@ def wrapper(ssl, where, return_code): # type: ignore[no-untyped-def]
"void (*)(const SSL *, int, int)", wrapper
)
_lib.SSL_set_info_callback(self._ssl, self._info_callback)

@_requires_client_hello_cb
def get_client_hello_extension(self, type: int) -> bytes:
"""
Returns the client extension with the specified type. If the extensions
cannot be found an empty byte string is returned.

:param type: The type of extension to retrieve as integer.
:return: A byte array containing the extension or an empty byte array
if the extension is absent.
"""
out = _ffi.new("const unsigned char **")
outlen = _ffi.new("size_t *")
_lib.SSL_client_hello_get0_ext(self._ssl, type, out, outlen)

if not outlen:
return b""

return _ffi.buffer(out[0], outlen[0])[:]

@_requires_client_hello_cb
def get_client_hello_extensions_present(self) -> list[int]:
"""
Returns a list of the types of the client hello extensions
that are present in the ClientHello message.
"""
# SSL_client_hello_get1_extensions_present returns a new array
# allocated by OpenSSL_malloc
data = _ffi.new("int **")
data_len = _ffi.new("size_t *")
rc = _lib.SSL_client_hello_get1_extensions_present(
self._ssl, data, data_len
)

_openssl_assert(rc == 1)

if not data_len:
return []

# OpenSSL returns the number of items and FFI wants the numbers of
# types, so multiply it by the size of each item (int)
data_gc = _ffi.gc(data[0], _lib.OPENSSL_free)

buf = _ffi.buffer(data_gc, data_len[0] * _ffi.sizeof("int"))
retarray = _ffi.from_buffer("int[]", buf)

return list(retarray)
54 changes: 54 additions & 0 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2313,6 +2313,60 @@ def select(conn: Connection, options: list[bytes]) -> bytes:
interact_in_memory(server, client)
assert select_args == [(server, [b"http/1.1", b"spdy/2"])]

@pytest.mark.skipif(
not getattr(_lib, "Cryptography_HAS_CLIENT_HELLO_CB", None),
reason="Client hello callback unavailable in crypto library",
)
def test_client_hello_callback(self) -> None:
"""
We can handle exceptions in the ALPN select callback.
"""
client_hello_extensions = {}

def client_hello_callback(conn: Connection) -> None:
for ext in conn.get_client_hello_extensions_present():
client_hello_extensions[ext] = conn.get_client_hello_extension(
ext
)

client_context = Context(SSLv23_METHOD)
client_context.set_alpn_protos([b"http/1.1", b"spdy/2"])

server_context = Context(SSLv23_METHOD)
server_context.set_ssl_ctx_client_hello_callback(client_hello_callback)

# Necessary to actually accept the connection
server_context.use_privatekey(
load_privatekey(FILETYPE_PEM, server_key_pem)
)
server_context.use_certificate(
load_certificate(FILETYPE_PEM, server_cert_pem)
)

# Do a little connection to trigger the logic
server = Connection(server_context, None)
server.set_accept_state()

client = Connection(client_context, None)
client.set_tlsext_host_name(b"unitest.example.com")
client.set_connect_state()

interact_in_memory(server, client)

# Servername indication has extensions number 0
# ALPN has extension number 16
assert 0 in client_hello_extensions
assert 16 in client_hello_extensions

# OpenSSL does not expose good APIs to parse hello extensions. Instead
# of implementing parsing them just for the unit test we hardcode the
# string we expect to see
assert (
client_hello_extensions[0]
== b"\x00\x16\x00\x00\x13unitest.example.com"
)
assert client_hello_extensions[16] == b"\x00\x10\x08http/1.1\x06spdy/2"


class TestSession:
"""
Expand Down
Loading