Skip to content

Add option to restrict the maximum hub message size #8135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ public HubOptions() { }
public bool? EnableDetailedErrors { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public System.TimeSpan? HandshakeTimeout { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public System.TimeSpan? KeepAliveInterval { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public long? MaximumReceiveMessageSize { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public System.Collections.Generic.IList<string> SupportedProtocols { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
}
public partial class HubOptions<THub> : Microsoft.AspNetCore.SignalR.HubOptions where THub : Microsoft.AspNetCore.SignalR.Hub
Expand Down
46 changes: 43 additions & 3 deletions src/SignalR/server/Core/src/HubConnectionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class HubConnectionHandler<THub> : ConnectionHandler where THub : Hub
private readonly IUserIdProvider _userIdProvider;
private readonly HubDispatcher<THub> _dispatcher;
private readonly bool _enableDetailedErrors;
private readonly long? _maximumMessageSize;

/// <summary>
/// Initializes a new instance of the <see cref="HubConnectionHandler{THub}"/> class.
Expand Down Expand Up @@ -61,6 +62,7 @@ HubDispatcher<THub> dispatcher
_dispatcher = dispatcher;

_enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _globalHubOptions.EnableDetailedErrors ?? false;
_maximumMessageSize = _hubOptions.MaximumReceiveMessageSize ?? _globalHubOptions.MaximumReceiveMessageSize;
}

/// <inheritdoc />
Expand All @@ -69,7 +71,7 @@ public override async Task OnConnectedAsync(ConnectionContext connection)
// We check to see if HubOptions<THub> are set because those take precedence over global hub options.
// Then set the keepAlive and handshakeTimeout values to the defaults in HubOptionsSetup incase they were explicitly set to null.
var keepAlive = _hubOptions.KeepAliveInterval ?? _globalHubOptions.KeepAliveInterval ?? HubOptionsSetup.DefaultKeepAliveInterval;
var clientTimeout = _hubOptions.ClientTimeoutInterval ?? _globalHubOptions.ClientTimeoutInterval ?? HubOptionsSetup.DefaultClientTimeoutInterval;
var clientTimeout = _hubOptions.ClientTimeoutInterval ?? _globalHubOptions.ClientTimeoutInterval ?? HubOptionsSetup.DefaultClientTimeoutInterval;
var handshakeTimeout = _hubOptions.HandshakeTimeout ?? _globalHubOptions.HandshakeTimeout ?? HubOptionsSetup.DefaultHandshakeTimeout;
var supportedProtocols = _hubOptions.SupportedProtocols ?? _globalHubOptions.SupportedProtocols;

Expand Down Expand Up @@ -205,9 +207,47 @@ private async Task DispatchMessagesAsync(HubConnectionContext connection)
{
connection.ResetClientTimeout();

while (protocol.TryParseMessage(ref buffer, binder, out var message))
// No message limit, just parse and dispatch
if (_maximumMessageSize == null)
{
await _dispatcher.DispatchMessageAsync(connection, message);
while (protocol.TryParseMessage(ref buffer, binder, out var message))
{
await _dispatcher.DispatchMessageAsync(connection, message);
}
}
else
{
// We give the parser a sliding window of the default message size
var maxMessageSize = _maximumMessageSize.Value;

while (!buffer.IsEmpty)
{
var segment = buffer;
var overLength = false;

if (segment.Length > maxMessageSize)
{
segment = segment.Slice(segment.Start, maxMessageSize);
overLength = true;
}

if (protocol.TryParseMessage(ref segment, binder, out var message))
{
await _dispatcher.DispatchMessageAsync(connection, message);
}
else if (overLength)
{
throw new InvalidDataException($"The maximum message size of {maxMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
}
else
{
// No need to update the buffer since we didn't parse anything
break;
}

// Update the buffer to the remaining segment
buffer = buffer.Slice(segment.Start);
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions src/SignalR/server/Core/src/HubOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public class HubOptions
/// </summary>
public IList<string> SupportedProtocols { get; set; } = null;

/// <summary>
/// Gets or sets the maximum message size of a single incoming hub message. The default is 32KB.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Null means no maximum message size" or is that implicit with nullable types

/// </summary>
public long? MaximumReceiveMessageSize { get; set; } = null;

/// <summary>
/// Gets or sets a value indicating whether detailed error messages are sent to the client.
/// Detailed error messages include details from exceptions thrown on the server.
Expand Down
9 changes: 8 additions & 1 deletion src/SignalR/server/Core/src/Internal/HubOptionsSetup.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
Expand All @@ -16,6 +16,8 @@ public class HubOptionsSetup : IConfigureOptions<HubOptions>

internal static TimeSpan DefaultClientTimeoutInterval => TimeSpan.FromSeconds(30);

internal const int DefaultMaximumMessageSize = 32 * 1024 * 1024;

private readonly List<string> _defaultProtocols = new List<string>();

public HubOptionsSetup(IEnumerable<IHubProtocol> protocols)
Expand All @@ -40,6 +42,11 @@ public void Configure(HubOptions options)
options.HandshakeTimeout = DefaultHandshakeTimeout;
}

if (options.MaximumReceiveMessageSize == null)
{
options.MaximumReceiveMessageSize = DefaultMaximumMessageSize;
}

if (options.SupportedProtocols == null)
{
options.SupportedProtocols = new List<string>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public static IServiceProvider CreateServiceProvider(Action<ServiceCollection> a
return services.BuildServiceProvider();
}

public static Connections.ConnectionHandler GetHubConnectionHandler(Type hubType, Action<ServiceCollection> addServices = null, ILoggerFactory loggerFactory = null)
public static Connections.ConnectionHandler GetHubConnectionHandler(Type hubType, ILoggerFactory loggerFactory = null, Action<ServiceCollection> addServices = null)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lambdas should be last.

{
var serviceProvider = CreateServiceProvider(addServices, loggerFactory);
return (Connections.ConnectionHandler)serviceProvider.GetService(GetConnectionHandlerType(hubType));
Expand Down
192 changes: 189 additions & 3 deletions src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Security.Claims;
using System.Text;
Expand Down Expand Up @@ -452,6 +453,191 @@ public async Task HandshakeSuccessSendsResponseWithoutError()
}
}

[Fact]
public async Task HubMessageOverTheMaxMessageSizeThrows()
{
var payload = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e");
var maximumMessageSize = payload.Length - 10;
InvalidDataException exception = null;

bool ExpectedErrors(WriteContext writeContext)
{
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
{
exception = ide;
return true;
}
return false;
}

using (StartVerifiableLog(ExpectedErrors))
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);

await client.Connection.Application.Output.WriteAsync(payload);

client.Dispose();

await connectionHandlerTask.OrTimeout();
}
}

Assert.NotNull(exception);
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
}

[Fact]
public async Task ChunkedHubMessageOverTheMaxMessageSizeThrows()
{
var payload = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e");
var maximumMessageSize = payload.Length - 10;
InvalidDataException exception = null;

bool ExpectedErrors(WriteContext writeContext)
{
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
{
exception = ide;
return true;
}
return false;
}

using (StartVerifiableLog(ExpectedErrors))
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);

await client.Connection.Application.Output.WriteAsync(payload.AsMemory(0, payload.Length / 2));
await client.Connection.Application.Output.WriteAsync(payload.AsMemory(payload.Length / 2));

client.Dispose();

await connectionHandlerTask.OrTimeout();
}
}

Assert.NotNull(exception);
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
}

[Fact]
public async Task ManyHubMessagesOneOverTheMaxMessageSizeThrows()
{
var payload1 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"one\"]}\u001e");
var payload2 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"2\", \"target\": \"Echo\", \"arguments\":[\"two\"]}\u001e");
var payload3 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"3\", \"target\": \"Echo\", \"arguments\":[\"three\"]}\u001e");

// Between the first and the second payload so we'll end up slicing with some remaining in the slice for
// the next message
var maximumMessageSize = payload1.Length + 1;
InvalidDataException exception = null;

bool ExpectedErrors(WriteContext writeContext)
{
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
{
exception = ide;
return true;
}
return false;
}

using (StartVerifiableLog(ExpectedErrors))
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);

client.Connection.Application.Output.Write(payload1);
client.Connection.Application.Output.Write(payload2);
client.Connection.Application.Output.Write(payload3);
await client.Connection.Application.Output.FlushAsync();

// 2 invocations should be processed
var completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
Assert.NotNull(completionMessage);
Assert.Equal("1", completionMessage.InvocationId);
Assert.Equal("one", completionMessage.Result);

completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
Assert.NotNull(completionMessage);
Assert.Equal("2", completionMessage.InvocationId);
Assert.Equal("two", completionMessage.Result);

// We never receive the 3rd message since it was over the maximum message size
CloseMessage closeMessage = await client.ReadAsync().OrTimeout() as CloseMessage;
Assert.NotNull(closeMessage);

client.Dispose();

await connectionHandlerTask.OrTimeout();
}
}

Assert.NotNull(exception);
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
}

[Fact]
public async Task ManyHubMessagesUnderTheMessageSizeButConfiguredWithMax()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good test 👍

{
var payload1 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"one\"]}\u001e");
var payload2 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"2\", \"target\": \"Echo\", \"arguments\":[\"two\"]}\u001e");
var payload3 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"3\", \"target\": \"Echo\", \"arguments\":[\"three\"]}\u001e");

// Bigger than all 3 messages
var maximumMessageSize = payload3.Length + 10;

using (StartVerifiableLog())
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);

client.Connection.Application.Output.Write(payload1);
client.Connection.Application.Output.Write(payload2);
client.Connection.Application.Output.Write(payload3);
await client.Connection.Application.Output.FlushAsync();

// 2 invocations should be processed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 3

var completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
Assert.NotNull(completionMessage);
Assert.Equal("1", completionMessage.InvocationId);
Assert.Equal("one", completionMessage.Result);

completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
Assert.NotNull(completionMessage);
Assert.Equal("2", completionMessage.InvocationId);
Assert.Equal("two", completionMessage.Result);

completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
Assert.NotNull(completionMessage);
Assert.Equal("3", completionMessage.InvocationId);
Assert.Equal("three", completionMessage.Result);

client.Dispose();

await connectionHandlerTask.OrTimeout();
}
}
}

[Fact]
public async Task HandshakeFailureFromIncompatibleProtocolVersionSendsResponseWithError()
{
Expand Down Expand Up @@ -2789,7 +2975,7 @@ public async Task UploadManyStreams()

foreach (string id in ids)
{
await client.BeginUploadStreamAsync("invocation_"+id, nameof(MethodHub.StreamingConcat), new[] { id }, Array.Empty<object>());
await client.BeginUploadStreamAsync("invocation_" + id, nameof(MethodHub.StreamingConcat), new[] { id }, Array.Empty<object>());
}

var words = new[] { "zygapophyses", "qwerty", "abcd" };
Expand Down Expand Up @@ -2868,7 +3054,7 @@ public async Task UploadStreamItemInvalidTypeAutoCasts()
}
}
}

[Fact]
public async Task ServerReportsProtocolMinorVersion()
{
Expand All @@ -2881,7 +3067,7 @@ public async Task ServerReportsProtocolMinorVersion()
testProtocol.Setup(m => m.TransferFormat).Returns(TransferFormat.Binary);

var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT),
(services) => services.AddSingleton<IHubProtocol>(testProtocol.Object), LoggerFactory);
LoggerFactory, (services) => services.AddSingleton<IHubProtocol>(testProtocol.Object));

using (var client = new TestClient(protocol: testProtocol.Object))
{
Expand Down