Skip to content

Commit afc520c

Browse files
authored
Support IAsyncEnumerable<T> and ChannelReader<T> with ValueTypes in SignalR native AOT (#56583)
* Support IAsyncEnumerable<T> and ChannelReader<T> with ValueTypes in SignalR native AOT Support streaming ValueTypes from a SignalR Hub method in both the client and the server in native AOT. In order to make this work, we need to use pure reflection to read from the streaming object. Support passing in an IAsyncEnumerable/ChannelReader of ValueType to a parameter in SignalR.Client. This works because the user code creates the concrete object, and the SignalR.Client library just needs to read from it using reflection. The only scenario that can't be supported is on the SignalR server we can't support receiving an IAsyncEnumerable/ChannelReader of ValueType. This is because there is no way for the SignalR library code to construct a concrete instance to pass into the user-defined method on native AOT. Fix #56179
1 parent 7f9b45e commit afc520c

File tree

5 files changed

+362
-54
lines changed

5 files changed

+362
-54
lines changed

src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -857,26 +857,68 @@ private void LaunchStreams(ConnectionState connectionState, Dictionary<string, o
857857
[UnconditionalSuppressMessage("Trimming", "IL2060:MakeGenericMethod",
858858
Justification = "The methods passed into here (SendStreamItems and SendIAsyncEnumerableStreamItems) don't have trimming annotations.")]
859859
[UnconditionalSuppressMessage("AOT", "IL3050:RequiresDynamicCode",
860-
Justification = "There is a runtime check for ValueType streaming item type when PublishAot=true. Developers will get an exception in this situation before publishing.")]
860+
Justification = "ValueTypes are handled without using MakeGenericMethod.")]
861861
private void InvokeStreamMethod(MethodInfo methodInfo, Type[] genericTypes, ConnectionState connectionState, string streamId, object reader, CancellationTokenSource tokenSource)
862862
{
863-
#if NET
864863
Debug.Assert(genericTypes.Length == 1);
865-
864+
#if NET6_0_OR_GREATER
866865
if (!RuntimeFeature.IsDynamicCodeSupported && genericTypes[0].IsValueType)
867866
{
868-
// NativeAOT apps are not able to stream IAsyncEnumerable and ChannelReader of ValueTypes
869-
// since we cannot create SendStreamItems and SendIAsyncEnumerableStreamItems methods with a generic ValueType.
870-
throw new InvalidOperationException($"Unable to stream an item with type '{genericTypes[0]}' because it is a ValueType. Native code to support streaming this ValueType will not be available with native AOT.");
867+
_ = ReflectionSendStreamItems(methodInfo, connectionState, streamId, reader, tokenSource);
871868
}
869+
else
872870
#endif
871+
{
872+
_ = methodInfo
873+
.MakeGenericMethod(genericTypes)
874+
.Invoke(this, [connectionState, streamId, reader, tokenSource]);
875+
}
876+
}
877+
878+
#if NET6_0_OR_GREATER
879+
880+
/// <summary>
881+
/// Uses reflection to read items from an IAsyncEnumerable{T} or ChannelReader{T} and send them to the server.
882+
///
883+
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
884+
/// we cannot use MakeGenericMethod to call the appropriate SendStreamItems method because the generic type is a value type.
885+
/// </summary>
886+
private Task ReflectionSendStreamItems(MethodInfo methodInfo, ConnectionState connectionState, string streamId, object reader, CancellationTokenSource tokenSource)
887+
{
888+
async Task ReadAsyncEnumeratorStream(IAsyncEnumerator<object?> enumerator)
889+
{
890+
try
891+
{
892+
while (await enumerator.MoveNextAsync().ConfigureAwait(false))
893+
{
894+
await SendStreamItemAsync(connectionState, streamId, enumerator.Current, tokenSource).ConfigureAwait(false);
895+
}
896+
}
897+
finally
898+
{
899+
await enumerator.DisposeAsync().ConfigureAwait(false);
900+
}
901+
}
902+
903+
Func<Task> createAndConsumeStream;
904+
if (methodInfo == _sendStreamItemsMethod)
905+
{
906+
// reader is a ChannelReader<T>
907+
createAndConsumeStream = () => ReadAsyncEnumeratorStream(AsyncEnumerableAdapters.MakeReflectionAsyncEnumeratorFromChannel(reader, tokenSource.Token));
908+
}
909+
else
910+
{
911+
// reader is an IAsyncEnumerable<T>
912+
Debug.Assert(methodInfo == _sendIAsyncStreamItemsMethod);
913+
914+
createAndConsumeStream = () => ReadAsyncEnumeratorStream(AsyncEnumerableAdapters.MakeReflectionAsyncEnumerator(reader, tokenSource.Token));
915+
}
873916

874-
_ = methodInfo
875-
.MakeGenericMethod(genericTypes)
876-
.Invoke(this, [connectionState, streamId, reader, tokenSource]);
917+
return CommonStreaming(connectionState, streamId, createAndConsumeStream, tokenSource);
877918
}
919+
#endif
878920

879-
// this is called via reflection using the `_sendStreamItems` field
921+
// this is called via reflection using the `_sendStreamItemsMethod` field
880922
private Task SendStreamItems<T>(ConnectionState connectionState, string streamId, ChannelReader<T> reader, CancellationTokenSource tokenSource)
881923
{
882924
async Task ReadChannelStream()
@@ -885,8 +927,7 @@ async Task ReadChannelStream()
885927
{
886928
while (!tokenSource.Token.IsCancellationRequested && reader.TryRead(out var item))
887929
{
888-
await SendWithLock(connectionState, new StreamItemMessage(streamId, item), tokenSource.Token).ConfigureAwait(false);
889-
Log.SendingStreamItem(_logger, streamId);
930+
await SendStreamItemAsync(connectionState, streamId, item, tokenSource).ConfigureAwait(false);
890931
}
891932
}
892933
}
@@ -901,14 +942,19 @@ async Task ReadAsyncEnumerableStream()
901942
{
902943
await foreach (var streamValue in stream.WithCancellation(tokenSource.Token).ConfigureAwait(false))
903944
{
904-
await SendWithLock(connectionState, new StreamItemMessage(streamId, streamValue), tokenSource.Token).ConfigureAwait(false);
905-
Log.SendingStreamItem(_logger, streamId);
945+
await SendStreamItemAsync(connectionState, streamId, streamValue, tokenSource).ConfigureAwait(false);
906946
}
907947
}
908948

909949
return CommonStreaming(connectionState, streamId, ReadAsyncEnumerableStream, tokenSource);
910950
}
911951

952+
private async Task SendStreamItemAsync(ConnectionState connectionState, string streamId, object? item, CancellationTokenSource tokenSource)
953+
{
954+
await SendWithLock(connectionState, new StreamItemMessage(streamId, item), tokenSource.Token).ConfigureAwait(false);
955+
Log.SendingStreamItem(_logger, streamId);
956+
}
957+
912958
private async Task CommonStreaming(ConnectionState connectionState, string streamId, Func<Task> createAndConsumeStream, CancellationTokenSource cts)
913959
{
914960
// make sure we dispose the CTS created by StreamAsyncCore once streaming completes

src/SignalR/common/Shared/AsyncEnumerableAdapters.cs

Lines changed: 116 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System;
45
using System.Collections.Generic;
6+
using System.Reflection;
57
using System.Threading;
68
using System.Threading.Channels;
79
using System.Threading.Tasks;
@@ -11,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal;
1113
// True-internal because this is a weird and tricky class to use :)
1214
internal static class AsyncEnumerableAdapters
1315
{
14-
public static IAsyncEnumerator<object?> MakeCancelableAsyncEnumerator<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
16+
public static IAsyncEnumerator<object?> MakeAsyncEnumerator<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
1517
{
1618
var enumerator = asyncEnumerable.GetAsyncEnumerator(cancellationToken);
1719
return enumerator as IAsyncEnumerator<object?> ?? new BoxedAsyncEnumerator<T>(enumerator);
@@ -47,15 +49,18 @@ public ValueTask<bool> MoveNextAsync()
4749
return new ValueTask<bool>(true);
4850
}
4951

50-
return new ValueTask<bool>(MoveNextAsyncAwaited());
52+
return MoveNextAsyncAwaited();
5153
}
5254

53-
private async Task<bool> MoveNextAsyncAwaited()
55+
private async ValueTask<bool> MoveNextAsyncAwaited()
5456
{
55-
if (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false) && _channel.TryRead(out var item))
57+
while (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false))
5658
{
57-
Current = item;
58-
return true;
59+
if (_channel.TryRead(out var item))
60+
{
61+
Current = item;
62+
return true;
63+
}
5964
}
6065
return false;
6166
}
@@ -137,4 +142,109 @@ public ValueTask DisposeAsync()
137142
return _asyncEnumerator.DisposeAsync();
138143
}
139144
}
145+
146+
#if NET6_0_OR_GREATER
147+
148+
private static readonly MethodInfo _asyncEnumerableGetAsyncEnumeratorMethodInfo = typeof(IAsyncEnumerable<>).GetMethod("GetAsyncEnumerator")!;
149+
150+
/// <summary>
151+
/// Creates an IAsyncEnumerator{object} from an IAsyncEnumerable{T} using reflection.
152+
///
153+
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
154+
/// we cannot use MakeGenericMethod to call a generic method because the generic type is a value type.
155+
/// </summary>
156+
public static IAsyncEnumerator<object?> MakeReflectionAsyncEnumerator(object asyncEnumerable, CancellationToken cancellationToken)
157+
{
158+
var constructedIAsyncEnumerableInterface = ReflectionHelper.GetIAsyncEnumerableInterface(asyncEnumerable.GetType())!;
159+
var enumerator = ((MethodInfo)constructedIAsyncEnumerableInterface.GetMemberWithSameMetadataDefinitionAs(_asyncEnumerableGetAsyncEnumeratorMethodInfo)).Invoke(asyncEnumerable, [cancellationToken])!;
160+
return new ReflectionAsyncEnumerator(enumerator);
161+
}
162+
163+
/// <summary>
164+
/// Creates an IAsyncEnumerator{object} from a ChannelReader{T} using reflection.
165+
///
166+
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
167+
/// we cannot use MakeGenericMethod to call a generic method because the generic type is a value type.
168+
/// </summary>
169+
public static IAsyncEnumerator<object?> MakeReflectionAsyncEnumeratorFromChannel(object channelReader, CancellationToken cancellationToken)
170+
{
171+
return new ReflectionChannelAsyncEnumerator(channelReader, cancellationToken);
172+
}
173+
174+
private sealed class ReflectionAsyncEnumerator : IAsyncEnumerator<object?>
175+
{
176+
private static readonly MethodInfo _asyncEnumeratorMoveNextAsyncMethodInfo = typeof(IAsyncEnumerator<>).GetMethod("MoveNextAsync")!;
177+
private static readonly MethodInfo _asyncEnumeratorGetCurrentMethodInfo = typeof(IAsyncEnumerator<>).GetMethod("get_Current")!;
178+
179+
private readonly object _enumerator;
180+
private readonly MethodInfo _moveNextAsyncMethodInfo;
181+
private readonly MethodInfo _getCurrentMethodInfo;
182+
183+
public ReflectionAsyncEnumerator(object enumerator)
184+
{
185+
_enumerator = enumerator;
186+
187+
var type = ReflectionHelper.GetIAsyncEnumeratorInterface(enumerator.GetType());
188+
_moveNextAsyncMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorMoveNextAsyncMethodInfo)!;
189+
_getCurrentMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorGetCurrentMethodInfo)!;
190+
}
191+
192+
public object? Current => _getCurrentMethodInfo.Invoke(_enumerator, []);
193+
194+
public ValueTask<bool> MoveNextAsync() => (ValueTask<bool>)_moveNextAsyncMethodInfo.Invoke(_enumerator, [])!;
195+
196+
public ValueTask DisposeAsync() => ((IAsyncDisposable)_enumerator).DisposeAsync();
197+
}
198+
199+
private sealed class ReflectionChannelAsyncEnumerator : IAsyncEnumerator<object?>
200+
{
201+
private static readonly MethodInfo _channelReaderTryReadMethodInfo = typeof(ChannelReader<>).GetMethod("TryRead")!;
202+
private static readonly MethodInfo _channelReaderWaitToReadAsyncMethodInfo = typeof(ChannelReader<>).GetMethod("WaitToReadAsync")!;
203+
204+
private readonly object _channelReader;
205+
private readonly object?[] _tryReadResult = [null];
206+
private readonly object[] _waitToReadArgs;
207+
private readonly MethodInfo _tryReadMethodInfo;
208+
private readonly MethodInfo _waitToReadAsyncMethodInfo;
209+
210+
public ReflectionChannelAsyncEnumerator(object channelReader, CancellationToken cancellationToken)
211+
{
212+
_channelReader = channelReader;
213+
_waitToReadArgs = [cancellationToken];
214+
215+
var type = channelReader.GetType();
216+
_tryReadMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_channelReaderTryReadMethodInfo)!;
217+
_waitToReadAsyncMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_channelReaderWaitToReadAsyncMethodInfo)!;
218+
}
219+
220+
public object? Current { get; private set; }
221+
222+
public ValueTask<bool> MoveNextAsync()
223+
{
224+
if ((bool)_tryReadMethodInfo.Invoke(_channelReader, _tryReadResult)!)
225+
{
226+
Current = _tryReadResult[0];
227+
return new ValueTask<bool>(true);
228+
}
229+
230+
return MoveNextAsyncAwaited();
231+
}
232+
233+
private async ValueTask<bool> MoveNextAsyncAwaited()
234+
{
235+
while (await ((ValueTask<bool>)_waitToReadAsyncMethodInfo.Invoke(_channelReader, _waitToReadArgs)!).ConfigureAwait(false))
236+
{
237+
if ((bool)_tryReadMethodInfo.Invoke(_channelReader, _tryReadResult)!)
238+
{
239+
Current = _tryReadResult[0];
240+
return true;
241+
}
242+
}
243+
return false;
244+
}
245+
246+
public ValueTask DisposeAsync() => default;
247+
}
248+
249+
#endif
140250
}

src/SignalR/common/Shared/ReflectionHelper.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,27 @@ public static bool TryGetStreamType(Type streamType, [NotNullWhen(true)] out Typ
6969

7070
return null;
7171
}
72+
73+
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern",
74+
Justification = "The 'IAsyncEnumerator<>' Type must exist and so trimmer kept it. In which case " +
75+
"It also kept it on any type which implements it. The below call to GetInterfaces " +
76+
"may return fewer results when trimmed but it will return 'IAsyncEnumerator<>' " +
77+
"if the type implemented it, even after trimming.")]
78+
public static Type GetIAsyncEnumeratorInterface(Type type)
79+
{
80+
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>))
81+
{
82+
return type;
83+
}
84+
85+
foreach (Type typeToCheck in type.GetInterfaces())
86+
{
87+
if (typeToCheck.IsGenericType && typeToCheck.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>))
88+
{
89+
return typeToCheck;
90+
}
91+
}
92+
93+
throw new InvalidOperationException($"Type '{type}' does not implement IAsyncEnumerator<>");
94+
}
7295
}

0 commit comments

Comments
 (0)