Skip to content

Commit ae37e6f

Browse files
committed
Support IAsyncEnumerable<T> and ChannelReader<T> with ValueTypes in SignalR native AOT
Support streaming ValueTypes from a SignalR Hub method in both the client and the server in native AOT. In order to make this work, we need to use pure reflection to read from the streaming object. Support passing in an IAsyncEnumerable/ChannelReader of ValueType to a parameter in SignalR.Client. This works because the user code creates the concrete object, and the SignalR.Client library just needs to read from it using reflection. The only scenario that can't be supported is on the SignalR server we can't support receiving an IAsyncEnumerable/ChannelReader of ValueType. This is because there is no way for the SignalR library code to construct a concrete instance to pass into the user-defined method on native AOT. Fix dotnet#56179
1 parent 95a6473 commit ae37e6f

File tree

5 files changed

+352
-51
lines changed

5 files changed

+352
-51
lines changed

src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -857,26 +857,69 @@ private void LaunchStreams(ConnectionState connectionState, Dictionary<string, o
857857
[UnconditionalSuppressMessage("Trimming", "IL2060:MakeGenericMethod",
858858
Justification = "The methods passed into here (SendStreamItems and SendIAsyncEnumerableStreamItems) don't have trimming annotations.")]
859859
[UnconditionalSuppressMessage("AOT", "IL3050:RequiresDynamicCode",
860-
Justification = "There is a runtime check for ValueType streaming item type when PublishAot=true. Developers will get an exception in this situation before publishing.")]
860+
Justification = "ValueTypes are handled without using MakeGenericMethod.")]
861861
private void InvokeStreamMethod(MethodInfo methodInfo, Type[] genericTypes, ConnectionState connectionState, string streamId, object reader, CancellationTokenSource tokenSource)
862862
{
863-
#if NET
864863
Debug.Assert(genericTypes.Length == 1);
865-
864+
#if NET6_0_OR_GREATER
866865
if (!RuntimeFeature.IsDynamicCodeSupported && genericTypes[0].IsValueType)
867866
{
868-
// NativeAOT apps are not able to stream IAsyncEnumerable and ChannelReader of ValueTypes
869-
// since we cannot create SendStreamItems and SendIAsyncEnumerableStreamItems methods with a generic ValueType.
870-
throw new InvalidOperationException($"Unable to stream an item with type '{genericTypes[0]}' because it is a ValueType. Native code to support streaming this ValueType will not be available with native AOT.");
867+
_ = ReflectionSendStreamItems(methodInfo, connectionState, streamId, reader, tokenSource);
871868
}
869+
else
872870
#endif
871+
{
872+
_ = methodInfo
873+
.MakeGenericMethod(genericTypes)
874+
.Invoke(this, [connectionState, streamId, reader, tokenSource]);
875+
}
876+
}
877+
878+
#if NET6_0_OR_GREATER
873879

874-
_ = methodInfo
875-
.MakeGenericMethod(genericTypes)
876-
.Invoke(this, [connectionState, streamId, reader, tokenSource]);
880+
/// <summary>
881+
/// Uses reflection to read items from an IAsyncEnumerable{T} or ChannelReader{T} and send them to the server.
882+
///
883+
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
884+
/// we cannot use MakeGenericMethod to call the appropriate SendStreamItems method because the generic type is a value type.
885+
/// </summary>
886+
private Task ReflectionSendStreamItems(MethodInfo methodInfo, ConnectionState connectionState, string streamId, object reader, CancellationTokenSource tokenSource)
887+
{
888+
async Task ReadAsyncEnumeratorStream(IAsyncEnumerator<object?> enumerator)
889+
{
890+
try
891+
{
892+
while (await enumerator.MoveNextAsync().ConfigureAwait(false))
893+
{
894+
await SendWithLock(connectionState, new StreamItemMessage(streamId, enumerator.Current), tokenSource.Token).ConfigureAwait(false);
895+
Log.SendingStreamItem(_logger, streamId);
896+
}
897+
}
898+
finally
899+
{
900+
await enumerator.DisposeAsync().ConfigureAwait(false);
901+
}
902+
}
903+
904+
Func<Task> createAndConsumeStream;
905+
if (methodInfo == _sendStreamItemsMethod)
906+
{
907+
// reader is a ChannelReader<T>
908+
createAndConsumeStream = () => ReadAsyncEnumeratorStream(AsyncEnumerableAdapters.MakeReflectionAsyncEnumeratorFromChannel(reader, tokenSource.Token));
909+
}
910+
else
911+
{
912+
// reader is an IAsyncEnumerable<T>
913+
Debug.Assert(methodInfo == _sendIAsyncStreamItemsMethod);
914+
915+
createAndConsumeStream = () => ReadAsyncEnumeratorStream(AsyncEnumerableAdapters.MakeReflectionAsyncEnumerator(reader, tokenSource.Token));
916+
}
917+
918+
return CommonStreaming(connectionState, streamId, createAndConsumeStream, tokenSource);
877919
}
920+
#endif
878921

879-
// this is called via reflection using the `_sendStreamItems` field
922+
// this is called via reflection using the `_sendStreamItemsMethod` field
880923
private Task SendStreamItems<T>(ConnectionState connectionState, string streamId, ChannelReader<T> reader, CancellationTokenSource tokenSource)
881924
{
882925
async Task ReadChannelStream()
@@ -899,9 +942,7 @@ private Task SendIAsyncEnumerableStreamItems<T>(ConnectionState connectionState,
899942
{
900943
async Task ReadAsyncEnumerableStream()
901944
{
902-
var streamValues = AsyncEnumerableAdapters.MakeCancelableTypedAsyncEnumerable(stream, tokenSource);
903-
904-
await foreach (var streamValue in streamValues.ConfigureAwait(false))
945+
await foreach (var streamValue in stream.WithCancellation(tokenSource.Token).ConfigureAwait(false))
905946
{
906947
await SendWithLock(connectionState, new StreamItemMessage(streamId, streamValue), tokenSource.Token).ConfigureAwait(false);
907948
Log.SendingStreamItem(_logger, streamId);

src/SignalR/common/Shared/AsyncEnumerableAdapters.cs

Lines changed: 114 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System;
45
using System.Collections.Generic;
6+
using System.Reflection;
57
using System.Threading;
68
using System.Threading.Channels;
79
using System.Threading.Tasks;
@@ -11,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal;
1113
// True-internal because this is a weird and tricky class to use :)
1214
internal static class AsyncEnumerableAdapters
1315
{
14-
public static IAsyncEnumerator<object?> MakeCancelableAsyncEnumerator<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
16+
public static IAsyncEnumerator<object?> MakeAsyncEnumerator<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
1517
{
1618
var enumerator = asyncEnumerable.GetAsyncEnumerator(cancellationToken);
1719
return enumerator as IAsyncEnumerator<object?> ?? new BoxedAsyncEnumerator<T>(enumerator);
@@ -52,10 +54,13 @@ public ValueTask<bool> MoveNextAsync()
5254

5355
private async Task<bool> MoveNextAsyncAwaited()
5456
{
55-
if (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false) && _channel.TryRead(out var item))
57+
while (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false))
5658
{
57-
Current = item;
58-
return true;
59+
if (_channel.TryRead(out var item))
60+
{
61+
Current = item;
62+
return true;
63+
}
5964
}
6065
return false;
6166
}
@@ -137,4 +142,109 @@ public ValueTask DisposeAsync()
137142
return _asyncEnumerator.DisposeAsync();
138143
}
139144
}
145+
146+
#if NET6_0_OR_GREATER
147+
148+
private static readonly MethodInfo _asyncEnumerableGetAsyncEnumeratorMethodInfo = typeof(IAsyncEnumerable<>).GetMethod("GetAsyncEnumerator")!;
149+
150+
/// <summary>
151+
/// Creates an IAsyncEnumerator{object} from an IAsyncEnumerable{T} using reflection.
152+
///
153+
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
154+
/// we cannot use MakeGenericMethod to call a generic method because the generic type is a value type.
155+
/// </summary>
156+
public static IAsyncEnumerator<object?> MakeReflectionAsyncEnumerator(object asyncEnumerable, CancellationToken cancellationToken)
157+
{
158+
var constructedIAsyncEnumerableInterface = ReflectionHelper.GetIAsyncEnumerableInterface(asyncEnumerable.GetType())!;
159+
var enumerator = ((MethodInfo)constructedIAsyncEnumerableInterface.GetMemberWithSameMetadataDefinitionAs(_asyncEnumerableGetAsyncEnumeratorMethodInfo)).Invoke(asyncEnumerable, [cancellationToken])!;
160+
return new ReflectionAsyncEnumerator(enumerator);
161+
}
162+
163+
/// <summary>
164+
/// Creates an IAsyncEnumerator{object} from a ChannelReader{T} using reflection.
165+
///
166+
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
167+
/// we cannot use MakeGenericMethod to call a generic method because the generic type is a value type.
168+
/// </summary>
169+
public static IAsyncEnumerator<object?> MakeReflectionAsyncEnumeratorFromChannel(object channelReader, CancellationToken cancellationToken)
170+
{
171+
return new ReflectionChannelAsyncEnumerator(channelReader, cancellationToken);
172+
}
173+
174+
private sealed class ReflectionAsyncEnumerator : IAsyncEnumerator<object?>
175+
{
176+
private static readonly MethodInfo _asyncEnumeratorMoveNextAsyncMethodInfo = typeof(IAsyncEnumerator<>).GetMethod("MoveNextAsync")!;
177+
private static readonly MethodInfo _asyncEnumeratorGetCurrentMethodInfo = typeof(IAsyncEnumerator<>).GetMethod("get_Current")!;
178+
179+
private readonly object _enumerator;
180+
private readonly MethodInfo _moveNextAsyncMethodInfo;
181+
private readonly MethodInfo _getCurrentMethodInfo;
182+
183+
public ReflectionAsyncEnumerator(object enumerator)
184+
{
185+
_enumerator = enumerator;
186+
187+
var type = ReflectionHelper.GetIAsyncEnumeratorInterface(enumerator.GetType())!;
188+
_moveNextAsyncMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorMoveNextAsyncMethodInfo)!;
189+
_getCurrentMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorGetCurrentMethodInfo)!;
190+
}
191+
192+
public object? Current => _getCurrentMethodInfo.Invoke(_enumerator, []);
193+
194+
public ValueTask<bool> MoveNextAsync() => (ValueTask<bool>)_moveNextAsyncMethodInfo.Invoke(_enumerator, [])!;
195+
196+
public ValueTask DisposeAsync() => ((IAsyncDisposable)_enumerator).DisposeAsync();
197+
}
198+
199+
private sealed class ReflectionChannelAsyncEnumerator : IAsyncEnumerator<object?>
200+
{
201+
private static readonly MethodInfo _channelReaderTryReadMethodInfo = typeof(ChannelReader<>).GetMethod("TryRead")!;
202+
private static readonly MethodInfo _channelReaderWaitToReadAsyncMethodInfo = typeof(ChannelReader<>).GetMethod("WaitToReadAsync")!;
203+
204+
private readonly object _channelReader;
205+
private readonly object?[] _tryReadResult = [null];
206+
private readonly object[] _waitToReadArgs;
207+
private readonly MethodInfo _tryReadMethodInfo;
208+
private readonly MethodInfo _waitToReadAsyncMethodInfo;
209+
210+
public ReflectionChannelAsyncEnumerator(object channelReader, CancellationToken cancellationToken)
211+
{
212+
_channelReader = channelReader;
213+
_waitToReadArgs = [cancellationToken];
214+
215+
var type = channelReader.GetType();
216+
_tryReadMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_channelReaderTryReadMethodInfo)!;
217+
_waitToReadAsyncMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_channelReaderWaitToReadAsyncMethodInfo)!;
218+
}
219+
220+
public object? Current { get; private set; }
221+
222+
public ValueTask<bool> MoveNextAsync()
223+
{
224+
if ((bool)_tryReadMethodInfo.Invoke(_channelReader, _tryReadResult)!)
225+
{
226+
Current = _tryReadResult[0];
227+
return new ValueTask<bool>(true);
228+
}
229+
230+
return new ValueTask<bool>(MoveNextAsyncAwaited());
231+
}
232+
233+
private async Task<bool> MoveNextAsyncAwaited()
234+
{
235+
while (await ((ValueTask<bool>)_waitToReadAsyncMethodInfo.Invoke(_channelReader, _waitToReadArgs)!).ConfigureAwait(false))
236+
{
237+
if ((bool)_tryReadMethodInfo.Invoke(_channelReader, _tryReadResult)!)
238+
{
239+
Current = _tryReadResult[0];
240+
return true;
241+
}
242+
}
243+
return false;
244+
}
245+
246+
public ValueTask DisposeAsync() => default;
247+
}
248+
249+
#endif
140250
}

src/SignalR/common/Shared/ReflectionHelper.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,27 @@ public static bool TryGetStreamType(Type streamType, [NotNullWhen(true)] out Typ
6969

7070
return null;
7171
}
72+
73+
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern",
74+
Justification = "The 'IAsyncEnumerator<>' Type must exist and so trimmer kept it. In which case " +
75+
"It also kept it on any type which implements it. The below call to GetInterfaces " +
76+
"may return fewer results when trimmed but it will return 'IAsyncEnumerator<>' " +
77+
"if the type implemented it, even after trimming.")]
78+
public static Type? GetIAsyncEnumeratorInterface(Type type)
79+
{
80+
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>))
81+
{
82+
return type;
83+
}
84+
85+
foreach (Type typeToCheck in type.GetInterfaces())
86+
{
87+
if (typeToCheck.IsGenericType && typeToCheck.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>))
88+
{
89+
return typeToCheck;
90+
}
91+
}
92+
93+
return null;
94+
}
7295
}

0 commit comments

Comments
 (0)