diff --git a/benchmarks/Kestrel.Performance/FrameWritingBenchmark.cs b/benchmarks/Kestrel.Performance/FrameWritingBenchmark.cs index bc0fa2058..82f943f52 100644 --- a/benchmarks/Kestrel.Performance/FrameWritingBenchmark.cs +++ b/benchmarks/Kestrel.Performance/FrameWritingBenchmark.cs @@ -24,16 +24,13 @@ public class FrameWritingBenchmark private static readonly Func _psuedoAsyncTaskFunc = (obj) => _psuedoAsyncTask; private readonly TestFrame _frame; - private readonly IPipe _outputPipe; + private (IPipeConnection Transport, IPipeConnection Application) _pair; private readonly byte[] _writeData; public FrameWritingBenchmark() { - var pipeFactory = new PipeFactory(); - - _outputPipe = pipeFactory.Create(); - _frame = MakeFrame(pipeFactory); + _frame = MakeFrame(); _writeData = Encoding.ASCII.GetBytes("Hello, World!"); } @@ -93,9 +90,11 @@ public Task WriteAsync() return _frame.ResponseBody.WriteAsync(_writeData, 0, _writeData.Length, default(CancellationToken)); } - private TestFrame MakeFrame(PipeFactory pipeFactory) + private TestFrame MakeFrame() { - var input = pipeFactory.Create(); + var pipeFactory = new PipeFactory(); + var pair = pipeFactory.CreateConnectionPair(); + _pair = pair; var serviceContext = new ServiceContext { @@ -109,8 +108,8 @@ private TestFrame MakeFrame(PipeFactory pipeFactory) { ServiceContext = serviceContext, PipeFactory = pipeFactory, - Input = input.Reader, - Output = _outputPipe + Application = pair.Application, + Transport = pair.Transport }); frame.Reset(); @@ -122,7 +121,7 @@ private TestFrame MakeFrame(PipeFactory pipeFactory) [IterationCleanup] public void Cleanup() { - var reader = _outputPipe.Reader; + var reader = _pair.Application.Input; if (reader.TryRead(out var readResult)) { reader.Advance(readResult.Buffer.End); diff --git a/benchmarks/Kestrel.Performance/ResponseHeadersWritingBenchmark.cs b/benchmarks/Kestrel.Performance/ResponseHeadersWritingBenchmark.cs index 980ef8464..52ffb6bb7 100644 --- a/benchmarks/Kestrel.Performance/ResponseHeadersWritingBenchmark.cs +++ b/benchmarks/Kestrel.Performance/ResponseHeadersWritingBenchmark.cs @@ -110,8 +110,7 @@ private Task PlaintextChunkedWithCookie() public void Setup() { var pipeFactory = new PipeFactory(); - var input = pipeFactory.Create(); - var output = pipeFactory.Create(); + var pair = pipeFactory.CreateConnectionPair(); var serviceContext = new ServiceContext { @@ -126,8 +125,8 @@ public void Setup() ServiceContext = serviceContext, PipeFactory = pipeFactory, TimeoutControl = new MockTimeoutControl(), - Input = input.Reader, - Output = output + Application = pair.Application, + Transport = pair.Transport }); frame.Reset(); diff --git a/src/Kestrel.Core/Adapter/Internal/AdaptedPipeline.cs b/src/Kestrel.Core/Adapter/Internal/AdaptedPipeline.cs index 823fa67b5..72c1780d0 100644 --- a/src/Kestrel.Core/Adapter/Internal/AdaptedPipeline.cs +++ b/src/Kestrel.Core/Adapter/Internal/AdaptedPipeline.cs @@ -5,36 +5,36 @@ using System.IO; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using System.IO.Pipelines; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal { - public class AdaptedPipeline + public class AdaptedPipeline : IPipeConnection { private const int MinAllocBufferSize = 2048; - private readonly IKestrelTrace _trace; - private readonly IPipe _transportOutputPipe; - private readonly IPipeReader _transportInputPipeReader; + private readonly IPipeConnection _transport; + private readonly IPipeConnection _application; - public AdaptedPipeline(IPipeReader transportInputPipeReader, - IPipe transportOutputPipe, + public AdaptedPipeline(IPipeConnection transport, + IPipeConnection application, IPipe inputPipe, - IPipe outputPipe, - IKestrelTrace trace) + IPipe outputPipe) { - _transportInputPipeReader = transportInputPipeReader; - _transportOutputPipe = transportOutputPipe; + _transport = transport; + _application = application; Input = inputPipe; Output = outputPipe; - _trace = trace; } public IPipe Input { get; } public IPipe Output { get; } + IPipeReader IPipeConnection.Input => Input.Reader; + + IPipeWriter IPipeConnection.Output => Output.Writer; + public async Task RunAsync(Stream stream) { var inputTask = ReadInputAsync(stream); @@ -65,7 +65,7 @@ private async Task WriteOutputAsync(Stream stream) if (result.IsCancelled) { // Forward the cancellation to the transport pipe - _transportOutputPipe.Reader.CancelPendingRead(); + _application.Input.CancelPendingRead(); break; } @@ -104,7 +104,7 @@ private async Task WriteOutputAsync(Stream stream) finally { Output.Reader.Complete(); - _transportOutputPipe.Writer.Complete(error); + _transport.Output.Complete(); } } @@ -161,8 +161,12 @@ private async Task ReadInputAsync(Stream stream) Input.Writer.Complete(error); // The application could have ended the input pipe so complete // the transport pipe as well - _transportInputPipeReader.Complete(); + _transport.Input.Complete(); } } + + public void Dispose() + { + } } } diff --git a/src/Kestrel.Core/Internal/ConnectionHandler.cs b/src/Kestrel.Core/Internal/ConnectionHandler.cs index 0dc7fbcbf..d94de1242 100644 --- a/src/Kestrel.Core/Internal/ConnectionHandler.cs +++ b/src/Kestrel.Core/Internal/ConnectionHandler.cs @@ -1,31 +1,27 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.IO.Pipelines; -using System.Net; -using System.Threading; -using Microsoft.AspNetCore.Hosting.Server; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Protocols; using Microsoft.AspNetCore.Protocols.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; +using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal { - public class ConnectionHandler : IConnectionHandler + public class ConnectionHandler : IConnectionHandler { - private static long _lastFrameConnectionId = long.MinValue; - - private readonly ListenOptions _listenOptions; private readonly ServiceContext _serviceContext; - private readonly IHttpApplication _application; + private readonly ConnectionDelegate _connectionDelegate; - public ConnectionHandler(ListenOptions listenOptions, ServiceContext serviceContext, IHttpApplication application) + public ConnectionHandler(ServiceContext serviceContext, ConnectionDelegate connectionDelegate) { - _listenOptions = listenOptions; _serviceContext = serviceContext; - _application = application; + _connectionDelegate = connectionDelegate; } public void OnConnection(IFeatureCollection features) @@ -34,89 +30,57 @@ public void OnConnection(IFeatureCollection features) var transportFeature = connectionContext.Features.Get(); - var inputPipe = transportFeature.PipeFactory.Create(GetInputPipeOptions(transportFeature.InputWriterScheduler)); - var outputPipe = transportFeature.PipeFactory.Create(GetOutputPipeOptions(transportFeature.OutputReaderScheduler)); + // REVIEW: Unfortunately, we still need to use the service context to create the pipes since the settings + // for the scheduler and limits are specified here + var inputOptions = GetInputPipeOptions(_serviceContext, transportFeature.InputWriterScheduler); + var outputOptions = GetOutputPipeOptions(_serviceContext, transportFeature.OutputReaderScheduler); - var connectionId = CorrelationIdGenerator.GetNextId(); - var frameConnectionId = Interlocked.Increment(ref _lastFrameConnectionId); + var pair = connectionContext.PipeFactory.CreateConnectionPair(inputOptions, outputOptions); // Set the transport and connection id - connectionContext.ConnectionId = connectionId; - transportFeature.Connection = new PipeConnection(inputPipe.Reader, outputPipe.Writer); - var applicationConnection = new PipeConnection(outputPipe.Reader, inputPipe.Writer); - - if (!_serviceContext.ConnectionManager.NormalConnectionCount.TryLockOne()) - { - var goAway = new RejectionConnection(inputPipe, outputPipe, connectionId, _serviceContext) - { - Connection = applicationConnection - }; - - connectionContext.Features.Set(goAway); + connectionContext.ConnectionId = CorrelationIdGenerator.GetNextId(); + connectionContext.Transport = pair.Transport; - goAway.Reject(); - return; - } - - var frameConnectionContext = new FrameConnectionContext - { - ConnectionId = connectionId, - FrameConnectionId = frameConnectionId, - ServiceContext = _serviceContext, - PipeFactory = connectionContext.PipeFactory, - ConnectionAdapters = _listenOptions.ConnectionAdapters, - Input = inputPipe, - Output = outputPipe - }; + // This *must* be set before returning from OnConnection + transportFeature.Application = pair.Application; - var connectionFeature = connectionContext.Features.Get(); + // REVIEW: This task should be tracked by the server for graceful shutdown + // Today it's handled specifically for http but not for aribitrary middleware + _ = Execute(connectionContext); + } - if (connectionFeature != null) + private async Task Execute(ConnectionContext connectionContext) + { + try { - if (connectionFeature.LocalIpAddress != null) - { - frameConnectionContext.LocalEndPoint = new IPEndPoint(connectionFeature.LocalIpAddress, connectionFeature.LocalPort); - } - - if (connectionFeature.RemoteIpAddress != null) - { - frameConnectionContext.RemoteEndPoint = new IPEndPoint(connectionFeature.RemoteIpAddress, connectionFeature.RemotePort); - } + await _connectionDelegate(connectionContext); } - - var connection = new FrameConnection(frameConnectionContext) + catch (Exception ex) { - Connection = applicationConnection - }; - - connectionContext.Features.Set(connection); - - // Since data cannot be added to the inputPipe by the transport until OnConnection returns, - // Frame.ProcessRequestsAsync is guaranteed to unblock the transport thread before calling - // application code. - connection.StartRequestProcessing(_application); + _serviceContext.Log.LogCritical(0, ex, $"{nameof(ConnectionHandler)}.{nameof(Execute)}() {connectionContext.ConnectionId}"); + } } // Internal for testing - internal PipeOptions GetInputPipeOptions(IScheduler writerScheduler) => new PipeOptions + internal static PipeOptions GetInputPipeOptions(ServiceContext serviceContext, IScheduler writerScheduler) => new PipeOptions { - ReaderScheduler = _serviceContext.ThreadPool, + ReaderScheduler = serviceContext.ThreadPool, WriterScheduler = writerScheduler, - MaximumSizeHigh = _serviceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0, - MaximumSizeLow = _serviceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0 + MaximumSizeHigh = serviceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0, + MaximumSizeLow = serviceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0 }; - internal PipeOptions GetOutputPipeOptions(IScheduler readerScheduler) => new PipeOptions + internal static PipeOptions GetOutputPipeOptions(ServiceContext serviceContext, IScheduler readerScheduler) => new PipeOptions { ReaderScheduler = readerScheduler, - WriterScheduler = _serviceContext.ThreadPool, - MaximumSizeHigh = GetOutputResponseBufferSize(), - MaximumSizeLow = GetOutputResponseBufferSize() + WriterScheduler = serviceContext.ThreadPool, + MaximumSizeHigh = GetOutputResponseBufferSize(serviceContext), + MaximumSizeLow = GetOutputResponseBufferSize(serviceContext) }; - private long GetOutputResponseBufferSize() + private static long GetOutputResponseBufferSize(ServiceContext serviceContext) { - var bufferSize = _serviceContext.ServerOptions.Limits.MaxResponseBufferSize; + var bufferSize = serviceContext.ServerOptions.Limits.MaxResponseBufferSize; if (bufferSize == 0) { // 0 = no buffering so we need to configure the pipe so the the writer waits on the reader directly diff --git a/src/Kestrel.Core/Internal/ConnectionLimitBuilderExtensions.cs b/src/Kestrel.Core/Internal/ConnectionLimitBuilderExtensions.cs new file mode 100644 index 000000000..2c02fa664 --- /dev/null +++ b/src/Kestrel.Core/Internal/ConnectionLimitBuilderExtensions.cs @@ -0,0 +1,16 @@ +using Microsoft.AspNetCore.Protocols; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public static class ConnectionLimitBuilderExtensions + { + public static IConnectionBuilder UseConnectionLimit(this IConnectionBuilder builder, ServiceContext serviceContext) + { + return builder.Use(next => + { + var middleware = new ConnectionLimitMiddleware(next, serviceContext); + return middleware.OnConnectionAsync; + }); + } + } +} diff --git a/src/Kestrel.Core/Internal/ConnectionLimitMiddleware.cs b/src/Kestrel.Core/Internal/ConnectionLimitMiddleware.cs new file mode 100644 index 000000000..cd058bc23 --- /dev/null +++ b/src/Kestrel.Core/Internal/ConnectionLimitMiddleware.cs @@ -0,0 +1,32 @@ +using System.Threading.Tasks; +using Microsoft.AspNetCore.Protocols; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public class ConnectionLimitMiddleware + { + private readonly ServiceContext _serviceContext; + private readonly ConnectionDelegate _next; + + public ConnectionLimitMiddleware(ConnectionDelegate next, ServiceContext serviceContext) + { + _next = next; + _serviceContext = serviceContext; + } + + public Task OnConnectionAsync(ConnectionContext connection) + { + if (!_serviceContext.ConnectionManager.NormalConnectionCount.TryLockOne()) + { + KestrelEventSource.Log.ConnectionRejected(connection.ConnectionId); + _serviceContext.Log.ConnectionRejected(connection.ConnectionId); + connection.Transport.Input.Complete(); + connection.Transport.Output.Complete(); + return Task.CompletedTask; + } + + return _next(connection); + } + } +} diff --git a/src/Kestrel.Core/Internal/FrameConnection.cs b/src/Kestrel.Core/Internal/FrameConnection.cs index e1b2dbde3..f11c6102a 100644 --- a/src/Kestrel.Core/Internal/FrameConnection.cs +++ b/src/Kestrel.Core/Internal/FrameConnection.cs @@ -11,7 +11,6 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting.Server; using Microsoft.AspNetCore.Http.Features; -using Microsoft.AspNetCore.Protocols.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; @@ -21,14 +20,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal { - public class FrameConnection : IConnectionApplicationFeature, ITimeoutControl + public class FrameConnection : ITimeoutControl { private const int Http2ConnectionNotStarted = 0; private const int Http2ConnectionStarted = 1; private const int Http2ConnectionClosed = 2; private readonly FrameConnectionContext _context; - private List _adaptedConnections; + private IList _adaptedConnections; private readonly TaskCompletionSource _socketClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private Frame _frame; private Http2Connection _http2Connection; @@ -62,8 +61,6 @@ public FrameConnection(FrameConnectionContext context) public bool TimedOut { get; private set; } public string ConnectionId => _context.ConnectionId; - public IPipeWriter Input => _context.Input.Writer; - public IPipeReader Output => _context.Output.Reader; public IPEndPoint LocalEndPoint => _context.LocalEndPoint; public IPEndPoint RemoteEndPoint => _context.RemoteEndPoint; @@ -88,14 +85,12 @@ public FrameConnection(FrameConnectionContext context) private IKestrelTrace Log => _context.ServiceContext.Log; - public IPipeConnection Connection { get; set; } - - public void StartRequestProcessing(IHttpApplication application) + public Task StartRequestProcessing(IHttpApplication application) { - _lifetimeTask = ProcessRequestsAsync(application); + return _lifetimeTask = ProcessRequestsAsync(application); } - private async Task ProcessRequestsAsync(IHttpApplication application) + private async Task ProcessRequestsAsync(IHttpApplication httpApplication) { using (BeginConnectionScope()) { @@ -106,23 +101,22 @@ private async Task ProcessRequestsAsync(IHttpApplication app AdaptedPipeline adaptedPipeline = null; var adaptedPipelineTask = Task.CompletedTask; - var input = _context.Input.Reader; - var output = _context.Output; + var transport = _context.Transport; + var application = _context.Application; + if (_context.ConnectionAdapters.Count > 0) { - adaptedPipeline = new AdaptedPipeline(input, - output, - PipeFactory.Create(AdaptedInputPipeOptions), - PipeFactory.Create(AdaptedOutputPipeOptions), - Log); - - input = adaptedPipeline.Input.Reader; - output = adaptedPipeline.Output; + adaptedPipeline = new AdaptedPipeline(transport, + application, + PipeFactory.Create(AdaptedInputPipeOptions), + PipeFactory.Create(AdaptedOutputPipeOptions)); + + transport = adaptedPipeline; } // _frame must be initialized before adding the connection to the connection manager - CreateFrame(application, input, output); + CreateFrame(httpApplication, transport, application); // _http2Connection must be initialized before yield control to the transport thread, // to prevent a race condition where _http2Connection.Abort() is called just as @@ -130,12 +124,12 @@ private async Task ProcessRequestsAsync(IHttpApplication app _http2Connection = new Http2Connection(new Http2ConnectionContext { ConnectionId = _context.ConnectionId, - ServiceContext = _context.ServiceContext, + ServiceContext = _context.ServiceContext, PipeFactory = PipeFactory, LocalEndPoint = LocalEndPoint, RemoteEndPoint = RemoteEndPoint, - Input = input, - Output = output + Application = application, + Transport = transport }); // Do this before the first await so we don't yield control to the transport until we've @@ -153,7 +147,7 @@ private async Task ProcessRequestsAsync(IHttpApplication app if (_frame.ConnectionFeatures?.Get()?.ApplicationProtocol == "h2" && Interlocked.CompareExchange(ref _http2ConnectionState, Http2ConnectionStarted, Http2ConnectionNotStarted) == Http2ConnectionNotStarted) { - await _http2Connection.ProcessAsync(application); + await _http2Connection.ProcessAsync(httpApplication); } else { @@ -187,9 +181,9 @@ private async Task ProcessRequestsAsync(IHttpApplication app } } - internal void CreateFrame(IHttpApplication application, IPipeReader input, IPipe output) + internal void CreateFrame(IHttpApplication httpApplication, IPipeConnection transport, IPipeConnection application) { - _frame = new Frame(application, new FrameContext + _frame = new Frame(httpApplication, new FrameContext { ConnectionId = _context.ConnectionId, PipeFactory = PipeFactory, @@ -197,8 +191,8 @@ internal void CreateFrame(IHttpApplication application, IPip RemoteEndPoint = RemoteEndPoint, ServiceContext = _context.ServiceContext, TimeoutControl = this, - Input = input, - Output = output + Transport = transport, + Application = application }); } @@ -268,7 +262,7 @@ private async Task ApplyConnectionAdaptersAsync() var features = new FeatureCollection(); var connectionAdapters = _context.ConnectionAdapters; - var stream = new RawStream(_context.Input.Reader, _context.Output.Writer); + var stream = new RawStream(_context.Transport.Input, _context.Transport.Output); var adapterContext = new ConnectionAdapterContext(features, stream); _adaptedConnections = new List(connectionAdapters.Count); diff --git a/src/Kestrel.Core/Internal/FrameConnectionContext.cs b/src/Kestrel.Core/Internal/FrameConnectionContext.cs index 57b08141f..5f5c5c2df 100644 --- a/src/Kestrel.Core/Internal/FrameConnectionContext.cs +++ b/src/Kestrel.Core/Internal/FrameConnectionContext.cs @@ -13,12 +13,11 @@ public class FrameConnectionContext public string ConnectionId { get; set; } public long FrameConnectionId { get; set; } public ServiceContext ServiceContext { get; set; } - public List ConnectionAdapters { get; set; } + public IList ConnectionAdapters { get; set; } public PipeFactory PipeFactory { get; set; } public IPEndPoint LocalEndPoint { get; set; } public IPEndPoint RemoteEndPoint { get; set; } - - public IPipe Input { get; set; } - public IPipe Output { get; set; } + public IPipeConnection Transport { get; set; } + public IPipeConnection Application { get; set; } } } diff --git a/src/Kestrel.Core/Internal/Http/Frame.cs b/src/Kestrel.Core/Internal/Http/Frame.cs index bdced76b6..1debda07d 100644 --- a/src/Kestrel.Core/Internal/Http/Frame.cs +++ b/src/Kestrel.Core/Internal/Http/Frame.cs @@ -96,7 +96,7 @@ public Frame(FrameContext frameContext) _keepAliveTicks = ServerOptions.Limits.KeepAliveTimeout.Ticks; _requestHeadersTimeoutTicks = ServerOptions.Limits.RequestHeadersTimeout.Ticks; - Output = new OutputProducer(frameContext.Output, frameContext.ConnectionId, frameContext.ServiceContext.Log, TimeoutControl); + Output = new OutputProducer(frameContext.Application.Input, frameContext.Transport.Output, frameContext.ConnectionId, frameContext.ServiceContext.Log, TimeoutControl); RequestBodyPipe = CreateRequestBodyPipe(); } @@ -107,7 +107,7 @@ public Frame(FrameContext frameContext) private IPEndPoint RemoteEndPoint => _frameContext.RemoteEndPoint; public IFeatureCollection ConnectionFeatures { get; set; } - public IPipeReader Input => _frameContext.Input; + public IPipeReader Input => _frameContext.Transport.Input; public OutputProducer Output { get; } public ITimeoutControl TimeoutControl => _frameContext.TimeoutControl; diff --git a/src/Kestrel.Core/Internal/Http/FrameContext.cs b/src/Kestrel.Core/Internal/Http/FrameContext.cs index 67274af48..5a52c03ea 100644 --- a/src/Kestrel.Core/Internal/Http/FrameContext.cs +++ b/src/Kestrel.Core/Internal/Http/FrameContext.cs @@ -16,7 +16,7 @@ public class FrameContext public IPEndPoint RemoteEndPoint { get; set; } public IPEndPoint LocalEndPoint { get; set; } public ITimeoutControl TimeoutControl { get; set; } - public IPipeReader Input { get; set; } - public IPipe Output { get; set; } + public IPipeConnection Transport { get; set; } + public IPipeConnection Application { get; set; } } } diff --git a/src/Kestrel.Core/Internal/Http/OutputProducer.cs b/src/Kestrel.Core/Internal/Http/OutputProducer.cs index b44af9f1b..2bd85c904 100644 --- a/src/Kestrel.Core/Internal/Http/OutputProducer.cs +++ b/src/Kestrel.Core/Internal/Http/OutputProducer.cs @@ -24,7 +24,8 @@ public class OutputProducer : IDisposable private bool _completed = false; - private readonly IPipe _pipe; + private readonly IPipeWriter _pipeWriter; + private readonly IPipeReader _outputPipeReader; // https://github.com/dotnet/corefxlab/issues/1334 // Pipelines don't support multiple awaiters on flush @@ -34,12 +35,14 @@ public class OutputProducer : IDisposable private Action _flushCompleted; public OutputProducer( - IPipe pipe, + IPipeReader outputPipeReader, + IPipeWriter pipeWriter, string connectionId, IKestrelTrace log, ITimeoutControl timeoutControl) { - _pipe = pipe; + _outputPipeReader = outputPipeReader; + _pipeWriter = pipeWriter; _connectionId = connectionId; _timeoutControl = timeoutControl; _log = log; @@ -70,7 +73,7 @@ public void Write(Action callback, T state) return; } - var buffer = _pipe.Writer.Alloc(1); + var buffer = _pipeWriter.Alloc(1); callback(buffer, state); buffer.Commit(); } @@ -87,7 +90,7 @@ public void Dispose() _log.ConnectionDisconnect(_connectionId); _completed = true; - _pipe.Writer.Complete(); + _pipeWriter.Complete(); } } @@ -103,8 +106,8 @@ public void Abort(Exception error) _log.ConnectionDisconnect(_connectionId); _completed = true; - _pipe.Reader.CancelPendingRead(); - _pipe.Writer.Complete(error); + _outputPipeReader.CancelPendingRead(); + _pipeWriter.Complete(error); } } @@ -122,7 +125,7 @@ private Task WriteAsync( return Task.CompletedTask; } - writableBuffer = _pipe.Writer.Alloc(1); + writableBuffer = _pipeWriter.Alloc(1); var writer = new WritableBufferWriter(writableBuffer); if (buffer.Count > 0) { diff --git a/src/Kestrel.Core/Internal/Http2/Http2Connection.cs b/src/Kestrel.Core/Internal/Http2/Http2Connection.cs index 4e3476d10..5419790ff 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2Connection.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2Connection.cs @@ -40,13 +40,13 @@ public class Http2Connection : ITimeoutControl, IHttp2StreamLifetimeHandler public Http2Connection(Http2ConnectionContext context) { _context = context; - _frameWriter = new Http2FrameWriter(context.Output); + _frameWriter = new Http2FrameWriter(context.Transport.Output, context.Application.Input); _hpackDecoder = new HPackDecoder(); } public string ConnectionId => _context.ConnectionId; - public IPipeReader Input => _context.Input; + public IPipeReader Input => _context.Transport.Input; public IKestrelTrace Log => _context.ServiceContext.Log; diff --git a/src/Kestrel.Core/Internal/Http2/Http2ConnectionContext.cs b/src/Kestrel.Core/Internal/Http2/Http2ConnectionContext.cs index ff22a1afb..2956040b1 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2ConnectionContext.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2ConnectionContext.cs @@ -14,7 +14,7 @@ public class Http2ConnectionContext public IPEndPoint LocalEndPoint { get; set; } public IPEndPoint RemoteEndPoint { get; set; } - public IPipeReader Input { get; set; } - public IPipe Output { get; set; } + public IPipeConnection Transport { get; set; } + public IPipeConnection Application { get; set; } } } diff --git a/src/Kestrel.Core/Internal/Http2/Http2FrameWriter.cs b/src/Kestrel.Core/Internal/Http2/Http2FrameWriter.cs index 698f6a19c..6623f489b 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2FrameWriter.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2FrameWriter.cs @@ -19,13 +19,15 @@ public class Http2FrameWriter : IHttp2FrameWriter private readonly Http2Frame _outgoingFrame = new Http2Frame(); private readonly object _writeLock = new object(); private readonly HPackEncoder _hpackEncoder = new HPackEncoder(); - private readonly IPipe _output; + private readonly IPipeWriter _outputWriter; + private readonly IPipeReader _outputReader; private bool _completed; - public Http2FrameWriter(IPipe output) + public Http2FrameWriter(IPipeWriter outputPipeWriter, IPipeReader outputPipeReader) { - _output = output; + _outputWriter = outputPipeWriter; + _outputReader = outputPipeReader; } public void Abort(Exception ex) @@ -33,8 +35,8 @@ public void Abort(Exception ex) lock (_writeLock) { _completed = true; - _output.Reader.CancelPendingRead(); - _output.Writer.Complete(ex); + _outputReader.CancelPendingRead(); + _outputWriter.Complete(ex); } } @@ -173,7 +175,7 @@ public Task WriteGoAwayAsync(int lastStreamId, Http2ErrorCode errorCode) return; } - var writeableBuffer = _output.Writer.Alloc(1); + var writeableBuffer = _outputWriter.Alloc(1); writeableBuffer.Write(data); await writeableBuffer.FlushAsync(cancellationToken); } diff --git a/src/Kestrel.Core/Internal/HttpConnectionBuilderExtensions.cs b/src/Kestrel.Core/Internal/HttpConnectionBuilderExtensions.cs new file mode 100644 index 000000000..ca3ca5c9a --- /dev/null +++ b/src/Kestrel.Core/Internal/HttpConnectionBuilderExtensions.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Protocols; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public static class HttpConnectionBuilderExtensions + { + public static IConnectionBuilder UseHttpServer(this IConnectionBuilder builder, ServiceContext serviceContext, IHttpApplication application) + { + return builder.UseHttpServer(Array.Empty(), serviceContext, application); + } + + public static IConnectionBuilder UseHttpServer(this IConnectionBuilder builder, IList adapters, ServiceContext serviceContext, IHttpApplication application) + { + var middleware = new HttpConnectionMiddleware(adapters, serviceContext, application); + return builder.Use(next => + { + return middleware.OnConnectionAsync; + }); + } + } +} diff --git a/src/Kestrel.Core/Internal/HttpConnectionMiddleware.cs b/src/Kestrel.Core/Internal/HttpConnectionMiddleware.cs new file mode 100644 index 000000000..e79658de2 --- /dev/null +++ b/src/Kestrel.Core/Internal/HttpConnectionMiddleware.cs @@ -0,0 +1,91 @@ +using System.Collections.Generic; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Protocols; +using Microsoft.AspNetCore.Protocols.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + public class HttpConnectionMiddleware + { + private static long _lastFrameConnectionId = long.MinValue; + + private readonly IList _connectionAdapters; + private readonly ServiceContext _serviceContext; + private readonly IHttpApplication _application; + + public HttpConnectionMiddleware(IList adapters, ServiceContext serviceContext, IHttpApplication application) + { + _serviceContext = serviceContext; + _application = application; + + // Keeping these around for now so progress can be made without updating tests + _connectionAdapters = adapters; + } + + public Task OnConnectionAsync(ConnectionContext connectionContext) + { + // We need the transport feature so that we can cancel the output reader that the transport is using + // This is a bit of a hack but it preserves the existing semantics + var transportFeature = connectionContext.Features.Get(); + + var frameConnectionId = Interlocked.Increment(ref _lastFrameConnectionId); + + var frameConnectionContext = new FrameConnectionContext + { + ConnectionId = connectionContext.ConnectionId, + FrameConnectionId = frameConnectionId, + ServiceContext = _serviceContext, + PipeFactory = connectionContext.PipeFactory, + ConnectionAdapters = _connectionAdapters, + Transport = connectionContext.Transport, + Application = transportFeature.Application + }; + + var connectionFeature = connectionContext.Features.Get(); + + if (connectionFeature != null) + { + if (connectionFeature.LocalIpAddress != null) + { + frameConnectionContext.LocalEndPoint = new IPEndPoint(connectionFeature.LocalIpAddress, connectionFeature.LocalPort); + } + + if (connectionFeature.RemoteIpAddress != null) + { + frameConnectionContext.RemoteEndPoint = new IPEndPoint(connectionFeature.RemoteIpAddress, connectionFeature.RemotePort); + } + } + + var connection = new FrameConnection(frameConnectionContext); + + // The order here is important, start request processing so that + // the frame is created before this yields. Events need to be wired up + // afterwards + var processingTask = connection.StartRequestProcessing(_application); + + // Wire up the events an forward calls to the frame connection + // It's important that these execute synchronously because graceful + // connection close is order sensative (for now) + connectionContext.ConnectionAborted.ContinueWith((task, state) => + { + // Unwrap the aggregate exception + ((FrameConnection)state).Abort(task.Exception?.InnerException); + }, + connection, TaskContinuationOptions.ExecuteSynchronously); + + connectionContext.ConnectionClosed.ContinueWith((task, state) => + { + // Unwrap the aggregate exception + ((FrameConnection)state).OnConnectionClosed(task.Exception?.InnerException); + }, + connection, TaskContinuationOptions.ExecuteSynchronously); + + return processingTask; + } + } +} diff --git a/src/Kestrel.Core/Internal/RejectionConnection.cs b/src/Kestrel.Core/Internal/RejectionConnection.cs deleted file mode 100644 index 776a63725..000000000 --- a/src/Kestrel.Core/Internal/RejectionConnection.cs +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.IO.Pipelines; -using Microsoft.AspNetCore.Protocols.Features; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; - -namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal -{ - public class RejectionConnection : IConnectionApplicationFeature - { - private readonly IKestrelTrace _log; - private readonly IPipe _input; - private readonly IPipe _output; - - public RejectionConnection(IPipe input, IPipe output, string connectionId, ServiceContext serviceContext) - { - ConnectionId = connectionId; - _log = serviceContext.Log; - _input = input; - _output = output; - } - - public string ConnectionId { get; } - public IPipeWriter Input => _input.Writer; - public IPipeReader Output => _output.Reader; - - public IPipeConnection Connection { get; set; } - - public void Reject() - { - KestrelEventSource.Log.ConnectionRejected(ConnectionId); - _log.ConnectionRejected(ConnectionId); - _input.Reader.Complete(); - _output.Writer.Complete(); - } - - void IConnectionApplicationFeature.OnConnectionClosed(Exception ex) - { - } - - void IConnectionApplicationFeature.Abort(Exception ex) - { - } - } -} \ No newline at end of file diff --git a/src/Kestrel.Core/KestrelServer.cs b/src/Kestrel.Core/KestrelServer.cs index e52e3895c..a1f5864f9 100644 --- a/src/Kestrel.Core/KestrelServer.cs +++ b/src/Kestrel.Core/KestrelServer.cs @@ -135,7 +135,16 @@ public async Task StartAsync(IHttpApplication application, C async Task OnBind(ListenOptions endpoint) { - var connectionHandler = new ConnectionHandler(endpoint, ServiceContext, application); + // Add the connection limit middleware + endpoint.UseConnectionLimit(ServiceContext); + + // Configure the user delegate + endpoint.Configure(endpoint); + + // Add the HTTP middleware as the terminal connection middleware + endpoint.UseHttpServer(endpoint.ConnectionAdapters, ServiceContext, application); + + var connectionHandler = new ConnectionHandler(ServiceContext, endpoint.Build()); var transport = _transportFactory.Create(endpoint, connectionHandler); _transports.Add(transport); diff --git a/src/Kestrel.Core/KestrelServerOptions.cs b/src/Kestrel.Core/KestrelServerOptions.cs index 136b98353..6a3363971 100644 --- a/src/Kestrel.Core/KestrelServerOptions.cs +++ b/src/Kestrel.Core/KestrelServerOptions.cs @@ -100,8 +100,7 @@ public void Listen(IPEndPoint endPoint, Action configure) throw new ArgumentNullException(nameof(configure)); } - var listenOptions = new ListenOptions(endPoint) { KestrelServerOptions = this }; - configure(listenOptions); + var listenOptions = new ListenOptions(endPoint) { KestrelServerOptions = this, Configure = configure }; ListenOptions.Add(listenOptions); } @@ -132,8 +131,7 @@ public void ListenUnixSocket(string socketPath, Action configure) throw new ArgumentNullException(nameof(configure)); } - var listenOptions = new ListenOptions(socketPath) { KestrelServerOptions = this }; - configure(listenOptions); + var listenOptions = new ListenOptions(socketPath) { KestrelServerOptions = this, Configure = configure }; ListenOptions.Add(listenOptions); } @@ -156,8 +154,7 @@ public void ListenHandle(ulong handle, Action configure) throw new ArgumentNullException(nameof(configure)); } - var listenOptions = new ListenOptions(handle) { KestrelServerOptions = this }; - configure(listenOptions); + var listenOptions = new ListenOptions(handle) { KestrelServerOptions = this, Configure = configure }; ListenOptions.Add(listenOptions); } } diff --git a/src/Kestrel.Core/ListenOptions.cs b/src/Kestrel.Core/ListenOptions.cs index 46bae7fba..f1f66209a 100644 --- a/src/Kestrel.Core/ListenOptions.cs +++ b/src/Kestrel.Core/ListenOptions.cs @@ -5,6 +5,8 @@ using System.Collections.Generic; using System.Linq; using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Protocols; using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; @@ -14,9 +16,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core /// Describes either an , Unix domain socket path, or a file descriptor for an already open /// socket that Kestrel should bind to or open. /// - public class ListenOptions : IEndPointInformation + public class ListenOptions : IEndPointInformation, IConnectionBuilder { private FileHandleType _handleType; + private readonly List> _components = new List>(); internal ListenOptions(IPEndPoint endPoint) { @@ -126,6 +129,10 @@ public FileHandleType HandleType /// public List ConnectionAdapters { get; } = new List(); + public IServiceProvider ApplicationServices => KestrelServerOptions?.ApplicationServices; + + internal Action Configure { get; set; } = _ => { }; + /// /// Gets the name of this endpoint to display on command-line when the web server starts. /// @@ -149,5 +156,27 @@ internal string GetDisplayName() } public override string ToString() => GetDisplayName(); + + public IConnectionBuilder Use(Func middleware) + { + _components.Add(middleware); + return this; + } + + public ConnectionDelegate Build() + { + ConnectionDelegate app = context => + { + return Task.CompletedTask; + }; + + for (int i = _components.Count - 1; i >= 0; i--) + { + var component = _components[i]; + app = component(app); + } + + return app; + } } } diff --git a/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.Features.cs b/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.Features.cs index fd5d60050..7a8cd3ce2 100644 --- a/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.Features.cs +++ b/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.Features.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; using System.IO.Pipelines; using System.Net; -using System.Text; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Protocols.Features; @@ -17,12 +17,10 @@ public partial class TransportConnection : IFeatureCollection, private static readonly Type IHttpConnectionFeatureType = typeof(IHttpConnectionFeature); private static readonly Type IConnectionIdFeatureType = typeof(IConnectionIdFeature); private static readonly Type IConnectionTransportFeatureType = typeof(IConnectionTransportFeature); - private static readonly Type IConnectionApplicationFeatureType = typeof(IConnectionApplicationFeature); private object _currentIHttpConnectionFeature; private object _currentIConnectionIdFeature; private object _currentIConnectionTransportFeature; - private object _currentIConnectionApplicationFeature; private int _featureRevision; @@ -99,12 +97,28 @@ int IHttpConnectionFeature.LocalPort PipeFactory IConnectionTransportFeature.PipeFactory => PipeFactory; - IPipeConnection IConnectionTransportFeature.Connection + IPipeConnection IConnectionTransportFeature.Transport { get => Transport; set => Transport = value; } + IPipeConnection IConnectionTransportFeature.Application + { + get => Application; + set => Application = value; + } + + Task IConnectionTransportFeature.ConnectionAborted + { + get => _abortTcs.Task; + } + + Task IConnectionTransportFeature.ConnectionClosed + { + get => _closedTcs.Task; + } + object IFeatureCollection.this[Type key] { get => FastFeatureGet(key); @@ -142,11 +156,6 @@ private object FastFeatureGet(Type key) return _currentIConnectionTransportFeature; } - if (key == IConnectionApplicationFeatureType) - { - return _currentIConnectionApplicationFeature; - } - return ExtraFeatureGet(key); } @@ -172,12 +181,6 @@ private void FastFeatureSet(Type key, object feature) return; } - if (key == IConnectionApplicationFeatureType) - { - _currentIConnectionApplicationFeature = feature; - return; - } - ExtraFeatureSet(key, feature); } @@ -198,11 +201,6 @@ private IEnumerable> FastEnumerable() yield return new KeyValuePair(IConnectionTransportFeatureType, _currentIConnectionTransportFeature); } - if (_currentIConnectionApplicationFeature != null) - { - yield return new KeyValuePair(IConnectionApplicationFeatureType, _currentIConnectionApplicationFeature); - } - if (MaybeExtra != null) { foreach (var item in MaybeExtra) diff --git a/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.cs b/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.cs index a499dc8e6..a447fae39 100644 --- a/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.cs +++ b/src/Kestrel.Transport.Abstractions/Internal/TransportConnection.cs @@ -1,14 +1,15 @@ using System; -using System.Collections.Generic; using System.IO.Pipelines; using System.Net; -using System.Text; -using Microsoft.AspNetCore.Protocols.Features; +using System.Threading.Tasks; namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal { public abstract partial class TransportConnection { + private readonly TaskCompletionSource _abortTcs = new TaskCompletionSource(); + private readonly TaskCompletionSource _closedTcs = new TaskCompletionSource(); + public TransportConnection() { _currentIConnectionIdFeature = this; @@ -28,6 +29,30 @@ public TransportConnection() public virtual IScheduler OutputReaderScheduler { get; } public IPipeConnection Transport { get; set; } - public IConnectionApplicationFeature Application => (IConnectionApplicationFeature)_currentIConnectionApplicationFeature; + public IPipeConnection Application { get; set; } + + protected void Abort(Exception exception) + { + if (exception == null) + { + _abortTcs.TrySetResult(null); + } + else + { + _abortTcs.TrySetException(exception); + } + } + + protected void Close(Exception exception) + { + if (exception == null) + { + _closedTcs.TrySetResult(null); + } + else + { + _closedTcs.TrySetException(exception); + } + } } } diff --git a/src/Kestrel.Transport.Libuv/Internal/LibuvConnection.cs b/src/Kestrel.Transport.Libuv/Internal/LibuvConnection.cs index db47620cd..1ed08a309 100644 --- a/src/Kestrel.Transport.Libuv/Internal/LibuvConnection.cs +++ b/src/Kestrel.Transport.Libuv/Internal/LibuvConnection.cs @@ -46,8 +46,8 @@ public LibuvConnection(ListenerContext context, UvStreamHandle socket) : base(co } } - public IPipeWriter Input => Application.Connection.Output; - public IPipeReader Output => Application.Connection.Input; + public IPipeWriter Input => Application.Output; + public IPipeReader Output => Application.Input; public LibuvOutputConsumer OutputConsumer { get; set; } @@ -83,7 +83,7 @@ public async Task Start() // Now, complete the input so that no more reads can happen Input.Complete(error ?? new ConnectionAbortedException()); Output.Complete(error); - Application.OnConnectionClosed(error); + Close(error); // Make sure it isn't possible for a paused read to resume reading after calling uv_close // on the stream handle @@ -178,7 +178,7 @@ private void OnRead(UvStreamHandle handle, int status) } } - Application.Abort(error); + Abort(error); // Complete after aborting the connection Input.Complete(error); } @@ -216,7 +216,7 @@ private void StartReading() Log.ConnectionReadFin(ConnectionId); var error = new IOException(ex.Message, ex); - Application.Abort(error); + Abort(error); Input.Complete(error); } } diff --git a/src/Kestrel.Transport.Sockets/SocketConnection.cs b/src/Kestrel.Transport.Sockets/SocketConnection.cs index 86f5df0d4..e39e605d6 100644 --- a/src/Kestrel.Transport.Sockets/SocketConnection.cs +++ b/src/Kestrel.Transport.Sockets/SocketConnection.cs @@ -48,8 +48,8 @@ public async Task StartAsync(IConnectionHandler connectionHandler) { connectionHandler.OnConnection(this); - _input = Application.Connection.Output; - _output = Application.Connection.Input; + _input = Application.Output; + _output = Application.Input; // Spawn send and receive logic Task receiveTask = DoReceive(); @@ -135,7 +135,7 @@ private async Task DoReceive() } finally { - Application.Abort(error); + Abort(error); _input.Complete(error); } } @@ -229,7 +229,7 @@ private async Task DoSend() } finally { - Application.OnConnectionClosed(error); + Close(error); _output.Complete(error); } } diff --git a/src/Protocols.Abstractions/ConnectionContext.cs b/src/Protocols.Abstractions/ConnectionContext.cs index ebbcb458a..c10a884f4 100644 --- a/src/Protocols.Abstractions/ConnectionContext.cs +++ b/src/Protocols.Abstractions/ConnectionContext.cs @@ -1,5 +1,7 @@ using System; using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Protocols @@ -13,5 +15,9 @@ public abstract class ConnectionContext public abstract IPipeConnection Transport { get; set; } public abstract PipeFactory PipeFactory { get; } + + public abstract Task ConnectionAborted { get; } + + public abstract Task ConnectionClosed { get; } } } diff --git a/src/Protocols.Abstractions/DefaultConnectionContext.cs b/src/Protocols.Abstractions/DefaultConnectionContext.cs index 0a759630b..814d4470f 100644 --- a/src/Protocols.Abstractions/DefaultConnectionContext.cs +++ b/src/Protocols.Abstractions/DefaultConnectionContext.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.IO.Pipelines; using System.Text; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Protocols.Features; @@ -34,10 +35,14 @@ public override string ConnectionId public override IPipeConnection Transport { - get => ConnectionTransportFeature.Connection; - set => ConnectionTransportFeature.Connection = value; + get => ConnectionTransportFeature.Transport; + set => ConnectionTransportFeature.Transport = value; } + public override Task ConnectionAborted => ConnectionTransportFeature.ConnectionAborted; + + public override Task ConnectionClosed => ConnectionTransportFeature.ConnectionClosed; + struct FeatureInterfaces { public IConnectionIdFeature ConnectionId; diff --git a/src/Protocols.Abstractions/Features/IConnectionApplicationFeature.cs b/src/Protocols.Abstractions/Features/IConnectionApplicationFeature.cs deleted file mode 100644 index a610cd06d..000000000 --- a/src/Protocols.Abstractions/Features/IConnectionApplicationFeature.cs +++ /dev/null @@ -1,18 +0,0 @@ -using System; -using System.IO.Pipelines; - -namespace Microsoft.AspNetCore.Protocols.Features -{ - public interface IConnectionApplicationFeature - { - IPipeConnection Connection { get; set; } - - // TODO: Remove these (https://github.com/aspnet/KestrelHttpServer/issues/1772) - // REVIEW: These are around for now because handling pipe events messes with the order - // of operations an that breaks tons of tests. Instead, we preserve the existing semantics - // and ordering. - void Abort(Exception exception); - - void OnConnectionClosed(Exception exception); - } -} diff --git a/src/Protocols.Abstractions/Features/IConnectionTransportFeature.cs b/src/Protocols.Abstractions/Features/IConnectionTransportFeature.cs index 463797265..e9cccb4a3 100644 --- a/src/Protocols.Abstractions/Features/IConnectionTransportFeature.cs +++ b/src/Protocols.Abstractions/Features/IConnectionTransportFeature.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.IO.Pipelines; using System.Text; +using System.Threading.Tasks; namespace Microsoft.AspNetCore.Protocols.Features { @@ -9,10 +10,16 @@ public interface IConnectionTransportFeature { PipeFactory PipeFactory { get; } - IPipeConnection Connection { get; set; } + IPipeConnection Transport { get; set; } + + IPipeConnection Application { get; set; } IScheduler InputWriterScheduler { get; } IScheduler OutputReaderScheduler { get; } + + Task ConnectionAborted { get; } + + Task ConnectionClosed { get; } } } diff --git a/src/Protocols.Abstractions/PipeFactoryExtensions.cs b/src/Protocols.Abstractions/PipeFactoryExtensions.cs new file mode 100644 index 000000000..9ded8a8f9 --- /dev/null +++ b/src/Protocols.Abstractions/PipeFactoryExtensions.cs @@ -0,0 +1,21 @@ +namespace System.IO.Pipelines +{ + public static class PipeFactoryExtensions + { + public static (IPipeConnection Transport, IPipeConnection Application) CreateConnectionPair(this PipeFactory pipeFactory) + { + return pipeFactory.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + } + + public static (IPipeConnection Transport, IPipeConnection Application) CreateConnectionPair(this PipeFactory pipeFactory, PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = pipeFactory.Create(inputOptions); + var output = pipeFactory.Create(outputOptions); + + var transportToApplication = new PipeConnection(output.Reader, input.Writer); + var applicationToTransport = new PipeConnection(input.Reader, output.Writer); + + return (applicationToTransport, transportToApplication); + } + } +} diff --git a/test/Kestrel.Core.Tests/FrameConnectionTests.cs b/test/Kestrel.Core.Tests/FrameConnectionTests.cs index 189187222..e77de9b87 100644 --- a/test/Kestrel.Core.Tests/FrameConnectionTests.cs +++ b/test/Kestrel.Core.Tests/FrameConnectionTests.cs @@ -25,6 +25,7 @@ public class FrameConnectionTests : IDisposable public FrameConnectionTests() { _pipeFactory = new PipeFactory(); + var pair = _pipeFactory.CreateConnectionPair(); _frameConnectionContext = new FrameConnectionContext { @@ -32,8 +33,8 @@ public FrameConnectionTests() ConnectionAdapters = new List(), PipeFactory = _pipeFactory, FrameConnectionId = long.MinValue, - Input = _pipeFactory.Create(), - Output = _pipeFactory.Create(), + Application = pair.Application, + Transport = pair.Transport, ServiceContext = new TestServiceContext { SystemClock = new SystemClock() @@ -54,7 +55,7 @@ public void DoesNotTimeOutWhenDebuggerIsAttached() var mockDebugger = new Mock(); mockDebugger.SetupGet(g => g.IsAttached).Returns(true); _frameConnection.Debugger = mockDebugger.Object; - _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Input.Reader, _frameConnectionContext.Output); + _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Transport, _frameConnectionContext.Application); var now = DateTimeOffset.Now; _frameConnection.Tick(now); @@ -101,7 +102,7 @@ private void TickBodyWithMinimumDataRate(IKestrelTrace logger, int bytesPerSecon _frameConnectionContext.ServiceContext.Log = logger; - _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Input.Reader, _frameConnectionContext.Output); + _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Transport, _frameConnectionContext.Application); _frameConnection.Frame.Reset(); // Initialize timestamp @@ -128,7 +129,7 @@ public void RequestBodyMinimumDataRateNotEnforcedDuringGracePeriod() var mockLogger = new Mock(); _frameConnectionContext.ServiceContext.Log = mockLogger.Object; - _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Input.Reader, _frameConnectionContext.Output); + _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Transport, _frameConnectionContext.Application); _frameConnection.Frame.Reset(); // Initialize timestamp @@ -170,7 +171,7 @@ public void RequestBodyDataRateIsAveragedOverTimeSpentReadingRequestBody() var mockLogger = new Mock(); _frameConnectionContext.ServiceContext.Log = mockLogger.Object; - _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Input.Reader, _frameConnectionContext.Output); + _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Transport, _frameConnectionContext.Application); _frameConnection.Frame.Reset(); // Initialize timestamp @@ -247,7 +248,7 @@ public void RequestBodyDataRateNotComputedOnPausedTime() var mockLogger = new Mock(); _frameConnectionContext.ServiceContext.Log = mockLogger.Object; - _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Input.Reader, _frameConnectionContext.Output); + _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Transport, _frameConnectionContext.Application); _frameConnection.Frame.Reset(); // Initialize timestamp @@ -315,7 +316,7 @@ public void ReadTimingNotPausedWhenResumeCalledBeforeNextTick() var mockLogger = new Mock(); _frameConnectionContext.ServiceContext.Log = mockLogger.Object; - _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Input.Reader, _frameConnectionContext.Output); + _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Transport, _frameConnectionContext.Application); _frameConnection.Frame.Reset(); // Initialize timestamp @@ -377,7 +378,7 @@ public void ReadTimingNotEnforcedWhenTimeoutIsSet() var mockLogger = new Mock(); _frameConnectionContext.ServiceContext.Log = mockLogger.Object; - _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Input.Reader, _frameConnectionContext.Output); + _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Transport, _frameConnectionContext.Application); _frameConnection.Frame.Reset(); var startTime = systemClock.UtcNow; @@ -418,7 +419,7 @@ public void WriteTimingAbortsConnectionWhenWriteDoesNotCompleteWithMinimumDataRa var mockLogger = new Mock(); _frameConnectionContext.ServiceContext.Log = mockLogger.Object; - _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Input.Reader, _frameConnectionContext.Output); + _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Transport, _frameConnectionContext.Application); _frameConnection.Frame.Reset(); _frameConnection.Frame.RequestAborted.Register(() => { @@ -452,7 +453,7 @@ public void WriteTimingAbortsConnectionWhenSmallWriteDoesNotCompleteWithinGraceP var mockLogger = new Mock(); _frameConnectionContext.ServiceContext.Log = mockLogger.Object; - _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Input.Reader, _frameConnectionContext.Output); + _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Transport, _frameConnectionContext.Application); _frameConnection.Frame.Reset(); _frameConnection.Frame.RequestAborted.Register(() => { @@ -494,7 +495,7 @@ public void WriteTimingTimeoutPushedOnConcurrentWrite() var mockLogger = new Mock(); _frameConnectionContext.ServiceContext.Log = mockLogger.Object; - _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Input.Reader, _frameConnectionContext.Output); + _frameConnection.CreateFrame(new DummyApplication(), _frameConnectionContext.Transport, _frameConnectionContext.Application); _frameConnection.Frame.Reset(); _frameConnection.Frame.RequestAborted.Register(() => { @@ -531,7 +532,7 @@ public void WriteTimingTimeoutPushedOnConcurrentWrite() [Fact] public async Task StartRequestProcessingCreatesLogScopeWithConnectionId() { - _frameConnection.StartRequestProcessing(new DummyApplication()); + _ = _frameConnection.StartRequestProcessing(new DummyApplication()); var scopeObjects = ((TestKestrelTrace)_frameConnectionContext.ServiceContext.Log) .Logger diff --git a/test/Kestrel.Core.Tests/FrameResponseHeadersTests.cs b/test/Kestrel.Core.Tests/FrameResponseHeadersTests.cs index 8cd807bc1..7c3dfa0f0 100644 --- a/test/Kestrel.Core.Tests/FrameResponseHeadersTests.cs +++ b/test/Kestrel.Core.Tests/FrameResponseHeadersTests.cs @@ -18,10 +18,14 @@ public class FrameResponseHeadersTests [Fact] public void InitialDictionaryIsEmpty() { + var factory = new PipeFactory(); + var pair = factory.CreateConnectionPair(); var frameContext = new FrameContext { ServiceContext = new TestServiceContext(), - PipeFactory = new PipeFactory(), + PipeFactory = factory, + Application = pair.Application, + Transport = pair.Transport, TimeoutControl = null }; diff --git a/test/Kestrel.Core.Tests/FrameTests.cs b/test/Kestrel.Core.Tests/FrameTests.cs index d5f14333b..2910bedc1 100644 --- a/test/Kestrel.Core.Tests/FrameTests.cs +++ b/test/Kestrel.Core.Tests/FrameTests.cs @@ -27,7 +27,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests { public class FrameTests : IDisposable { - private readonly IPipe _input; + private readonly IPipeConnection _transport; + private readonly IPipeConnection _application; private readonly TestFrame _frame; private readonly ServiceContext _serviceContext; private readonly FrameContext _frameContext; @@ -52,8 +53,10 @@ public Task ProduceEndAsync() public FrameTests() { _pipelineFactory = new PipeFactory(); - _input = _pipelineFactory.Create(); - var output = _pipelineFactory.Create(); + var pair = _pipelineFactory.CreateConnectionPair(); + + _transport = pair.Transport; + _application = pair.Application; _serviceContext = new TestServiceContext(); _timeoutControl = new Mock(); @@ -62,8 +65,8 @@ public FrameTests() ServiceContext = _serviceContext, PipeFactory = _pipelineFactory, TimeoutControl = _timeoutControl.Object, - Input = _input.Reader, - Output = output + Application = pair.Application, + Transport = pair.Transport }; _frame = new TestFrame(application: null, context: _frameContext); @@ -72,8 +75,12 @@ public FrameTests() public void Dispose() { - _input.Reader.Complete(); - _input.Writer.Complete(); + _transport.Input.Complete(); + _application.Output.Complete(); + + _application.Input.Complete(); + _application.Output.Complete(); + _pipelineFactory.Dispose(); } @@ -84,11 +91,11 @@ public async Task TakeMessageHeadersThrowsWhenHeadersExceedTotalSizeLimit() _serviceContext.ServerOptions.Limits.MaxRequestHeadersTotalSize = headerLine.Length - 1; _frame.Reset(); - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine}\r\n")); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine}\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var exception = Assert.Throws(() => _frame.TakeMessageHeaders(readableBuffer, out _consumed, out _examined)); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.Equal(CoreStrings.BadRequest_HeadersExceedMaxTotalSize, exception.Message); Assert.Equal(StatusCodes.Status431RequestHeaderFieldsTooLarge, exception.StatusCode); @@ -100,11 +107,11 @@ public async Task TakeMessageHeadersThrowsWhenHeadersExceedCountLimit() const string headerLines = "Header-1: value1\r\nHeader-2: value2\r\n"; _serviceContext.ServerOptions.Limits.MaxRequestHeaderCount = 1; - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes($"{headerLines}\r\n")); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"{headerLines}\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var exception = Assert.Throws(() => _frame.TakeMessageHeaders(readableBuffer, out _consumed, out _examined)); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.Equal(CoreStrings.BadRequest_TooManyHeaders, exception.Message); Assert.Equal(StatusCodes.Status431RequestHeaderFieldsTooLarge, exception.StatusCode); @@ -205,11 +212,11 @@ public async Task ResetResetsHeaderLimits() options.Limits.MaxRequestHeaderCount = 1; _serviceContext.ServerOptions = options; - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine1}\r\n")); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine1}\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var takeMessageHeaders = _frame.TakeMessageHeaders(readableBuffer, out _consumed, out _examined); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.True(takeMessageHeaders); Assert.Equal(1, _frame.RequestHeaders.Count); @@ -217,11 +224,11 @@ public async Task ResetResetsHeaderLimits() _frame.Reset(); - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine2}\r\n")); - readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"{headerLine2}\r\n")); + readableBuffer = (await _transport.Input.ReadAsync()).Buffer; takeMessageHeaders = _frame.TakeMessageHeaders(readableBuffer, out _consumed, out _examined); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.True(takeMessageHeaders); Assert.Equal(1, _frame.RequestHeaders.Count); @@ -332,8 +339,8 @@ public async Task TakeStartLineSetsFrameProperties( string requestLine, string expectedMethod, string expectedRawTarget, -// This warns that theory methods should use all of their parameters, -// but this method is using a shared data collection with HttpParserTests.ParsesRequestLine and others. + // This warns that theory methods should use all of their parameters, + // but this method is using a shared data collection with HttpParserTests.ParsesRequestLine and others. #pragma warning disable xUnit1026 string expectedRawPath, #pragma warning restore xUnit1026 @@ -342,11 +349,11 @@ public async Task TakeStartLineSetsFrameProperties( string expectedHttpVersion) { var requestLineBytes = Encoding.ASCII.GetBytes(requestLine); - await _input.Writer.WriteAsync(requestLineBytes); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(requestLineBytes); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var returnValue = _frame.TakeStartLine(readableBuffer, out _consumed, out _examined); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.True(returnValue); Assert.Equal(expectedMethod, _frame.Method); @@ -365,11 +372,11 @@ public async Task TakeStartLineRemovesDotSegmentsFromTarget( string expectedQueryString) { var requestLineBytes = Encoding.ASCII.GetBytes(requestLine); - await _input.Writer.WriteAsync(requestLineBytes); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(requestLineBytes); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var returnValue = _frame.TakeStartLine(readableBuffer, out _consumed, out _examined); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.True(returnValue); Assert.Equal(expectedRawTarget, _frame.RawTarget); @@ -380,10 +387,10 @@ public async Task TakeStartLineRemovesDotSegmentsFromTarget( [Fact] public async Task ParseRequestStartsRequestHeadersTimeoutOnFirstByteAvailable() { - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes("G")); + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("G")); - _frame.ParseRequest((await _input.Reader.ReadAsync()).Buffer, out _consumed, out _examined); - _input.Reader.Advance(_consumed, _examined); + _frame.ParseRequest((await _transport.Input.ReadAsync()).Buffer, out _consumed, out _examined); + _transport.Input.Advance(_consumed, _examined); var expectedRequestHeadersTimeout = _serviceContext.ServerOptions.Limits.RequestHeadersTimeout.Ticks; _timeoutControl.Verify(cc => cc.ResetTimeout(expectedRequestHeadersTimeout, TimeoutAction.SendTimeoutResponse)); @@ -395,11 +402,11 @@ public async Task TakeStartLineThrowsWhenTooLong() _serviceContext.ServerOptions.Limits.MaxRequestLineSize = "GET / HTTP/1.1\r\n".Length; var requestLineBytes = Encoding.ASCII.GetBytes("GET /a HTTP/1.1\r\n"); - await _input.Writer.WriteAsync(requestLineBytes); + await _application.Output.WriteAsync(requestLineBytes); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out _consumed, out _examined)); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.Equal(CoreStrings.BadRequest_RequestLineTooLong, exception.Message); Assert.Equal(StatusCodes.Status414UriTooLong, exception.StatusCode); @@ -409,12 +416,12 @@ public async Task TakeStartLineThrowsWhenTooLong() [MemberData(nameof(TargetWithEncodedNullCharData))] public async Task TakeStartLineThrowsOnEncodedNullCharInTarget(string target) { - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes($"GET {target} HTTP/1.1\r\n")); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"GET {target} HTTP/1.1\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out _consumed, out _examined)); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(target), exception.Message); } @@ -423,12 +430,12 @@ public async Task TakeStartLineThrowsOnEncodedNullCharInTarget(string target) [MemberData(nameof(TargetWithNullCharData))] public async Task TakeStartLineThrowsOnNullCharInTarget(string target) { - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes($"GET {target} HTTP/1.1\r\n")); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"GET {target} HTTP/1.1\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out _consumed, out _examined)); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(target.EscapeNonPrintable()), exception.Message); } @@ -439,12 +446,12 @@ public async Task TakeStartLineThrowsOnNullCharInMethod(string method) { var requestLine = $"{method} / HTTP/1.1\r\n"; - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes(requestLine)); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(requestLine)); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out _consumed, out _examined)); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestLine_Detail(requestLine.EscapeNonPrintable()), exception.Message); } @@ -455,12 +462,12 @@ public async Task TakeStartLineThrowsOnNullCharInQueryString(string queryString) { var target = $"/{queryString}"; - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes($"GET {target} HTTP/1.1\r\n")); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"GET {target} HTTP/1.1\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out _consumed, out _examined)); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(target.EscapeNonPrintable()), exception.Message); } @@ -471,12 +478,12 @@ public async Task TakeStartLineThrowsWhenRequestTargetIsInvalid(string method, s { var requestLine = $"{method} {target} HTTP/1.1\r\n"; - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes(requestLine)); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(requestLine)); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out _consumed, out _examined)); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(target.EscapeNonPrintable()), exception.Message); } @@ -485,12 +492,12 @@ public async Task TakeStartLineThrowsWhenRequestTargetIsInvalid(string method, s [MemberData(nameof(MethodNotAllowedTargetData))] public async Task TakeStartLineThrowsWhenMethodNotAllowed(string requestLine, HttpMethod allowedMethod) { - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes(requestLine)); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(requestLine)); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out _consumed, out _examined)); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.Equal(405, exception.StatusCode); Assert.Equal(CoreStrings.BadRequest_MethodNotAllowed, exception.Message); @@ -506,7 +513,7 @@ public void ProcessRequestsAsyncEnablesKeepAliveTimeout() _timeoutControl.Verify(cc => cc.SetTimeout(expectedKeepAliveTimeout, TimeoutAction.CloseConnection)); _frame.Stop(); - _input.Writer.Complete(); + _application.Output.Complete(); requestProcessingTask.Wait(); } @@ -588,13 +595,13 @@ public async Task RequestProcessingTaskIsUnwrapped() var requestProcessingTask = _frame.ProcessRequestsAsync(); var data = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n"); - await _input.Writer.WriteAsync(data); + await _application.Output.WriteAsync(data); _frame.Stop(); Assert.IsNotType>(requestProcessingTask); await requestProcessingTask.TimeoutAfter(TimeSpan.FromSeconds(10)); - _input.Writer.Complete(); + _application.Output.Complete(); } [Fact] @@ -684,12 +691,12 @@ public async Task ExceptionDetailNotIncludedWhenLogLevelInformationNotEnabled() _serviceContext.Log = mockTrace.Object; - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes($"GET /%00 HTTP/1.1\r\n")); - var readableBuffer = (await _input.Reader.ReadAsync()).Buffer; + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes($"GET /%00 HTTP/1.1\r\n")); + var readableBuffer = (await _transport.Input.ReadAsync()).Buffer; var exception = Assert.Throws(() => _frame.TakeStartLine(readableBuffer, out _consumed, out _examined)); - _input.Reader.Advance(_consumed, _examined); + _transport.Input.Advance(_consumed, _examined); Assert.Equal(CoreStrings.FormatBadRequest_InvalidRequestTarget_Detail(string.Empty), exception.Message); Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); @@ -716,19 +723,19 @@ public async Task AcceptsHeadersAcrossSends(int header0Count, int header1Count) var requestProcessingTask = _frame.ProcessRequestsAsync(); - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes("GET / HTTP/1.0\r\n")); + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("GET / HTTP/1.0\r\n")); await WaitForCondition(TimeSpan.FromSeconds(1), () => _frame.RequestHeaders != null); Assert.Equal(0, _frame.RequestHeaders.Count); - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes(headers0)); + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(headers0)); await WaitForCondition(TimeSpan.FromSeconds(1), () => _frame.RequestHeaders.Count >= header0Count); Assert.Equal(header0Count, _frame.RequestHeaders.Count); - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes(headers1)); + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(headers1)); await WaitForCondition(TimeSpan.FromSeconds(1), () => _frame.RequestHeaders.Count >= header0Count + header1Count); Assert.Equal(header0Count + header1Count, _frame.RequestHeaders.Count); - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes("\r\n")); + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("\r\n")); Assert.Equal(header0Count + header1Count, _frame.RequestHeaders.Count); await requestProcessingTask.TimeoutAfter(TimeSpan.FromSeconds(10)); @@ -750,7 +757,7 @@ public async Task KeepsSameHeaderCollectionAcrossSends(int header0Count, int hea var requestProcessingTask = _frame.ProcessRequestsAsync(); - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes("GET / HTTP/1.0\r\n")); + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("GET / HTTP/1.0\r\n")); await WaitForCondition(TimeSpan.FromSeconds(1), () => _frame.RequestHeaders != null); Assert.Equal(0, _frame.RequestHeaders.Count); @@ -758,17 +765,17 @@ public async Task KeepsSameHeaderCollectionAcrossSends(int header0Count, int hea _frame.RequestHeaders = newRequestHeaders; Assert.Same(newRequestHeaders, _frame.RequestHeaders); - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes(headers0)); + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(headers0)); await WaitForCondition(TimeSpan.FromSeconds(1), () => _frame.RequestHeaders.Count >= header0Count); Assert.Same(newRequestHeaders, _frame.RequestHeaders); Assert.Equal(header0Count, _frame.RequestHeaders.Count); - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes(headers1)); + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes(headers1)); await WaitForCondition(TimeSpan.FromSeconds(1), () => _frame.RequestHeaders.Count >= header0Count + header1Count); Assert.Same(newRequestHeaders, _frame.RequestHeaders); Assert.Equal(header0Count + header1Count, _frame.RequestHeaders.Count); - await _input.Writer.WriteAsync(Encoding.ASCII.GetBytes("\r\n")); + await _application.Output.WriteAsync(Encoding.ASCII.GetBytes("\r\n")); Assert.Same(newRequestHeaders, _frame.RequestHeaders); Assert.Equal(header0Count + header1Count, _frame.RequestHeaders.Count); diff --git a/test/Kestrel.Core.Tests/Http2ConnectionTests.cs b/test/Kestrel.Core.Tests/Http2ConnectionTests.cs index 0e1e4668f..67743cad1 100644 --- a/test/Kestrel.Core.Tests/Http2ConnectionTests.cs +++ b/test/Kestrel.Core.Tests/Http2ConnectionTests.cs @@ -73,8 +73,7 @@ public class Http2ConnectionTests : IDisposable private static readonly byte[] _noData = new byte[0]; private readonly PipeFactory _pipeFactory = new PipeFactory(); - private readonly IPipe _inputPipe; - private readonly IPipe _outputPipe; + private readonly (IPipeConnection Transport, IPipeConnection Application) _pair; private readonly Http2ConnectionContext _connectionContext; private readonly Http2Connection _connection; private readonly HPackEncoder _hpackEncoder = new HPackEncoder(); @@ -99,8 +98,7 @@ public class Http2ConnectionTests : IDisposable public Http2ConnectionTests() { - _inputPipe = _pipeFactory.Create(); - _outputPipe = _pipeFactory.Create(); + _pair = _pipeFactory.CreateConnectionPair(); _noopApplication = context => Task.CompletedTask; @@ -213,8 +211,8 @@ public Http2ConnectionTests() { ServiceContext = new TestServiceContext(), PipeFactory = _pipeFactory, - Input = _inputPipe.Reader, - Output = _outputPipe + Application = _pair.Application, + Transport = _pair.Transport }; _connection = new Http2Connection(_connectionContext); } @@ -1201,7 +1199,7 @@ private Task WaitForAllStreamsAsync() private async Task SendAsync(ArraySegment span) { - var writableBuffer = _inputPipe.Writer.Alloc(1); + var writableBuffer = _pair.Application.Output.Alloc(1); writableBuffer.Write(span); await writableBuffer.FlushAsync(); } @@ -1413,7 +1411,7 @@ private async Task ReceiveFrameAsync() while (true) { - var result = await _outputPipe.Reader.ReadAsync(); + var result = await _pair.Application.Input.ReadAsync(); var buffer = result.Buffer; var consumed = buffer.Start; var examined = buffer.End; @@ -1429,7 +1427,7 @@ private async Task ReceiveFrameAsync() } finally { - _outputPipe.Reader.Advance(consumed, examined); + _pair.Application.Input.Advance(consumed, examined); } } } @@ -1456,7 +1454,7 @@ private async Task ExpectAsync(Http2FrameType type, int withLength, private Task StopConnectionAsync(int expectedLastStreamId, bool ignoreNonGoAwayFrames) { - _inputPipe.Writer.Complete(); + _pair.Application.Output.Complete(); return WaitForConnectionStopAsync(expectedLastStreamId, ignoreNonGoAwayFrames); } @@ -1486,7 +1484,7 @@ private async Task WaitForConnectionErrorAsync(int expectedLastStreamId, Http2Er Assert.Equal(expectedErrorCode, frame.GoAwayErrorCode); await _connectionTask; - _inputPipe.Writer.Complete(); + _pair.Application.Output.Complete(); } private async Task WaitForStreamErrorAsync(int expectedStreamId, Http2ErrorCode expectedErrorCode, bool ignoreNonRstStreamFrames) diff --git a/test/Kestrel.Core.Tests/KestrelServerOptionsTests.cs b/test/Kestrel.Core.Tests/KestrelServerOptionsTests.cs index f0209cbea..b34f92a1d 100644 --- a/test/Kestrel.Core.Tests/KestrelServerOptionsTests.cs +++ b/test/Kestrel.Core.Tests/KestrelServerOptionsTests.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Net; -using Microsoft.AspNetCore.Server.Kestrel.Core; using Xunit; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests @@ -19,6 +18,9 @@ public void NoDelayDefaultsToTrue() d.NoDelay = false; }); + // Execute the callback + o1.ListenOptions[1].Configure(o1.ListenOptions[1]); + Assert.True(o1.ListenOptions[0].NoDelay); Assert.False(o1.ListenOptions[1].NoDelay); } diff --git a/test/Kestrel.Core.Tests/MessageBodyTests.cs b/test/Kestrel.Core.Tests/MessageBodyTests.cs index c6d228330..0d5a11ee2 100644 --- a/test/Kestrel.Core.Tests/MessageBodyTests.cs +++ b/test/Kestrel.Core.Tests/MessageBodyTests.cs @@ -445,7 +445,7 @@ public async Task CopyToAsyncDoesNotCopyBlocks(FrameRequestHeaders headers, stri // The block returned by IncomingStart always has at least 2048 available bytes, // so no need to bounds check in this test. var bytes = Encoding.ASCII.GetBytes(data[0]); - var buffer = input.Pipe.Writer.Alloc(2048); + var buffer = input.Application.Output.Alloc(2048); ArraySegment block; Assert.True(buffer.Buffer.TryGetArray(out block)); Buffer.BlockCopy(bytes, 0, block.Array, block.Offset, bytes.Length); @@ -457,7 +457,7 @@ public async Task CopyToAsyncDoesNotCopyBlocks(FrameRequestHeaders headers, stri writeTcs = new TaskCompletionSource(); bytes = Encoding.ASCII.GetBytes(data[1]); - buffer = input.Pipe.Writer.Alloc(2048); + buffer = input.Application.Output.Alloc(2048); Assert.True(buffer.Buffer.TryGetArray(out block)); Buffer.BlockCopy(bytes, 0, block.Array, block.Offset, bytes.Length); buffer.Advance(bytes.Length); @@ -467,7 +467,7 @@ public async Task CopyToAsyncDoesNotCopyBlocks(FrameRequestHeaders headers, stri if (headers.HeaderConnection == "close") { - input.Pipe.Writer.Complete(); + input.Application.Output.Complete(); } await copyToAsyncTask; @@ -516,7 +516,7 @@ public async Task PumpAsyncDoesNotReturnAfterCancelingInput() input.Add("a"); Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); - input.Pipe.Reader.CancelPendingRead(); + input.Transport.Input.CancelPendingRead(); // Add more input and verify is read input.Add("b"); diff --git a/test/Kestrel.Core.Tests/OutputProducerTests.cs b/test/Kestrel.Core.Tests/OutputProducerTests.cs index d539a4012..f95b13052 100644 --- a/test/Kestrel.Core.Tests/OutputProducerTests.cs +++ b/test/Kestrel.Core.Tests/OutputProducerTests.cs @@ -55,7 +55,8 @@ private OutputProducer CreateOutputProducer(PipeOptions pipeOptions) var pipe = _pipeFactory.Create(pipeOptions); var serviceContext = new TestServiceContext(); var socketOutput = new OutputProducer( - pipe, + pipe.Reader, + pipe.Writer, "0", serviceContext.Log, Mock.Of()); diff --git a/test/Kestrel.Core.Tests/PipeOptionsTests.cs b/test/Kestrel.Core.Tests/PipeOptionsTests.cs index d35ac3c8a..2d0db7c07 100644 --- a/test/Kestrel.Core.Tests/PipeOptionsTests.cs +++ b/test/Kestrel.Core.Tests/PipeOptionsTests.cs @@ -1,7 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; using System.IO.Pipelines; +using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Testing; @@ -22,9 +24,8 @@ public void OutputPipeOptionsConfiguredCorrectly(long? maxResponseBufferSize, lo serviceContext.ServerOptions.Limits.MaxResponseBufferSize = maxResponseBufferSize; serviceContext.ThreadPool = new LoggingThreadPool(null); - var connectionHandler = new ConnectionHandler(listenOptions: null, serviceContext: serviceContext, application: null); var mockScheduler = Mock.Of(); - var outputPipeOptions = connectionHandler.GetOutputPipeOptions(readerScheduler: mockScheduler); + var outputPipeOptions = ConnectionHandler.GetOutputPipeOptions(serviceContext, readerScheduler: mockScheduler); Assert.Equal(expectedMaximumSizeLow, outputPipeOptions.MaximumSizeLow); Assert.Equal(expectedMaximumSizeHigh, outputPipeOptions.MaximumSizeHigh); @@ -41,9 +42,8 @@ public void InputPipeOptionsConfiguredCorrectly(long? maxRequestBufferSize, long serviceContext.ServerOptions.Limits.MaxRequestBufferSize = maxRequestBufferSize; serviceContext.ThreadPool = new LoggingThreadPool(null); - var connectionHandler = new ConnectionHandler(listenOptions: null, serviceContext: serviceContext, application: null); var mockScheduler = Mock.Of(); - var inputPipeOptions = connectionHandler.GetInputPipeOptions(writerScheduler: mockScheduler); + var inputPipeOptions = ConnectionHandler.GetInputPipeOptions(serviceContext, writerScheduler: mockScheduler); Assert.Equal(expectedMaximumSizeLow, inputPipeOptions.MaximumSizeLow); Assert.Equal(expectedMaximumSizeHigh, inputPipeOptions.MaximumSizeHigh); diff --git a/test/Kestrel.Core.Tests/TestInput.cs b/test/Kestrel.Core.Tests/TestInput.cs index 018dcb111..220f9393c 100644 --- a/test/Kestrel.Core.Tests/TestInput.cs +++ b/test/Kestrel.Core.Tests/TestInput.cs @@ -20,12 +20,15 @@ public TestInput() { _memoryPool = new MemoryPool(); _pipelineFactory = new PipeFactory(); - Pipe = _pipelineFactory.Create(); + var pair = _pipelineFactory.CreateConnectionPair(); + Transport = pair.Transport; + Application = pair.Application; FrameContext = new FrameContext { ServiceContext = new TestServiceContext(), - Input = Pipe.Reader, + Application = Application, + Transport = Transport, PipeFactory = _pipelineFactory, TimeoutControl = Mock.Of() }; @@ -34,28 +37,30 @@ public TestInput() Frame.FrameControl = Mock.Of(); } - public IPipe Pipe { get; } + public IPipeConnection Transport { get; } + + public IPipeConnection Application { get; } public PipeFactory PipeFactory => _pipelineFactory; - public FrameContext FrameContext { get; } + public FrameContext FrameContext { get; } public Frame Frame { get; set; } public void Add(string text) { var data = Encoding.ASCII.GetBytes(text); - Pipe.Writer.WriteAsync(data).Wait(); + Application.Output.WriteAsync(data).Wait(); } public void Fin() { - Pipe.Writer.Complete(); + Application.Output.Complete(); } public void Cancel() { - Pipe.Reader.CancelPendingRead(); + Transport.Input.CancelPendingRead(); } public void Dispose() diff --git a/test/Kestrel.FunctionalTests/RequestTests.cs b/test/Kestrel.FunctionalTests/RequestTests.cs index 017787772..3a55e9a4c 100644 --- a/test/Kestrel.FunctionalTests/RequestTests.cs +++ b/test/Kestrel.FunctionalTests/RequestTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Concurrent; +using System.Diagnostics; using System.Globalization; using System.IO; using System.Linq; @@ -37,7 +38,7 @@ public class RequestTests { private const int _connectionStartedEventId = 1; private const int _connectionResetEventId = 19; - private const int _semaphoreWaitTimeout = 2500; + private static readonly int _semaphoreWaitTimeout = Debugger.IsAttached ? 10000 : 2500; private readonly ITestOutputHelper _output; diff --git a/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs b/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs index 5bf0ff1b3..cd604be6f 100644 --- a/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs +++ b/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs @@ -683,7 +683,7 @@ public async Task WritesAreAggregated(long? maxResponseBufferSize) private OutputProducer CreateOutputProducer(PipeOptions pipeOptions, CancellationTokenSource cts = null) { - var pipe = _pipeFactory.Create(pipeOptions); + var pair = _pipeFactory.CreateConnectionPair(pipeOptions, pipeOptions); var logger = new TestApplicationErrorLogger(); var serviceContext = new TestServiceContext @@ -694,14 +694,15 @@ private OutputProducer CreateOutputProducer(PipeOptions pipeOptions, Cancellatio var transportContext = new TestLibuvTransportContext { Log = new LibuvTrace(logger) }; var socket = new MockSocket(_mockLibuv, _libuvThread.Loop.ThreadId, transportContext.Log); - var consumer = new LibuvOutputConsumer(pipe.Reader, _libuvThread, socket, "0", transportContext.Log); + var consumer = new LibuvOutputConsumer(pair.Application.Input, _libuvThread, socket, "0", transportContext.Log); var frame = new Frame(null, new FrameContext { ServiceContext = serviceContext, PipeFactory = _pipeFactory, TimeoutControl = Mock.Of(), - Output = pipe + Application = pair.Application, + Transport = pair.Transport }); if (cts != null) @@ -709,7 +710,7 @@ private OutputProducer CreateOutputProducer(PipeOptions pipeOptions, Cancellatio frame.RequestAborted.Register(cts.Cancel); } - var ignore = WriteOutputAsync(consumer, pipe.Reader, frame); + var ignore = WriteOutputAsync(consumer, pair.Application.Input, frame); return frame.Output; } diff --git a/test/Kestrel.Transport.Libuv.Tests/LibuvTransportTests.cs b/test/Kestrel.Transport.Libuv.Tests/LibuvTransportTests.cs index 33c971a4b..6f3e9126a 100644 --- a/test/Kestrel.Transport.Libuv.Tests/LibuvTransportTests.cs +++ b/test/Kestrel.Transport.Libuv.Tests/LibuvTransportTests.cs @@ -53,10 +53,14 @@ public async Task TransportCanBindUnbindAndStop(ListenOptions listenOptions) [MemberData(nameof(ConnectionAdapterData))] public async Task ConnectionCanReadAndWrite(ListenOptions listenOptions) { + var serviceContext = new TestServiceContext(); + listenOptions.UseHttpServer(listenOptions.ConnectionAdapters, serviceContext, new DummyApplication(TestApp.EchoApp)); + var transportContext = new TestLibuvTransportContext() { - ConnectionHandler = new ConnectionHandler(listenOptions, new TestServiceContext(), new DummyApplication(TestApp.EchoApp)) + ConnectionHandler = new ConnectionHandler(serviceContext, listenOptions.Build()) }; + var transport = new LibuvTransport(transportContext, listenOptions); await transport.BindAsync(); diff --git a/test/Kestrel.Transport.Libuv.Tests/ListenerPrimaryTests.cs b/test/Kestrel.Transport.Libuv.Tests/ListenerPrimaryTests.cs index 3d4eea164..f751fa2de 100644 --- a/test/Kestrel.Transport.Libuv.Tests/ListenerPrimaryTests.cs +++ b/test/Kestrel.Transport.Libuv.Tests/ListenerPrimaryTests.cs @@ -2,11 +2,13 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.IO; using System.Linq; using System.Net; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Protocols; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; @@ -25,17 +27,20 @@ public class ListenerPrimaryTests public async Task ConnectionsGetRoundRobinedToSecondaryListeners() { var libuv = new LibuvFunctions(); + var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); var serviceContextPrimary = new TestServiceContext(); var transportContextPrimary = new TestLibuvTransportContext(); - transportContextPrimary.ConnectionHandler = new ConnectionHandler( - listenOptions, serviceContextPrimary, new DummyApplication(c => c.Response.WriteAsync("Primary"))); + var builderPrimary = new ConnectionBuilder(); + builderPrimary.UseHttpServer(serviceContextPrimary, new DummyApplication(c => c.Response.WriteAsync("Primary"))); + transportContextPrimary.ConnectionHandler = new ConnectionHandler(serviceContextPrimary, builderPrimary.Build()); var serviceContextSecondary = new TestServiceContext(); + var builderSecondary = new ConnectionBuilder(); + builderSecondary.UseHttpServer(serviceContextSecondary, new DummyApplication(c => c.Response.WriteAsync("Secondary"))); var transportContextSecondary = new TestLibuvTransportContext(); - transportContextSecondary.ConnectionHandler = new ConnectionHandler( - listenOptions, serviceContextSecondary, new DummyApplication(c => c.Response.WriteAsync("Secondary"))); + transportContextSecondary.ConnectionHandler = new ConnectionHandler(serviceContextSecondary, builderSecondary.Build()); var libuvTransport = new LibuvTransport(libuv, transportContextPrimary, listenOptions); @@ -92,13 +97,13 @@ public async Task NonListenerPipeConnectionsAreLoggedAndIgnored() { var libuv = new LibuvFunctions(); var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); - var logger = new TestApplicationErrorLogger(); var serviceContextPrimary = new TestServiceContext(); + var builderPrimary = new ConnectionBuilder(); + builderPrimary.UseHttpServer(serviceContextPrimary, new DummyApplication(c => c.Response.WriteAsync("Primary"))); var transportContextPrimary = new TestLibuvTransportContext() { Log = new LibuvTrace(logger) }; - transportContextPrimary.ConnectionHandler = new ConnectionHandler( - listenOptions, serviceContextPrimary, new DummyApplication(c => c.Response.WriteAsync("Primary"))); + transportContextPrimary.ConnectionHandler = new ConnectionHandler(serviceContextPrimary, builderPrimary.Build()); var serviceContextSecondary = new TestServiceContext { @@ -107,9 +112,10 @@ public async Task NonListenerPipeConnectionsAreLoggedAndIgnored() ThreadPool = serviceContextPrimary.ThreadPool, HttpParserFactory = serviceContextPrimary.HttpParserFactory, }; + var builderSecondary = new ConnectionBuilder(); + builderSecondary.UseHttpServer(serviceContextSecondary, new DummyApplication(c => c.Response.WriteAsync("Secondary"))); var transportContextSecondary = new TestLibuvTransportContext(); - transportContextSecondary.ConnectionHandler = new ConnectionHandler( - listenOptions, serviceContextSecondary, new DummyApplication(c => c.Response.WriteAsync("Secondary"))); + transportContextSecondary.ConnectionHandler = new ConnectionHandler(serviceContextSecondary, builderSecondary.Build()); var libuvTransport = new LibuvTransport(libuv, transportContextPrimary, listenOptions); @@ -205,9 +211,10 @@ public async Task PipeConnectionsWithWrongMessageAreLoggedAndIgnored() var logger = new TestApplicationErrorLogger(); var serviceContextPrimary = new TestServiceContext(); + var builderPrimary = new ConnectionBuilder(); + builderPrimary.UseHttpServer(serviceContextPrimary, new DummyApplication(c => c.Response.WriteAsync("Primary"))); var transportContextPrimary = new TestLibuvTransportContext() { Log = new LibuvTrace(logger) }; - transportContextPrimary.ConnectionHandler = new ConnectionHandler( - listenOptions, serviceContextPrimary, new DummyApplication(c => c.Response.WriteAsync("Primary"))); + transportContextPrimary.ConnectionHandler = new ConnectionHandler(serviceContextPrimary, builderPrimary.Build()); var serviceContextSecondary = new TestServiceContext { @@ -216,9 +223,10 @@ public async Task PipeConnectionsWithWrongMessageAreLoggedAndIgnored() ThreadPool = serviceContextPrimary.ThreadPool, HttpParserFactory = serviceContextPrimary.HttpParserFactory, }; + var builderSecondary = new ConnectionBuilder(); + builderSecondary.UseHttpServer(serviceContextSecondary, new DummyApplication(c => c.Response.WriteAsync("Secondary"))); var transportContextSecondary = new TestLibuvTransportContext(); - transportContextSecondary.ConnectionHandler = new ConnectionHandler( - listenOptions, serviceContextSecondary, new DummyApplication(c => c.Response.WriteAsync("Secondary"))); + transportContextSecondary.ConnectionHandler = new ConnectionHandler(serviceContextSecondary, builderSecondary.Build()); var libuvTransport = new LibuvTransport(libuv, transportContextPrimary, listenOptions); @@ -300,5 +308,34 @@ private static Uri GetUri(ListenOptions options) return new Uri($"{scheme}://{options.IPEndPoint}"); } + + private class ConnectionBuilder : IConnectionBuilder + { + private readonly List> _components = new List>(); + + public IServiceProvider ApplicationServices { get; set; } + + public IConnectionBuilder Use(Func middleware) + { + _components.Add(middleware); + return this; + } + + public ConnectionDelegate Build() + { + ConnectionDelegate app = context => + { + return Task.CompletedTask; + }; + + for (int i = _components.Count - 1; i >= 0; i--) + { + var component = _components[i]; + app = component(app); + } + + return app; + } + } } } diff --git a/test/Kestrel.Transport.Libuv.Tests/TestHelpers/MockConnectionHandler.cs b/test/Kestrel.Transport.Libuv.Tests/TestHelpers/MockConnectionHandler.cs index 5b3480b0e..ef856124f 100644 --- a/test/Kestrel.Transport.Libuv.Tests/TestHelpers/MockConnectionHandler.cs +++ b/test/Kestrel.Transport.Libuv.Tests/TestHelpers/MockConnectionHandler.cs @@ -22,29 +22,13 @@ public void OnConnection(IFeatureCollection features) Input = connectionContext.PipeFactory.Create(InputOptions ?? new PipeOptions()); Output = connectionContext.PipeFactory.Create(OutputOptions ?? new PipeOptions()); - var context = new TestConnectionContext - { - Connection = new PipeConnection(Output.Reader, Input.Writer) - }; + var feature = connectionContext.Features.Get(); - connectionContext.Features.Set(context); + connectionContext.Transport = new PipeConnection(Input.Reader, Output.Writer); + feature.Application = new PipeConnection(Output.Reader, Input.Writer); } public IPipe Input { get; private set; } public IPipe Output { get; private set; } - - private class TestConnectionContext : IConnectionApplicationFeature - { - public string ConnectionId { get; } - public IPipeConnection Connection { get; set; } - - public void Abort(Exception ex) - { - } - - public void OnConnectionClosed(Exception ex) - { - } - } } }