Skip to content

Support IAsyncEnumerable<T> and ChannelReader<T> with ValueTypes in SignalR native AOT #56583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 60 additions & 14 deletions src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -857,26 +857,68 @@ private void LaunchStreams(ConnectionState connectionState, Dictionary<string, o
[UnconditionalSuppressMessage("Trimming", "IL2060:MakeGenericMethod",
Justification = "The methods passed into here (SendStreamItems and SendIAsyncEnumerableStreamItems) don't have trimming annotations.")]
[UnconditionalSuppressMessage("AOT", "IL3050:RequiresDynamicCode",
Justification = "There is a runtime check for ValueType streaming item type when PublishAot=true. Developers will get an exception in this situation before publishing.")]
Justification = "ValueTypes are handled without using MakeGenericMethod.")]
private void InvokeStreamMethod(MethodInfo methodInfo, Type[] genericTypes, ConnectionState connectionState, string streamId, object reader, CancellationTokenSource tokenSource)
{
#if NET
Debug.Assert(genericTypes.Length == 1);

#if NET6_0_OR_GREATER
if (!RuntimeFeature.IsDynamicCodeSupported && genericTypes[0].IsValueType)
{
// NativeAOT apps are not able to stream IAsyncEnumerable and ChannelReader of ValueTypes
// since we cannot create SendStreamItems and SendIAsyncEnumerableStreamItems methods with a generic ValueType.
throw new InvalidOperationException($"Unable to stream an item with type '{genericTypes[0]}' because it is a ValueType. Native code to support streaming this ValueType will not be available with native AOT.");
_ = ReflectionSendStreamItems(methodInfo, connectionState, streamId, reader, tokenSource);
}
else
#endif
{
_ = methodInfo
.MakeGenericMethod(genericTypes)
.Invoke(this, [connectionState, streamId, reader, tokenSource]);
}
}

#if NET6_0_OR_GREATER

/// <summary>
/// Uses reflection to read items from an IAsyncEnumerable{T} or ChannelReader{T} and send them to the server.
///
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
/// we cannot use MakeGenericMethod to call the appropriate SendStreamItems method because the generic type is a value type.
/// </summary>
private Task ReflectionSendStreamItems(MethodInfo methodInfo, ConnectionState connectionState, string streamId, object reader, CancellationTokenSource tokenSource)
{
async Task ReadAsyncEnumeratorStream(IAsyncEnumerator<object?> enumerator)
{
try
{
while (await enumerator.MoveNextAsync().ConfigureAwait(false))
{
await SendStreamItemAsync(connectionState, streamId, enumerator.Current, tokenSource).ConfigureAwait(false);
}
}
finally
{
await enumerator.DisposeAsync().ConfigureAwait(false);
}
}

Func<Task> createAndConsumeStream;
if (methodInfo == _sendStreamItemsMethod)
{
// reader is a ChannelReader<T>
createAndConsumeStream = () => ReadAsyncEnumeratorStream(AsyncEnumerableAdapters.MakeReflectionAsyncEnumeratorFromChannel(reader, tokenSource.Token));
}
else
{
// reader is an IAsyncEnumerable<T>
Debug.Assert(methodInfo == _sendIAsyncStreamItemsMethod);

createAndConsumeStream = () => ReadAsyncEnumeratorStream(AsyncEnumerableAdapters.MakeReflectionAsyncEnumerator(reader, tokenSource.Token));
}

_ = methodInfo
.MakeGenericMethod(genericTypes)
.Invoke(this, [connectionState, streamId, reader, tokenSource]);
return CommonStreaming(connectionState, streamId, createAndConsumeStream, tokenSource);
}
#endif

// this is called via reflection using the `_sendStreamItems` field
// this is called via reflection using the `_sendStreamItemsMethod` field
private Task SendStreamItems<T>(ConnectionState connectionState, string streamId, ChannelReader<T> reader, CancellationTokenSource tokenSource)
{
async Task ReadChannelStream()
Expand All @@ -885,8 +927,7 @@ async Task ReadChannelStream()
{
while (!tokenSource.Token.IsCancellationRequested && reader.TryRead(out var item))
{
await SendWithLock(connectionState, new StreamItemMessage(streamId, item), tokenSource.Token).ConfigureAwait(false);
Log.SendingStreamItem(_logger, streamId);
await SendStreamItemAsync(connectionState, streamId, item, tokenSource).ConfigureAwait(false);
}
}
}
Expand All @@ -901,14 +942,19 @@ async Task ReadAsyncEnumerableStream()
{
await foreach (var streamValue in stream.WithCancellation(tokenSource.Token).ConfigureAwait(false))
{
await SendWithLock(connectionState, new StreamItemMessage(streamId, streamValue), tokenSource.Token).ConfigureAwait(false);
Log.SendingStreamItem(_logger, streamId);
await SendStreamItemAsync(connectionState, streamId, streamValue, tokenSource).ConfigureAwait(false);
}
}

return CommonStreaming(connectionState, streamId, ReadAsyncEnumerableStream, tokenSource);
}

private async Task SendStreamItemAsync(ConnectionState connectionState, string streamId, object? item, CancellationTokenSource tokenSource)
{
await SendWithLock(connectionState, new StreamItemMessage(streamId, item), tokenSource.Token).ConfigureAwait(false);
Log.SendingStreamItem(_logger, streamId);
}

private async Task CommonStreaming(ConnectionState connectionState, string streamId, Func<Task> createAndConsumeStream, CancellationTokenSource cts)
{
// make sure we dispose the CTS created by StreamAsyncCore once streaming completes
Expand Down
122 changes: 116 additions & 6 deletions src/SignalR/common/Shared/AsyncEnumerableAdapters.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Reflection;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
Expand All @@ -11,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal;
// True-internal because this is a weird and tricky class to use :)
internal static class AsyncEnumerableAdapters
{
public static IAsyncEnumerator<object?> MakeCancelableAsyncEnumerator<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
public static IAsyncEnumerator<object?> MakeAsyncEnumerator<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
{
var enumerator = asyncEnumerable.GetAsyncEnumerator(cancellationToken);
return enumerator as IAsyncEnumerator<object?> ?? new BoxedAsyncEnumerator<T>(enumerator);
Expand Down Expand Up @@ -47,15 +49,18 @@ public ValueTask<bool> MoveNextAsync()
return new ValueTask<bool>(true);
}

return new ValueTask<bool>(MoveNextAsyncAwaited());
return MoveNextAsyncAwaited();
}

private async Task<bool> MoveNextAsyncAwaited()
private async ValueTask<bool> MoveNextAsyncAwaited()
{
if (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false) && _channel.TryRead(out var item))
while (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false))
{
Current = item;
return true;
if (_channel.TryRead(out var item))
{
Current = item;
return true;
}
}
return false;
}
Expand Down Expand Up @@ -137,4 +142,109 @@ public ValueTask DisposeAsync()
return _asyncEnumerator.DisposeAsync();
}
}

#if NET6_0_OR_GREATER

private static readonly MethodInfo _asyncEnumerableGetAsyncEnumeratorMethodInfo = typeof(IAsyncEnumerable<>).GetMethod("GetAsyncEnumerator")!;

/// <summary>
/// Creates an IAsyncEnumerator{object} from an IAsyncEnumerable{T} using reflection.
///
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
/// we cannot use MakeGenericMethod to call a generic method because the generic type is a value type.
/// </summary>
public static IAsyncEnumerator<object?> MakeReflectionAsyncEnumerator(object asyncEnumerable, CancellationToken cancellationToken)
{
var constructedIAsyncEnumerableInterface = ReflectionHelper.GetIAsyncEnumerableInterface(asyncEnumerable.GetType())!;
var enumerator = ((MethodInfo)constructedIAsyncEnumerableInterface.GetMemberWithSameMetadataDefinitionAs(_asyncEnumerableGetAsyncEnumeratorMethodInfo)).Invoke(asyncEnumerable, [cancellationToken])!;
return new ReflectionAsyncEnumerator(enumerator);
}

/// <summary>
/// Creates an IAsyncEnumerator{object} from a ChannelReader{T} using reflection.
///
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
/// we cannot use MakeGenericMethod to call a generic method because the generic type is a value type.
/// </summary>
public static IAsyncEnumerator<object?> MakeReflectionAsyncEnumeratorFromChannel(object channelReader, CancellationToken cancellationToken)
{
return new ReflectionChannelAsyncEnumerator(channelReader, cancellationToken);
}

private sealed class ReflectionAsyncEnumerator : IAsyncEnumerator<object?>
{
private static readonly MethodInfo _asyncEnumeratorMoveNextAsyncMethodInfo = typeof(IAsyncEnumerator<>).GetMethod("MoveNextAsync")!;
private static readonly MethodInfo _asyncEnumeratorGetCurrentMethodInfo = typeof(IAsyncEnumerator<>).GetMethod("get_Current")!;

private readonly object _enumerator;
private readonly MethodInfo _moveNextAsyncMethodInfo;
private readonly MethodInfo _getCurrentMethodInfo;

public ReflectionAsyncEnumerator(object enumerator)
{
_enumerator = enumerator;

var type = ReflectionHelper.GetIAsyncEnumeratorInterface(enumerator.GetType());
_moveNextAsyncMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorMoveNextAsyncMethodInfo)!;
_getCurrentMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorGetCurrentMethodInfo)!;
}

public object? Current => _getCurrentMethodInfo.Invoke(_enumerator, []);

public ValueTask<bool> MoveNextAsync() => (ValueTask<bool>)_moveNextAsyncMethodInfo.Invoke(_enumerator, [])!;

public ValueTask DisposeAsync() => ((IAsyncDisposable)_enumerator).DisposeAsync();
}

private sealed class ReflectionChannelAsyncEnumerator : IAsyncEnumerator<object?>
{
private static readonly MethodInfo _channelReaderTryReadMethodInfo = typeof(ChannelReader<>).GetMethod("TryRead")!;
private static readonly MethodInfo _channelReaderWaitToReadAsyncMethodInfo = typeof(ChannelReader<>).GetMethod("WaitToReadAsync")!;

private readonly object _channelReader;
private readonly object?[] _tryReadResult = [null];
private readonly object[] _waitToReadArgs;
private readonly MethodInfo _tryReadMethodInfo;
private readonly MethodInfo _waitToReadAsyncMethodInfo;

public ReflectionChannelAsyncEnumerator(object channelReader, CancellationToken cancellationToken)
{
_channelReader = channelReader;
_waitToReadArgs = [cancellationToken];

var type = channelReader.GetType();
_tryReadMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_channelReaderTryReadMethodInfo)!;
_waitToReadAsyncMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_channelReaderWaitToReadAsyncMethodInfo)!;
}

public object? Current { get; private set; }

public ValueTask<bool> MoveNextAsync()
{
if ((bool)_tryReadMethodInfo.Invoke(_channelReader, _tryReadResult)!)
{
Current = _tryReadResult[0];
return new ValueTask<bool>(true);
}

return MoveNextAsyncAwaited();
}

private async ValueTask<bool> MoveNextAsyncAwaited()
{
while (await ((ValueTask<bool>)_waitToReadAsyncMethodInfo.Invoke(_channelReader, _waitToReadArgs)!).ConfigureAwait(false))
{
if ((bool)_tryReadMethodInfo.Invoke(_channelReader, _tryReadResult)!)
{
Current = _tryReadResult[0];
return true;
}
}
return false;
}

public ValueTask DisposeAsync() => default;
}

#endif
}
23 changes: 23 additions & 0 deletions src/SignalR/common/Shared/ReflectionHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,27 @@ public static bool TryGetStreamType(Type streamType, [NotNullWhen(true)] out Typ

return null;
}

[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern",
Justification = "The 'IAsyncEnumerator<>' Type must exist and so trimmer kept it. In which case " +
"It also kept it on any type which implements it. The below call to GetInterfaces " +
"may return fewer results when trimmed but it will return 'IAsyncEnumerator<>' " +
"if the type implemented it, even after trimming.")]
public static Type GetIAsyncEnumeratorInterface(Type type)
{
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>))
{
return type;
}

foreach (Type typeToCheck in type.GetInterfaces())
{
if (typeToCheck.IsGenericType && typeToCheck.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>))
{
return typeToCheck;
}
}

throw new InvalidOperationException($"Type '{type}' does not implement IAsyncEnumerator<>");
}
}
Loading
Loading