Skip to content

Commit b38dabe

Browse files
committed
Updated with feedback
1 parent 6fbb9c0 commit b38dabe

File tree

4 files changed

+16
-220
lines changed

4 files changed

+16
-220
lines changed

src/Components/test/E2ETest/ServerExecutionTests/ServerReconnectionTest.cs

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,13 @@ public void ReconnectUI()
4949
Disconnect();
5050

5151
// We should see the 'reconnecting' UI appear
52-
var reconnectionDialog = WaitUntilReconnectionDialogExists();
53-
Browser.True(() => reconnectionDialog.GetCssValue("display") == "block");
52+
Browser.True(
53+
() => Browser.FindElement(By.Id("components-reconnect-modal"))?.GetCssValue("display") == "block",
54+
TimeSpan.FromSeconds(10));
5455

5556
// Then it should disappear
56-
new WebDriverWait(Browser, TimeSpan.FromSeconds(10))
57-
.Until(driver => reconnectionDialog.GetCssValue("display") == "none");
57+
Browser.True(() => Browser.FindElement(By.Id("components-reconnect-modal"))?.GetCssValue("display") == "none",
58+
TimeSpan.FromSeconds(10));
5859

5960
counterButton = Browser.FindElement(By.Id("counter-click"));
6061
for (int i = 0; i < 10; i++)
@@ -78,12 +79,13 @@ public void RendersContinueAfterReconnect()
7879
Disconnect();
7980

8081
// We should see the 'reconnecting' UI appear
81-
var reconnectionDialog = WaitUntilReconnectionDialogExists();
82-
Browser.True(() => reconnectionDialog.GetCssValue("display") == "block");
82+
Browser.True(
83+
() => Browser.FindElement(By.Id("components-reconnect-modal"))?.GetCssValue("display") == "block",
84+
TimeSpan.FromSeconds(10));
8385

8486
// Then it should disappear
85-
new WebDriverWait(Browser, TimeSpan.FromSeconds(10))
86-
.Until(driver => reconnectionDialog.GetCssValue("display") == "none");
87+
Browser.True(() => Browser.FindElement(By.Id("components-reconnect-modal"))?.GetCssValue("display") == "none",
88+
TimeSpan.FromSeconds(10));
8789

8890
// We should receive a render that occurred while disconnected
8991
var currentValue = element.Text;
@@ -100,13 +102,5 @@ private void Disconnect()
100102
Browser.ExecuteAsyncScript($"fetch('/WebSockets/Interrupt?WebSockets.Identifier={SessionIdentifier}').then(r => window['WebSockets.{SessionIdentifier}'] = r.ok)");
101103
Browser.HasJavaScriptValue(true, $"window['WebSockets.{SessionIdentifier}']", (r) => r != null);
102104
}
103-
104-
private IWebElement WaitUntilReconnectionDialogExists()
105-
{
106-
IWebElement reconnectionDialog = null;
107-
new WebDriverWait(Browser, TimeSpan.FromSeconds(10))
108-
.Until(driver => (reconnectionDialog = driver.FindElement(By.Id("components-reconnect-modal"))) != null);
109-
return reconnectionDialog;
110-
}
111105
}
112106
}

src/Components/test/testassets/TestServer/Infrastructure/InterruptibleSocketMiddleware.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public InterruptibleSocketMiddleware(
2222

2323
public async Task Invoke(HttpContext context)
2424
{
25+
var socketsFeature = context.Features.Get<IHttpWebSocketFeature>();
2526
if (context.Request.Path.Equals(Options.InterruptPath) && context.Request.Query.TryGetValue(Options.WebSocketIdParameterName,out var currentIdentifier))
2627
{
2728
if (Registry.TryGetValue(currentIdentifier, out var webSocket))
@@ -39,7 +40,7 @@ public async Task Invoke(HttpContext context)
3940
if (context.Request.Path.Equals(Options.WebSocketPath, StringComparison.OrdinalIgnoreCase) &&
4041
context.Request.Cookies.TryGetValue(Options.WebSocketIdParameterName, out var identifier))
4142
{
42-
context.Features.Set<IHttpWebSocketFeature>(new InterruptibleWebSocketFeature(context, identifier, Registry));
43+
context.Features.Set<IHttpWebSocketFeature>(new InterruptibleWebSocketFeature(socketsFeature, identifier, Registry));
4344
}
4445

4546
await Next(context);

src/Components/test/testassets/TestServer/Infrastructure/InterruptibleWebSocketAppBuilderExtensions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ public static IApplicationBuilder UseInterruptibleWebSockets(
88
this IApplicationBuilder builder,
99
InterruptibleWebSocketOptions options)
1010
{
11+
builder.UseWebSockets();
1112
builder.UseMiddleware<InterruptibleSocketMiddleware>(options);
1213
return builder;
1314
}
Lines changed: 3 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,30 @@
11
using System;
22
using System.Collections.Concurrent;
3-
using System.Collections.Generic;
4-
using System.IO;
53
using System.Net.WebSockets;
6-
using System.Security.Cryptography;
7-
using System.Text;
84
using System.Threading.Tasks;
9-
using Microsoft.AspNetCore.Builder;
105
using Microsoft.AspNetCore.Http;
116
using Microsoft.AspNetCore.Http.Features;
12-
using Microsoft.AspNetCore.WebSockets;
13-
using Microsoft.Extensions.DependencyInjection;
14-
using Microsoft.Extensions.Options;
15-
using Microsoft.Net.Http.Headers;
167

178
namespace Components.TestServer
189
{
1910
public class InterruptibleWebSocketFeature : IHttpWebSocketFeature
2011
{
2112
public InterruptibleWebSocketFeature(
22-
HttpContext httpContext,
13+
IHttpWebSocketFeature socketsFeature,
2314
string socketIdentifier,
2415
ConcurrentDictionary<string, InterruptibleWebSocket> registry)
2516
{
26-
HttpContext = httpContext;
17+
OriginalFeature = socketsFeature;
2718
SocketIdentifier = socketIdentifier;
28-
OriginalFeature = new UpgradeHandshake(
29-
httpContext,
30-
httpContext.Features.Get<IHttpUpgradeFeature>(),
31-
httpContext.RequestServices.GetRequiredService<IOptions<WebSocketOptions>>().Value);
3219
Registry = registry;
3320
}
3421

3522
public bool IsWebSocketRequest => OriginalFeature.IsWebSocketRequest;
3623

37-
public HttpContext HttpContext { get; }
3824
public string SocketIdentifier { get; }
3925

4026
private IHttpWebSocketFeature OriginalFeature { get; }
27+
4128
public ConcurrentDictionary<string, InterruptibleWebSocket> Registry { get; }
4229

4330
public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext context)
@@ -56,192 +43,5 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext context)
5643
return socket;
5744
});
5845
}
59-
60-
private class UpgradeHandshake : IHttpWebSocketFeature
61-
{
62-
public static readonly IEnumerable<string> NeededHeaders = new[]
63-
{
64-
HeaderNames.Upgrade,
65-
HeaderNames.Connection,
66-
HeaderNames.SecWebSocketKey,
67-
HeaderNames.SecWebSocketVersion
68-
};
69-
70-
private readonly HttpContext _context;
71-
private readonly IHttpUpgradeFeature _upgradeFeature;
72-
private readonly WebSocketOptions _options;
73-
private bool? _isWebSocketRequest;
74-
75-
public UpgradeHandshake(HttpContext context, IHttpUpgradeFeature upgradeFeature, WebSocketOptions options)
76-
{
77-
_context = context;
78-
_upgradeFeature = upgradeFeature;
79-
_options = options;
80-
}
81-
82-
public bool IsWebSocketRequest
83-
{
84-
get
85-
{
86-
if (_isWebSocketRequest == null)
87-
{
88-
if (!_upgradeFeature.IsUpgradableRequest)
89-
{
90-
_isWebSocketRequest = false;
91-
}
92-
else
93-
{
94-
var headers = new List<KeyValuePair<string, string>>();
95-
foreach (string headerName in NeededHeaders)
96-
{
97-
foreach (var value in _context.Request.Headers.GetCommaSeparatedValues(headerName))
98-
{
99-
headers.Add(new KeyValuePair<string, string>(headerName, value));
100-
}
101-
}
102-
_isWebSocketRequest = CheckSupportedWebSocketRequest(_context.Request.Method, headers);
103-
}
104-
}
105-
return _isWebSocketRequest.Value;
106-
}
107-
}
108-
109-
public static bool CheckSupportedWebSocketRequest(string method, IEnumerable<KeyValuePair<string, string>> headers)
110-
{
111-
bool validUpgrade = false, validConnection = false, validKey = false, validVersion = false;
112-
113-
if (!string.Equals("GET", method, StringComparison.OrdinalIgnoreCase))
114-
{
115-
return false;
116-
}
117-
118-
foreach (var pair in headers)
119-
{
120-
if (string.Equals(HeaderNames.Connection, pair.Key, StringComparison.OrdinalIgnoreCase))
121-
{
122-
if (string.Equals(Constants.Headers.ConnectionUpgrade, pair.Value, StringComparison.OrdinalIgnoreCase))
123-
{
124-
validConnection = true;
125-
}
126-
}
127-
else if (string.Equals(HeaderNames.Upgrade, pair.Key, StringComparison.OrdinalIgnoreCase))
128-
{
129-
if (string.Equals(Constants.Headers.UpgradeWebSocket, pair.Value, StringComparison.OrdinalIgnoreCase))
130-
{
131-
validUpgrade = true;
132-
}
133-
}
134-
else if (string.Equals(HeaderNames.SecWebSocketVersion, pair.Key, StringComparison.OrdinalIgnoreCase))
135-
{
136-
if (string.Equals(Constants.Headers.SupportedVersion, pair.Value, StringComparison.OrdinalIgnoreCase))
137-
{
138-
validVersion = true;
139-
}
140-
}
141-
else if (string.Equals(HeaderNames.SecWebSocketKey, pair.Key, StringComparison.OrdinalIgnoreCase))
142-
{
143-
validKey = IsRequestKeyValid(pair.Value);
144-
}
145-
}
146-
147-
return validConnection && validUpgrade && validVersion && validKey;
148-
}
149-
150-
public static bool IsRequestKeyValid(string value)
151-
{
152-
if (string.IsNullOrWhiteSpace(value))
153-
{
154-
return false;
155-
}
156-
try
157-
{
158-
byte[] data = Convert.FromBase64String(value);
159-
return data.Length == 16;
160-
}
161-
catch (Exception)
162-
{
163-
return false;
164-
}
165-
}
166-
167-
public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
168-
{
169-
if (!IsWebSocketRequest)
170-
{
171-
throw new InvalidOperationException("Not a WebSocket request."); // TODO: LOC
172-
}
173-
174-
string subProtocol = null;
175-
if (acceptContext != null)
176-
{
177-
subProtocol = acceptContext.SubProtocol;
178-
}
179-
180-
TimeSpan keepAliveInterval = _options.KeepAliveInterval;
181-
int receiveBufferSize = _options.ReceiveBufferSize;
182-
var advancedAcceptContext = acceptContext as ExtendedWebSocketAcceptContext;
183-
if (advancedAcceptContext != null)
184-
{
185-
if (advancedAcceptContext.ReceiveBufferSize.HasValue)
186-
{
187-
receiveBufferSize = advancedAcceptContext.ReceiveBufferSize.Value;
188-
}
189-
if (advancedAcceptContext.KeepAliveInterval.HasValue)
190-
{
191-
keepAliveInterval = advancedAcceptContext.KeepAliveInterval.Value;
192-
}
193-
}
194-
195-
string key = string.Join(", ", _context.Request.Headers[HeaderNames.SecWebSocketKey]);
196-
197-
GenerateResponseHeaders(key, subProtocol, _context.Response.Headers);
198-
199-
Stream opaqueTransport = await _upgradeFeature.UpgradeAsync(); // Sets status code to 101
200-
201-
return WebSocket.CreateFromStream(opaqueTransport, isServer: true, subProtocol: subProtocol, keepAliveInterval: keepAliveInterval);
202-
}
203-
204-
public static void GenerateResponseHeaders(string key, string subProtocol, IHeaderDictionary headers)
205-
{
206-
headers[HeaderNames.Connection] = Constants.Headers.ConnectionUpgrade;
207-
headers[HeaderNames.Upgrade] = Constants.Headers.UpgradeWebSocket;
208-
headers[HeaderNames.SecWebSocketAccept] = CreateResponseKey(key);
209-
if (!string.IsNullOrWhiteSpace(subProtocol))
210-
{
211-
headers[HeaderNames.SecWebSocketProtocol] = subProtocol;
212-
}
213-
}
214-
215-
public static string CreateResponseKey(string requestKey)
216-
{
217-
// "The value of this header field is constructed by concatenating /key/, defined above in step 4
218-
// in Section 4.2.2, with the string "258EAFA5- E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of
219-
// this concatenated value to obtain a 20-byte value and base64-encoding"
220-
// https://tools.ietf.org/html/rfc6455#section-4.2.2
221-
222-
if (requestKey == null)
223-
{
224-
throw new ArgumentNullException(nameof(requestKey));
225-
}
226-
227-
using (var algorithm = SHA1.Create())
228-
{
229-
string merged = requestKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
230-
byte[] mergedBytes = Encoding.UTF8.GetBytes(merged);
231-
byte[] hashedBytes = algorithm.ComputeHash(mergedBytes);
232-
return Convert.ToBase64String(hashedBytes);
233-
}
234-
}
235-
236-
internal static class Constants
237-
{
238-
public static class Headers
239-
{
240-
public const string UpgradeWebSocket = "websocket";
241-
public const string ConnectionUpgrade = "Upgrade";
242-
public const string SupportedVersion = "13";
243-
}
244-
}
245-
}
24646
}
24747
}

0 commit comments

Comments
 (0)