diff --git a/src/Microsoft.AspNet.TestHost/TestServer.cs b/src/Microsoft.AspNet.TestHost/TestServer.cs
index 571ceb26..8bdd89f8 100644
--- a/src/Microsoft.AspNet.TestHost/TestServer.cs
+++ b/src/Microsoft.AspNet.TestHost/TestServer.cs
@@ -97,6 +97,12 @@ public HttpClient CreateClient()
return new HttpClient(CreateHandler()) { BaseAddress = BaseAddress };
}
+ public WebSocketClient CreateWebSocketClient()
+ {
+ var pathBase = BaseAddress == null ? PathString.Empty : PathString.FromUriComponent(BaseAddress);
+ return new WebSocketClient(Invoke, pathBase);
+ }
+
///
/// Begins constructing a request message for submission.
///
diff --git a/src/Microsoft.AspNet.TestHost/TestWebSocket.cs b/src/Microsoft.AspNet.TestHost/TestWebSocket.cs
new file mode 100644
index 00000000..f028d8f0
--- /dev/null
+++ b/src/Microsoft.AspNet.TestHost/TestWebSocket.cs
@@ -0,0 +1,354 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Net.WebSockets;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.AspNet.TestHost
+{
+ internal class TestWebSocket : WebSocket
+ {
+ private ReceiverSenderBuffer _receiveBuffer;
+ private ReceiverSenderBuffer _sendBuffer;
+ private readonly string _subProtocol;
+ private WebSocketState _state;
+ private WebSocketCloseStatus? _closeStatus;
+ private string _closeStatusDescription;
+ private Message _receiveMessage;
+
+ public static Tuple CreatePair(string subProtocol)
+ {
+ var buffers = new[] { new ReceiverSenderBuffer(), new ReceiverSenderBuffer() };
+ return Tuple.Create(
+ new TestWebSocket(subProtocol, buffers[0], buffers[1]),
+ new TestWebSocket(subProtocol, buffers[1], buffers[0]));
+ }
+
+ public override WebSocketCloseStatus? CloseStatus
+ {
+ get { return _closeStatus; }
+ }
+
+ public override string CloseStatusDescription
+ {
+ get { return _closeStatusDescription; }
+ }
+
+ public override WebSocketState State
+ {
+ get { return _state; }
+ }
+
+ public override string SubProtocol
+ {
+ get { return _subProtocol; }
+ }
+
+ public async override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
+ {
+ ThrowIfDisposed();
+
+ if (State == WebSocketState.Open || State == WebSocketState.CloseReceived)
+ {
+ // Send a close message.
+ await CloseOutputAsync(closeStatus, statusDescription, cancellationToken);
+ }
+
+ if (State == WebSocketState.CloseSent)
+ {
+ // Do a receiving drain
+ var data = new byte[1024];
+ WebSocketReceiveResult result;
+ do
+ {
+ result = await ReceiveAsync(new ArraySegment(data), cancellationToken);
+ }
+ while (result.MessageType != WebSocketMessageType.Close);
+ }
+ }
+
+ public async override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
+ {
+ ThrowIfDisposed();
+ ThrowIfOutputClosed();
+
+ var message = new Message(closeStatus, statusDescription);
+ await _sendBuffer.SendAsync(message, cancellationToken);
+
+ if (State == WebSocketState.Open)
+ {
+ _state = WebSocketState.CloseSent;
+ }
+ else if (State == WebSocketState.CloseReceived)
+ {
+ _state = WebSocketState.Closed;
+ Close();
+ }
+ }
+
+ public override void Abort()
+ {
+ if (_state >= WebSocketState.Closed) // or Aborted
+ {
+ return;
+ }
+
+ _state = WebSocketState.Aborted;
+ Close();
+ }
+
+ public override void Dispose()
+ {
+ if (_state >= WebSocketState.Closed) // or Aborted
+ {
+ return;
+ }
+
+ _state = WebSocketState.Closed;
+ Close();
+ }
+
+ public override async Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken)
+ {
+ ThrowIfDisposed();
+ ThrowIfInputClosed();
+ ValidateSegment(buffer);
+ // TODO: InvalidOperationException if any receives are currently in progress.
+
+ Message receiveMessage = _receiveMessage;
+ _receiveMessage = null;
+ if (receiveMessage == null)
+ {
+ receiveMessage = await _receiveBuffer.ReceiveAsync(cancellationToken);
+ }
+ if (receiveMessage.MessageType == WebSocketMessageType.Close)
+ {
+ _closeStatus = receiveMessage.CloseStatus;
+ _closeStatusDescription = receiveMessage.CloseStatusDescription ?? string.Empty;
+ var result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, _closeStatus, _closeStatusDescription);
+ if (_state == WebSocketState.Open)
+ {
+ _state = WebSocketState.CloseReceived;
+ }
+ else if (_state == WebSocketState.CloseSent)
+ {
+ _state = WebSocketState.Closed;
+ Close();
+ }
+ return result;
+ }
+ else
+ {
+ int count = Math.Min(buffer.Count, receiveMessage.Buffer.Count);
+ bool endOfMessage = count == receiveMessage.Buffer.Count;
+ Array.Copy(receiveMessage.Buffer.Array, receiveMessage.Buffer.Offset, buffer.Array, buffer.Offset, count);
+ if (!endOfMessage)
+ {
+ receiveMessage.Buffer = new ArraySegment(receiveMessage.Buffer.Array, receiveMessage.Buffer.Offset + count, receiveMessage.Buffer.Count - count);
+ _receiveMessage = receiveMessage;
+ }
+ endOfMessage = endOfMessage && receiveMessage.EndOfMessage;
+ return new WebSocketReceiveResult(count, receiveMessage.MessageType, endOfMessage);
+ }
+ }
+
+ public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
+ {
+ ValidateSegment(buffer);
+ if (messageType != WebSocketMessageType.Binary && messageType != WebSocketMessageType.Text)
+ {
+ // Block control frames
+ throw new ArgumentOutOfRangeException(nameof(messageType), messageType, string.Empty);
+ }
+
+ var message = new Message(buffer, messageType, endOfMessage, cancellationToken);
+ return _sendBuffer.SendAsync(message, cancellationToken);
+ }
+
+ private void Close()
+ {
+ _receiveBuffer.SetReceiverClosed();
+ _sendBuffer.SetSenderClosed();
+ }
+
+ private void ThrowIfDisposed()
+ {
+ if (_state >= WebSocketState.Closed) // or Aborted
+ {
+ throw new ObjectDisposedException(typeof(TestWebSocket).FullName);
+ }
+ }
+
+ private void ThrowIfOutputClosed()
+ {
+ if (State == WebSocketState.CloseSent)
+ {
+ throw new InvalidOperationException("Close already sent.");
+ }
+ }
+
+ private void ThrowIfInputClosed()
+ {
+ if (State == WebSocketState.CloseReceived)
+ {
+ throw new InvalidOperationException("Close already received.");
+ }
+ }
+
+ private void ValidateSegment(ArraySegment buffer)
+ {
+ if (buffer.Array == null)
+ {
+ throw new ArgumentNullException(nameof(buffer));
+ }
+ if (buffer.Offset < 0 || buffer.Offset > buffer.Array.Length)
+ {
+ throw new ArgumentOutOfRangeException(nameof(buffer.Offset), buffer.Offset, string.Empty);
+ }
+ if (buffer.Count < 0 || buffer.Count > buffer.Array.Length - buffer.Offset)
+ {
+ throw new ArgumentOutOfRangeException(nameof(buffer.Count), buffer.Count, string.Empty);
+ }
+ }
+
+ private TestWebSocket(string subProtocol, ReceiverSenderBuffer readBuffer, ReceiverSenderBuffer writeBuffer)
+ {
+ _state = WebSocketState.Open;
+ _subProtocol = subProtocol;
+ _receiveBuffer = readBuffer;
+ _sendBuffer = writeBuffer;
+ }
+
+ private class Message
+ {
+ public Message(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken token)
+ {
+ Buffer = buffer;
+ CloseStatus = null;
+ CloseStatusDescription = null;
+ EndOfMessage = endOfMessage;
+ MessageType = messageType;
+ }
+
+ public Message(WebSocketCloseStatus? closeStatus, string closeStatusDescription)
+ {
+ Buffer = new ArraySegment(new byte[0]);
+ CloseStatus = closeStatus;
+ CloseStatusDescription = closeStatusDescription;
+ MessageType = WebSocketMessageType.Close;
+ EndOfMessage = true;
+ }
+
+ public WebSocketCloseStatus? CloseStatus { get; set; }
+ public string CloseStatusDescription { get; set; }
+ public ArraySegment Buffer { get; set; }
+ public bool EndOfMessage { get; set; }
+ public WebSocketMessageType MessageType { get; set; }
+ }
+
+ private class ReceiverSenderBuffer
+ {
+ private bool _receiverClosed;
+ private bool _senderClosed;
+ private bool _disposed;
+ private SemaphoreSlim _sem;
+ private Queue _messageQueue;
+
+ public ReceiverSenderBuffer()
+ {
+ _sem = new SemaphoreSlim(0);
+ _messageQueue = new Queue();
+ }
+
+ public async virtual Task ReceiveAsync(CancellationToken cancellationToken)
+ {
+ if (_disposed)
+ {
+ ThrowNoReceive();
+ }
+ await _sem.WaitAsync(cancellationToken);
+ lock (_messageQueue)
+ {
+ if (_messageQueue.Count == 0)
+ {
+ _disposed = true;
+ _sem.Dispose();
+ ThrowNoReceive();
+ }
+ return _messageQueue.Dequeue();
+ }
+ }
+
+ public virtual Task SendAsync(Message message, CancellationToken cancellationToken)
+ {
+ lock (_messageQueue)
+ {
+ if (_senderClosed)
+ {
+ throw new ObjectDisposedException(typeof(TestWebSocket).FullName);
+ }
+ if (_receiverClosed)
+ {
+ throw new IOException("The remote end closed the connection.", new ObjectDisposedException(typeof(TestWebSocket).FullName));
+ }
+
+ // we return immediately so we need to copy the buffer since the sender can re-use it
+ var array = new byte[message.Buffer.Count];
+ Array.Copy(message.Buffer.Array, message.Buffer.Offset, array, 0, message.Buffer.Count);
+ message.Buffer = new ArraySegment(array);
+
+ _messageQueue.Enqueue(message);
+ _sem.Release();
+
+ return Task.FromResult(true);
+ }
+ }
+
+ public void SetReceiverClosed()
+ {
+ lock (_messageQueue)
+ {
+ if (!_receiverClosed)
+ {
+ _receiverClosed = true;
+ if (!_disposed)
+ {
+ _sem.Release();
+ }
+ }
+ }
+ }
+
+ public void SetSenderClosed()
+ {
+ lock (_messageQueue)
+ {
+ if (!_senderClosed)
+ {
+ _senderClosed = true;
+ if (!_disposed)
+ {
+ _sem.Release();
+ }
+ }
+ }
+ }
+
+ private void ThrowNoReceive()
+ {
+ if (_receiverClosed)
+ {
+ throw new ObjectDisposedException(typeof(TestWebSocket).FullName);
+ }
+ else // _senderClosed
+ {
+ throw new IOException("The remote end closed the connection.", new ObjectDisposedException(typeof(TestWebSocket).FullName));
+ }
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.AspNet.TestHost/WebSocketClient.cs b/src/Microsoft.AspNet.TestHost/WebSocketClient.cs
new file mode 100644
index 00000000..0c0f958f
--- /dev/null
+++ b/src/Microsoft.AspNet.TestHost/WebSocketClient.cs
@@ -0,0 +1,179 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Net.WebSockets;
+using System.Security.Cryptography;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.AspNet.Http.Features;
+using Microsoft.AspNet.Http;
+using Microsoft.AspNet.Http.Internal;
+using Microsoft.Framework.Internal;
+
+namespace Microsoft.AspNet.TestHost
+{
+ public class WebSocketClient
+ {
+ private readonly Func _next;
+ private readonly PathString _pathBase;
+
+ internal WebSocketClient([NotNull] Func next, PathString pathBase)
+ {
+ _next = next;
+
+ // PathString.StartsWithSegments that we use below requires the base path to not end in a slash.
+ if (pathBase.HasValue && pathBase.Value.EndsWith("/"))
+ {
+ pathBase = new PathString(pathBase.Value.Substring(0, pathBase.Value.Length - 1));
+ }
+ _pathBase = pathBase;
+
+ SubProtocols = new List();
+ }
+
+ public IList SubProtocols
+ {
+ get;
+ private set;
+ }
+
+ public Action ConfigureRequest
+ {
+ get;
+ set;
+ }
+
+ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken)
+ {
+ var state = new RequestState(uri, _pathBase, cancellationToken);
+
+ if (ConfigureRequest != null)
+ {
+ ConfigureRequest(state.HttpContext.Request);
+ }
+
+ // Async offload, don't let the test code block the caller.
+ var offload = Task.Factory.StartNew(async () =>
+ {
+ try
+ {
+ await _next(state.FeatureCollection);
+ state.PipelineComplete();
+ }
+ catch (Exception ex)
+ {
+ state.PipelineFailed(ex);
+ }
+ finally
+ {
+ state.Dispose();
+ }
+ });
+
+ return await state.WebSocketTask;
+ }
+
+ private class RequestState : IDisposable, IHttpWebSocketFeature
+ {
+ private TaskCompletionSource _clientWebSocketTcs;
+ private WebSocket _serverWebSocket;
+
+ public IFeatureCollection FeatureCollection { get; private set; }
+ public HttpContext HttpContext { get; private set; }
+ public Task WebSocketTask { get { return _clientWebSocketTcs.Task; } }
+
+ public RequestState(Uri uri, PathString pathBase, CancellationToken cancellationToken)
+ {
+ _clientWebSocketTcs = new TaskCompletionSource();
+
+ // HttpContext
+ FeatureCollection = new FeatureCollection();
+ HttpContext = new DefaultHttpContext(FeatureCollection);
+
+ // Request
+ HttpContext.SetFeature(new RequestFeature());
+ var request = HttpContext.Request;
+ request.Protocol = "HTTP/1.1";
+ var scheme = uri.Scheme;
+ scheme = (scheme == "ws") ? "http" : scheme;
+ scheme = (scheme == "wss") ? "https" : scheme;
+ request.Scheme = scheme;
+ request.Method = "GET";
+ var fullPath = PathString.FromUriComponent(uri);
+ PathString remainder;
+ if (fullPath.StartsWithSegments(pathBase, out remainder))
+ {
+ request.PathBase = pathBase;
+ request.Path = remainder;
+ }
+ else
+ {
+ request.PathBase = PathString.Empty;
+ request.Path = fullPath;
+ }
+ request.QueryString = QueryString.FromUriComponent(uri);
+ request.Headers.Add("Connection", new string[] { "Upgrade" });
+ request.Headers.Add("Upgrade", new string[] { "websocket" });
+ request.Headers.Add("Sec-WebSocket-Version", new string[] { "13" });
+ request.Headers.Add("Sec-WebSocket-Key", new string[] { CreateRequestKey() });
+ request.Body = Stream.Null;
+
+ // Response
+ HttpContext.SetFeature(new ResponseFeature());
+ var response = HttpContext.Response;
+ response.Body = Stream.Null;
+ response.StatusCode = 200;
+
+ // WebSocket
+ HttpContext.SetFeature(this);
+ }
+
+ public void PipelineComplete()
+ {
+ PipelineFailed(new InvalidOperationException("Incomplete handshake, status code: " + HttpContext.Response.StatusCode));
+ }
+
+ public void PipelineFailed(Exception ex)
+ {
+ _clientWebSocketTcs.TrySetException(new InvalidOperationException("The websocket was not accepted.", ex));
+ }
+
+ public void Dispose()
+ {
+ if (_serverWebSocket != null)
+ {
+ _serverWebSocket.Dispose();
+ }
+ }
+
+ private string CreateRequestKey()
+ {
+ byte[] data = new byte[16];
+ var rng = RandomNumberGenerator.Create();
+ rng.GetBytes(data);
+ return Convert.ToBase64String(data);
+ }
+
+ bool IHttpWebSocketFeature.IsWebSocketRequest
+ {
+ get
+ {
+ return true;
+ }
+ }
+
+ Task IHttpWebSocketFeature.AcceptAsync(WebSocketAcceptContext context)
+ {
+ HttpContext.Response.StatusCode = 101; // Switching Protocols
+
+ var websockets = TestWebSocket.CreatePair(context.SubProtocol);
+ _clientWebSocketTcs.SetResult(websockets.Item1);
+ _serverWebSocket = websockets.Item2;
+ return Task.FromResult(_serverWebSocket);
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs b/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs
index b77a6818..e207d499 100644
--- a/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs
+++ b/test/Microsoft.AspNet.TestHost.Tests/TestClientTests.cs
@@ -2,7 +2,11 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
+using System.Linq;
using System.Net.Http;
+using System.Net.WebSockets;
+using System.Text;
+using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNet.Builder;
using Microsoft.AspNet.Http;
@@ -111,5 +115,135 @@ public async Task PostAsyncWorks()
// Assert
Assert.Equal("Hello world POST Response", await response.Content.ReadAsStringAsync());
}
+
+ [Fact]
+ public async Task WebSocketWorks()
+ {
+ // Arrange
+ RequestDelegate appDelegate = async ctx =>
+ {
+ if (ctx.WebSockets.IsWebSocketRequest)
+ {
+ var websocket = await ctx.WebSockets.AcceptWebSocketAsync();
+ var receiveArray = new byte[1024];
+ while (true)
+ {
+ var receiveResult = await websocket.ReceiveAsync(new System.ArraySegment(receiveArray), CancellationToken.None);
+ if (receiveResult.MessageType == WebSocketMessageType.Close)
+ {
+ await websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None);
+ break;
+ }
+ else
+ {
+ var sendBuffer = new System.ArraySegment(receiveArray, 0, receiveResult.Count);
+ await websocket.SendAsync(sendBuffer, receiveResult.MessageType, receiveResult.EndOfMessage, CancellationToken.None);
+ }
+ }
+ }
+ };
+ var server = TestServer.Create(app =>
+ {
+ app.Run(appDelegate);
+ });
+
+ // Act
+ var client = server.CreateWebSocketClient();
+ var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None);
+ var hello = Encoding.UTF8.GetBytes("hello");
+ await clientSocket.SendAsync(new System.ArraySegment(hello), WebSocketMessageType.Text, true, CancellationToken.None);
+ var world = Encoding.UTF8.GetBytes("world!");
+ await clientSocket.SendAsync(new System.ArraySegment(world), WebSocketMessageType.Binary, true, CancellationToken.None);
+ await clientSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None);
+
+ // Assert
+ Assert.Equal(WebSocketState.CloseSent, clientSocket.State);
+
+ var buffer = new byte[1024];
+ var result = await clientSocket.ReceiveAsync(new System.ArraySegment(buffer), CancellationToken.None);
+ Assert.Equal(hello.Length, result.Count);
+ Assert.True(hello.SequenceEqual(buffer.Take(hello.Length)));
+ Assert.Equal(WebSocketMessageType.Text, result.MessageType);
+
+ result = await clientSocket.ReceiveAsync(new System.ArraySegment(buffer), CancellationToken.None);
+ Assert.Equal(world.Length, result.Count);
+ Assert.True(world.SequenceEqual(buffer.Take(world.Length)));
+ Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
+
+ result = await clientSocket.ReceiveAsync(new System.ArraySegment(buffer), CancellationToken.None);
+ Assert.Equal(WebSocketMessageType.Close, result.MessageType);
+ Assert.Equal(WebSocketState.Closed, clientSocket.State);
+
+ clientSocket.Dispose();
+ }
+
+ [Fact]
+ public async Task WebSocketDisposalThrowsOnPeer()
+ {
+ // Arrange
+ RequestDelegate appDelegate = async ctx =>
+ {
+ if (ctx.WebSockets.IsWebSocketRequest)
+ {
+ var websocket = await ctx.WebSockets.AcceptWebSocketAsync();
+ websocket.Dispose();
+ }
+ };
+ var server = TestServer.Create(app =>
+ {
+ app.Run(appDelegate);
+ });
+
+ // Act
+ var client = server.CreateWebSocketClient();
+ var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None);
+ var buffer = new byte[1024];
+ await Assert.ThrowsAsync(async () => await clientSocket.ReceiveAsync(new System.ArraySegment(buffer), CancellationToken.None));
+
+ clientSocket.Dispose();
+ }
+
+ [Fact]
+ public async Task WebSocketTinyReceiveGeneratesEndOfMessage()
+ {
+ // Arrange
+ RequestDelegate appDelegate = async ctx =>
+ {
+ if (ctx.WebSockets.IsWebSocketRequest)
+ {
+ var websocket = await ctx.WebSockets.AcceptWebSocketAsync();
+ var receiveArray = new byte[1024];
+ while (true)
+ {
+ var receiveResult = await websocket.ReceiveAsync(new System.ArraySegment(receiveArray), CancellationToken.None);
+ var sendBuffer = new System.ArraySegment(receiveArray, 0, receiveResult.Count);
+ await websocket.SendAsync(sendBuffer, receiveResult.MessageType, receiveResult.EndOfMessage, CancellationToken.None);
+ }
+ }
+ };
+ var server = TestServer.Create(app =>
+ {
+ app.Run(appDelegate);
+ });
+
+ // Act
+ var client = server.CreateWebSocketClient();
+ var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None);
+ var hello = Encoding.UTF8.GetBytes("hello");
+ await clientSocket.SendAsync(new System.ArraySegment(hello), WebSocketMessageType.Text, true, CancellationToken.None);
+
+ // Assert
+ var buffer = new byte[1];
+ for (int i = 0; i < hello.Length; i++)
+ {
+ bool last = i == (hello.Length - 1);
+ var result = await clientSocket.ReceiveAsync(new System.ArraySegment(buffer), CancellationToken.None);
+ Assert.Equal(buffer.Length, result.Count);
+ Assert.Equal(buffer[0], hello[i]);
+ Assert.Equal(last, result.EndOfMessage);
+ }
+
+ clientSocket.Dispose();
+ }
}
}