diff --git a/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp3.0.cs b/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp3.0.cs index 3e5dea838a4a..7e8e10911cc2 100644 --- a/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp3.0.cs +++ b/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp3.0.cs @@ -196,6 +196,7 @@ public HubOptions() { } public bool? EnableDetailedErrors { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.TimeSpan? HandshakeTimeout { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.TimeSpan? KeepAliveInterval { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public long? MaximumReceiveMessageSize { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.Collections.Generic.IList SupportedProtocols { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } } public partial class HubOptions : Microsoft.AspNetCore.SignalR.HubOptions where THub : Microsoft.AspNetCore.SignalR.Hub diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 85e765571c58..30c007354720 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -28,6 +28,7 @@ public class HubConnectionHandler : ConnectionHandler where THub : Hub private readonly IUserIdProvider _userIdProvider; private readonly HubDispatcher _dispatcher; private readonly bool _enableDetailedErrors; + private readonly long? _maximumMessageSize; /// /// Initializes a new instance of the class. @@ -61,6 +62,7 @@ HubDispatcher dispatcher _dispatcher = dispatcher; _enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _globalHubOptions.EnableDetailedErrors ?? false; + _maximumMessageSize = _hubOptions.MaximumReceiveMessageSize ?? _globalHubOptions.MaximumReceiveMessageSize; } /// @@ -69,7 +71,7 @@ public override async Task OnConnectedAsync(ConnectionContext connection) // We check to see if HubOptions are set because those take precedence over global hub options. // Then set the keepAlive and handshakeTimeout values to the defaults in HubOptionsSetup incase they were explicitly set to null. var keepAlive = _hubOptions.KeepAliveInterval ?? _globalHubOptions.KeepAliveInterval ?? HubOptionsSetup.DefaultKeepAliveInterval; - var clientTimeout = _hubOptions.ClientTimeoutInterval ?? _globalHubOptions.ClientTimeoutInterval ?? HubOptionsSetup.DefaultClientTimeoutInterval; + var clientTimeout = _hubOptions.ClientTimeoutInterval ?? _globalHubOptions.ClientTimeoutInterval ?? HubOptionsSetup.DefaultClientTimeoutInterval; var handshakeTimeout = _hubOptions.HandshakeTimeout ?? _globalHubOptions.HandshakeTimeout ?? HubOptionsSetup.DefaultHandshakeTimeout; var supportedProtocols = _hubOptions.SupportedProtocols ?? _globalHubOptions.SupportedProtocols; @@ -205,9 +207,47 @@ private async Task DispatchMessagesAsync(HubConnectionContext connection) { connection.ResetClientTimeout(); - while (protocol.TryParseMessage(ref buffer, binder, out var message)) + // No message limit, just parse and dispatch + if (_maximumMessageSize == null) { - await _dispatcher.DispatchMessageAsync(connection, message); + while (protocol.TryParseMessage(ref buffer, binder, out var message)) + { + await _dispatcher.DispatchMessageAsync(connection, message); + } + } + else + { + // We give the parser a sliding window of the default message size + var maxMessageSize = _maximumMessageSize.Value; + + while (!buffer.IsEmpty) + { + var segment = buffer; + var overLength = false; + + if (segment.Length > maxMessageSize) + { + segment = segment.Slice(segment.Start, maxMessageSize); + overLength = true; + } + + if (protocol.TryParseMessage(ref segment, binder, out var message)) + { + await _dispatcher.DispatchMessageAsync(connection, message); + } + else if (overLength) + { + throw new InvalidDataException($"The maximum message size of {maxMessageSize}B was exceeded. The message size can be configured in AddHubOptions."); + } + else + { + // No need to update the buffer since we didn't parse anything + break; + } + + // Update the buffer to the remaining segment + buffer = buffer.Slice(segment.Start); + } } } diff --git a/src/SignalR/server/Core/src/HubOptions.cs b/src/SignalR/server/Core/src/HubOptions.cs index 25d997dc4d83..a4238750cbab 100644 --- a/src/SignalR/server/Core/src/HubOptions.cs +++ b/src/SignalR/server/Core/src/HubOptions.cs @@ -36,6 +36,11 @@ public class HubOptions /// public IList SupportedProtocols { get; set; } = null; + /// + /// Gets or sets the maximum message size of a single incoming hub message. The default is 32KB. + /// + public long? MaximumReceiveMessageSize { get; set; } = null; + /// /// Gets or sets a value indicating whether detailed error messages are sent to the client. /// Detailed error messages include details from exceptions thrown on the server. diff --git a/src/SignalR/server/Core/src/Internal/HubOptionsSetup.cs b/src/SignalR/server/Core/src/Internal/HubOptionsSetup.cs index ed6498b964a5..de98df4c60e4 100644 --- a/src/SignalR/server/Core/src/Internal/HubOptionsSetup.cs +++ b/src/SignalR/server/Core/src/Internal/HubOptionsSetup.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -16,6 +16,8 @@ public class HubOptionsSetup : IConfigureOptions internal static TimeSpan DefaultClientTimeoutInterval => TimeSpan.FromSeconds(30); + internal const int DefaultMaximumMessageSize = 32 * 1024 * 1024; + private readonly List _defaultProtocols = new List(); public HubOptionsSetup(IEnumerable protocols) @@ -40,6 +42,11 @@ public void Configure(HubOptions options) options.HandshakeTimeout = DefaultHandshakeTimeout; } + if (options.MaximumReceiveMessageSize == null) + { + options.MaximumReceiveMessageSize = DefaultMaximumMessageSize; + } + if (options.SupportedProtocols == null) { options.SupportedProtocols = new List(); diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Utils.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Utils.cs index 6f9055d4c998..5861cab6ea6f 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Utils.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Utils.cs @@ -72,7 +72,7 @@ public static IServiceProvider CreateServiceProvider(Action a return services.BuildServiceProvider(); } - public static Connections.ConnectionHandler GetHubConnectionHandler(Type hubType, Action addServices = null, ILoggerFactory loggerFactory = null) + public static Connections.ConnectionHandler GetHubConnectionHandler(Type hubType, ILoggerFactory loggerFactory = null, Action addServices = null) { var serviceProvider = CreateServiceProvider(addServices, loggerFactory); return (Connections.ConnectionHandler)serviceProvider.GetService(GetConnectionHandlerType(hubType)); diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 817e2aa343a2..affc5108598e 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -5,6 +5,7 @@ using System.Buffers; using System.Collections.Generic; using System.Diagnostics; +using System.IO; using System.Linq; using System.Security.Claims; using System.Text; @@ -452,6 +453,191 @@ public async Task HandshakeSuccessSendsResponseWithoutError() } } + [Fact] + public async Task HubMessageOverTheMaxMessageSizeThrows() + { + var payload = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e"); + var maximumMessageSize = payload.Length - 10; + InvalidDataException exception = null; + + bool ExpectedErrors(WriteContext writeContext) + { + if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide)) + { + exception = ide; + return true; + } + return false; + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory, + services => services.AddSignalR().AddHubOptions(o => o.MaximumReceiveMessageSize = maximumMessageSize)); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await client.Connection.Application.Output.WriteAsync(payload); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + + Assert.NotNull(exception); + Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions."); + } + + [Fact] + public async Task ChunkedHubMessageOverTheMaxMessageSizeThrows() + { + var payload = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e"); + var maximumMessageSize = payload.Length - 10; + InvalidDataException exception = null; + + bool ExpectedErrors(WriteContext writeContext) + { + if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide)) + { + exception = ide; + return true; + } + return false; + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory, + services => services.AddSignalR().AddHubOptions(o => o.MaximumReceiveMessageSize = maximumMessageSize)); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await client.Connection.Application.Output.WriteAsync(payload.AsMemory(0, payload.Length / 2)); + await client.Connection.Application.Output.WriteAsync(payload.AsMemory(payload.Length / 2)); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + + Assert.NotNull(exception); + Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions."); + } + + [Fact] + public async Task ManyHubMessagesOneOverTheMaxMessageSizeThrows() + { + var payload1 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"one\"]}\u001e"); + var payload2 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"2\", \"target\": \"Echo\", \"arguments\":[\"two\"]}\u001e"); + var payload3 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"3\", \"target\": \"Echo\", \"arguments\":[\"three\"]}\u001e"); + + // Between the first and the second payload so we'll end up slicing with some remaining in the slice for + // the next message + var maximumMessageSize = payload1.Length + 1; + InvalidDataException exception = null; + + bool ExpectedErrors(WriteContext writeContext) + { + if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide)) + { + exception = ide; + return true; + } + return false; + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory, + services => services.AddSignalR().AddHubOptions(o => o.MaximumReceiveMessageSize = maximumMessageSize)); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + client.Connection.Application.Output.Write(payload1); + client.Connection.Application.Output.Write(payload2); + client.Connection.Application.Output.Write(payload3); + await client.Connection.Application.Output.FlushAsync(); + + // 2 invocations should be processed + var completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage; + Assert.NotNull(completionMessage); + Assert.Equal("1", completionMessage.InvocationId); + Assert.Equal("one", completionMessage.Result); + + completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage; + Assert.NotNull(completionMessage); + Assert.Equal("2", completionMessage.InvocationId); + Assert.Equal("two", completionMessage.Result); + + // We never receive the 3rd message since it was over the maximum message size + CloseMessage closeMessage = await client.ReadAsync().OrTimeout() as CloseMessage; + Assert.NotNull(closeMessage); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + + Assert.NotNull(exception); + Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions."); + } + + [Fact] + public async Task ManyHubMessagesUnderTheMessageSizeButConfiguredWithMax() + { + var payload1 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"one\"]}\u001e"); + var payload2 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"2\", \"target\": \"Echo\", \"arguments\":[\"two\"]}\u001e"); + var payload3 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"3\", \"target\": \"Echo\", \"arguments\":[\"three\"]}\u001e"); + + // Bigger than all 3 messages + var maximumMessageSize = payload3.Length + 10; + + using (StartVerifiableLog()) + { + var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory, + services => services.AddSignalR().AddHubOptions(o => o.MaximumReceiveMessageSize = maximumMessageSize)); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + client.Connection.Application.Output.Write(payload1); + client.Connection.Application.Output.Write(payload2); + client.Connection.Application.Output.Write(payload3); + await client.Connection.Application.Output.FlushAsync(); + + // 2 invocations should be processed + var completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage; + Assert.NotNull(completionMessage); + Assert.Equal("1", completionMessage.InvocationId); + Assert.Equal("one", completionMessage.Result); + + completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage; + Assert.NotNull(completionMessage); + Assert.Equal("2", completionMessage.InvocationId); + Assert.Equal("two", completionMessage.Result); + + completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage; + Assert.NotNull(completionMessage); + Assert.Equal("3", completionMessage.InvocationId); + Assert.Equal("three", completionMessage.Result); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + [Fact] public async Task HandshakeFailureFromIncompatibleProtocolVersionSendsResponseWithError() { @@ -2789,7 +2975,7 @@ public async Task UploadManyStreams() foreach (string id in ids) { - await client.BeginUploadStreamAsync("invocation_"+id, nameof(MethodHub.StreamingConcat), new[] { id }, Array.Empty()); + await client.BeginUploadStreamAsync("invocation_" + id, nameof(MethodHub.StreamingConcat), new[] { id }, Array.Empty()); } var words = new[] { "zygapophyses", "qwerty", "abcd" }; @@ -2868,7 +3054,7 @@ public async Task UploadStreamItemInvalidTypeAutoCasts() } } } - + [Fact] public async Task ServerReportsProtocolMinorVersion() { @@ -2881,7 +3067,7 @@ public async Task ServerReportsProtocolMinorVersion() testProtocol.Setup(m => m.TransferFormat).Returns(TransferFormat.Binary); var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), - (services) => services.AddSingleton(testProtocol.Object), LoggerFactory); + LoggerFactory, (services) => services.AddSingleton(testProtocol.Object)); using (var client = new TestClient(protocol: testProtocol.Object)) {