Skip to content

Commit 9cb1185

Browse files
authored
Add option to restrict the maximum hub message size (#8135)
- This change moves the limit checking from the transport layer to the protocol parsing layer. One nice side effect is that it gives us better control over error handling.
1 parent 67d339e commit 9cb1185

File tree

6 files changed

+247
-8
lines changed

6 files changed

+247
-8
lines changed

src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp3.0.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ public HubOptions() { }
196196
public bool? EnableDetailedErrors { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
197197
public System.TimeSpan? HandshakeTimeout { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
198198
public System.TimeSpan? KeepAliveInterval { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
199+
public long? MaximumReceiveMessageSize { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
199200
public System.Collections.Generic.IList<string> SupportedProtocols { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
200201
}
201202
public partial class HubOptions<THub> : Microsoft.AspNetCore.SignalR.HubOptions where THub : Microsoft.AspNetCore.SignalR.Hub

src/SignalR/server/Core/src/HubConnectionHandler.cs

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public class HubConnectionHandler<THub> : ConnectionHandler where THub : Hub
2828
private readonly IUserIdProvider _userIdProvider;
2929
private readonly HubDispatcher<THub> _dispatcher;
3030
private readonly bool _enableDetailedErrors;
31+
private readonly long? _maximumMessageSize;
3132

3233
/// <summary>
3334
/// Initializes a new instance of the <see cref="HubConnectionHandler{THub}"/> class.
@@ -61,6 +62,7 @@ HubDispatcher<THub> dispatcher
6162
_dispatcher = dispatcher;
6263

6364
_enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _globalHubOptions.EnableDetailedErrors ?? false;
65+
_maximumMessageSize = _hubOptions.MaximumReceiveMessageSize ?? _globalHubOptions.MaximumReceiveMessageSize;
6466
}
6567

6668
/// <inheritdoc />
@@ -69,7 +71,7 @@ public override async Task OnConnectedAsync(ConnectionContext connection)
6971
// We check to see if HubOptions<THub> are set because those take precedence over global hub options.
7072
// Then set the keepAlive and handshakeTimeout values to the defaults in HubOptionsSetup incase they were explicitly set to null.
7173
var keepAlive = _hubOptions.KeepAliveInterval ?? _globalHubOptions.KeepAliveInterval ?? HubOptionsSetup.DefaultKeepAliveInterval;
72-
var clientTimeout = _hubOptions.ClientTimeoutInterval ?? _globalHubOptions.ClientTimeoutInterval ?? HubOptionsSetup.DefaultClientTimeoutInterval;
74+
var clientTimeout = _hubOptions.ClientTimeoutInterval ?? _globalHubOptions.ClientTimeoutInterval ?? HubOptionsSetup.DefaultClientTimeoutInterval;
7375
var handshakeTimeout = _hubOptions.HandshakeTimeout ?? _globalHubOptions.HandshakeTimeout ?? HubOptionsSetup.DefaultHandshakeTimeout;
7476
var supportedProtocols = _hubOptions.SupportedProtocols ?? _globalHubOptions.SupportedProtocols;
7577

@@ -205,9 +207,47 @@ private async Task DispatchMessagesAsync(HubConnectionContext connection)
205207
{
206208
connection.ResetClientTimeout();
207209

208-
while (protocol.TryParseMessage(ref buffer, binder, out var message))
210+
// No message limit, just parse and dispatch
211+
if (_maximumMessageSize == null)
209212
{
210-
await _dispatcher.DispatchMessageAsync(connection, message);
213+
while (protocol.TryParseMessage(ref buffer, binder, out var message))
214+
{
215+
await _dispatcher.DispatchMessageAsync(connection, message);
216+
}
217+
}
218+
else
219+
{
220+
// We give the parser a sliding window of the default message size
221+
var maxMessageSize = _maximumMessageSize.Value;
222+
223+
while (!buffer.IsEmpty)
224+
{
225+
var segment = buffer;
226+
var overLength = false;
227+
228+
if (segment.Length > maxMessageSize)
229+
{
230+
segment = segment.Slice(segment.Start, maxMessageSize);
231+
overLength = true;
232+
}
233+
234+
if (protocol.TryParseMessage(ref segment, binder, out var message))
235+
{
236+
await _dispatcher.DispatchMessageAsync(connection, message);
237+
}
238+
else if (overLength)
239+
{
240+
throw new InvalidDataException($"The maximum message size of {maxMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
241+
}
242+
else
243+
{
244+
// No need to update the buffer since we didn't parse anything
245+
break;
246+
}
247+
248+
// Update the buffer to the remaining segment
249+
buffer = buffer.Slice(segment.Start);
250+
}
211251
}
212252
}
213253

src/SignalR/server/Core/src/HubOptions.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ public class HubOptions
3636
/// </summary>
3737
public IList<string> SupportedProtocols { get; set; } = null;
3838

39+
/// <summary>
40+
/// Gets or sets the maximum message size of a single incoming hub message. The default is 32KB.
41+
/// </summary>
42+
public long? MaximumReceiveMessageSize { get; set; } = null;
43+
3944
/// <summary>
4045
/// Gets or sets a value indicating whether detailed error messages are sent to the client.
4146
/// Detailed error messages include details from exceptions thrown on the server.

src/SignalR/server/Core/src/Internal/HubOptionsSetup.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) .NET Foundation. All rights reserved.
1+
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
@@ -16,6 +16,8 @@ public class HubOptionsSetup : IConfigureOptions<HubOptions>
1616

1717
internal static TimeSpan DefaultClientTimeoutInterval => TimeSpan.FromSeconds(30);
1818

19+
internal const int DefaultMaximumMessageSize = 32 * 1024 * 1024;
20+
1921
private readonly List<string> _defaultProtocols = new List<string>();
2022

2123
public HubOptionsSetup(IEnumerable<IHubProtocol> protocols)
@@ -40,6 +42,11 @@ public void Configure(HubOptions options)
4042
options.HandshakeTimeout = DefaultHandshakeTimeout;
4143
}
4244

45+
if (options.MaximumReceiveMessageSize == null)
46+
{
47+
options.MaximumReceiveMessageSize = DefaultMaximumMessageSize;
48+
}
49+
4350
if (options.SupportedProtocols == null)
4451
{
4552
options.SupportedProtocols = new List<string>();

src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Utils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public static IServiceProvider CreateServiceProvider(Action<ServiceCollection> a
7272
return services.BuildServiceProvider();
7373
}
7474

75-
public static Connections.ConnectionHandler GetHubConnectionHandler(Type hubType, Action<ServiceCollection> addServices = null, ILoggerFactory loggerFactory = null)
75+
public static Connections.ConnectionHandler GetHubConnectionHandler(Type hubType, ILoggerFactory loggerFactory = null, Action<ServiceCollection> addServices = null)
7676
{
7777
var serviceProvider = CreateServiceProvider(addServices, loggerFactory);
7878
return (Connections.ConnectionHandler)serviceProvider.GetService(GetConnectionHandlerType(hubType));

src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs

Lines changed: 189 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Buffers;
66
using System.Collections.Generic;
77
using System.Diagnostics;
8+
using System.IO;
89
using System.Linq;
910
using System.Security.Claims;
1011
using System.Text;
@@ -452,6 +453,191 @@ public async Task HandshakeSuccessSendsResponseWithoutError()
452453
}
453454
}
454455

456+
[Fact]
457+
public async Task HubMessageOverTheMaxMessageSizeThrows()
458+
{
459+
var payload = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e");
460+
var maximumMessageSize = payload.Length - 10;
461+
InvalidDataException exception = null;
462+
463+
bool ExpectedErrors(WriteContext writeContext)
464+
{
465+
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
466+
{
467+
exception = ide;
468+
return true;
469+
}
470+
return false;
471+
}
472+
473+
using (StartVerifiableLog(ExpectedErrors))
474+
{
475+
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
476+
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
477+
478+
using (var client = new TestClient())
479+
{
480+
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
481+
482+
await client.Connection.Application.Output.WriteAsync(payload);
483+
484+
client.Dispose();
485+
486+
await connectionHandlerTask.OrTimeout();
487+
}
488+
}
489+
490+
Assert.NotNull(exception);
491+
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
492+
}
493+
494+
[Fact]
495+
public async Task ChunkedHubMessageOverTheMaxMessageSizeThrows()
496+
{
497+
var payload = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e");
498+
var maximumMessageSize = payload.Length - 10;
499+
InvalidDataException exception = null;
500+
501+
bool ExpectedErrors(WriteContext writeContext)
502+
{
503+
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
504+
{
505+
exception = ide;
506+
return true;
507+
}
508+
return false;
509+
}
510+
511+
using (StartVerifiableLog(ExpectedErrors))
512+
{
513+
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
514+
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
515+
516+
using (var client = new TestClient())
517+
{
518+
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
519+
520+
await client.Connection.Application.Output.WriteAsync(payload.AsMemory(0, payload.Length / 2));
521+
await client.Connection.Application.Output.WriteAsync(payload.AsMemory(payload.Length / 2));
522+
523+
client.Dispose();
524+
525+
await connectionHandlerTask.OrTimeout();
526+
}
527+
}
528+
529+
Assert.NotNull(exception);
530+
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
531+
}
532+
533+
[Fact]
534+
public async Task ManyHubMessagesOneOverTheMaxMessageSizeThrows()
535+
{
536+
var payload1 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"one\"]}\u001e");
537+
var payload2 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"2\", \"target\": \"Echo\", \"arguments\":[\"two\"]}\u001e");
538+
var payload3 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"3\", \"target\": \"Echo\", \"arguments\":[\"three\"]}\u001e");
539+
540+
// Between the first and the second payload so we'll end up slicing with some remaining in the slice for
541+
// the next message
542+
var maximumMessageSize = payload1.Length + 1;
543+
InvalidDataException exception = null;
544+
545+
bool ExpectedErrors(WriteContext writeContext)
546+
{
547+
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
548+
{
549+
exception = ide;
550+
return true;
551+
}
552+
return false;
553+
}
554+
555+
using (StartVerifiableLog(ExpectedErrors))
556+
{
557+
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
558+
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
559+
560+
using (var client = new TestClient())
561+
{
562+
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
563+
564+
client.Connection.Application.Output.Write(payload1);
565+
client.Connection.Application.Output.Write(payload2);
566+
client.Connection.Application.Output.Write(payload3);
567+
await client.Connection.Application.Output.FlushAsync();
568+
569+
// 2 invocations should be processed
570+
var completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
571+
Assert.NotNull(completionMessage);
572+
Assert.Equal("1", completionMessage.InvocationId);
573+
Assert.Equal("one", completionMessage.Result);
574+
575+
completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
576+
Assert.NotNull(completionMessage);
577+
Assert.Equal("2", completionMessage.InvocationId);
578+
Assert.Equal("two", completionMessage.Result);
579+
580+
// We never receive the 3rd message since it was over the maximum message size
581+
CloseMessage closeMessage = await client.ReadAsync().OrTimeout() as CloseMessage;
582+
Assert.NotNull(closeMessage);
583+
584+
client.Dispose();
585+
586+
await connectionHandlerTask.OrTimeout();
587+
}
588+
}
589+
590+
Assert.NotNull(exception);
591+
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
592+
}
593+
594+
[Fact]
595+
public async Task ManyHubMessagesUnderTheMessageSizeButConfiguredWithMax()
596+
{
597+
var payload1 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"one\"]}\u001e");
598+
var payload2 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"2\", \"target\": \"Echo\", \"arguments\":[\"two\"]}\u001e");
599+
var payload3 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"3\", \"target\": \"Echo\", \"arguments\":[\"three\"]}\u001e");
600+
601+
// Bigger than all 3 messages
602+
var maximumMessageSize = payload3.Length + 10;
603+
604+
using (StartVerifiableLog())
605+
{
606+
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
607+
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
608+
609+
using (var client = new TestClient())
610+
{
611+
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
612+
613+
client.Connection.Application.Output.Write(payload1);
614+
client.Connection.Application.Output.Write(payload2);
615+
client.Connection.Application.Output.Write(payload3);
616+
await client.Connection.Application.Output.FlushAsync();
617+
618+
// 2 invocations should be processed
619+
var completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
620+
Assert.NotNull(completionMessage);
621+
Assert.Equal("1", completionMessage.InvocationId);
622+
Assert.Equal("one", completionMessage.Result);
623+
624+
completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
625+
Assert.NotNull(completionMessage);
626+
Assert.Equal("2", completionMessage.InvocationId);
627+
Assert.Equal("two", completionMessage.Result);
628+
629+
completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
630+
Assert.NotNull(completionMessage);
631+
Assert.Equal("3", completionMessage.InvocationId);
632+
Assert.Equal("three", completionMessage.Result);
633+
634+
client.Dispose();
635+
636+
await connectionHandlerTask.OrTimeout();
637+
}
638+
}
639+
}
640+
455641
[Fact]
456642
public async Task HandshakeFailureFromIncompatibleProtocolVersionSendsResponseWithError()
457643
{
@@ -2789,7 +2975,7 @@ public async Task UploadManyStreams()
27892975

27902976
foreach (string id in ids)
27912977
{
2792-
await client.BeginUploadStreamAsync("invocation_"+id, nameof(MethodHub.StreamingConcat), new[] { id }, Array.Empty<object>());
2978+
await client.BeginUploadStreamAsync("invocation_" + id, nameof(MethodHub.StreamingConcat), new[] { id }, Array.Empty<object>());
27932979
}
27942980

27952981
var words = new[] { "zygapophyses", "qwerty", "abcd" };
@@ -2868,7 +3054,7 @@ public async Task UploadStreamItemInvalidTypeAutoCasts()
28683054
}
28693055
}
28703056
}
2871-
3057+
28723058
[Fact]
28733059
public async Task ServerReportsProtocolMinorVersion()
28743060
{
@@ -2881,7 +3067,7 @@ public async Task ServerReportsProtocolMinorVersion()
28813067
testProtocol.Setup(m => m.TransferFormat).Returns(TransferFormat.Binary);
28823068

28833069
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT),
2884-
(services) => services.AddSingleton<IHubProtocol>(testProtocol.Object), LoggerFactory);
3070+
LoggerFactory, (services) => services.AddSingleton<IHubProtocol>(testProtocol.Object));
28853071

28863072
using (var client = new TestClient(protocol: testProtocol.Object))
28873073
{

0 commit comments

Comments
 (0)