Skip to content

Commit 3aa03f4

Browse files
authored
Use handshake timeout for Tls listener callback (#62177)
1 parent f445836 commit 3aa03f4

File tree

5 files changed

+158
-71
lines changed

5 files changed

+158
-71
lines changed

src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System.Security.Cryptography.X509Certificates;
66
using Microsoft.AspNetCore.Server.Kestrel.Core;
77
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
8-
using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;
98
using Microsoft.AspNetCore.Server.Kestrel.Https;
109
using Microsoft.AspNetCore.Server.Kestrel.Https.Internal;
1110
using Microsoft.Extensions.DependencyInjection;
@@ -198,15 +197,6 @@ public static ListenOptions UseHttps(this ListenOptions listenOptions, HttpsConn
198197
listenOptions.IsTls = true;
199198
listenOptions.HttpsOptions = httpsOptions;
200199

201-
if (httpsOptions.TlsClientHelloBytesCallback is not null)
202-
{
203-
listenOptions.Use(next =>
204-
{
205-
var middleware = new TlsListenerMiddleware(next, httpsOptions.TlsClientHelloBytesCallback);
206-
return middleware.OnTlsClientHelloAsync;
207-
});
208-
}
209-
210200
listenOptions.Use(next =>
211201
{
212202
var middleware = new HttpsConnectionMiddleware(next, httpsOptions, listenOptions.Protocols, loggerFactory, metrics);

src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
1818
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
1919
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
20+
using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;
2021
using Microsoft.Extensions.Logging;
2122
using Microsoft.Extensions.Logging.Abstractions;
2223

@@ -44,6 +45,9 @@ internal sealed class HttpsConnectionMiddleware
4445
private readonly Func<TlsHandshakeCallbackContext, ValueTask<SslServerAuthenticationOptions>>? _tlsCallbackOptions;
4546
private readonly object? _tlsCallbackOptionsState;
4647

48+
// Captures raw TLS client hello and invokes a user callback if any
49+
private readonly TlsListener? _tlsListener;
50+
4751
// Internal for testing
4852
internal readonly HttpProtocols _httpProtocols;
4953

@@ -112,6 +116,11 @@ public HttpsConnectionMiddleware(ConnectionDelegate next, HttpsConnectionAdapter
112116
(RemoteCertificateValidationCallback?)null : RemoteCertificateValidationCallback;
113117

114118
_sslStreamFactory = s => new SslStream(s, leaveInnerStreamOpen: false, userCertificateValidationCallback: remoteCertificateValidationCallback);
119+
120+
if (options.TlsClientHelloBytesCallback is not null)
121+
{
122+
_tlsListener = new TlsListener(options.TlsClientHelloBytesCallback);
123+
}
115124
}
116125

117126
internal HttpsConnectionMiddleware(
@@ -162,6 +171,10 @@ public async Task OnConnectionAsync(ConnectionContext context)
162171
using var cancellationTokenSource = _ctsPool.Rent();
163172
cancellationTokenSource.CancelAfter(_handshakeTimeout);
164173

174+
if (_tlsListener is not null)
175+
{
176+
await _tlsListener.OnTlsClientHelloAsync(context, cancellationTokenSource.Token);
177+
}
165178
if (_tlsCallbackOptions is null)
166179
{
167180
await DoOptionsBasedHandshakeAsync(context, sslStream, feature, cancellationTokenSource.Token);

src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs renamed to src/Servers/Kestrel/Core/src/Middleware/TlsListener.cs

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,27 @@
77

88
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;
99

10-
internal sealed class TlsListenerMiddleware
10+
internal sealed class TlsListener
1111
{
12-
private readonly ConnectionDelegate _next;
1312
private readonly Action<ConnectionContext, ReadOnlySequence<byte>> _tlsClientHelloBytesCallback;
1413

15-
public TlsListenerMiddleware(ConnectionDelegate next, Action<ConnectionContext, ReadOnlySequence<byte>> tlsClientHelloBytesCallback)
14+
public TlsListener(Action<ConnectionContext, ReadOnlySequence<byte>> tlsClientHelloBytesCallback)
1615
{
17-
_next = next;
1816
_tlsClientHelloBytesCallback = tlsClientHelloBytesCallback;
1917
}
2018

2119
/// <summary>
2220
/// Sniffs the TLS Client Hello message, and invokes a callback if found.
2321
/// </summary>
24-
internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
22+
internal async Task OnTlsClientHelloAsync(ConnectionContext connection, CancellationToken cancellationToken)
2523
{
2624
var input = connection.Transport.Input;
2725
ClientHelloParseState parseState = ClientHelloParseState.NotEnoughData;
26+
short recordLength = -1; // remembers the length of TLS record to not re-parse header on every iteration
2827

2928
while (true)
3029
{
31-
var result = await input.ReadAsync();
30+
var result = await input.ReadAsync(cancellationToken);
3231
var buffer = result.Buffer;
3332

3433
try
@@ -40,7 +39,7 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
4039
break;
4140
}
4241

43-
parseState = TryParseClientHello(buffer, out var clientHelloBytes);
42+
parseState = TryParseClientHello(buffer, ref recordLength, out var clientHelloBytes);
4443
if (parseState == ClientHelloParseState.NotEnoughData)
4544
{
4645
// if no data will be added, and we still lack enough bytes
@@ -74,8 +73,6 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
7473
}
7574
}
7675
}
77-
78-
await _next(connection);
7976
}
8077

8178
/// <summary>
@@ -85,10 +82,25 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
8582
/// TLS 1.2: https://datatracker.ietf.org/doc/html/rfc5246#section-6.2
8683
/// TLS 1.3: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1
8784
/// </summary>
88-
private static ClientHelloParseState TryParseClientHello(ReadOnlySequence<byte> buffer, out ReadOnlySequence<byte> clientHelloBytes)
85+
private static ClientHelloParseState TryParseClientHello(ReadOnlySequence<byte> buffer, ref short recordLength, out ReadOnlySequence<byte> clientHelloBytes)
8986
{
9087
clientHelloBytes = default;
9188

89+
// in case bad actor will be sending a TLS client hello one byte at a time
90+
// and we know the expected length of TLS client hello,
91+
// we can check and fail quickly here instead of re-parsing the TLS client hello "header" on each iteration
92+
if (recordLength != -1 && buffer.Length < 5 + recordLength)
93+
{
94+
return ClientHelloParseState.NotEnoughData;
95+
}
96+
97+
// this means we finally got a full tls record, so we can return without parsing again
98+
if (recordLength != -1)
99+
{
100+
clientHelloBytes = buffer.Slice(0, 5 + recordLength);
101+
return ClientHelloParseState.ValidTlsClientHello;
102+
}
103+
92104
if (buffer.Length < 6)
93105
{
94106
return ClientHelloParseState.NotEnoughData;
@@ -109,7 +121,7 @@ private static ClientHelloParseState TryParseClientHello(ReadOnlySequence<byte>
109121
}
110122

111123
// Record length
112-
if (!reader.TryReadBigEndian(out short recordLength))
124+
if (!reader.TryReadBigEndian(out recordLength))
113125
{
114126
return ClientHelloParseState.NotTlsClientHello;
115127
}

0 commit comments

Comments
 (0)