From c3f7649d48aa6e8a92f7e2ed4ce5db21dd6a420b Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Fri, 7 Jun 2024 14:13:02 -0500 Subject: [PATCH 1/3] Support IAsyncEnumerable and ChannelReader with ValueTypes in SignalR native AOT Support streaming ValueTypes from a SignalR Hub method in both the client and the server in native AOT. In order to make this work, we need to use pure reflection to read from the streaming object. Support passing in an IAsyncEnumerable/ChannelReader of ValueType to a parameter in SignalR.Client. This works because the user code creates the concrete object, and the SignalR.Client library just needs to read from it using reflection. The only scenario that can't be supported is on the SignalR server we can't support receiving an IAsyncEnumerable/ChannelReader of ValueType. This is because there is no way for the SignalR library code to construct a concrete instance to pass into the user-defined method on native AOT. Fix #56179 --- .../csharp/Client.Core/src/HubConnection.cs | 63 +++++++-- .../common/Shared/AsyncEnumerableAdapters.cs | 118 +++++++++++++++- src/SignalR/common/Shared/ReflectionHelper.cs | 23 +++ .../Core/src/Internal/HubMethodDescriptor.cs | 64 +++++---- .../NativeAotTests.cs | 131 ++++++++++++++++-- 5 files changed, 351 insertions(+), 48 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 93d4c3492b9b..35a4ade375bf 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -857,26 +857,69 @@ private void LaunchStreams(ConnectionState connectionState, Dictionary + /// Uses reflection to read items from an IAsyncEnumerable{T} or ChannelReader{T} and send them to the server. + /// + /// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario, + /// we cannot use MakeGenericMethod to call the appropriate SendStreamItems method because the generic type is a value type. + /// + private Task ReflectionSendStreamItems(MethodInfo methodInfo, ConnectionState connectionState, string streamId, object reader, CancellationTokenSource tokenSource) + { + async Task ReadAsyncEnumeratorStream(IAsyncEnumerator enumerator) + { + try + { + while (await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + await SendWithLock(connectionState, new StreamItemMessage(streamId, enumerator.Current), tokenSource.Token).ConfigureAwait(false); + Log.SendingStreamItem(_logger, streamId); + } + } + finally + { + await enumerator.DisposeAsync().ConfigureAwait(false); + } + } - _ = methodInfo - .MakeGenericMethod(genericTypes) - .Invoke(this, [connectionState, streamId, reader, tokenSource]); + Func createAndConsumeStream; + if (methodInfo == _sendStreamItemsMethod) + { + // reader is a ChannelReader + createAndConsumeStream = () => ReadAsyncEnumeratorStream(AsyncEnumerableAdapters.MakeReflectionAsyncEnumeratorFromChannel(reader, tokenSource.Token)); + } + else + { + // reader is an IAsyncEnumerable + Debug.Assert(methodInfo == _sendIAsyncStreamItemsMethod); + + createAndConsumeStream = () => ReadAsyncEnumeratorStream(AsyncEnumerableAdapters.MakeReflectionAsyncEnumerator(reader, tokenSource.Token)); + } + + return CommonStreaming(connectionState, streamId, createAndConsumeStream, tokenSource); } +#endif - // this is called via reflection using the `_sendStreamItems` field + // this is called via reflection using the `_sendStreamItemsMethod` field private Task SendStreamItems(ConnectionState connectionState, string streamId, ChannelReader reader, CancellationTokenSource tokenSource) { async Task ReadChannelStream() diff --git a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs index 9df94d279128..0644a81473be 100644 --- a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs +++ b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs @@ -1,7 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; +using System.Reflection; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -11,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal; // True-internal because this is a weird and tricky class to use :) internal static class AsyncEnumerableAdapters { - public static IAsyncEnumerator MakeCancelableAsyncEnumerator(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken = default) + public static IAsyncEnumerator MakeAsyncEnumerator(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken = default) { var enumerator = asyncEnumerable.GetAsyncEnumerator(cancellationToken); return enumerator as IAsyncEnumerator ?? new BoxedAsyncEnumerator(enumerator); @@ -52,10 +54,13 @@ public ValueTask MoveNextAsync() private async Task MoveNextAsyncAwaited() { - if (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false) && _channel.TryRead(out var item)) + while (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false)) { - Current = item; - return true; + if (_channel.TryRead(out var item)) + { + Current = item; + return true; + } } return false; } @@ -137,4 +142,109 @@ public ValueTask DisposeAsync() return _asyncEnumerator.DisposeAsync(); } } + +#if NET6_0_OR_GREATER + + private static readonly MethodInfo _asyncEnumerableGetAsyncEnumeratorMethodInfo = typeof(IAsyncEnumerable<>).GetMethod("GetAsyncEnumerator")!; + + /// + /// Creates an IAsyncEnumerator{object} from an IAsyncEnumerable{T} using reflection. + /// + /// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario, + /// we cannot use MakeGenericMethod to call a generic method because the generic type is a value type. + /// + public static IAsyncEnumerator MakeReflectionAsyncEnumerator(object asyncEnumerable, CancellationToken cancellationToken) + { + var constructedIAsyncEnumerableInterface = ReflectionHelper.GetIAsyncEnumerableInterface(asyncEnumerable.GetType())!; + var enumerator = ((MethodInfo)constructedIAsyncEnumerableInterface.GetMemberWithSameMetadataDefinitionAs(_asyncEnumerableGetAsyncEnumeratorMethodInfo)).Invoke(asyncEnumerable, [cancellationToken])!; + return new ReflectionAsyncEnumerator(enumerator); + } + + /// + /// Creates an IAsyncEnumerator{object} from a ChannelReader{T} using reflection. + /// + /// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario, + /// we cannot use MakeGenericMethod to call a generic method because the generic type is a value type. + /// + public static IAsyncEnumerator MakeReflectionAsyncEnumeratorFromChannel(object channelReader, CancellationToken cancellationToken) + { + return new ReflectionChannelAsyncEnumerator(channelReader, cancellationToken); + } + + private sealed class ReflectionAsyncEnumerator : IAsyncEnumerator + { + private static readonly MethodInfo _asyncEnumeratorMoveNextAsyncMethodInfo = typeof(IAsyncEnumerator<>).GetMethod("MoveNextAsync")!; + private static readonly MethodInfo _asyncEnumeratorGetCurrentMethodInfo = typeof(IAsyncEnumerator<>).GetMethod("get_Current")!; + + private readonly object _enumerator; + private readonly MethodInfo _moveNextAsyncMethodInfo; + private readonly MethodInfo _getCurrentMethodInfo; + + public ReflectionAsyncEnumerator(object enumerator) + { + _enumerator = enumerator; + + var type = ReflectionHelper.GetIAsyncEnumeratorInterface(enumerator.GetType())!; + _moveNextAsyncMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorMoveNextAsyncMethodInfo)!; + _getCurrentMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorGetCurrentMethodInfo)!; + } + + public object? Current => _getCurrentMethodInfo.Invoke(_enumerator, []); + + public ValueTask MoveNextAsync() => (ValueTask)_moveNextAsyncMethodInfo.Invoke(_enumerator, [])!; + + public ValueTask DisposeAsync() => ((IAsyncDisposable)_enumerator).DisposeAsync(); + } + + private sealed class ReflectionChannelAsyncEnumerator : IAsyncEnumerator + { + private static readonly MethodInfo _channelReaderTryReadMethodInfo = typeof(ChannelReader<>).GetMethod("TryRead")!; + private static readonly MethodInfo _channelReaderWaitToReadAsyncMethodInfo = typeof(ChannelReader<>).GetMethod("WaitToReadAsync")!; + + private readonly object _channelReader; + private readonly object?[] _tryReadResult = [null]; + private readonly object[] _waitToReadArgs; + private readonly MethodInfo _tryReadMethodInfo; + private readonly MethodInfo _waitToReadAsyncMethodInfo; + + public ReflectionChannelAsyncEnumerator(object channelReader, CancellationToken cancellationToken) + { + _channelReader = channelReader; + _waitToReadArgs = [cancellationToken]; + + var type = channelReader.GetType(); + _tryReadMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_channelReaderTryReadMethodInfo)!; + _waitToReadAsyncMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_channelReaderWaitToReadAsyncMethodInfo)!; + } + + public object? Current { get; private set; } + + public ValueTask MoveNextAsync() + { + if ((bool)_tryReadMethodInfo.Invoke(_channelReader, _tryReadResult)!) + { + Current = _tryReadResult[0]; + return new ValueTask(true); + } + + return new ValueTask(MoveNextAsyncAwaited()); + } + + private async Task MoveNextAsyncAwaited() + { + while (await ((ValueTask)_waitToReadAsyncMethodInfo.Invoke(_channelReader, _waitToReadArgs)!).ConfigureAwait(false)) + { + if ((bool)_tryReadMethodInfo.Invoke(_channelReader, _tryReadResult)!) + { + Current = _tryReadResult[0]; + return true; + } + } + return false; + } + + public ValueTask DisposeAsync() => default; + } + +#endif } diff --git a/src/SignalR/common/Shared/ReflectionHelper.cs b/src/SignalR/common/Shared/ReflectionHelper.cs index 1713ffa06da7..64a6f5dcfb2a 100644 --- a/src/SignalR/common/Shared/ReflectionHelper.cs +++ b/src/SignalR/common/Shared/ReflectionHelper.cs @@ -69,4 +69,27 @@ public static bool TryGetStreamType(Type streamType, [NotNullWhen(true)] out Typ return null; } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern", + Justification = "The 'IAsyncEnumerator<>' Type must exist and so trimmer kept it. In which case " + + "It also kept it on any type which implements it. The below call to GetInterfaces " + + "may return fewer results when trimmed but it will return 'IAsyncEnumerator<>' " + + "if the type implemented it, even after trimming.")] + public static Type? GetIAsyncEnumeratorInterface(Type type) + { + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>)) + { + return type; + } + + foreach (Type typeToCheck in type.GetInterfaces()) + { + if (typeToCheck.IsGenericType && typeToCheck.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>)) + { + return typeToCheck; + } + } + + return null; + } } diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index 01ec440004eb..56f0e8a4311c 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -17,16 +17,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal; internal sealed class HubMethodDescriptor { - private static readonly MethodInfo MakeCancelableAsyncEnumeratorMethod = typeof(AsyncEnumerableAdapters) + private static readonly MethodInfo MakeAsyncEnumeratorMethod = typeof(AsyncEnumerableAdapters) .GetRuntimeMethods() - .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerator)) && m.IsGenericMethod); + .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeAsyncEnumerator)) && m.IsGenericMethod); private static readonly MethodInfo MakeAsyncEnumeratorFromChannelMethod = typeof(AsyncEnumerableAdapters) .GetRuntimeMethods() .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeAsyncEnumeratorFromChannel)) && m.IsGenericMethod); private readonly MethodInfo? _makeCancelableEnumeratorMethodInfo; - private Func>? _makeCancelableEnumerator; + private Func>? _makeCancelableEnumerator; // bitset to store which parameters come from DI up to 64 arguments private ulong _isServiceArgument; @@ -41,8 +41,8 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider var asyncEnumerableType = ReflectionHelper.GetIAsyncEnumerableInterface(NonAsyncReturnType); if (asyncEnumerableType is not null) { - StreamReturnType = ValidateStreamType(asyncEnumerableType.GetGenericArguments()[0]); - _makeCancelableEnumeratorMethodInfo = MakeCancelableAsyncEnumeratorMethod; + StreamReturnType = asyncEnumerableType.GetGenericArguments()[0]; + _makeCancelableEnumeratorMethodInfo = MakeAsyncEnumeratorMethod; } else { @@ -50,7 +50,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider { if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(ChannelReader<>)) { - StreamReturnType = ValidateStreamType(returnType.GetGenericArguments()[0]); + StreamReturnType = returnType.GetGenericArguments()[0]; _makeCancelableEnumeratorMethodInfo = MakeAsyncEnumeratorFromChannelMethod; break; } @@ -73,7 +73,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider StreamingParameters = new List(); } - StreamingParameters.Add(ValidateStreamType(p.ParameterType.GetGenericArguments()[0])); + StreamingParameters.Add(ValidateParameterStreamType(p.ParameterType.GetGenericArguments()[0], p.ParameterType)); HasSyntheticArguments = true; return false; } @@ -201,7 +201,7 @@ public object GetService(IServiceProvider serviceProvider, int index, Type param return serviceProvider.GetRequiredService(parameterType); } - public IAsyncEnumerator FromReturnedStream(object stream, CancellationToken cancellationToken) + public IAsyncEnumerator FromReturnedStream(object stream, CancellationToken cancellationToken) { // there is the potential for _makeCancelableEnumerator to be set multiple times but this has no harmful effect other than startup perf if (_makeCancelableEnumerator == null) @@ -220,12 +220,12 @@ public IAsyncEnumerator FromReturnedStream(object stream, CancellationTo } [UnconditionalSuppressMessage("Trimming", "IL2060:MakeGenericMethod", - Justification = "The adapter methods passed into here (MakeCancelableAsyncEnumerator and MakeAsyncEnumeratorFromChannel) don't have trimming annotations.")] + Justification = "The adapter methods passed into here (MakeAsyncEnumerator and MakeAsyncEnumeratorFromChannel) don't have trimming annotations.")] [RequiresDynamicCode("Calls MakeGenericMethod with types that may be ValueTypes")] - private static Func> CompileConvertToEnumerator(MethodInfo adapterMethodInfo, Type streamReturnType) + private static Func> CompileConvertToEnumerator(MethodInfo adapterMethodInfo, Type streamReturnType) { // This will call one of two adapter methods to wrap the passed in streamable value into an IAsyncEnumerable: - // - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerator(asyncEnumerable, cancellationToken); + // - AsyncEnumerableAdapters.MakeAsyncEnumerator(asyncEnumerable, cancellationToken); // - AsyncEnumerableAdapters.MakeCancelableAsyncEnumeratorFromChannel(channelReader, cancellationToken); var parameters = new[] @@ -243,23 +243,39 @@ private static Func> Compile }; var methodCall = Expression.Call(null, genericMethodInfo, methodArguments); - var lambda = Expression.Lambda>>(methodCall, parameters); + var lambda = Expression.Lambda>>(methodCall, parameters); return lambda.Compile(); } [UnconditionalSuppressMessage("Trimming", "IL2060:MakeGenericMethod", - Justification = "The adapter methods passed into here (MakeCancelableAsyncEnumerator and MakeAsyncEnumeratorFromChannel) don't have trimming annotations.")] + Justification = "The adapter methods passed into here (MakeAsyncEnumerator and MakeAsyncEnumeratorFromChannel) don't have trimming annotations.")] [UnconditionalSuppressMessage("AOT", "IL3050:RequiresDynamicCode", - Justification = "There is a runtime check for ValueType streaming item type when PublishAot=true. Developers will get an exception in this situation before publishing.")] - private static Func> ConvertToEnumeratorWithReflection(MethodInfo adapterMethodInfo, Type streamReturnType) + Justification = "ValueTypes are handled without using MakeGenericMethod.")] + private static Func> ConvertToEnumeratorWithReflection(MethodInfo adapterMethodInfo, Type streamReturnType) { - Debug.Assert(!streamReturnType.IsValueType, "ValidateStreamType will throw during the ctor if the streamReturnType is a ValueType when PublishAot=true."); + if (streamReturnType.IsValueType) + { + if (adapterMethodInfo == MakeAsyncEnumeratorMethod) + { + // return type is an IAsyncEnumerable + return AsyncEnumerableAdapters.MakeReflectionAsyncEnumerator; + } + else + { + // must be a ChannelReader + Debug.Assert(adapterMethodInfo == MakeAsyncEnumeratorFromChannelMethod); - var genericAdapterMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType); - return (stream, cancellationToken) => + return AsyncEnumerableAdapters.MakeReflectionAsyncEnumeratorFromChannel; + } + } + else { - return (IAsyncEnumerator)genericAdapterMethodInfo.Invoke(null, [stream, cancellationToken])!; - }; + var genericAdapterMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType); + return (stream, cancellationToken) => + { + return (IAsyncEnumerator)genericAdapterMethodInfo.Invoke(null, [stream, cancellationToken])!; + }; + } } private static Type GetServiceType(Type type) @@ -276,14 +292,14 @@ private static Type GetServiceType(Type type) return type; } - private Type ValidateStreamType(Type streamType) + private Type ValidateParameterStreamType(Type streamType, Type parameterType) { if (!RuntimeFeature.IsDynamicCodeSupported && streamType.IsValueType) { - // NativeAOT apps are not able to stream IAsyncEnumerable and ChannelReader of ValueTypes - // since we cannot create AsyncEnumerableAdapters.MakeCancelableAsyncEnumerator and AsyncEnumerableAdapters.MakeAsyncEnumeratorFromChannel methods with a generic ValueType. + // NativeAOT apps are not able to stream IAsyncEnumerable and ChannelReader of ValueTypes as parameters + // since we cannot create a concrete IAsyncEnumerable and ChannelReader of ValueType to pass into the Hub method. var methodInfo = MethodExecutor.MethodInfo; - throw new InvalidOperationException($"Unable to stream an item with type '{streamType}' on method '{methodInfo.DeclaringType}.{methodInfo.Name}' because it is a ValueType. Native code to support streaming this ValueType will not be available with native AOT."); + throw new InvalidOperationException($"Method '{methodInfo.DeclaringType}.{methodInfo.Name}' is not supported with native AOT because it has a parameter of type '{parameterType}'. Streaming parameters of ValueTypes is not supported because the native code to support the ValueType will not be available with native AOT."); } return streamType; diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/NativeAotTests.cs b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/NativeAotTests.cs index 94f37318eff7..ff3c1cb60ec4 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/NativeAotTests.cs +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/NativeAotTests.cs @@ -3,6 +3,7 @@ using System.Globalization; using System.Text; +using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Channels; using Microsoft.AspNetCore.Builder; @@ -87,6 +88,35 @@ public void CanCallAsyncMethods() echoResults.Add(item); } Assert.Equal(["echo:some data", "echo:some more data", "echo:even more data"], echoResults); + + var streamValueTypeResults = new List(); + await foreach (var item in connection.StreamAsync(nameof(AsyncMethodHub.ReturnEnumerableValueType))) + { + streamValueTypeResults.Add(item); + } + Assert.Equal([1, 2], streamValueTypeResults); + + var returnChannelValueTypeResults = new List(); + var returnChannelValueTypeReader = await connection.StreamAsChannelAsync(nameof(AsyncMethodHub.ReturnChannelValueType), "Hello"); + await foreach (var item in returnChannelValueTypeReader.ReadAllAsync()) + { + returnChannelValueTypeResults.Add(item); + } + Assert.Equal(['H', 'e', 'l', 'l', 'o'], returnChannelValueTypeResults); + + // Even though SignalR server doesn't support Hub methods with streaming value types in native AOT (https://github.com/dotnet/aspnetcore/issues/56179), + // still test that the client can send them. + var stringResult = await connection.InvokeAsync(nameof(AsyncMethodHub.EnumerableIntParameter), StreamInts()); + Assert.Equal("1, 2, 3", stringResult); + + var channelShorts = Channel.CreateBounded(10); + await channelShorts.Writer.WriteAsync(9); + await channelShorts.Writer.WriteAsync(8); + await channelShorts.Writer.WriteAsync(7); + channelShorts.Writer.Complete(); + + stringResult = await connection.InvokeAsync(nameof(AsyncMethodHub.ChannelShortParameter), channelShorts.Reader); + Assert.Equal("9, 8, 7", stringResult); } }); } @@ -99,20 +129,30 @@ private static async IAsyncEnumerable StreamMessages() yield return "message two"; } + private static async IAsyncEnumerable StreamInts() + { + await Task.Yield(); + yield return 1; + await Task.Yield(); + yield return 2; + await Task.Yield(); + yield return 3; + } + [ConditionalFact] [RemoteExecutionSupported] public void UsingValueTypesInStreamingThrows() { RunNativeAotTest(static async () => { - var e = await Assert.ThrowsAsync(() => InProcessTestServer>.StartServer(NullLoggerFactory.Instance)); - Assert.Contains("Unable to stream an item with type 'System.Int32' on method 'Microsoft.AspNetCore.SignalR.Tests.NativeAotTests+AsyncEnumerableIntMethodHub.StreamValueType' because it is a ValueType.", e.Message); + var e = await Assert.ThrowsAsync(() => InProcessTestServer>.StartServer(NullLoggerFactory.Instance)); + Assert.Contains("Method 'Microsoft.AspNetCore.SignalR.Tests.NativeAotTests+ChannelValueTypeMethodHub.StreamValueType' is not supported with native AOT because it has a parameter of type 'System.Threading.Channels.ChannelReader`1[System.Double]'.", e.Message); }); RunNativeAotTest(static async () => { - var e = await Assert.ThrowsAsync(() => InProcessTestServer>.StartServer(NullLoggerFactory.Instance)); - Assert.Contains("Unable to stream an item with type 'System.Double' on method 'Microsoft.AspNetCore.SignalR.Tests.NativeAotTests+ChannelDoubleMethodHub.StreamValueType' because it is a ValueType.", e.Message); + var e = await Assert.ThrowsAsync(() => InProcessTestServer>.StartServer(NullLoggerFactory.Instance)); + Assert.Contains("Method 'Microsoft.AspNetCore.SignalR.Tests.NativeAotTests+EnumerableValueTypeMethodHub.StreamValueType' is not supported with native AOT because it has a parameter of type 'System.Collections.Generic.IAsyncEnumerable`1[System.Single]'.", e.Message); }); } @@ -228,22 +268,79 @@ public async IAsyncEnumerable StreamEchoAsyncEnumerable(IAsyncEnumerable yield return "echo:" + item; } } - } - public class AsyncEnumerableIntMethodHub : TestHub - { - public async IAsyncEnumerable StreamValueType() + public async IAsyncEnumerable ReturnEnumerableValueType() { await Task.Yield(); yield return 1; await Task.Yield(); yield return 2; } + + public ChannelReader ReturnChannelValueType(string source) + { + Channel output = Channel.CreateUnbounded(); + + _ = Task.Run(async () => + { + foreach (var item in source) + { + await Task.Yield(); + await output.Writer.WriteAsync(item); + } + + output.Writer.TryComplete(); + }); + + return output.Reader; + } + + public async Task EnumerableIntParameter(IAsyncEnumerable source) + { + var result = new StringBuilder(); + var first = true; + // These get deserialized as JsonElement since the parameter is 'ChannelReader' + await foreach (JsonElement item in source) + { + if (first) + { + first = false; + } + else + { + result.Append(", "); + } + + result.Append(item.GetInt32()); + } + return result.ToString(); + } + + public async Task ChannelShortParameter(ChannelReader source) + { + var result = new StringBuilder(); + var first = true; + // These get deserialized as JsonElement since the parameter is 'ChannelReader' + await foreach (JsonElement item in source.ReadAllAsync()) + { + if (first) + { + first = false; + } + else + { + result.Append(", "); + } + + result.Append(item.GetInt16()); + } + return result.ToString(); + } } - public class ChannelDoubleMethodHub : TestHub + public class ChannelValueTypeMethodHub : TestHub { - public async Task StreamValueType(ILogger logger, ChannelReader source) + public async Task StreamValueType(ILogger logger, ChannelReader source) { await foreach (var item in source.ReadAllAsync()) { @@ -252,6 +349,17 @@ public async Task StreamValueType(ILogger logger, Channe } } + public class EnumerableValueTypeMethodHub : TestHub + { + public async Task StreamValueType(ILogger logger, IAsyncEnumerable source) + { + await foreach (var item in source) + { + logger.LogInformation("Received: {item}", item); + } + } + } + public class TaskDerivedType : Task { public TaskDerivedType() @@ -325,8 +433,11 @@ public void Dispose() } } + [JsonSerializable(typeof(object))] [JsonSerializable(typeof(string))] [JsonSerializable(typeof(int))] + [JsonSerializable(typeof(short))] + [JsonSerializable(typeof(char))] internal partial class AppJsonSerializerContext : JsonSerializerContext { public static void AddToJsonHubProtocol(IServiceCollection services) From 14da607e249e5ce294ae9a8e0934f594efca00e9 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Wed, 3 Jul 2024 18:46:24 -0500 Subject: [PATCH 2/3] Respond to PR feedback --- .../csharp/Client.Core/src/HubConnection.cs | 15 +++++++++------ .../common/Shared/AsyncEnumerableAdapters.cs | 10 +++++----- src/SignalR/common/Shared/ReflectionHelper.cs | 4 ++-- .../NativeAotTests.cs | 6 ++++-- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 35a4ade375bf..e6d343cd7465 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -891,8 +891,7 @@ async Task ReadAsyncEnumeratorStream(IAsyncEnumerator enumerator) { while (await enumerator.MoveNextAsync().ConfigureAwait(false)) { - await SendWithLock(connectionState, new StreamItemMessage(streamId, enumerator.Current), tokenSource.Token).ConfigureAwait(false); - Log.SendingStreamItem(_logger, streamId); + await SendStreamItemAsync(connectionState, streamId, enumerator.Current, tokenSource).ConfigureAwait(false); } } finally @@ -928,8 +927,7 @@ async Task ReadChannelStream() { while (!tokenSource.Token.IsCancellationRequested && reader.TryRead(out var item)) { - await SendWithLock(connectionState, new StreamItemMessage(streamId, item), tokenSource.Token).ConfigureAwait(false); - Log.SendingStreamItem(_logger, streamId); + await SendStreamItemAsync(connectionState, streamId, item, tokenSource).ConfigureAwait(false); } } } @@ -944,14 +942,19 @@ async Task ReadAsyncEnumerableStream() { await foreach (var streamValue in stream.WithCancellation(tokenSource.Token).ConfigureAwait(false)) { - await SendWithLock(connectionState, new StreamItemMessage(streamId, streamValue), tokenSource.Token).ConfigureAwait(false); - Log.SendingStreamItem(_logger, streamId); + await SendStreamItemAsync(connectionState, streamId, streamValue, tokenSource).ConfigureAwait(false); } } return CommonStreaming(connectionState, streamId, ReadAsyncEnumerableStream, tokenSource); } + private async Task SendStreamItemAsync(ConnectionState connectionState, string streamId, object? item, CancellationTokenSource tokenSource) + { + await SendWithLock(connectionState, new StreamItemMessage(streamId, item), tokenSource.Token).ConfigureAwait(false); + Log.SendingStreamItem(_logger, streamId); + } + private async Task CommonStreaming(ConnectionState connectionState, string streamId, Func createAndConsumeStream, CancellationTokenSource cts) { // make sure we dispose the CTS created by StreamAsyncCore once streaming completes diff --git a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs index 0644a81473be..6d42b25194c9 100644 --- a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs +++ b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs @@ -49,10 +49,10 @@ public ValueTask MoveNextAsync() return new ValueTask(true); } - return new ValueTask(MoveNextAsyncAwaited()); + return MoveNextAsyncAwaited(); } - private async Task MoveNextAsyncAwaited() + private async ValueTask MoveNextAsyncAwaited() { while (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false)) { @@ -184,7 +184,7 @@ public ReflectionAsyncEnumerator(object enumerator) { _enumerator = enumerator; - var type = ReflectionHelper.GetIAsyncEnumeratorInterface(enumerator.GetType())!; + var type = ReflectionHelper.GetIAsyncEnumeratorInterface(enumerator.GetType()); _moveNextAsyncMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorMoveNextAsyncMethodInfo)!; _getCurrentMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorGetCurrentMethodInfo)!; } @@ -227,10 +227,10 @@ public ValueTask MoveNextAsync() return new ValueTask(true); } - return new ValueTask(MoveNextAsyncAwaited()); + return MoveNextAsyncAwaited(); } - private async Task MoveNextAsyncAwaited() + private async ValueTask MoveNextAsyncAwaited() { while (await ((ValueTask)_waitToReadAsyncMethodInfo.Invoke(_channelReader, _waitToReadArgs)!).ConfigureAwait(false)) { diff --git a/src/SignalR/common/Shared/ReflectionHelper.cs b/src/SignalR/common/Shared/ReflectionHelper.cs index 64a6f5dcfb2a..5cdf12c4b49b 100644 --- a/src/SignalR/common/Shared/ReflectionHelper.cs +++ b/src/SignalR/common/Shared/ReflectionHelper.cs @@ -75,7 +75,7 @@ public static bool TryGetStreamType(Type streamType, [NotNullWhen(true)] out Typ "It also kept it on any type which implements it. The below call to GetInterfaces " + "may return fewer results when trimmed but it will return 'IAsyncEnumerator<>' " + "if the type implemented it, even after trimming.")] - public static Type? GetIAsyncEnumeratorInterface(Type type) + public static Type GetIAsyncEnumeratorInterface(Type type) { if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>)) { @@ -90,6 +90,6 @@ public static bool TryGetStreamType(Type streamType, [NotNullWhen(true)] out Typ } } - return null; + throw new InvalidOperationException($"Type '{type}' does not implement IAsyncEnumerator<>"); } } diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/NativeAotTests.cs b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/NativeAotTests.cs index ff3c1cb60ec4..4b8b97baa72d 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/NativeAotTests.cs +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/NativeAotTests.cs @@ -295,11 +295,12 @@ public ChannelReader ReturnChannelValueType(string source) return output.Reader; } + // using 'object' as the streaming parameter type because streaming ValueTypes is not supported on the server public async Task EnumerableIntParameter(IAsyncEnumerable source) { var result = new StringBuilder(); var first = true; - // These get deserialized as JsonElement since the parameter is 'ChannelReader' + // These get deserialized as JsonElement since the streaming parameter is 'object' await foreach (JsonElement item in source) { if (first) @@ -316,11 +317,12 @@ public async Task EnumerableIntParameter(IAsyncEnumerable source return result.ToString(); } + // using 'object' as the streaming parameter type because streaming ValueTypes is not supported on the server public async Task ChannelShortParameter(ChannelReader source) { var result = new StringBuilder(); var first = true; - // These get deserialized as JsonElement since the parameter is 'ChannelReader' + // These get deserialized as JsonElement since the streaming parameter is 'object' await foreach (JsonElement item in source.ReadAllAsync()) { if (first) From 2d2e147fa824cdb0a7b08fbcb73c3c500828d588 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Mon, 8 Jul 2024 11:32:57 -0500 Subject: [PATCH 3/3] Update exception message wording. --- src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index 56f0e8a4311c..d2d064c10378 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -299,7 +299,7 @@ private Type ValidateParameterStreamType(Type streamType, Type parameterType) // NativeAOT apps are not able to stream IAsyncEnumerable and ChannelReader of ValueTypes as parameters // since we cannot create a concrete IAsyncEnumerable and ChannelReader of ValueType to pass into the Hub method. var methodInfo = MethodExecutor.MethodInfo; - throw new InvalidOperationException($"Method '{methodInfo.DeclaringType}.{methodInfo.Name}' is not supported with native AOT because it has a parameter of type '{parameterType}'. Streaming parameters of ValueTypes is not supported because the native code to support the ValueType will not be available with native AOT."); + throw new InvalidOperationException($"Method '{methodInfo.DeclaringType}.{methodInfo.Name}' is not supported with native AOT because it has a parameter of type '{parameterType}'. A ValueType streaming parameter is not supported because the native code to support the ValueType will not be available with native AOT."); } return streamType;