diff --git a/src/Servers/Kestrel/Transport.Sockets/src/PublicAPI.Unshipped.txt b/src/Servers/Kestrel/Transport.Sockets/src/PublicAPI.Unshipped.txt index 7daf5c9863fe..5708e0985dfd 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/PublicAPI.Unshipped.txt +++ b/src/Servers/Kestrel/Transport.Sockets/src/PublicAPI.Unshipped.txt @@ -7,3 +7,6 @@ Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportFactory.Bin ~Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportFactory.SocketTransportFactory(Microsoft.Extensions.Options.IOptions! options, Microsoft.Extensions.Logging.ILoggerFactory! loggerFactory) -> void static Microsoft.AspNetCore.Hosting.WebHostBuilderSocketExtensions.UseSockets(this Microsoft.AspNetCore.Hosting.IWebHostBuilder! hostBuilder) -> Microsoft.AspNetCore.Hosting.IWebHostBuilder! static Microsoft.AspNetCore.Hosting.WebHostBuilderSocketExtensions.UseSockets(this Microsoft.AspNetCore.Hosting.IWebHostBuilder! hostBuilder, System.Action! configureOptions) -> Microsoft.AspNetCore.Hosting.IWebHostBuilder! +static Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateDefaultBoundListenSocket(System.Net.EndPoint! endpoint) -> System.Net.Sockets.Socket! +Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.get -> System.Func! +Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.set -> void \ No newline at end of file diff --git a/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs b/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs index 6b618e5d5f76..43e5f19172bc 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs +++ b/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs @@ -3,6 +3,7 @@ using System; using System.Buffers; +using System.ComponentModel; using System.Diagnostics; using System.IO.Pipelines; using System.Net; @@ -23,7 +24,6 @@ internal sealed class SocketConnectionListener : IConnectionListener private Socket? _listenSocket; private int _settingsIndex; private readonly SocketTransportOptions _options; - private SafeSocketHandle? _socketHandle; public EndPoint EndPoint { get; private set; } @@ -92,43 +92,13 @@ internal void Bind() } Socket listenSocket; - - switch (EndPoint) + try { - case FileHandleEndPoint fileHandle: - _socketHandle = new SafeSocketHandle((IntPtr)fileHandle.FileHandle, ownsHandle: true); - listenSocket = new Socket(_socketHandle); - break; - case UnixDomainSocketEndPoint unix: - listenSocket = new Socket(unix.AddressFamily, SocketType.Stream, ProtocolType.Unspecified); - BindSocket(); - break; - case IPEndPoint ip: - listenSocket = new Socket(ip.AddressFamily, SocketType.Stream, ProtocolType.Tcp); - - // Kestrel expects IPv6Any to bind to both IPv6 and IPv4 - if (ip.Address == IPAddress.IPv6Any) - { - listenSocket.DualMode = true; - } - BindSocket(); - break; - default: - listenSocket = new Socket(EndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); - BindSocket(); - break; + listenSocket = _options.CreateBoundListenSocket(EndPoint); } - - void BindSocket() + catch (SocketException e) when (e.SocketErrorCode == SocketError.AddressAlreadyInUse) { - try - { - listenSocket.Bind(EndPoint); - } - catch (SocketException e) when (e.SocketErrorCode == SocketError.AddressAlreadyInUse) - { - throw new AddressInUseException(e.Message, e); - } + throw new AddressInUseException(e.Message, e); } Debug.Assert(listenSocket.LocalEndPoint != null); @@ -193,8 +163,6 @@ void BindSocket() public ValueTask UnbindAsync(CancellationToken cancellationToken = default) { _listenSocket?.Dispose(); - - _socketHandle?.Dispose(); return default; } @@ -202,8 +170,6 @@ public ValueTask DisposeAsync() { _listenSocket?.Dispose(); - _socketHandle?.Dispose(); - // Dispose the memory pool _memoryPool.Dispose(); diff --git a/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportOptions.cs b/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportOptions.cs index 1c6ab6114886..6e2cb7ca4735 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportOptions.cs +++ b/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportOptions.cs @@ -3,6 +3,9 @@ using System; using System.Buffers; +using System.Net; +using System.Net.Sockets; +using Microsoft.AspNetCore.Connections; namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets { @@ -65,6 +68,78 @@ public class SocketTransportOptions /// public bool UnsafePreferInlineScheduling { get; set; } + /// + /// A function used to create a new to listen with. If + /// not set, is used. + /// + /// + /// Implementors are expected to call on the + /// . Please note that + /// calls as part of its implementation, so implementors + /// using this method do not need to call it again. + /// + public Func CreateBoundListenSocket { get; set; } = CreateDefaultBoundListenSocket; + + /// + /// Creates a default instance of for the given + /// that can be used by a connection listener to listen for inbound requests. + /// is called by this method. + /// + /// + /// An . + /// + /// + /// A instance. + /// + public static Socket CreateDefaultBoundListenSocket(EndPoint endpoint) + { + Socket listenSocket; + switch (endpoint) + { + case FileHandleEndPoint fileHandle: + // We're passing "ownsHandle: true" here even though we don't necessarily + // own the handle because Socket.Dispose will clean-up everything safely. + // If the handle was already closed or disposed then the socket will + // be torn down gracefully, and if the caller never cleans up their handle + // then we'll do it for them. + // + // If we don't do this then we run the risk of Kestrel hanging because the + // the underlying socket is never closed and the transport manager can hang + // when it attempts to stop. + listenSocket = new Socket( + new SafeSocketHandle((IntPtr)fileHandle.FileHandle, ownsHandle: true) + ); + break; + case UnixDomainSocketEndPoint unix: + listenSocket = new Socket(unix.AddressFamily, SocketType.Stream, ProtocolType.Unspecified); + break; + case IPEndPoint ip: + listenSocket = new Socket(ip.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + + // Kestrel expects IPv6Any to bind to both IPv6 and IPv4 + if (ip.Address == IPAddress.IPv6Any) + { + listenSocket.DualMode = true; + } + + break; + default: + listenSocket = new Socket(endpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + break; + } + + // we only call Bind on sockets that were _not_ created + // using a file handle; the handle is already bound + // to an underlying socket so doing it again causes the + // underlying PAL call to throw + if (!(endpoint is FileHandleEndPoint)) + { + listenSocket.Bind(endpoint); + } + + return listenSocket; + } + internal Func> MemoryPoolFactory { get; set; } = System.Buffers.PinnedBlockMemoryPoolFactory.Create; } } diff --git a/src/Servers/Kestrel/test/Sockets.BindTests/SocketTransportOptionsTests.cs b/src/Servers/Kestrel/test/Sockets.BindTests/SocketTransportOptionsTests.cs new file mode 100644 index 000000000000..b094510bfde1 --- /dev/null +++ b/src/Servers/Kestrel/test/Sockets.BindTests/SocketTransportOptionsTests.cs @@ -0,0 +1,104 @@ +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.FunctionalTests; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Hosting; +using Xunit; + +namespace Sockets.BindTests +{ + public class SocketTransportOptionsTests : LoggedTestBase + { + [Theory] + [MemberData(nameof(GetEndpoints))] + public async Task SocketTransportCallsCreateBoundListenSocket(EndPoint endpointToTest) + { + var wasCalled = false; + + Socket CreateListenSocket(EndPoint endpoint) + { + wasCalled = true; + return SocketTransportOptions.CreateDefaultBoundListenSocket(endpoint); + } + + using var host = CreateWebHost( + endpointToTest, + options => + { + options.CreateBoundListenSocket = CreateListenSocket; + } + ); + + await host.StartAsync(); + Assert.True(wasCalled, $"Expected {nameof(SocketTransportOptions.CreateBoundListenSocket)} to be called."); + await host.StopAsync(); + } + + [Theory] + [MemberData(nameof(GetEndpoints))] + public void CreateDefaultBoundListenSocket_BindsForAllEndPoints(EndPoint endpoint) + { + using var listenSocket = SocketTransportOptions.CreateDefaultBoundListenSocket(endpoint); + Assert.NotNull(listenSocket.LocalEndPoint); + } + + // static to ensure that the underlying handle doesn't get disposed + // when a local reference is GCed by the iterator in GetEndPoints + private static Socket _fileHandleSocket; + + public static IEnumerable GetEndpoints() + { + // IPv4 + yield return new object[] {new IPEndPoint(IPAddress.Loopback, 0)}; + // IPv6 + yield return new object[] {new IPEndPoint(IPAddress.IPv6Loopback, 0)}; + // Unix sockets + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + yield return new object[] + { + new UnixDomainSocketEndPoint($"/tmp/{DateTime.UtcNow:yyyyMMddTHHmmss.fff}.sock") + }; + } + + // file handle + // slightly messy but allows us to create a FileHandleEndPoint + // from the underlying OS handle used by the socket + _fileHandleSocket = new( + AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp + ); + _fileHandleSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + yield return new object[] + { + new FileHandleEndPoint((ulong) _fileHandleSocket.Handle, FileHandleType.Auto) + }; + + // TODO: other endpoint types? + } + + private IHost CreateWebHost(EndPoint endpoint, Action configureSocketOptions) => + TransportSelector.GetHostBuilder() + .ConfigureWebHost( + webHostBuilder => + { + webHostBuilder + .UseSockets(configureSocketOptions) + .UseKestrel(options => options.Listen(endpoint)) + .Configure( + app => app.Run(ctx => ctx.Response.WriteAsync("Hello World")) + ); + } + ) + .ConfigureServices(AddTestLogging) + .Build(); + } +}