Skip to content
This repository was archived by the owner on Nov 20, 2018. It is now read-only.

Commit 59b605c

Browse files
tuespetreTratcher
authored andcommitted
Add UseWhenExtensions and UseWhenExtensionsTests
1 parent 62eaf16 commit 59b605c

File tree

2 files changed

+237
-0
lines changed

2 files changed

+237
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System;
5+
using Microsoft.AspNetCore.Http;
6+
7+
namespace Microsoft.AspNetCore.Builder
8+
{
9+
using Predicate = Func<HttpContext, bool>;
10+
11+
/// <summary>
12+
/// Extension methods for <see cref="IApplicationBuilder"/>.
13+
/// </summary>
14+
public static class UseWhenExtensions
15+
{
16+
/// <summary>
17+
/// Conditionally creates a branch in the request pipeline that is rejoined to the main pipeline.
18+
/// </summary>
19+
/// <param name="app"></param>
20+
/// <param name="predicate">Invoked with the request environment to determine if the branch should be taken</param>
21+
/// <param name="configuration">Configures a branch to take</param>
22+
/// <returns></returns>
23+
public static IApplicationBuilder UseWhen(this IApplicationBuilder app, Predicate predicate, Action<IApplicationBuilder> configuration)
24+
{
25+
if (app == null)
26+
{
27+
throw new ArgumentNullException(nameof(app));
28+
}
29+
30+
if (predicate == null)
31+
{
32+
throw new ArgumentNullException(nameof(predicate));
33+
}
34+
35+
if (configuration == null)
36+
{
37+
throw new ArgumentNullException(nameof(configuration));
38+
}
39+
40+
// Create and configure the branch builder right away; otherwise,
41+
// we would end up running our branch after all the components
42+
// that were subsequently added to the main builder.
43+
var branchBuilder = app.New();
44+
configuration(branchBuilder);
45+
46+
return app.Use(main =>
47+
{
48+
// This is called only when the main application builder
49+
// is built, not per request.
50+
branchBuilder.Run(main);
51+
var branch = branchBuilder.Build();
52+
53+
return async context =>
54+
{
55+
if (predicate(context))
56+
{
57+
await branch(context);
58+
}
59+
else
60+
{
61+
await main(context);
62+
}
63+
};
64+
});
65+
}
66+
}
67+
}
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System;
5+
using System.Threading.Tasks;
6+
using Microsoft.AspNetCore.Builder.Internal;
7+
using Microsoft.AspNetCore.Http;
8+
using Xunit;
9+
10+
namespace Microsoft.AspNetCore.Builder.Extensions
11+
{
12+
public class UseWhenExtensionsTests
13+
{
14+
[Fact]
15+
public void NullArguments_ArgumentNullException()
16+
{
17+
// Arrange
18+
var builder = CreateBuilder();
19+
20+
// Act
21+
Action nullPredicate = () => builder.UseWhen(null, app => { });
22+
Action nullConfiguration = () => builder.UseWhen(TruePredicate, null);
23+
24+
// Assert
25+
Assert.Throws<ArgumentNullException>(nullPredicate);
26+
Assert.Throws<ArgumentNullException>(nullConfiguration);
27+
}
28+
29+
[Fact]
30+
public void PredicateTrue_BranchTaken_WillRejoin()
31+
{
32+
// Arrange
33+
var context = CreateContext();
34+
var parent = CreateBuilder();
35+
36+
parent.UseWhen(TruePredicate, child =>
37+
{
38+
child.UseWhen(TruePredicate, grandchild =>
39+
{
40+
grandchild.Use(Increment("grandchild"));
41+
});
42+
43+
child.Use(Increment("child"));
44+
});
45+
46+
parent.Use(Increment("parent"));
47+
48+
// Act
49+
parent.Build().Invoke(context).Wait();
50+
51+
// Assert
52+
Assert.Equal(1, Count(context, "parent"));
53+
Assert.Equal(1, Count(context, "child"));
54+
Assert.Equal(1, Count(context, "grandchild"));
55+
}
56+
57+
[Fact]
58+
public void PredicateTrue_BranchTaken_CanTerminate()
59+
{
60+
// Arrange
61+
var context = CreateContext();
62+
var parent = CreateBuilder();
63+
64+
parent.UseWhen(TruePredicate, child =>
65+
{
66+
child.UseWhen(TruePredicate, grandchild =>
67+
{
68+
grandchild.Use(Increment("grandchild", terminate: true));
69+
});
70+
71+
child.Use(Increment("child"));
72+
});
73+
74+
parent.Use(Increment("parent"));
75+
76+
// Act
77+
parent.Build().Invoke(context).Wait();
78+
79+
// Assert
80+
Assert.Equal(0, Count(context, "parent"));
81+
Assert.Equal(0, Count(context, "child"));
82+
Assert.Equal(1, Count(context, "grandchild"));
83+
}
84+
85+
[Fact]
86+
public void PredicateFalse_PassThrough()
87+
{
88+
// Arrange
89+
var context = CreateContext();
90+
var parent = CreateBuilder();
91+
92+
parent.UseWhen(FalsePredicate, child =>
93+
{
94+
child.Use(Increment("child"));
95+
});
96+
97+
parent.Use(Increment("parent"));
98+
99+
// Act
100+
parent.Build().Invoke(context).Wait();
101+
102+
// Assert
103+
Assert.Equal(1, Count(context, "parent"));
104+
Assert.Equal(0, Count(context, "child"));
105+
}
106+
107+
private static HttpContext CreateContext()
108+
{
109+
return new DefaultHttpContext();
110+
}
111+
112+
private static ApplicationBuilder CreateBuilder()
113+
{
114+
return new ApplicationBuilder(serviceProvider: null);
115+
}
116+
117+
private static bool TruePredicate(HttpContext context)
118+
{
119+
return true;
120+
}
121+
122+
private static bool FalsePredicate(HttpContext context)
123+
{
124+
return false;
125+
}
126+
127+
private static Func<HttpContext, Func<Task>, Task> Increment(string key, bool terminate = false)
128+
{
129+
return (context, next) =>
130+
{
131+
if (!context.Items.ContainsKey(key))
132+
{
133+
context.Items[key] = 1;
134+
}
135+
else
136+
{
137+
var item = context.Items[key];
138+
139+
if (item is int)
140+
{
141+
context.Items[key] = 1 + (int)item;
142+
}
143+
else
144+
{
145+
context.Items[key] = 1;
146+
}
147+
}
148+
149+
return terminate ? Task.FromResult<object>(null) : next();
150+
};
151+
}
152+
153+
private static int Count(HttpContext context, string key)
154+
{
155+
if (!context.Items.ContainsKey(key))
156+
{
157+
return 0;
158+
}
159+
160+
var item = context.Items[key];
161+
162+
if (item is int)
163+
{
164+
return (int)item;
165+
}
166+
167+
return 0;
168+
}
169+
}
170+
}

0 commit comments

Comments
 (0)