diff --git a/src/Microsoft.AspNet.TestHost/TestServer.cs b/src/Microsoft.AspNet.TestHost/TestServer.cs index 571ceb26..8bdd89f8 100644 --- a/src/Microsoft.AspNet.TestHost/TestServer.cs +++ b/src/Microsoft.AspNet.TestHost/TestServer.cs @@ -97,6 +97,12 @@ public HttpClient CreateClient() return new HttpClient(CreateHandler()) { BaseAddress = BaseAddress }; } + public WebSocketClient CreateWebSocketClient() + { + var pathBase = BaseAddress == null ? PathString.Empty : PathString.FromUriComponent(BaseAddress); + return new WebSocketClient(Invoke, pathBase); + } + /// /// Begins constructing a request message for submission. /// diff --git a/src/Microsoft.AspNet.TestHost/TestWebSocket.cs b/src/Microsoft.AspNet.TestHost/TestWebSocket.cs new file mode 100644 index 00000000..f028d8f0 --- /dev/null +++ b/src/Microsoft.AspNet.TestHost/TestWebSocket.cs @@ -0,0 +1,354 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.TestHost +{ + internal class TestWebSocket : WebSocket + { + private ReceiverSenderBuffer _receiveBuffer; + private ReceiverSenderBuffer _sendBuffer; + private readonly string _subProtocol; + private WebSocketState _state; + private WebSocketCloseStatus? _closeStatus; + private string _closeStatusDescription; + private Message _receiveMessage; + + public static Tuple CreatePair(string subProtocol) + { + var buffers = new[] { new ReceiverSenderBuffer(), new ReceiverSenderBuffer() }; + return Tuple.Create( + new TestWebSocket(subProtocol, buffers[0], buffers[1]), + new TestWebSocket(subProtocol, buffers[1], buffers[0])); + } + + public override WebSocketCloseStatus? CloseStatus + { + get { return _closeStatus; } + } + + public override string CloseStatusDescription + { + get { return _closeStatusDescription; } + } + + public override WebSocketState State + { + get { return _state; } + } + + public override string SubProtocol + { + get { return _subProtocol; } + } + + public async override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + { + ThrowIfDisposed(); + + if (State == WebSocketState.Open || State == WebSocketState.CloseReceived) + { + // Send a close message. + await CloseOutputAsync(closeStatus, statusDescription, cancellationToken); + } + + if (State == WebSocketState.CloseSent) + { + // Do a receiving drain + var data = new byte[1024]; + WebSocketReceiveResult result; + do + { + result = await ReceiveAsync(new ArraySegment(data), cancellationToken); + } + while (result.MessageType != WebSocketMessageType.Close); + } + } + + public async override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + { + ThrowIfDisposed(); + ThrowIfOutputClosed(); + + var message = new Message(closeStatus, statusDescription); + await _sendBuffer.SendAsync(message, cancellationToken); + + if (State == WebSocketState.Open) + { + _state = WebSocketState.CloseSent; + } + else if (State == WebSocketState.CloseReceived) + { + _state = WebSocketState.Closed; + Close(); + } + } + + public override void Abort() + { + if (_state >= WebSocketState.Closed) // or Aborted + { + return; + } + + _state = WebSocketState.Aborted; + Close(); + } + + public override void Dispose() + { + if (_state >= WebSocketState.Closed) // or Aborted + { + return; + } + + _state = WebSocketState.Closed; + Close(); + } + + public override async Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + ThrowIfDisposed(); + ThrowIfInputClosed(); + ValidateSegment(buffer); + // TODO: InvalidOperationException if any receives are currently in progress. + + Message receiveMessage = _receiveMessage; + _receiveMessage = null; + if (receiveMessage == null) + { + receiveMessage = await _receiveBuffer.ReceiveAsync(cancellationToken); + } + if (receiveMessage.MessageType == WebSocketMessageType.Close) + { + _closeStatus = receiveMessage.CloseStatus; + _closeStatusDescription = receiveMessage.CloseStatusDescription ?? string.Empty; + var result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, _closeStatus, _closeStatusDescription); + if (_state == WebSocketState.Open) + { + _state = WebSocketState.CloseReceived; + } + else if (_state == WebSocketState.CloseSent) + { + _state = WebSocketState.Closed; + Close(); + } + return result; + } + else + { + int count = Math.Min(buffer.Count, receiveMessage.Buffer.Count); + bool endOfMessage = count == receiveMessage.Buffer.Count; + Array.Copy(receiveMessage.Buffer.Array, receiveMessage.Buffer.Offset, buffer.Array, buffer.Offset, count); + if (!endOfMessage) + { + receiveMessage.Buffer = new ArraySegment(receiveMessage.Buffer.Array, receiveMessage.Buffer.Offset + count, receiveMessage.Buffer.Count - count); + _receiveMessage = receiveMessage; + } + endOfMessage = endOfMessage && receiveMessage.EndOfMessage; + return new WebSocketReceiveResult(count, receiveMessage.MessageType, endOfMessage); + } + } + + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + { + ValidateSegment(buffer); + if (messageType != WebSocketMessageType.Binary && messageType != WebSocketMessageType.Text) + { + // Block control frames + throw new ArgumentOutOfRangeException(nameof(messageType), messageType, string.Empty); + } + + var message = new Message(buffer, messageType, endOfMessage, cancellationToken); + return _sendBuffer.SendAsync(message, cancellationToken); + } + + private void Close() + { + _receiveBuffer.SetReceiverClosed(); + _sendBuffer.SetSenderClosed(); + } + + private void ThrowIfDisposed() + { + if (_state >= WebSocketState.Closed) // or Aborted + { + throw new ObjectDisposedException(typeof(TestWebSocket).FullName); + } + } + + private void ThrowIfOutputClosed() + { + if (State == WebSocketState.CloseSent) + { + throw new InvalidOperationException("Close already sent."); + } + } + + private void ThrowIfInputClosed() + { + if (State == WebSocketState.CloseReceived) + { + throw new InvalidOperationException("Close already received."); + } + } + + private void ValidateSegment(ArraySegment buffer) + { + if (buffer.Array == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + if (buffer.Offset < 0 || buffer.Offset > buffer.Array.Length) + { + throw new ArgumentOutOfRangeException(nameof(buffer.Offset), buffer.Offset, string.Empty); + } + if (buffer.Count < 0 || buffer.Count > buffer.Array.Length - buffer.Offset) + { + throw new ArgumentOutOfRangeException(nameof(buffer.Count), buffer.Count, string.Empty); + } + } + + private TestWebSocket(string subProtocol, ReceiverSenderBuffer readBuffer, ReceiverSenderBuffer writeBuffer) + { + _state = WebSocketState.Open; + _subProtocol = subProtocol; + _receiveBuffer = readBuffer; + _sendBuffer = writeBuffer; + } + + private class Message + { + public Message(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken token) + { + Buffer = buffer; + CloseStatus = null; + CloseStatusDescription = null; + EndOfMessage = endOfMessage; + MessageType = messageType; + } + + public Message(WebSocketCloseStatus? closeStatus, string closeStatusDescription) + { + Buffer = new ArraySegment(new byte[0]); + CloseStatus = closeStatus; + CloseStatusDescription = closeStatusDescription; + MessageType = WebSocketMessageType.Close; + EndOfMessage = true; + } + + public WebSocketCloseStatus? CloseStatus { get; set; } + public string CloseStatusDescription { get; set; } + public ArraySegment Buffer { get; set; } + public bool EndOfMessage { get; set; } + public WebSocketMessageType MessageType { get; set; } + } + + private class ReceiverSenderBuffer + { + private bool _receiverClosed; + private bool _senderClosed; + private bool _disposed; + private SemaphoreSlim _sem; + private Queue _messageQueue; + + public ReceiverSenderBuffer() + { + _sem = new SemaphoreSlim(0); + _messageQueue = new Queue(); + } + + public async virtual Task ReceiveAsync(CancellationToken cancellationToken) + { + if (_disposed) + { + ThrowNoReceive(); + } + await _sem.WaitAsync(cancellationToken); + lock (_messageQueue) + { + if (_messageQueue.Count == 0) + { + _disposed = true; + _sem.Dispose(); + ThrowNoReceive(); + } + return _messageQueue.Dequeue(); + } + } + + public virtual Task SendAsync(Message message, CancellationToken cancellationToken) + { + lock (_messageQueue) + { + if (_senderClosed) + { + throw new ObjectDisposedException(typeof(TestWebSocket).FullName); + } + if (_receiverClosed) + { + throw new IOException("The remote end closed the connection.", new ObjectDisposedException(typeof(TestWebSocket).FullName)); + } + + // we return immediately so we need to copy the buffer since the sender can re-use it + var array = new byte[message.Buffer.Count]; + Array.Copy(message.Buffer.Array, message.Buffer.Offset, array, 0, message.Buffer.Count); + message.Buffer = new ArraySegment(array); + + _messageQueue.Enqueue(message); + _sem.Release(); + + return Task.FromResult(true); + } + } + + public void SetReceiverClosed() + { + lock (_messageQueue) + { + if (!_receiverClosed) + { + _receiverClosed = true; + if (!_disposed) + { + _sem.Release(); + } + } + } + } + + public void SetSenderClosed() + { + lock (_messageQueue) + { + if (!_senderClosed) + { + _senderClosed = true; + if (!_disposed) + { + _sem.Release(); + } + } + } + } + + private void ThrowNoReceive() + { + if (_receiverClosed) + { + throw new ObjectDisposedException(typeof(TestWebSocket).FullName); + } + else // _senderClosed + { + throw new IOException("The remote end closed the connection.", new ObjectDisposedException(typeof(TestWebSocket).FullName)); + } + } + } + } +} diff --git a/src/Microsoft.AspNet.TestHost/WebSocketClient.cs b/src/Microsoft.AspNet.TestHost/WebSocketClient.cs new file mode 100644 index 00000000..0c0f958f --- /dev/null +++ b/src/Microsoft.AspNet.TestHost/WebSocketClient.cs @@ -0,0 +1,179 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.WebSockets; +using System.Security.Cryptography; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNet.Http.Features; +using Microsoft.AspNet.Http; +using Microsoft.AspNet.Http.Internal; +using Microsoft.Framework.Internal; + +namespace Microsoft.AspNet.TestHost +{ + public class WebSocketClient + { + private readonly Func _next; + private readonly PathString _pathBase; + + internal WebSocketClient([NotNull] Func next, PathString pathBase) + { + _next = next; + + // PathString.StartsWithSegments that we use below requires the base path to not end in a slash. + if (pathBase.HasValue && pathBase.Value.EndsWith("/")) + { + pathBase = new PathString(pathBase.Value.Substring(0, pathBase.Value.Length - 1)); + } + _pathBase = pathBase; + + SubProtocols = new List(); + } + + public IList SubProtocols + { + get; + private set; + } + + public Action ConfigureRequest + { + get; + set; + } + + public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken) + { + var state = new RequestState(uri, _pathBase, cancellationToken); + + if (ConfigureRequest != null) + { + ConfigureRequest(state.HttpContext.Request); + } + + // Async offload, don't let the test code block the caller. + var offload = Task.Factory.StartNew(async () => + { + try + { + await _next(state.FeatureCollection); + state.PipelineComplete(); + } + catch (Exception ex) + { + state.PipelineFailed(ex); + } + finally + { + state.Dispose(); + } + }); + + return await state.WebSocketTask; + } + + private class RequestState : IDisposable, IHttpWebSocketFeature + { + private TaskCompletionSource _clientWebSocketTcs; + private WebSocket _serverWebSocket; + + public IFeatureCollection FeatureCollection { get; private set; } + public HttpContext HttpContext { get; private set; } + public Task WebSocketTask { get { return _clientWebSocketTcs.Task; } } + + public RequestState(Uri uri, PathString pathBase, CancellationToken cancellationToken) + { + _clientWebSocketTcs = new TaskCompletionSource(); + + // HttpContext + FeatureCollection = new FeatureCollection(); + HttpContext = new DefaultHttpContext(FeatureCollection); + + // Request + HttpContext.SetFeature(new RequestFeature()); + var request = HttpContext.Request; + request.Protocol = "HTTP/1.1"; + var scheme = uri.Scheme; + scheme = (scheme == "ws") ? "http" : scheme; + scheme = (scheme == "wss") ? "https" : scheme; + request.Scheme = scheme; + request.Method = "GET"; + var fullPath = PathString.FromUriComponent(uri); + PathString remainder; + if (fullPath.StartsWithSegments(pathBase, out remainder)) + { + request.PathBase = pathBase; + request.Path = remainder; + } + else + { + request.PathBase = PathString.Empty; + request.Path = fullPath; + } + request.QueryString = QueryString.FromUriComponent(uri); + request.Headers.Add("Connection", new string[] { "Upgrade" }); + request.Headers.Add("Upgrade", new string[] { "websocket" }); + request.Headers.Add("Sec-WebSocket-Version", new string[] { "13" }); + request.Headers.Add("Sec-WebSocket-Key", new string[] { CreateRequestKey() }); + request.Body = Stream.Null; + + // Response + HttpContext.SetFeature(new ResponseFeature()); + var response = HttpContext.Response; + response.Body = Stream.Null; + response.StatusCode = 200; + + // WebSocket + HttpContext.SetFeature(this); + } + + public void PipelineComplete() + { + PipelineFailed(new InvalidOperationException("Incomplete handshake, status code: " + HttpContext.Response.StatusCode)); + } + + public void PipelineFailed(Exception ex) + { + _clientWebSocketTcs.TrySetException(new InvalidOperationException("The websocket was not accepted.", ex)); + } + + public void Dispose() + { + if (_serverWebSocket != null) + { + _serverWebSocket.Dispose(); + } + } + + private string CreateRequestKey() + { + byte[] data = new byte[16]; + var rng = RandomNumberGenerator.Create(); + rng.GetBytes(data); + return Convert.ToBase64String(data); + } + + bool IHttpWebSocketFeature.IsWebSocketRequest + { + get + { + return true; + } + } + + Task IHttpWebSocketFeature.AcceptAsync(WebSocketAcceptContext context) + { + HttpContext.Response.StatusCode = 101; // Switching Protocols + + var websockets = TestWebSocket.CreatePair(context.SubProtocol); + _clientWebSocketTcs.SetResult(websockets.Item1); + _serverWebSocket = websockets.Item2; + return Task.FromResult(_serverWebSocket); + } + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs b/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs index b77a6818..e207d499 100644 --- a/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs +++ b/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs @@ -2,7 +2,11 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.IO; +using System.Linq; using System.Net.Http; +using System.Net.WebSockets; +using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNet.Builder; using Microsoft.AspNet.Http; @@ -111,5 +115,135 @@ public async Task PostAsyncWorks() // Assert Assert.Equal("Hello world POST Response", await response.Content.ReadAsStringAsync()); } + + [Fact] + public async Task WebSocketWorks() + { + // Arrange + RequestDelegate appDelegate = async ctx => + { + if (ctx.WebSockets.IsWebSocketRequest) + { + var websocket = await ctx.WebSockets.AcceptWebSocketAsync(); + var receiveArray = new byte[1024]; + while (true) + { + var receiveResult = await websocket.ReceiveAsync(new System.ArraySegment(receiveArray), CancellationToken.None); + if (receiveResult.MessageType == WebSocketMessageType.Close) + { + await websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None); + break; + } + else + { + var sendBuffer = new System.ArraySegment(receiveArray, 0, receiveResult.Count); + await websocket.SendAsync(sendBuffer, receiveResult.MessageType, receiveResult.EndOfMessage, CancellationToken.None); + } + } + } + }; + var server = TestServer.Create(app => + { + app.Run(appDelegate); + }); + + // Act + var client = server.CreateWebSocketClient(); + var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None); + var hello = Encoding.UTF8.GetBytes("hello"); + await clientSocket.SendAsync(new System.ArraySegment(hello), WebSocketMessageType.Text, true, CancellationToken.None); + var world = Encoding.UTF8.GetBytes("world!"); + await clientSocket.SendAsync(new System.ArraySegment(world), WebSocketMessageType.Binary, true, CancellationToken.None); + await clientSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None); + + // Assert + Assert.Equal(WebSocketState.CloseSent, clientSocket.State); + + var buffer = new byte[1024]; + var result = await clientSocket.ReceiveAsync(new System.ArraySegment(buffer), CancellationToken.None); + Assert.Equal(hello.Length, result.Count); + Assert.True(hello.SequenceEqual(buffer.Take(hello.Length))); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + result = await clientSocket.ReceiveAsync(new System.ArraySegment(buffer), CancellationToken.None); + Assert.Equal(world.Length, result.Count); + Assert.True(world.SequenceEqual(buffer.Take(world.Length))); + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + + result = await clientSocket.ReceiveAsync(new System.ArraySegment(buffer), CancellationToken.None); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketState.Closed, clientSocket.State); + + clientSocket.Dispose(); + } + + [Fact] + public async Task WebSocketDisposalThrowsOnPeer() + { + // Arrange + RequestDelegate appDelegate = async ctx => + { + if (ctx.WebSockets.IsWebSocketRequest) + { + var websocket = await ctx.WebSockets.AcceptWebSocketAsync(); + websocket.Dispose(); + } + }; + var server = TestServer.Create(app => + { + app.Run(appDelegate); + }); + + // Act + var client = server.CreateWebSocketClient(); + var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None); + var buffer = new byte[1024]; + await Assert.ThrowsAsync(async () => await clientSocket.ReceiveAsync(new System.ArraySegment(buffer), CancellationToken.None)); + + clientSocket.Dispose(); + } + + [Fact] + public async Task WebSocketTinyReceiveGeneratesEndOfMessage() + { + // Arrange + RequestDelegate appDelegate = async ctx => + { + if (ctx.WebSockets.IsWebSocketRequest) + { + var websocket = await ctx.WebSockets.AcceptWebSocketAsync(); + var receiveArray = new byte[1024]; + while (true) + { + var receiveResult = await websocket.ReceiveAsync(new System.ArraySegment(receiveArray), CancellationToken.None); + var sendBuffer = new System.ArraySegment(receiveArray, 0, receiveResult.Count); + await websocket.SendAsync(sendBuffer, receiveResult.MessageType, receiveResult.EndOfMessage, CancellationToken.None); + } + } + }; + var server = TestServer.Create(app => + { + app.Run(appDelegate); + }); + + // Act + var client = server.CreateWebSocketClient(); + var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None); + var hello = Encoding.UTF8.GetBytes("hello"); + await clientSocket.SendAsync(new System.ArraySegment(hello), WebSocketMessageType.Text, true, CancellationToken.None); + + // Assert + var buffer = new byte[1]; + for (int i = 0; i < hello.Length; i++) + { + bool last = i == (hello.Length - 1); + var result = await clientSocket.ReceiveAsync(new System.ArraySegment(buffer), CancellationToken.None); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal(buffer[0], hello[i]); + Assert.Equal(last, result.EndOfMessage); + } + + clientSocket.Dispose(); + } } }