diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py index 3819b492..e41b8235 100644 --- a/src/OpenSSL/SSL.py +++ b/src/OpenSSL/SSL.py @@ -820,6 +820,14 @@ def explode(*args, **kwargs): # type: ignore[no-untyped-def] "Getting group name is not supported by the linked OpenSSL version", ) +_requires_client_hello_cb = _make_requires( + getattr(_lib, "Cryptography_HAS_CLIENT_HELLO_CB", 0), + ( + "SSL client hello callback is not supported by the " + "linked cryptographic library" + ), +) + class Session: """ @@ -905,6 +913,7 @@ def __init__(self, method: int) -> None: self._info_callback = None self._keylog_callback = None self._tlsext_servername_callback = None + self._client_hello_callback = None self._app_data = None self._alpn_select_helper: _ALPNSelectHelper | None = None self._alpn_select_callback: _ALPNSelectCallback | None = None @@ -1762,6 +1771,33 @@ def wrapper(ssl, alert, arg): # type: ignore[no-untyped-def] self._context, self._tlsext_servername_callback ) + @_requires_client_hello_cb + @_require_not_used + def set_ssl_ctx_client_hello_callback( + self, callback: Callable[[Connection], None] + ) -> None: + """ + Specify a callback function to be called when the ClientHello + is received. + + :param callback: The callback function. It will be invoked with one + argument, the Connection instance. + + .. versionadded:: 0.13 + """ + + @wraps(callback) + def wrapper(ssl, alert, arg): # type: ignore[no-untyped-def] + callback(Connection._reverse_mapping[ssl]) + return 1 + + self._client_hello_callback = _ffi.callback( + "int (*)(SSL *, int *, void *)", wrapper + ) + _lib.SSL_CTX_set_client_hello_cb( + self._context, self._client_hello_callback, _ffi.NULL + ) + @_require_not_used def set_tlsext_use_srtp(self, profiles: bytes) -> None: """ @@ -3262,3 +3298,50 @@ def wrapper(ssl, where, return_code): # type: ignore[no-untyped-def] "void (*)(const SSL *, int, int)", wrapper ) _lib.SSL_set_info_callback(self._ssl, self._info_callback) + + @_requires_client_hello_cb + def get_client_hello_extension(self, type: int) -> bytes: + """ + Returns the client extension with the specified type. If the extensions + cannot be found an empty byte string is returned. + + :param type: The type of extension to retrieve as integer. + :return: A byte array containing the extension or an empty byte array + if the extension is absent. + """ + out = _ffi.new("const unsigned char **") + outlen = _ffi.new("size_t *") + _lib.SSL_client_hello_get0_ext(self._ssl, type, out, outlen) + + if not outlen: + return b"" + + return _ffi.buffer(out[0], outlen[0])[:] + + @_requires_client_hello_cb + def get_client_hello_extensions_present(self) -> list[int]: + """ + Returns a list of the types of the client hello extensions + that are present in the ClientHello message. + """ + # SSL_client_hello_get1_extensions_present returns a new array + # allocated by OpenSSL_malloc + data = _ffi.new("int **") + data_len = _ffi.new("size_t *") + rc = _lib.SSL_client_hello_get1_extensions_present( + self._ssl, data, data_len + ) + + _openssl_assert(rc == 1) + + if not data_len: + return [] + + # OpenSSL returns the number of items and FFI wants the numbers of + # types, so multiply it by the size of each item (int) + data_gc = _ffi.gc(data[0], _lib.OPENSSL_free) + + buf = _ffi.buffer(data_gc, data_len[0] * _ffi.sizeof("int")) + retarray = _ffi.from_buffer("int[]", buf) + + return list(retarray) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index b4506afc..8d9a0e59 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -2313,6 +2313,60 @@ def select(conn: Connection, options: list[bytes]) -> bytes: interact_in_memory(server, client) assert select_args == [(server, [b"http/1.1", b"spdy/2"])] + @pytest.mark.skipif( + not getattr(_lib, "Cryptography_HAS_CLIENT_HELLO_CB", None), + reason="Client hello callback unavailable in crypto library", + ) + def test_client_hello_callback(self) -> None: + """ + We can handle exceptions in the ALPN select callback. + """ + client_hello_extensions = {} + + def client_hello_callback(conn: Connection) -> None: + for ext in conn.get_client_hello_extensions_present(): + client_hello_extensions[ext] = conn.get_client_hello_extension( + ext + ) + + client_context = Context(SSLv23_METHOD) + client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) + + server_context = Context(SSLv23_METHOD) + server_context.set_ssl_ctx_client_hello_callback(client_hello_callback) + + # Necessary to actually accept the connection + server_context.use_privatekey( + load_privatekey(FILETYPE_PEM, server_key_pem) + ) + server_context.use_certificate( + load_certificate(FILETYPE_PEM, server_cert_pem) + ) + + # Do a little connection to trigger the logic + server = Connection(server_context, None) + server.set_accept_state() + + client = Connection(client_context, None) + client.set_tlsext_host_name(b"unitest.example.com") + client.set_connect_state() + + interact_in_memory(server, client) + + # Servername indication has extensions number 0 + # ALPN has extension number 16 + assert 0 in client_hello_extensions + assert 16 in client_hello_extensions + + # OpenSSL does not expose good APIs to parse hello extensions. Instead + # of implementing parsing them just for the unit test we hardcode the + # string we expect to see + assert ( + client_hello_extensions[0] + == b"\x00\x16\x00\x00\x13unitest.example.com" + ) + assert client_hello_extensions[16] == b"\x00\x10\x08http/1.1\x06spdy/2" + class TestSession: """