diff --git a/src/Microsoft.AspNetCore.Http.Abstractions/Extensions/UseWhenExtensions.cs b/src/Microsoft.AspNetCore.Http.Abstractions/Extensions/UseWhenExtensions.cs new file mode 100644 index 00000000..3709a1e9 --- /dev/null +++ b/src/Microsoft.AspNetCore.Http.Abstractions/Extensions/UseWhenExtensions.cs @@ -0,0 +1,67 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Builder +{ + using Predicate = Func; + + /// + /// Extension methods for . + /// + public static class UseWhenExtensions + { + /// + /// Conditionally creates a branch in the request pipeline that is rejoined to the main pipeline. + /// + /// + /// Invoked with the request environment to determine if the branch should be taken + /// Configures a branch to take + /// + public static IApplicationBuilder UseWhen(this IApplicationBuilder app, Predicate predicate, Action configuration) + { + if (app == null) + { + throw new ArgumentNullException(nameof(app)); + } + + if (predicate == null) + { + throw new ArgumentNullException(nameof(predicate)); + } + + if (configuration == null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + // Create and configure the branch builder right away; otherwise, + // we would end up running our branch after all the components + // that were subsequently added to the main builder. + var branchBuilder = app.New(); + configuration(branchBuilder); + + return app.Use(main => + { + // This is called only when the main application builder + // is built, not per request. + branchBuilder.Run(main); + var branch = branchBuilder.Build(); + + return async context => + { + if (predicate(context)) + { + await branch(context); + } + else + { + await main(context); + } + }; + }); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Http.Abstractions.Tests/UseWhenExtensionsTests.cs b/test/Microsoft.AspNetCore.Http.Abstractions.Tests/UseWhenExtensionsTests.cs new file mode 100644 index 00000000..902a003b --- /dev/null +++ b/test/Microsoft.AspNetCore.Http.Abstractions.Tests/UseWhenExtensionsTests.cs @@ -0,0 +1,170 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder.Internal; +using Microsoft.AspNetCore.Http; +using Xunit; + +namespace Microsoft.AspNetCore.Builder.Extensions +{ + public class UseWhenExtensionsTests + { + [Fact] + public void NullArguments_ArgumentNullException() + { + // Arrange + var builder = CreateBuilder(); + + // Act + Action nullPredicate = () => builder.UseWhen(null, app => { }); + Action nullConfiguration = () => builder.UseWhen(TruePredicate, null); + + // Assert + Assert.Throws(nullPredicate); + Assert.Throws(nullConfiguration); + } + + [Fact] + public void PredicateTrue_BranchTaken_WillRejoin() + { + // Arrange + var context = CreateContext(); + var parent = CreateBuilder(); + + parent.UseWhen(TruePredicate, child => + { + child.UseWhen(TruePredicate, grandchild => + { + grandchild.Use(Increment("grandchild")); + }); + + child.Use(Increment("child")); + }); + + parent.Use(Increment("parent")); + + // Act + parent.Build().Invoke(context).Wait(); + + // Assert + Assert.Equal(1, Count(context, "parent")); + Assert.Equal(1, Count(context, "child")); + Assert.Equal(1, Count(context, "grandchild")); + } + + [Fact] + public void PredicateTrue_BranchTaken_CanTerminate() + { + // Arrange + var context = CreateContext(); + var parent = CreateBuilder(); + + parent.UseWhen(TruePredicate, child => + { + child.UseWhen(TruePredicate, grandchild => + { + grandchild.Use(Increment("grandchild", terminate: true)); + }); + + child.Use(Increment("child")); + }); + + parent.Use(Increment("parent")); + + // Act + parent.Build().Invoke(context).Wait(); + + // Assert + Assert.Equal(0, Count(context, "parent")); + Assert.Equal(0, Count(context, "child")); + Assert.Equal(1, Count(context, "grandchild")); + } + + [Fact] + public void PredicateFalse_PassThrough() + { + // Arrange + var context = CreateContext(); + var parent = CreateBuilder(); + + parent.UseWhen(FalsePredicate, child => + { + child.Use(Increment("child")); + }); + + parent.Use(Increment("parent")); + + // Act + parent.Build().Invoke(context).Wait(); + + // Assert + Assert.Equal(1, Count(context, "parent")); + Assert.Equal(0, Count(context, "child")); + } + + private static HttpContext CreateContext() + { + return new DefaultHttpContext(); + } + + private static ApplicationBuilder CreateBuilder() + { + return new ApplicationBuilder(serviceProvider: null); + } + + private static bool TruePredicate(HttpContext context) + { + return true; + } + + private static bool FalsePredicate(HttpContext context) + { + return false; + } + + private static Func, Task> Increment(string key, bool terminate = false) + { + return (context, next) => + { + if (!context.Items.ContainsKey(key)) + { + context.Items[key] = 1; + } + else + { + var item = context.Items[key]; + + if (item is int) + { + context.Items[key] = 1 + (int)item; + } + else + { + context.Items[key] = 1; + } + } + + return terminate ? Task.FromResult(null) : next(); + }; + } + + private static int Count(HttpContext context, string key) + { + if (!context.Items.ContainsKey(key)) + { + return 0; + } + + var item = context.Items[key]; + + if (item is int) + { + return (int)item; + } + + return 0; + } + } +}