Skip to content

Commit 642d7af

Browse files
committed
Add support for antiforgery
1 parent 9da6177 commit 642d7af

29 files changed

+503
-31
lines changed

src/Components/Components/src/Rendering/RenderTreeBuilder.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,11 @@ public void AddAttribute(int sequence, string name)
175175
throw new InvalidOperationException($"Valueless attributes may only be added immediately after frames of type {RenderTreeFrameType.Element}");
176176
}
177177

178+
if (TrackNamedEventHandlers && string.Equals(name, "@onsubmit:name", StringComparison.Ordinal))
179+
{
180+
_entries.AppendAttribute(sequence, name, "");
181+
}
182+
178183
_entries.AppendAttribute(sequence, name, BoxedTrue);
179184
}
180185

@@ -873,6 +878,18 @@ internal void ProcessDuplicateAttributes(int first)
873878
// This attribute has been overridden. For now, blank out its name to *mark* it. We'll do a pass
874879
// later to wipe it out.
875880
frame = default;
881+
// We are wiping out this frame, which means that if we are tracking named events, we have to adjust the
882+
// indexes of the named event handlers that come after this frame.
883+
if (_seenEventHandlerNames != null && _seenEventHandlerNames.Count > 0)
884+
{
885+
foreach (var (name, eventIndex) in _seenEventHandlerNames)
886+
{
887+
if (eventIndex >= i)
888+
{
889+
_seenEventHandlerNames[name] = eventIndex - 1;
890+
}
891+
}
892+
}
876893
}
877894
else
878895
{
@@ -935,6 +952,62 @@ public void Dispose()
935952

936953
internal Dictionary<string, int>? GetNamedEvents()
937954
{
955+
if (TrackNamedEventHandlers)
956+
{
957+
var i = _entries.Count - 1;
958+
while (i >= 0)
959+
{
960+
ref var frame = ref _entries.Buffer[i];
961+
if (frame.FrameType != RenderTreeFrameType.Attribute)
962+
{
963+
i--;
964+
continue;
965+
}
966+
var j = i;
967+
var submitHandlerIndex = -1;
968+
var submitHanderNameIndex = -1;
969+
while (j > 0)
970+
{
971+
// we are inside a list of attribute frames.
972+
// Walk backwards to find pairs of onsubmit and @onsubmit:name
973+
// and stop the first time we find an element or component frame.
974+
ref var attributeFrame = ref _entries.Buffer[j];
975+
if (attributeFrame.FrameType == RenderTreeFrameType.Component)
976+
{
977+
// If we were processing a component, ignore the values
978+
// as this feature is only for HTML elements.
979+
submitHanderNameIndex = -1;
980+
submitHandlerIndex = -1;
981+
break;
982+
}
983+
else if (attributeFrame.FrameType == RenderTreeFrameType.Element)
984+
{
985+
// We are at the end of the elements sequence
986+
break;
987+
}
988+
if (string.Equals(attributeFrame.AttributeName, "onsubmit", StringComparison.Ordinal))
989+
{
990+
submitHandlerIndex = j;
991+
}
992+
else if (string.Equals(attributeFrame.AttributeName, "@onsubmit:name", StringComparison.Ordinal))
993+
{
994+
submitHanderNameIndex = j;
995+
}
996+
997+
if (submitHandlerIndex != -1 && submitHanderNameIndex != -1)
998+
{
999+
// We found a pair, add it to the dictionary in case it was missing.
1000+
_seenEventHandlerNames ??= new Dictionary<string, int>(SimplifiedStringHashComparer.Instance);
1001+
var eventHandlerName = _entries.Buffer[submitHanderNameIndex].AttributeValue;
1002+
_seenEventHandlerNames[(eventHandlerName as string)!] = submitHandlerIndex;
1003+
submitHanderNameIndex = -1;
1004+
submitHandlerIndex = -1;
1005+
}
1006+
j--;
1007+
}
1008+
i = j;
1009+
}
1010+
}
9381011
return _seenEventHandlerNames;
9391012
}
9401013
}

src/Components/Endpoints/src/DependencyInjection/RazorComponentsServiceCollectionExtensions.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Microsoft.AspNetCore.Components.Binding;
77
using Microsoft.AspNetCore.Components.Endpoints;
88
using Microsoft.AspNetCore.Components.Endpoints.DependencyInjection;
9+
using Microsoft.AspNetCore.Components.Endpoints.Forms;
910
using Microsoft.AspNetCore.Components.Forms;
1011
using Microsoft.AspNetCore.Components.Infrastructure;
1112
using Microsoft.AspNetCore.Components.Routing;
@@ -30,6 +31,9 @@ public static class RazorComponentsServiceCollectionExtensions
3031
[RequiresUnreferencedCode("Razor Components does not currently support native AOT.", Url = "https://aka.ms/aspnet/nativeaot")]
3132
public static IRazorComponentsBuilder AddRazorComponents(this IServiceCollection services)
3233
{
34+
// Dependencies
35+
services.AddAntiforgery();
36+
3337
services.TryAddSingleton<RazorComponentsMarkerService>();
3438

3539
// Results
@@ -60,7 +64,7 @@ public static IRazorComponentsBuilder AddRazorComponents(this IServiceCollection
6064
services.TryAddScoped<IFormValueSupplier, DefaultFormValuesSupplier>();
6165
services.TryAddEnumerable(ServiceDescriptor.Scoped<CascadingModelBindingProvider, CascadingQueryModelBindingProvider>());
6266
services.TryAddEnumerable(ServiceDescriptor.Scoped<CascadingModelBindingProvider, CascadingFormModelBindingProvider>());
63-
67+
services.TryAddScoped<AntiforgeryStateProvider, EndpointAntiforgeryStateProvider>();
6468
return new DefaultRazorComponentsBuilder(services);
6569
}
6670

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using Microsoft.AspNetCore.Antiforgery;
5+
using Microsoft.AspNetCore.Components.Forms;
6+
using Microsoft.AspNetCore.Http;
7+
8+
namespace Microsoft.AspNetCore.Components.Endpoints.Forms;
9+
10+
internal class EndpointAntiforgeryStateProvider(IAntiforgery antiforgery, PersistentComponentState state) : AntiforgeryStateProvider(state)
11+
{
12+
private HttpContext? _context;
13+
14+
internal void SetRequestContext(HttpContext context)
15+
{
16+
_context = context;
17+
}
18+
19+
public override AntiforgeryRequestToken? GetAntiforgeryToken()
20+
{
21+
if (_context == null)
22+
{
23+
return null;
24+
}
25+
26+
// We already have a callback setup to generate the token when the response starts if needed.
27+
// If we need the tokens before we start streaming the response, we'll generate and store them;
28+
// otherwise we'll just retrieve them.
29+
// In case there are no tokens available, we are going to return null and no-op.
30+
var tokens = !_context.Response.HasStarted ? antiforgery.GetAndStoreTokens(_context) : antiforgery.GetTokens(_context);
31+
if (tokens.RequestToken is null)
32+
{
33+
return null;
34+
}
35+
36+
return new AntiforgeryRequestToken(tokens.RequestToken, tokens.FormFieldName);
37+
}
38+
}

src/Components/Endpoints/src/Microsoft.AspNetCore.Components.Endpoints.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
</ItemGroup>
4545

4646
<ItemGroup>
47+
<Reference Include="Microsoft.AspNetCore.Antiforgery" />
4748
<Reference Include="Microsoft.AspNetCore.Components.Authorization" />
4849
<Reference Include="Microsoft.AspNetCore.Components.Web" />
4950
<Reference Include="Microsoft.AspNetCore.DataProtection.Extensions" />

src/Components/Endpoints/src/RazorComponentEndpointInvoker.cs

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Buffers;
5+
using System.Diagnostics;
56
using System.Diagnostics.CodeAnalysis;
67
using System.Text;
78
using System.Text.Encodings.Web;
9+
using Microsoft.AspNetCore.Antiforgery;
810
using Microsoft.AspNetCore.Http;
911
using Microsoft.AspNetCore.WebUtilities;
1012
using Microsoft.Extensions.DependencyInjection;
@@ -36,13 +38,23 @@ private async Task RenderComponentCore()
3638
_context.Response.ContentType = RazorComponentResultExecutor.DefaultContentType;
3739
_renderer.InitializeStreamingRenderingFraming(_context);
3840

39-
if (!await TryValidateRequestAsync(out var isPost, out var handler))
41+
var antiforgery = _context.RequestServices.GetRequiredService<IAntiforgery>();
42+
var (valid, isPost, handler) = await ValidateRequestAsync(antiforgery);
43+
if (!valid)
4044
{
4145
// If the request is not valid we've already set the response to a 400 or similar
4246
// and we can just exit early.
4347
return;
4448
}
4549

50+
_context.Response.OnStarting(() =>
51+
{
52+
// Generate the antiforgery tokens before we start streaming the response, as it needs
53+
// to set the cookie header.
54+
antiforgery.GetAndStoreTokens(_context);
55+
return Task.CompletedTask;
56+
});
57+
4658
await EndpointHtmlRenderer.InitializeStandardComponentServicesAsync(
4759
_context,
4860
componentType: _componentType,
@@ -89,16 +101,21 @@ await EndpointHtmlRenderer.InitializeStandardComponentServicesAsync(
89101
await writer.FlushAsync();
90102
}
91103

92-
private Task<bool> TryValidateRequestAsync(out bool isPost, out string? handler)
104+
private async Task<RequestValidationState> ValidateRequestAsync(IAntiforgery antiforgery)
93105
{
94-
handler = null;
95-
isPost = HttpMethods.IsPost(_context.Request.Method);
106+
var isPost = HttpMethods.IsPost(_context.Request.Method);
96107
if (isPost)
97108
{
98-
return Task.FromResult(TrySetFormHandler(out handler));
109+
var valid = await antiforgery.IsRequestValidAsync(_context);
110+
if (!valid)
111+
{
112+
_context.Response.StatusCode = StatusCodes.Status400BadRequest;
113+
}
114+
var formValid = TrySetFormHandler(out var handler);
115+
return new(valid && formValid, isPost, handler);
99116
}
100117

101-
return Task.FromResult(true);
118+
return new(true, false, null);
102119
}
103120

104121
private bool TrySetFormHandler([NotNullWhen(true)] out string? handler)
@@ -128,4 +145,26 @@ private static TextWriter CreateResponseWriter(Stream bodyStream)
128145
const int DefaultBufferSize = 16 * 1024;
129146
return new HttpResponseStreamWriter(bodyStream, Encoding.UTF8, DefaultBufferSize, ArrayPool<byte>.Shared, ArrayPool<char>.Shared);
130147
}
148+
149+
[DebuggerDisplay($"{{{nameof(GetDebuggerDisplay)}(),nq}}")]
150+
private readonly struct RequestValidationState(bool isValid, bool isPost, string? handlerName)
151+
{
152+
public bool IsValid => isValid;
153+
154+
public bool IsPost => isPost;
155+
156+
public string? HandlerName => handlerName;
157+
158+
private string GetDebuggerDisplay()
159+
{
160+
return $"{nameof(RequestValidationState)}: {IsValid} {IsPost} {HandlerName}";
161+
}
162+
163+
public void Deconstruct(out bool isValid, out bool isPost, out string? handlerName)
164+
{
165+
isValid = IsValid;
166+
isPost = IsPost;
167+
handlerName = HandlerName;
168+
}
169+
}
131170
}

src/Components/Endpoints/src/Rendering/EndpointHtmlRenderer.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Text;
77
using Microsoft.AspNetCore.Components.Authorization;
88
using Microsoft.AspNetCore.Components.Endpoints.DependencyInjection;
9+
using Microsoft.AspNetCore.Components.Endpoints.Forms;
910
using Microsoft.AspNetCore.Components.Forms;
1011
using Microsoft.AspNetCore.Components.HtmlRendering.Infrastructure;
1112
using Microsoft.AspNetCore.Components.Infrastructure;
@@ -88,6 +89,12 @@ internal static async Task InitializeStandardComponentServicesAsync(
8889
formData.SetFormData(handler, new FormCollectionReadOnlyDictionary(form));
8990
}
9091

92+
var antiforgery = httpContext.RequestServices.GetRequiredService<AntiforgeryStateProvider>();
93+
if (antiforgery is EndpointAntiforgeryStateProvider endpointAntiforgery)
94+
{
95+
endpointAntiforgery.SetRequestContext(httpContext);
96+
}
97+
9198
// It's important that this is initialized since a component might try to restore state during prerendering
9299
// (which will obviously not work, but should not fail)
93100
var componentApplicationLifetime = httpContext.RequestServices.GetRequiredService<ComponentStatePersistenceManager>();

src/Components/Endpoints/test/EndpointHtmlRendererTest.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Text.Encodings.Web;
55
using System.Text.Json;
66
using System.Text.RegularExpressions;
7+
using Microsoft.AspNetCore.Components.Endpoints.Forms;
78
using Microsoft.AspNetCore.Components.Endpoints.Tests.TestComponents;
89
using Microsoft.AspNetCore.Components.Forms;
910
using Microsoft.AspNetCore.Components.Infrastructure;
@@ -1297,6 +1298,11 @@ private static ServiceCollection CreateDefaultServiceCollection()
12971298
services.AddSingleton(sp => sp.GetRequiredService<ComponentStatePersistenceManager>().State);
12981299
services.AddSingleton<ServerComponentSerializer>();
12991300
services.AddSingleton<FormDataProvider, HttpContextFormDataProvider>();
1301+
services.AddAntiforgery();
1302+
services.AddSingleton<ComponentStatePersistenceManager>();
1303+
services.AddSingleton<PersistentComponentState>(sp => sp.GetRequiredService<ComponentStatePersistenceManager>().State);
1304+
services.AddSingleton<AntiforgeryStateProvider, EndpointAntiforgeryStateProvider>();
1305+
13001306
return services;
13011307
}
13021308

src/Components/Endpoints/test/RazorComponentResultExecutorTest.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
using Microsoft.AspNetCore.Components.Endpoints.Tests.TestComponents;
1616
using System.Text.RegularExpressions;
1717
using Microsoft.AspNetCore.Components.Forms;
18+
using Microsoft.AspNetCore.Components.Endpoints.Forms;
1819

1920
namespace Microsoft.AspNetCore.Components.Endpoints;
2021

@@ -403,6 +404,7 @@ public static DefaultHttpContext GetTestHttpContext(string environmentName = nul
403404
var mockWebHostEnvironment = Mock.Of<IWebHostEnvironment>(
404405
x => x.EnvironmentName == (environmentName ?? Environments.Production));
405406
var serviceCollection = new ServiceCollection()
407+
.AddAntiforgery()
406408
.AddSingleton(new DiagnosticListener("test"))
407409
.AddSingleton<IWebHostEnvironment>(mockWebHostEnvironment)
408410
.AddSingleton<RazorComponentResultExecutor>()
@@ -413,6 +415,9 @@ public static DefaultHttpContext GetTestHttpContext(string environmentName = nul
413415
.AddSingleton<ComponentStatePersistenceManager>()
414416
.AddSingleton<IDataProtectionProvider, FakeDataProtectionProvider>()
415417
.AddSingleton<FormDataProvider, HttpContextFormDataProvider>()
418+
.AddSingleton<ComponentStatePersistenceManager>()
419+
.AddSingleton<PersistentComponentState>(sp => sp.GetRequiredService<ComponentStatePersistenceManager>().State)
420+
.AddSingleton<AntiforgeryStateProvider, EndpointAntiforgeryStateProvider>()
416421
.AddLogging();
417422

418423
var result = new DefaultHttpContext { RequestServices = serviceCollection.BuildServiceProvider() };
Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
@page "/"
2-
32
<PageTitle>Index</PageTitle>
43

5-
<h1>@Parameter</h1>
4+
<h1>@Value?.Parameter</h1>
65

7-
<EditForm method="POST" Model="Parameter">
8-
<InputText @bind-Value="Parameter" />
6+
<EditForm method="POST" Model="Value?.Parameter">
7+
<InputText @bind-Value="Value!.Parameter" />
8+
<AntiforgeryToken />
99
<input type="submit" value="Send" />
1010
</EditForm>
1111

@@ -15,11 +15,15 @@
1515
}
1616

1717
@code{
18-
[SupplyParameterFromForm] string Parameter { get; set; } = "Hello, world!";
18+
[SupplyParameterFromForm] Data? Value { get; set; }
19+
20+
protected override void OnInitialized() => Value ??= new();
1921

2022
bool _submitted = false;
21-
public void Submit()
23+
public void Submit() => _submitted = true;
24+
25+
public class Data
2226
{
23-
_submitted = true;
27+
public string Parameter { get; set; } = "";
2428
}
2529
}

src/Components/Server/src/DependencyInjection/ComponentServiceCollectionExtensions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ public static IServerSideBlazorBuilder AddServerSideBlazor(this IServiceCollecti
6969
services.TryAddSingleton<ComponentParametersTypeCache>();
7070
services.TryAddSingleton<CircuitIdFactory>();
7171
services.TryAddScoped<IErrorBoundaryLogger, RemoteErrorBoundaryLogger>();
72+
services.TryAddScoped<AntiforgeryStateProvider>();
7273

7374
services.TryAddScoped(s => s.GetRequiredService<ICircuitAccessor>().Circuit);
7475
services.TryAddScoped<ICircuitAccessor, DefaultCircuitAccessor>();

0 commit comments

Comments
 (0)