diff --git a/src/Components/Server/src/Circuits/CircuitHost.cs b/src/Components/Server/src/Circuits/CircuitHost.cs index cebfc2284ff3..fcfcfee086a8 100644 --- a/src/Components/Server/src/Circuits/CircuitHost.cs +++ b/src/Components/Server/src/Circuits/CircuitHost.cs @@ -21,6 +21,7 @@ internal partial class CircuitHost : IAsyncDisposable private readonly CircuitHandler[] _circuitHandlers; private readonly RemoteNavigationManager _navigationManager; private readonly ILogger _logger; + private readonly Func, Task> _dispatchInboundActivity; private bool _initialized; private bool _disposed; @@ -66,6 +67,8 @@ public CircuitHost( Circuit = new Circuit(this); Handle = new CircuitHandle() { CircuitHost = this, }; + _dispatchInboundActivity = BuildInboundActivityDispatcher(_circuitHandlers, Circuit); + // An unhandled exception from the renderer is always fatal because it came from user code. Renderer.UnhandledException += ReportAndInvoke_UnhandledException; Renderer.UnhandledSynchronizationException += SynchronizationContext_UnhandledException; @@ -324,7 +327,7 @@ public async Task OnRenderCompletedAsync(long renderId, string errorMessageOrNul try { - _ = Renderer.OnRenderCompletedAsync(renderId, errorMessageOrNull); + _ = HandleInboundActivityAsync(() => Renderer.OnRenderCompletedAsync(renderId, errorMessageOrNull)); } catch (Exception e) { @@ -345,12 +348,12 @@ public async Task BeginInvokeDotNetFromJS(string callId, string assemblyName, st try { - await Renderer.Dispatcher.InvokeAsync(() => + await HandleInboundActivityAsync(() => Renderer.Dispatcher.InvokeAsync(() => { Log.BeginInvokeDotNet(_logger, callId, assemblyName, methodIdentifier, dotNetObjectId); var invocationInfo = new DotNetInvocationInfo(assemblyName, methodIdentifier, dotNetObjectId, callId); DotNetDispatcher.BeginInvokeDotNet(JSRuntime, invocationInfo, argsJson); - }); + })); } catch (Exception ex) { @@ -371,7 +374,7 @@ public async Task EndInvokeJSFromDotNet(long asyncCall, bool succeeded, string a try { - await Renderer.Dispatcher.InvokeAsync(() => + await HandleInboundActivityAsync(() => Renderer.Dispatcher.InvokeAsync(() => { if (!succeeded) { @@ -384,7 +387,7 @@ await Renderer.Dispatcher.InvokeAsync(() => } DotNetDispatcher.EndInvokeJS(JSRuntime, arguments); - }); + })); } catch (Exception ex) { @@ -405,11 +408,11 @@ internal async Task ReceiveByteArray(int id, byte[] data) try { - await Renderer.Dispatcher.InvokeAsync(() => + await HandleInboundActivityAsync(() => Renderer.Dispatcher.InvokeAsync(() => { Log.ReceiveByteArraySuccess(_logger, id); DotNetDispatcher.ReceiveByteArray(JSRuntime, id, data); - }); + })); } catch (Exception ex) { @@ -430,10 +433,10 @@ internal async Task ReceiveJSDataChunk(long streamId, long chunkId, byte[] try { - return await Renderer.Dispatcher.InvokeAsync(() => + return await HandleInboundActivityAsync(() => Renderer.Dispatcher.InvokeAsync(() => { return RemoteJSDataStream.ReceiveData(JSRuntime, streamId, chunkId, chunk, error); - }); + })); } catch (Exception ex) { @@ -453,7 +456,7 @@ public async Task SendDotNetStreamAsync(DotNetStreamReference dotNetStreamR try { - return await Renderer.Dispatcher.InvokeAsync(async () => await dotNetStreamReference.Stream.ReadAsync(buffer)); + return await Renderer.Dispatcher.InvokeAsync(async () => await dotNetStreamReference.Stream.ReadAsync(buffer)); } catch (Exception ex) { @@ -505,12 +508,12 @@ public async Task OnLocationChangedAsync(string uri, string state, bool intercep try { - await Renderer.Dispatcher.InvokeAsync(() => + await HandleInboundActivityAsync(() => Renderer.Dispatcher.InvokeAsync(() => { Log.LocationChange(_logger, uri, CircuitId); _navigationManager.NotifyLocationChanged(uri, state, intercepted); Log.LocationChangeSucceeded(_logger, uri, CircuitId); - }); + })); } // It's up to the NavigationManager implementation to validate the URI. @@ -547,11 +550,11 @@ public async Task OnLocationChangingAsync(int callId, string uri, string? state, try { - var shouldContinueNavigation = await Renderer.Dispatcher.InvokeAsync(async () => + var shouldContinueNavigation = await HandleInboundActivityAsync(() => Renderer.Dispatcher.InvokeAsync(async () => { Log.LocationChanging(_logger, uri, CircuitId); return await _navigationManager.HandleLocationChangingAsync(uri, state, intercepted); - }); + })); await Client.SendAsync("JS.EndLocationChanging", callId, shouldContinueNavigation); } @@ -589,6 +592,40 @@ public void SendPendingBatches() _ = Renderer.Dispatcher.InvokeAsync(Renderer.ProcessBufferedRenderBatches); } + // Internal for testing. + internal Task HandleInboundActivityAsync(Func handler) + => _dispatchInboundActivity(handler); + + // Internal for testing. + internal async Task HandleInboundActivityAsync(Func> handler) + { + TResult result = default; + await _dispatchInboundActivity(async () => result = await handler()); + return result; + } + + private static Func, Task> BuildInboundActivityDispatcher(IReadOnlyList circuitHandlers, Circuit circuit) + { + Func? result = null; + + for (var i = circuitHandlers.Count - 1; i >= 0; i--) + { + if (circuitHandlers[i] is IHandleCircuitActivity inboundActivityHandler) + { + var next = result ?? (static (context) => context.Handler()); + result = (context) => inboundActivityHandler.HandleInboundActivityAsync(context, next); + } + } + + if (result is null) + { + // If there are no registered handlers, there is no need to allocate a context on each call. + return static (handler) => handler(); + } + + return (handler) => result(new(handler, circuit)); + } + private void AssertInitialized() { if (!_initialized) diff --git a/src/Components/Server/src/Circuits/CircuitInboundActivityContext.cs b/src/Components/Server/src/Circuits/CircuitInboundActivityContext.cs new file mode 100644 index 000000000000..cabf6d0dcd40 --- /dev/null +++ b/src/Components/Server/src/Circuits/CircuitInboundActivityContext.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Components.Server.Circuits; + +/// +/// Contains information about inbound activity. +/// +public sealed class CircuitInboundActivityContext +{ + internal Func Handler { get; } + + /// + /// Gets the associated with the activity. + /// + public Circuit Circuit { get; } + + internal CircuitInboundActivityContext(Func handler, Circuit circuit) + { + Handler = handler; + Circuit = circuit; + } +} diff --git a/src/Components/Server/src/Circuits/IHandleCircuitActivity.cs b/src/Components/Server/src/Circuits/IHandleCircuitActivity.cs new file mode 100644 index 000000000000..f46109feadf8 --- /dev/null +++ b/src/Components/Server/src/Circuits/IHandleCircuitActivity.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Components.Server.Circuits; + +/// +/// A handler to process inbound circuit activity. +/// +public interface IHandleCircuitActivity +{ + /// + /// Invoked when inbound activity on the circuit causes an asynchronous task to be dispatched on the server. + /// + /// The . + /// The next handler to invoke. + /// A that completes when the activity has finished. + Task HandleInboundActivityAsync(CircuitInboundActivityContext context, Func next); +} diff --git a/src/Components/Server/src/PublicAPI.Unshipped.txt b/src/Components/Server/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..d470f8781656 100644 --- a/src/Components/Server/src/PublicAPI.Unshipped.txt +++ b/src/Components/Server/src/PublicAPI.Unshipped.txt @@ -1 +1,5 @@ #nullable enable +Microsoft.AspNetCore.Components.Server.Circuits.CircuitInboundActivityContext +Microsoft.AspNetCore.Components.Server.Circuits.CircuitInboundActivityContext.Circuit.get -> Microsoft.AspNetCore.Components.Server.Circuits.Circuit! +Microsoft.AspNetCore.Components.Server.Circuits.IHandleCircuitActivity +Microsoft.AspNetCore.Components.Server.Circuits.IHandleCircuitActivity.HandleInboundActivityAsync(Microsoft.AspNetCore.Components.Server.Circuits.CircuitInboundActivityContext! context, System.Func! next) -> System.Threading.Tasks.Task! diff --git a/src/Components/Server/test/Circuits/CircuitHostTest.cs b/src/Components/Server/test/Circuits/CircuitHostTest.cs index a1adae467b3a..2711534fe019 100644 --- a/src/Components/Server/test/Circuits/CircuitHostTest.cs +++ b/src/Components/Server/test/Circuits/CircuitHostTest.cs @@ -318,6 +318,83 @@ public async Task DisposeAsync_InvokesCircuitHandler() handler2.VerifyAll(); } + [Fact] + public async Task HandleInboundActivityAsync_InvokesCircuitActivityHandlers() + { + // Arrange + var handler1 = new Mock(MockBehavior.Strict); + var handler2 = new Mock(MockBehavior.Strict); + var handler3 = new Mock(MockBehavior.Strict); + var sequence = new MockSequence(); + + // We deliberately avoid making handler2 an inbound activity handler + var activityHandler1 = handler1.As(); + var activityHandler3 = handler3.As(); + + var asyncLocal1 = new AsyncLocal(); + var asyncLocal3 = new AsyncLocal(); + + activityHandler1 + .InSequence(sequence) + .Setup(h => h.HandleInboundActivityAsync(It.IsAny(), It.IsAny>())) + .Returns(async (CircuitInboundActivityContext context, Func next) => + { + asyncLocal1.Value = true; + await next(context); + }) + .Verifiable(); + + activityHandler3 + .InSequence(sequence) + .Setup(h => h.HandleInboundActivityAsync(It.IsAny(), It.IsAny>())) + .Returns(async (CircuitInboundActivityContext context, Func next) => + { + asyncLocal3.Value = true; + await next(context); + }) + .Verifiable(); + + var circuitHost = TestCircuitHost.Create(handlers: new[] { handler1.Object, handler2.Object, handler3.Object }); + var asyncLocal1ValueInHandler = false; + var asyncLocal3ValueInHandler = false; + + // Act + await circuitHost.HandleInboundActivityAsync(() => + { + asyncLocal1ValueInHandler = asyncLocal1.Value; + asyncLocal3ValueInHandler = asyncLocal3.Value; + return Task.CompletedTask; + }); + + // Assert + activityHandler1.VerifyAll(); + activityHandler3.VerifyAll(); + + Assert.False(asyncLocal1.Value); + Assert.False(asyncLocal3.Value); + + Assert.True(asyncLocal1ValueInHandler); + Assert.True(asyncLocal3ValueInHandler); + } + + [Fact] + public async Task HandleInboundActivityAsync_InvokesHandlerFunc_WhenNoCircuitActivityHandlersAreRegistered() + { + // Arrange + var circuitHost = TestCircuitHost.Create(); + var wasHandlerFuncInvoked = false; + + // Act + await circuitHost.HandleInboundActivityAsync(() => + { + wasHandlerFuncInvoked = true; + return Task.CompletedTask; + }); + + // Assert + Assert.True(wasHandlerFuncInvoked); + } + private static TestRemoteRenderer GetRemoteRenderer() { var serviceCollection = new ServiceCollection(); diff --git a/src/Components/test/E2ETest/ServerExecutionTests/CircuitContextTest.cs b/src/Components/test/E2ETest/ServerExecutionTests/CircuitContextTest.cs new file mode 100644 index 000000000000..d14759152303 --- /dev/null +++ b/src/Components/test/E2ETest/ServerExecutionTests/CircuitContextTest.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Components.TestServer; +using Microsoft.AspNetCore.Components.E2ETest; +using Microsoft.AspNetCore.Components.E2ETest.Infrastructure; +using Microsoft.AspNetCore.Components.E2ETest.Infrastructure.ServerFixtures; +using Microsoft.AspNetCore.E2ETesting; +using OpenQA.Selenium; +using TestServer; +using Xunit.Abstractions; + +namespace Microsoft.AspNetCore.Components.E2ETests.ServerExecutionTests; + +public class CircuitContextTest : ServerTestBase> +{ + public CircuitContextTest( + BrowserFixture browserFixture, + BasicTestAppServerSiteFixture serverFixture, + ITestOutputHelper output) + : base(browserFixture, serverFixture, output) + { + } + + protected override void InitializeAsyncCore() + { + Navigate(ServerPathBase, noReload: false); + Browser.MountTestComponent(); + Browser.Equal("Circuit Context", () => Browser.Exists(By.TagName("h1")).Text); + } + + [Fact] + public void ComponentMethods_HaveCircuitContext() + { + Browser.Click(By.Id("trigger-click-event-button")); + + Browser.True(() => HasCircuitContext("SetParametersAsync")); + Browser.True(() => HasCircuitContext("OnInitializedAsync")); + Browser.True(() => HasCircuitContext("OnParametersSetAsync")); + Browser.True(() => HasCircuitContext("OnAfterRenderAsync")); + Browser.True(() => HasCircuitContext("InvokeDotNet")); + Browser.True(() => HasCircuitContext("OnClickEvent")); + + bool HasCircuitContext(string eventName) + { + var resultText = Browser.FindElement(By.Id($"circuit-context-result-{eventName}")).Text; + var result = bool.Parse(resultText); + return result; + } + } +} diff --git a/src/Components/test/testassets/BasicTestApp/Index.razor b/src/Components/test/testassets/BasicTestApp/Index.razor index f9c3814bd069..b741639932dd 100644 --- a/src/Components/test/testassets/BasicTestApp/Index.razor +++ b/src/Components/test/testassets/BasicTestApp/Index.razor @@ -10,6 +10,7 @@ + diff --git a/src/Components/test/testassets/BasicTestApp/wwwroot/index.html b/src/Components/test/testassets/BasicTestApp/wwwroot/index.html index 86ee65077ebe..9a309622d52c 100644 --- a/src/Components/test/testassets/BasicTestApp/wwwroot/index.html +++ b/src/Components/test/testassets/BasicTestApp/wwwroot/index.html @@ -26,6 +26,7 @@ + diff --git a/src/Components/test/testassets/BasicTestApp/wwwroot/js/circuitContextTest.js b/src/Components/test/testassets/BasicTestApp/wwwroot/js/circuitContextTest.js new file mode 100644 index 000000000000..b0512967f654 --- /dev/null +++ b/src/Components/test/testassets/BasicTestApp/wwwroot/js/circuitContextTest.js @@ -0,0 +1,5 @@ +window.circuitContextTest = { + invokeDotNetMethod: async (dotNetObject) => { + await dotNetObject.invokeMethodAsync('InvokeDotNet'); + }, +}; diff --git a/src/Components/test/testassets/TestServer/CircuitContextComponent.razor b/src/Components/test/testassets/TestServer/CircuitContextComponent.razor new file mode 100644 index 000000000000..7a7b79e4f9f7 --- /dev/null +++ b/src/Components/test/testassets/TestServer/CircuitContextComponent.razor @@ -0,0 +1,79 @@ +@using Microsoft.JSInterop; + +@implements IDisposable + +@inject IJSRuntime JS +@inject TestCircuitContextAccessor CircuitContextAccessor + +

Circuit Context

+ + + +@foreach (var entry in _hasCircuitContextByEventName) +{ +

+ @entry.Key: @entry.Value +

+} + +@code { + private readonly DotNetObjectReference _selfRef; + private readonly Dictionary _hasCircuitContextByEventName = new(); + + public CircuitContextComponent() + { + _selfRef = DotNetObjectReference.Create(this); + } + + public override async Task SetParametersAsync(ParameterView parameters) + { + RecordHasCircuitContext(nameof(SetParametersAsync)); + await base.SetParametersAsync(parameters); + } + + protected override Task OnInitializedAsync() + { + RecordHasCircuitContext(nameof(OnInitializedAsync)); + return Task.CompletedTask; + } + + protected override async Task OnAfterRenderAsync(bool firstRender) + { + if (firstRender) + { + RecordHasCircuitContext(nameof(OnAfterRenderAsync)); + + await JS.InvokeVoidAsync("circuitContextTest.invokeDotNetMethod", _selfRef); + + StateHasChanged(); + } + } + + protected override Task OnParametersSetAsync() + { + RecordHasCircuitContext(nameof(OnParametersSetAsync)); + return Task.CompletedTask; + } + + private Task OnClickEvent() + { + RecordHasCircuitContext(nameof(OnClickEvent)); + return Task.CompletedTask; + } + + [JSInvokable] + public void InvokeDotNet() + { + RecordHasCircuitContext(nameof(InvokeDotNet)); + } + + private void RecordHasCircuitContext(string eventName) + { + _hasCircuitContextByEventName[eventName] = CircuitContextAccessor.HasCircuitContext; + } + + public void Dispose() + { + _selfRef.Dispose(); + } +} diff --git a/src/Components/test/testassets/TestServer/Pages/_ServerHost.cshtml b/src/Components/test/testassets/TestServer/Pages/_ServerHost.cshtml index 91ea15c3c374..d57502e248d8 100644 --- a/src/Components/test/testassets/TestServer/Pages/_ServerHost.cshtml +++ b/src/Components/test/testassets/TestServer/Pages/_ServerHost.cshtml @@ -17,6 +17,7 @@ + diff --git a/src/Components/test/testassets/TestServer/ServerStartup.cs b/src/Components/test/testassets/TestServer/ServerStartup.cs index 9da10dfe1a72..34e4ec495fe2 100644 --- a/src/Components/test/testassets/TestServer/ServerStartup.cs +++ b/src/Components/test/testassets/TestServer/ServerStartup.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Globalization; +using Microsoft.AspNetCore.Components.Server.Circuits; using Microsoft.AspNetCore.Components.Web; using Microsoft.AspNetCore.DataProtection; @@ -31,6 +32,10 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); services.AddTransient(); + var circuitContextAccessor = new TestCircuitContextAccessor(); + services.AddSingleton(circuitContextAccessor); + services.AddSingleton(circuitContextAccessor); + // Since tests run in parallel, we use an ephemeral key provider to avoid filesystem // contention issues. services.AddSingleton(); diff --git a/src/Components/test/testassets/TestServer/TestCircuitContextAccessor.cs b/src/Components/test/testassets/TestServer/TestCircuitContextAccessor.cs new file mode 100644 index 000000000000..de07b3b8ba29 --- /dev/null +++ b/src/Components/test/testassets/TestServer/TestCircuitContextAccessor.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Components.Server.Circuits; + +namespace TestServer; + +public class TestCircuitContextAccessor : CircuitHandler, IHandleCircuitActivity +{ + private readonly AsyncLocal _hasCircuitContext = new(); + + public bool HasCircuitContext => _hasCircuitContext.Value; + + public async Task HandleInboundActivityAsync(CircuitInboundActivityContext context, Func next) + { + _hasCircuitContext.Value = true; + await next(context); + _hasCircuitContext.Value = false; + } +} diff --git a/src/Components/test/testassets/TestServer/_Imports.razor b/src/Components/test/testassets/TestServer/_Imports.razor index a71616835c7b..3a9c63da2676 100644 --- a/src/Components/test/testassets/TestServer/_Imports.razor +++ b/src/Components/test/testassets/TestServer/_Imports.razor @@ -1,2 +1,3 @@ @using Microsoft.AspNetCore.Components.Web @using Microsoft.AspNetCore.Components.Web.Virtualization +@using global::TestServer;