diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Client/SocketConnectionFactory.cs b/src/Servers/Kestrel/Transport.Sockets/src/Client/SocketConnectionFactory.cs index 7465ece15f98..94f28c3b616e 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/Client/SocketConnectionFactory.cs +++ b/src/Servers/Kestrel/Transport.Sockets/src/Client/SocketConnectionFactory.cs @@ -81,7 +81,6 @@ public async ValueTask ConnectAsync(EndPoint endpoint, Cancel _outputOptions, _options.WaitForDataBeforeAllocatingBuffer); - socketConnection.Start(); return socketConnection; } diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.DuplexPipe.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.DuplexPipe.cs new file mode 100644 index 000000000000..1b6e2eccc0b6 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.DuplexPipe.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO.Pipelines; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + internal sealed partial class SocketConnection + { + // We could implement this on SocketConnection to remove an extra allocation but this is a + // bit cleaner + private class SocketDuplexPipe : IDuplexPipe + { + public SocketDuplexPipe(SocketConnection connection) + { + Input = new SocketPipeReader(connection); + Output = new SocketPipeWriter(connection); + } + + public PipeReader Input { get; } + + public PipeWriter Output { get; } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.PipeReader.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.PipeReader.cs new file mode 100644 index 000000000000..9c394d9cc05b --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.PipeReader.cs @@ -0,0 +1,77 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO.Pipelines; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + internal sealed partial class SocketConnection + { + private class SocketPipeReader : PipeReader + { + private readonly SocketConnection _socketConnection; + private readonly PipeReader _reader; + + public SocketPipeReader(SocketConnection socketConnection) + { + _socketConnection = socketConnection; + _reader = socketConnection.InnerTransport.Input; + } + + public override void AdvanceTo(SequencePosition consumed) + { + _reader.AdvanceTo(consumed); + } + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + _reader.AdvanceTo(consumed, examined); + } + + public override void CancelPendingRead() + { + _reader.CancelPendingRead(); + } + + public override void Complete(Exception? exception = null) + { + _reader.Complete(exception); + } + + public override ValueTask CompleteAsync(Exception? exception = null) + { + return _reader.CompleteAsync(exception); + } + + public override Task CopyToAsync(PipeWriter destination, CancellationToken cancellationToken = default) + { + _socketConnection.EnsureStarted(); + return _reader.CopyToAsync(destination, cancellationToken); + } + + public override Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default) + { + _socketConnection.EnsureStarted(); + return _reader.CopyToAsync(destination, cancellationToken); + } + + protected override ValueTask ReadAtLeastAsyncCore(int minimumSize, CancellationToken cancellationToken) + { + _socketConnection.EnsureStarted(); + return _reader.ReadAtLeastAsync(minimumSize, cancellationToken); + } + + public override ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + _socketConnection.EnsureStarted(); + return _reader.ReadAsync(cancellationToken); + } + + public override bool TryRead(out ReadResult result) + { + _socketConnection.EnsureStarted(); + return _reader.TryRead(out result); + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.PipeWriter.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.PipeWriter.cs new file mode 100644 index 000000000000..8193d7af6200 --- /dev/null +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.PipeWriter.cs @@ -0,0 +1,68 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO.Pipelines; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + internal sealed partial class SocketConnection + { + private class SocketPipeWriter : PipeWriter + { + private readonly SocketConnection _socketConnection; + private readonly PipeWriter _writer; + + public SocketPipeWriter(SocketConnection socketConnection) + { + _socketConnection = socketConnection; + _writer = socketConnection.InnerTransport.Output; + } + + public override bool CanGetUnflushedBytes => _writer.CanGetUnflushedBytes; + + public override long UnflushedBytes => _writer.UnflushedBytes; + + public override void Advance(int bytes) + { + _writer.Advance(bytes); + } + + public override void CancelPendingFlush() + { + _writer.CancelPendingFlush(); + } + + public override void Complete(Exception? exception = null) + { + _writer.Complete(exception); + } + + public override ValueTask CompleteAsync(Exception? exception = null) + { + return _writer.CompleteAsync(exception); + } + + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + _socketConnection.EnsureStarted(); + return _writer.WriteAsync(source, cancellationToken); + } + + public override ValueTask FlushAsync(CancellationToken cancellationToken = default) + { + _socketConnection.EnsureStarted(); + return _writer.FlushAsync(cancellationToken); + } + + public override Memory GetMemory(int sizeHint = 0) + { + return _writer.GetMemory(sizeHint); + } + + public override Span GetSpan(int sizeHint = 0) + { + return _writer.GetSpan(sizeHint); + } + } + } +} diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs index 17e662a69f0a..7eb4a649f103 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs @@ -33,6 +33,7 @@ internal sealed partial class SocketConnection : TransportConnection private readonly TaskCompletionSource _waitForConnectionClosedTcs = new TaskCompletionSource(); private bool _connectionClosed; private readonly bool _waitForData; + private int _connectionStarted; internal SocketConnection(Socket socket, MemoryPool memoryPool, @@ -67,31 +68,32 @@ internal SocketConnection(Socket socket, var pair = DuplexPipe.CreateConnectionPair(inputOptions, outputOptions); - // Set the transport and connection id - Transport = _originalTransport = pair.Transport; + _originalTransport = pair.Transport; Application = pair.Application; + Transport = new SocketDuplexPipe(this); + InitiaizeFeatures(); } + public IDuplexPipe InnerTransport => _originalTransport; + public PipeWriter Input => Application.Output; public PipeReader Output => Application.Input; public override MemoryPool MemoryPool { get; } - public void Start() + private void EnsureStarted() { - try + if (_connectionStarted == 1 || Interlocked.CompareExchange(ref _connectionStarted, 1, 0) == 1) { - // Spawn send and receive logic - _receivingTask = DoReceive(); - _sendingTask = DoSend(); - } - catch (Exception ex) - { - _trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}.{nameof(Start)}."); + return; } + + // Offload these to avoid potentially blocking the first read/write/flush + _receivingTask = Task.Run(DoReceive); + _sendingTask = Task.Run(DoSend); } public override void Abort(ConnectionAbortedException abortReason) @@ -106,6 +108,9 @@ public override void Abort(ConnectionAbortedException abortReason) // Only called after connection middleware is complete which means the ConnectionClosed token has fired. public override async ValueTask DisposeAsync() { + // Just in case we haven't started the connection, start it here so we can clean up properly. + EnsureStarted(); + _originalTransport.Input.Complete(); _originalTransport.Output.Complete(); @@ -125,7 +130,7 @@ public override async ValueTask DisposeAsync() } catch (Exception ex) { - _trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}.{nameof(Start)}."); + _trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}."); } finally { diff --git a/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs b/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs index 485dc9c99746..7bd8b820ee88 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs +++ b/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs @@ -136,8 +136,6 @@ internal void Bind() setting.OutputOptions, waitForData: _options.WaitForDataBeforeAllocatingBuffer); - connection.Start(); - _settingsIndex = (_settingsIndex + 1) % _settingsCount; return connection; diff --git a/src/Servers/Kestrel/test/Sockets.FunctionalTests/SocketTranspotTests.cs b/src/Servers/Kestrel/test/Sockets.FunctionalTests/SocketTranspotTests.cs index ae619a1b7e21..bc70982088f2 100644 --- a/src/Servers/Kestrel/test/Sockets.FunctionalTests/SocketTranspotTests.cs +++ b/src/Servers/Kestrel/test/Sockets.FunctionalTests/SocketTranspotTests.cs @@ -1,15 +1,22 @@ +using System.Buffers; +using System.Diagnostics; using System.Net; using System.Net.Http; using System.Net.Sockets; +using System.Text; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.FunctionalTests; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Hosting; using Xunit; +using KestrelHttpVersion = Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpVersion; +using KestrelHttpMethod = Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpMethod; + namespace Sockets.FunctionalTests { public class SocketTranspotTests : LoggedTestBase @@ -50,5 +57,158 @@ public async Task SocketTransportExposesSocketsFeature() await host.StopAsync(); } + + [Fact] + public async Task CanReadAndWriteFromSocketFeatureInConnectionMiddleware() + { + var builder = TransportSelector.GetHostBuilder() + .ConfigureWebHost(webHostBuilder => + { + webHostBuilder + .UseKestrel(options => + { + options.ListenAnyIP(0, lo => + { + lo.Use(next => + { + return async connection => + { + var socket = connection.Features.Get().Socket; + Assert.NotNull(socket); + + var buffer = new byte[4096]; + + var read = await socket.ReceiveAsync(buffer, SocketFlags.None); + + static void ParseHttp(ReadOnlySequence data) + { + var parser = new HttpParser(); + var handler = new ParserHandler(); + + var reader = new SequenceReader(data); + + // Assume we can parse the HTTP request in a single buffer + Assert.True(parser.ParseRequestLine(handler, ref reader)); + Assert.True(parser.ParseHeaders(handler, ref reader)); + + Assert.Equal(KestrelHttpMethod.Get, handler.HttpMethod); + Assert.Equal(KestrelHttpVersion.Http11, handler.HttpVersion); + } + + ParseHttp(new ReadOnlySequence(buffer[0..read])); + + await socket.SendAsync(Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n"), SocketFlags.None); + }; + }); + }); + }) + .Configure(app => { }); + }) + .ConfigureServices(AddTestLogging); + + using var host = builder.Build(); + using var client = new HttpClient(); + + await host.StartAsync(); + + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); + response.EnsureSuccessStatusCode(); + + await host.StopAsync(); + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.Linux)] + [OSSkipCondition(OperatingSystems.MacOSX)] + public async Task CanDuplicateAndCloseSocketFeatureInConnectionMiddleware() + { + var builder = TransportSelector.GetHostBuilder() + .ConfigureWebHost(webHostBuilder => + { + webHostBuilder + .UseKestrel(options => + { + options.ListenAnyIP(0, lo => + { + lo.Use(next => + { + return async connection => + { + var originalSocket = connection.Features.Get().Socket; + Assert.NotNull(originalSocket); + + var si = originalSocket.DuplicateAndClose(Process.GetCurrentProcess().Id); + + using var socket = new Socket(si); + var buffer = new byte[4096]; + + var read = await socket.ReceiveAsync(buffer, SocketFlags.None); + + static void ParseHttp(ReadOnlySequence data) + { + var parser = new HttpParser(); + var handler = new ParserHandler(); + + var reader = new SequenceReader(data); + + // Assume we can parse the HTTP request in a single buffer + Assert.True(parser.ParseRequestLine(handler, ref reader)); + Assert.True(parser.ParseHeaders(handler, ref reader)); + + Assert.Equal(KestrelHttpMethod.Get, handler.HttpMethod); + Assert.Equal(KestrelHttpVersion.Http11, handler.HttpVersion); + } + + ParseHttp(new ReadOnlySequence(buffer[0..read])); + + await socket.SendAsync(Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n"), SocketFlags.None); + }; + }); + }); + }) + .Configure(app => { }); + }) + .ConfigureServices(AddTestLogging); + + using var host = builder.Build(); + using var client = new HttpClient(); + + await host.StartAsync(); + + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); + response.EnsureSuccessStatusCode(); + + await host.StopAsync(); + } + + private class ParserHandler : IHttpRequestLineHandler, IHttpHeadersHandler + { + public KestrelHttpVersion HttpVersion { get; set; } + public KestrelHttpMethod HttpMethod { get; set; } + public Dictionary Headers = new(); + + public void OnHeader(ReadOnlySpan name, ReadOnlySpan value) + { + Headers[Encoding.ASCII.GetString(name)] = Encoding.ASCII.GetString(value); + } + + public void OnHeadersComplete(bool endStream) + { + } + + public void OnStartLine(HttpVersionAndMethod versionAndMethod, TargetOffsetPathLength targetPath, Span startLine) + { + HttpMethod = versionAndMethod.Method; + HttpVersion = versionAndMethod.Version; + } + + public void OnStaticIndexedHeader(int index) + { + } + + public void OnStaticIndexedHeader(int index, ReadOnlySpan value) + { + } + } } }