Skip to content

Commit f283f8b

Browse files
Add support for DI in Hub methods (#34047)
1 parent 4fc8081 commit f283f8b

File tree

10 files changed

+369
-20
lines changed

10 files changed

+369
-20
lines changed

src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public void GlobalSetup()
3333
serviceScopeFactory,
3434
new HubContext<TestHub>(new DefaultHubLifetimeManager<TestHub>(NullLogger<DefaultHubLifetimeManager<TestHub>>.Instance)),
3535
enableDetailedErrors: false,
36+
disableImplicitFromServiceParameters: true,
3637
new Logger<DefaultHubDispatcher<TestHub>>(NullLoggerFactory.Instance),
3738
hubFilters: null);
3839

src/SignalR/server/Core/src/HubConnectionHandler.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,15 @@ IServiceScopeFactory serviceScopeFactory
6262
_userIdProvider = userIdProvider;
6363

6464
_enableDetailedErrors = false;
65+
bool disableImplicitFromServiceParameters;
6566

6667
List<IHubFilter>? hubFilters = null;
6768
if (_hubOptions.UserHasSetValues)
6869
{
6970
_maximumMessageSize = _hubOptions.MaximumReceiveMessageSize;
7071
_enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
7172
_maxParallelInvokes = _hubOptions.MaximumParallelInvocationsPerClient;
73+
disableImplicitFromServiceParameters = _hubOptions.DisableImplicitFromServiceParameters;
7274

7375
if (_hubOptions.HubFilters != null)
7476
{
@@ -80,6 +82,7 @@ IServiceScopeFactory serviceScopeFactory
8082
_maximumMessageSize = _globalHubOptions.MaximumReceiveMessageSize;
8183
_enableDetailedErrors = _globalHubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
8284
_maxParallelInvokes = _globalHubOptions.MaximumParallelInvocationsPerClient;
85+
disableImplicitFromServiceParameters = _globalHubOptions.DisableImplicitFromServiceParameters;
8386

8487
if (_globalHubOptions.HubFilters != null)
8588
{
@@ -91,6 +94,7 @@ IServiceScopeFactory serviceScopeFactory
9194
serviceScopeFactory,
9295
new HubContext<THub>(lifetimeManager),
9396
_enableDetailedErrors,
97+
disableImplicitFromServiceParameters,
9498
new Logger<DefaultHubDispatcher<THub>>(loggerFactory),
9599
hubFilters);
96100
}

src/SignalR/server/Core/src/HubOptions.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using Microsoft.AspNetCore.Http.Metadata;
5+
using Microsoft.Extensions.DependencyInjection;
6+
47
namespace Microsoft.AspNetCore.SignalR;
58

69
/// <summary>
@@ -70,4 +73,13 @@ public int MaximumParallelInvocationsPerClient
7073
_maximumParallelInvocationsPerClient = value;
7174
}
7275
}
76+
77+
/// <summary>
78+
/// When <see langword="false"/>, <see cref="IServiceProviderIsService"/> determines if a Hub method parameter will be injected from the DI container.
79+
/// Parameters can be explicitly marked with an attribute that implements <see cref="IFromServiceMetadata"/> with or without this option set.
80+
/// </summary>
81+
/// <remarks>
82+
/// False by default. Hub method arguments will be resolved from a DI container if possible.
83+
/// </remarks>
84+
public bool DisableImplicitFromServiceParameters { get; set; }
7385
}

src/SignalR/server/Core/src/HubOptionsSetup`T.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public void Configure(HubOptions<THub> options)
3737
options.MaximumReceiveMessageSize = _hubOptions.MaximumReceiveMessageSize;
3838
options.StreamBufferCapacity = _hubOptions.StreamBufferCapacity;
3939
options.MaximumParallelInvocationsPerClient = _hubOptions.MaximumParallelInvocationsPerClient;
40+
options.DisableImplicitFromServiceParameters = _hubOptions.DisableImplicitFromServiceParameters;
4041

4142
options.UserHasSetValues = true;
4243

src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ internal partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> where TH
2828
private readonly Func<HubLifetimeContext, Exception?, Task>? _onDisconnectedMiddleware;
2929

3030
public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext<THub> hubContext, bool enableDetailedErrors,
31-
ILogger<DefaultHubDispatcher<THub>> logger, List<IHubFilter>? hubFilters)
31+
bool disableImplicitFromServiceParameters, ILogger<DefaultHubDispatcher<THub>> logger, List<IHubFilter>? hubFilters)
3232
{
3333
_serviceScopeFactory = serviceScopeFactory;
3434
_hubContext = hubContext;
3535
_enableDetailedErrors = enableDetailedErrors;
3636
_logger = logger;
37-
DiscoverHubMethods();
37+
DiscoverHubMethods(disableImplicitFromServiceParameters);
3838

3939
var count = hubFilters?.Count ?? 0;
4040
if (count != 0)
@@ -307,7 +307,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
307307
CancellationTokenSource? cts = null;
308308
if (descriptor.HasSyntheticArguments)
309309
{
310-
ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, ref arguments, out cts);
310+
ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, scope, ref arguments, out cts);
311311
}
312312

313313
if (isStreamResponse)
@@ -601,7 +601,7 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa
601601
}
602602

603603
private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamCall,
604-
HubConnectionContext connection, ref object?[] arguments, out CancellationTokenSource? cts)
604+
HubConnectionContext connection, AsyncServiceScope scope, ref object?[] arguments, out CancellationTokenSource? cts)
605605
{
606606
cts = null;
607607
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
@@ -626,6 +626,10 @@ private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocatio
626626
cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
627627
arguments[parameterPointer] = cts.Token;
628628
}
629+
else if (descriptor.IsServiceArgument(parameterPointer))
630+
{
631+
arguments[parameterPointer] = scope.ServiceProvider.GetRequiredService(descriptor.OriginalParameterTypes[parameterPointer]);
632+
}
629633
else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true))
630634
{
631635
Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds![streamPointer]);
@@ -644,12 +648,20 @@ private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocatio
644648
}
645649
}
646650

647-
private void DiscoverHubMethods()
651+
private void DiscoverHubMethods(bool disableImplicitFromServiceParameters)
648652
{
649653
var hubType = typeof(THub);
650654
var hubTypeInfo = hubType.GetTypeInfo();
651655
var hubName = hubType.Name;
652656

657+
using var scope = _serviceScopeFactory.CreateScope();
658+
659+
IServiceProviderIsService? serviceProviderIsService = null;
660+
if (!disableImplicitFromServiceParameters)
661+
{
662+
serviceProviderIsService = scope.ServiceProvider.GetService<IServiceProviderIsService>();
663+
}
664+
653665
foreach (var methodInfo in HubReflectionHelper.GetHubMethods(hubType))
654666
{
655667
if (methodInfo.IsGenericMethod)
@@ -668,7 +680,7 @@ private void DiscoverHubMethods()
668680

669681
var executor = ObjectMethodExecutor.Create(methodInfo, hubTypeInfo);
670682
var authorizeAttributes = methodInfo.GetCustomAttributes<AuthorizeAttribute>(inherit: true);
671-
_methods[methodName] = new HubMethodDescriptor(executor, authorizeAttributes);
683+
_methods[methodName] = new HubMethodDescriptor(executor, serviceProviderIsService, authorizeAttributes);
672684

673685
Log.HubMethodBound(_logger, hubName, methodName);
674686
}

src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
using System.Reflection;
77
using System.Threading.Channels;
88
using Microsoft.AspNetCore.Authorization;
9+
using Microsoft.AspNetCore.Http.Metadata;
10+
using Microsoft.Extensions.DependencyInjection;
911
using Microsoft.Extensions.Internal;
1012

1113
namespace Microsoft.AspNetCore.SignalR.Internal;
@@ -22,8 +24,10 @@ internal class HubMethodDescriptor
2224

2325
private readonly MethodInfo? _makeCancelableEnumeratorMethodInfo;
2426
private Func<object, CancellationToken, IAsyncEnumerator<object>>? _makeCancelableEnumerator;
27+
// bitset to store which parameters come from DI up to 64 arguments
28+
private ulong _isServiceArgument;
2529

26-
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
30+
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProviderIsService? serviceProviderIsService, IEnumerable<IAuthorizeData> policies)
2731
{
2832
MethodExecutor = methodExecutor;
2933

@@ -56,7 +60,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut
5660
}
5761

5862
// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
59-
ParameterTypes = methodExecutor.MethodParameters.Where(p =>
63+
ParameterTypes = methodExecutor.MethodParameters.Where((p, index) =>
6064
{
6165
// Only streams can take CancellationTokens currently
6266
if (IsStreamResponse && p.ParameterType == typeof(CancellationToken))
@@ -75,6 +79,18 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut
7579
HasSyntheticArguments = true;
7680
return false;
7781
}
82+
else if (p.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType)) ||
83+
serviceProviderIsService?.IsService(p.ParameterType) == true)
84+
{
85+
if (index >= 64)
86+
{
87+
throw new InvalidOperationException(
88+
"Hub methods can't use services from DI in the parameters after the 64th parameter.");
89+
}
90+
_isServiceArgument |= (1UL << index);
91+
HasSyntheticArguments = true;
92+
return false;
93+
}
7894
return true;
7995
}).Select(p => p.ParameterType).ToArray();
8096

@@ -104,6 +120,11 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut
104120

105121
public bool HasSyntheticArguments { get; private set; }
106122

123+
public bool IsServiceArgument(int argumentIndex)
124+
{
125+
return (_isServiceArgument & (1UL << argumentIndex)) != 0;
126+
}
127+
107128
public IAsyncEnumerator<object> FromReturnedStream(object stream, CancellationToken cancellationToken)
108129
{
109130
// there is the potential for compile to be called times but this has no harmful effect other than perf
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
#nullable enable
2+
Microsoft.AspNetCore.SignalR.HubOptions.DisableImplicitFromServiceParameters.get -> bool
3+
Microsoft.AspNetCore.SignalR.HubOptions.DisableImplicitFromServiceParameters.set -> void

src/SignalR/server/SignalR/test/AddSignalRTests.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ public void HubSpecificOptionsHaveSameValuesAsGlobalHubOptions()
111111
Assert.Equal(globalHubOptions.SupportedProtocols, hubOptions.SupportedProtocols);
112112
Assert.Equal(globalHubOptions.ClientTimeoutInterval, hubOptions.ClientTimeoutInterval);
113113
Assert.Equal(globalHubOptions.MaximumParallelInvocationsPerClient, hubOptions.MaximumParallelInvocationsPerClient);
114+
Assert.Equal(globalHubOptions.DisableImplicitFromServiceParameters, hubOptions.DisableImplicitFromServiceParameters);
114115
Assert.True(hubOptions.UserHasSetValues);
115116
}
116117

@@ -145,6 +146,7 @@ public void UserSpecifiedOptionsRunAfterDefaultOptions()
145146
options.SupportedProtocols = null;
146147
options.ClientTimeoutInterval = TimeSpan.FromSeconds(1);
147148
options.MaximumParallelInvocationsPerClient = 3;
149+
options.DisableImplicitFromServiceParameters = true;
148150
});
149151

150152
var serviceProvider = serviceCollection.BuildServiceProvider();
@@ -158,6 +160,7 @@ public void UserSpecifiedOptionsRunAfterDefaultOptions()
158160
Assert.Null(globalOptions.SupportedProtocols);
159161
Assert.Equal(3, globalOptions.MaximumParallelInvocationsPerClient);
160162
Assert.Equal(TimeSpan.FromSeconds(1), globalOptions.ClientTimeoutInterval);
163+
Assert.True(globalOptions.DisableImplicitFromServiceParameters);
161164
}
162165

163166
[Fact]

src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4-
using System;
5-
using System.Collections.Generic;
64
using System.Globalization;
7-
using System.Linq;
85
using System.Runtime.CompilerServices;
96
using System.Text;
10-
using System.Threading;
117
using System.Threading.Channels;
12-
using System.Threading.Tasks;
138
using Microsoft.AspNetCore.Authorization;
9+
using Microsoft.AspNetCore.Http.Metadata;
1410
using Newtonsoft.Json.Serialization;
1511

1612
namespace Microsoft.AspNetCore.SignalR.Tests;
@@ -1247,3 +1243,65 @@ public void SetCaller(IClientProxy caller)
12471243
Caller = caller;
12481244
}
12491245
}
1246+
1247+
[AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false, Inherited = true)]
1248+
public class FromService : Attribute, IFromServiceMetadata
1249+
{ }
1250+
public class Service1
1251+
{ }
1252+
public class Service2
1253+
{ }
1254+
public class Service3
1255+
{ }
1256+
1257+
public class ServicesHub : TestHub
1258+
{
1259+
public bool SingleService([FromService] Service1 service)
1260+
{
1261+
return true;
1262+
}
1263+
1264+
public bool MultipleServices([FromService] Service1 service, [FromService] Service2 service2, [FromService] Service3 service3)
1265+
{
1266+
return true;
1267+
}
1268+
1269+
public async Task<int> ServicesAndParams(int value, [FromService] Service1 service, ChannelReader<int> channelReader, [FromService] Service2 service2, bool value2)
1270+
{
1271+
int total = 0;
1272+
while (await channelReader.WaitToReadAsync())
1273+
{
1274+
total += await channelReader.ReadAsync();
1275+
}
1276+
return total + value;
1277+
}
1278+
1279+
public int ServiceWithoutAttribute(Service1 service)
1280+
{
1281+
return 1;
1282+
}
1283+
1284+
public int ServiceWithAndWithoutAttribute(Service1 service, [FromService] Service2 service2)
1285+
{
1286+
return 1;
1287+
}
1288+
1289+
public async Task Stream(ChannelReader<int> channelReader)
1290+
{
1291+
while (await channelReader.WaitToReadAsync())
1292+
{
1293+
await channelReader.ReadAsync();
1294+
}
1295+
}
1296+
}
1297+
1298+
public class TooManyParamsHub : Hub
1299+
{
1300+
public void ManyParams(int a1, string a2, bool a3, float a4, string a5, int a6, int a7, int a8, int a9, int a10, int a11,
1301+
int a12, int a13, int a14, int a15, int a16, int a17, int a18, int a19, int a20, int a21, int a22, int a23, int a24,
1302+
int a25, int a26, int a27, int a28, int a29, int a30, int a31, int a32, int a33, int a34, int a35, int a36, int a37,
1303+
int a38, int a39, int a40, int a41, int a42, int a43, int a44, int a45, int a46, int a47, int a48, int a49, int a50,
1304+
int a51, int a52, int a53, int a54, int a55, int a56, int a57, int a58, int a59, int a60, int a61, int a62, int a63,
1305+
int a64, [FromService] Service1 service)
1306+
{ }
1307+
}

0 commit comments

Comments
 (0)