diff --git a/src/Transports.Subscriptions.Abstractions/OperationMessage.cs b/src/Transports.Subscriptions.Abstractions/OperationMessage.cs index bda4bb91..d172e404 100644 --- a/src/Transports.Subscriptions.Abstractions/OperationMessage.cs +++ b/src/Transports.Subscriptions.Abstractions/OperationMessage.cs @@ -1,5 +1,3 @@ -using Newtonsoft.Json.Linq; - namespace GraphQL.Server.Transports.Subscriptions.Abstractions { /// @@ -20,7 +18,7 @@ public class OperationMessage /// /// Nullable payload /// - public JObject Payload { get; set; } + public object Payload { get; set; } /// diff --git a/src/Transports.Subscriptions.Abstractions/ProtocolMessageListener.cs b/src/Transports.Subscriptions.Abstractions/ProtocolMessageListener.cs index ee4ee9ef..bd035b0d 100644 --- a/src/Transports.Subscriptions.Abstractions/ProtocolMessageListener.cs +++ b/src/Transports.Subscriptions.Abstractions/ProtocolMessageListener.cs @@ -54,14 +54,13 @@ private Task HandleUnknownAsync(MessageHandlingContext context) { Type = MessageType.GQL_CONNECTION_ERROR, Id = message.Id, - Payload = JObject.FromObject(new + Payload = new ExecutionResult { - message.Id, Errors = new ExecutionErrors { new ExecutionError($"Unexpected message type {message.Type}") } - }) + } }); } @@ -76,7 +75,7 @@ private Task HandleStartAsync(MessageHandlingContext context) { var message = context.Message; _logger.LogDebug("Handle start: {id}", message.Id); - var payload = message.Payload.ToObject(); + var payload = ((JObject)message.Payload).ToObject(); if (payload == null) throw new InvalidOperationException($"Could not get OperationMessagePayload from message.Payload"); diff --git a/src/Transports.Subscriptions.Abstractions/Subscription.cs b/src/Transports.Subscriptions.Abstractions/Subscription.cs index a65e3121..012fb1c6 100644 --- a/src/Transports.Subscriptions.Abstractions/Subscription.cs +++ b/src/Transports.Subscriptions.Abstractions/Subscription.cs @@ -63,7 +63,7 @@ public void OnNext(ExecutionResult value) { Type = MessageType.GQL_DATA, Id = Id, - Payload = JObject.FromObject(value) + Payload = value }); } diff --git a/src/Transports.Subscriptions.Abstractions/SubscriptionManager.cs b/src/Transports.Subscriptions.Abstractions/SubscriptionManager.cs index b953921f..15e5f3cb 100644 --- a/src/Transports.Subscriptions.Abstractions/SubscriptionManager.cs +++ b/src/Transports.Subscriptions.Abstractions/SubscriptionManager.cs @@ -93,7 +93,7 @@ await writer.SendAsync(new OperationMessage { Type = MessageType.GQL_ERROR, Id = id, - Payload = JObject.FromObject(result) + Payload = result }); return null; @@ -110,7 +110,7 @@ await writer.SendAsync(new OperationMessage { Type = MessageType.GQL_ERROR, Id = id, - Payload = JObject.FromObject(result) + Payload = result }); return null; @@ -131,7 +131,7 @@ await writer.SendAsync(new OperationMessage { Type = MessageType.GQL_DATA, Id = id, - Payload = JObject.FromObject(result) + Payload = result }); await writer.SendAsync(new OperationMessage diff --git a/src/Transports.Subscriptions.WebSockets/WebSocketConnectionFactory.cs b/src/Transports.Subscriptions.WebSockets/WebSocketConnectionFactory.cs index 5637e1d6..ce9c4318 100644 --- a/src/Transports.Subscriptions.WebSockets/WebSocketConnectionFactory.cs +++ b/src/Transports.Subscriptions.WebSockets/WebSocketConnectionFactory.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using System.Net.WebSockets; +using GraphQL.Http; using GraphQL.Server.Internal; using GraphQL.Server.Transports.Subscriptions.Abstractions; using GraphQL.Types; @@ -14,23 +15,26 @@ public class WebSocketConnectionFactory : IWebSocketConnectionFactory _executer; private readonly IEnumerable _messageListeners; + private readonly IDocumentWriter _documentWriter; public WebSocketConnectionFactory(ILogger> logger, ILoggerFactory loggerFactory, IGraphQLExecuter executer, - IEnumerable messageListeners) + IEnumerable messageListeners, + IDocumentWriter documentWriter) { _logger = logger; _loggerFactory = loggerFactory; _executer = executer; _messageListeners = messageListeners; + _documentWriter = documentWriter; } public WebSocketConnection CreateConnection(WebSocket socket, string connectionId) { _logger.LogDebug("Creating server for connection {connectionId}", connectionId); - var transport = new WebSocketTransport(socket); + var transport = new WebSocketTransport(socket, _documentWriter); var manager = new SubscriptionManager(_executer, _loggerFactory); var server = new SubscriptionServer( transport, diff --git a/src/Transports.Subscriptions.WebSockets/WebSocketTransport.cs b/src/Transports.Subscriptions.WebSockets/WebSocketTransport.cs index 433bfa4c..0b54453c 100644 --- a/src/Transports.Subscriptions.WebSockets/WebSocketTransport.cs +++ b/src/Transports.Subscriptions.WebSockets/WebSocketTransport.cs @@ -1,6 +1,7 @@ using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; +using GraphQL.Http; using GraphQL.Server.Transports.Subscriptions.Abstractions; using Newtonsoft.Json; using Newtonsoft.Json.Serialization; @@ -11,7 +12,7 @@ public class WebSocketTransport : IMessageTransport { private readonly WebSocket _socket; - public WebSocketTransport(WebSocket socket) + public WebSocketTransport(WebSocket socket, IDocumentWriter documentWriter) { _socket = socket; var serializerSettings = new JsonSerializerSettings @@ -22,7 +23,7 @@ public WebSocketTransport(WebSocket socket) }; Reader = new WebSocketReaderPipeline(_socket, serializerSettings); - Writer = new WebSocketWriterPipeline(_socket, serializerSettings); + Writer = new WebSocketWriterPipeline(_socket, documentWriter); } diff --git a/src/Transports.Subscriptions.WebSockets/WebSocketWriterPipeline.cs b/src/Transports.Subscriptions.WebSockets/WebSocketWriterPipeline.cs index 92cddbc4..ac8f570e 100644 --- a/src/Transports.Subscriptions.WebSockets/WebSocketWriterPipeline.cs +++ b/src/Transports.Subscriptions.WebSockets/WebSocketWriterPipeline.cs @@ -1,33 +1,23 @@ -using System; using System.Net.WebSockets; -using System.Text; -using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Dataflow; +using GraphQL.Http; using GraphQL.Server.Transports.Subscriptions.Abstractions; -using Newtonsoft.Json; namespace GraphQL.Server.Transports.WebSockets { public class WebSocketWriterPipeline : IWriterPipeline { - private readonly ITargetBlock _endBlock; - private readonly JsonSerializerSettings _serializerSettings; private readonly WebSocket _socket; - private readonly IPropagatorBlock _startBlock; + private readonly IDocumentWriter _documentWriter; + private readonly ITargetBlock _startBlock; - public WebSocketWriterPipeline(WebSocket socket, JsonSerializerSettings serializerSettings) + public WebSocketWriterPipeline(WebSocket socket, IDocumentWriter documentWriter) { _socket = socket; - _serializerSettings = serializerSettings; + _documentWriter = documentWriter; - _endBlock = CreateMessageWriter(); - _startBlock = CreateWriterJsonTransformer(); - - _startBlock.LinkTo(_endBlock, new DataflowLinkOptions - { - PropagateCompletion = true - }); + _startBlock = CreateMessageWriter(); } public bool Post(OperationMessage message) @@ -40,7 +30,7 @@ public Task SendAsync(OperationMessage message) return _startBlock.SendAsync(message); } - public Task Completion => _endBlock.Completion; + public Task Completion => _startBlock.Completion; public Task Complete() { @@ -48,21 +38,9 @@ public Task Complete() return Task.CompletedTask; } - protected IPropagatorBlock CreateWriterJsonTransformer() + private ITargetBlock CreateMessageWriter() { - var transformer = new TransformBlock( - input => JsonConvert.SerializeObject(input, _serializerSettings), - new ExecutionDataflowBlockOptions - { - EnsureOrdered = true - }); - - return transformer; - } - - private ITargetBlock CreateMessageWriter() - { - var target = new ActionBlock( + var target = new ActionBlock( WriteMessageAsync, new ExecutionDataflowBlockOptions { BoundedCapacity = 1, @@ -73,12 +51,20 @@ private ITargetBlock CreateMessageWriter() return target; } - private Task WriteMessageAsync(string message) + private async Task WriteMessageAsync(OperationMessage message) { - if (_socket.CloseStatus.HasValue) return Task.CompletedTask; + if (_socket.CloseStatus.HasValue) return; - var messageSegment = new ArraySegment(Encoding.UTF8.GetBytes(message)); - return _socket.SendAsync(messageSegment, WebSocketMessageType.Text, true, CancellationToken.None); + var stream = new WebsocketWriterStream(_socket); + try + { + await _documentWriter.WriteAsync(stream, message); + } + finally + { + await stream.FlushAsync(); + stream.Dispose(); + } } } } \ No newline at end of file diff --git a/src/Transports.Subscriptions.WebSockets/WebsocketWriterStream.cs b/src/Transports.Subscriptions.WebSockets/WebsocketWriterStream.cs new file mode 100644 index 00000000..1a4164b6 --- /dev/null +++ b/src/Transports.Subscriptions.WebSockets/WebsocketWriterStream.cs @@ -0,0 +1,66 @@ +using System; +using System.IO; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; + +namespace GraphQL.Server.Transports.WebSockets +{ + public class WebsocketWriterStream : Stream + { + private readonly WebSocket _webSocket; + + public WebsocketWriterStream(WebSocket webSocket) + { + _webSocket = webSocket; + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _webSocket.SendAsync(new ArraySegment(buffer, offset, count), WebSocketMessageType.Text, false, + cancellationToken); + } + + public override void Write(byte[] buffer, int offset, int count) + { + WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _webSocket.SendAsync(new ArraySegment(Array.Empty()), WebSocketMessageType.Text, true, cancellationToken); + } + + public override void Flush() + { + FlushAsync().GetAwaiter().GetResult(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new System.NotSupportedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new System.NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new System.NotSupportedException(); + } + + public override bool CanRead => false; + public override bool CanSeek => false; + public override bool CanWrite => true; + + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + } +} diff --git a/tests/Transports.Subscriptions.Abstractions.Tests/Specs/ChatSpec.cs b/tests/Transports.Subscriptions.Abstractions.Tests/Specs/ChatSpec.cs index 7dfbe18c..50bdb9f3 100644 --- a/tests/Transports.Subscriptions.Abstractions.Tests/Specs/ChatSpec.cs +++ b/tests/Transports.Subscriptions.Abstractions.Tests/Specs/ChatSpec.cs @@ -49,7 +49,7 @@ public ChatSpec() private void AssertReceivedData(List writtenMessages, Predicate predicate) { var dataMessages = writtenMessages.Where(m => m.Type == MessageType.GQL_DATA); - var results = dataMessages.Select(m => m.Payload["data"] as JObject) + var results = dataMessages.Select(m => JObject.FromObject(((ExecutionResult)m.Payload).Data)) .ToList(); Assert.Contains(results, predicate); diff --git a/tests/Transports.Subscriptions.WebSockets.Tests/TestMessage.cs b/tests/Transports.Subscriptions.WebSockets.Tests/TestMessage.cs new file mode 100644 index 00000000..805c1236 --- /dev/null +++ b/tests/Transports.Subscriptions.WebSockets.Tests/TestMessage.cs @@ -0,0 +1,11 @@ +using System; + +namespace GraphQL.Server.Transports.WebSockets.Tests +{ + public class TestMessage + { + public string Content { get; set; } + + public DateTimeOffset SentAt { get; set; } + } +} \ No newline at end of file diff --git a/tests/Transports.Subscriptions.WebSockets.Tests/TestWebSocket.cs b/tests/Transports.Subscriptions.WebSockets.Tests/TestWebSocket.cs new file mode 100644 index 00000000..3e5abe6e --- /dev/null +++ b/tests/Transports.Subscriptions.WebSockets.Tests/TestWebSocket.cs @@ -0,0 +1,211 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; + +namespace GraphQL.Server.Transports.WebSockets.Tests +{ + public class TestWebSocket : WebSocket + { + public TestWebSocket() + { + CurrentMessage = new ChunkedMemoryStream(); + } + + public override void Abort() + { + } + + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, + CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, + CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public override void Dispose() + { + } + + public override Task ReceiveAsync(ArraySegment buffer, + CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, + CancellationToken cancellationToken) + { + if (buffer.Array != null) + { + CurrentMessage.Write(buffer.Array, buffer.Offset, buffer.Count); + } + + if (endOfMessage) + { + Messages.Add(CurrentMessage); + CurrentMessage = new ChunkedMemoryStream(); + } + + return Task.CompletedTask; + } + + internal List Messages { get; } = new List(); + private ChunkedMemoryStream CurrentMessage { get; set; } + + public override WebSocketCloseStatus? CloseStatus { get; } = null; + public override string CloseStatusDescription { get; } = ""; + public override string SubProtocol { get; } = ""; + public override WebSocketState State { get; } = WebSocketState.Open; + } + + internal class ChunkedMemoryStream : Stream + { + private readonly List _chunks = new List(); + private int _positionChunk; + private int _positionOffset; + private long _position; + + public override bool CanRead + { + get { return true; } + } + + public override bool CanSeek + { + get { return true; } + } + + public override bool CanWrite + { + get { return true; } + } + + public override void Flush() + { + } + + public override long Length + { + get { return _chunks.Sum(c => c.Length); } + } + + public override long Position + { + get { return _position; } + set + { + _position = value; + + _positionChunk = 0; + + while (_positionOffset != 0) + { + if (_positionChunk >= _chunks.Count) + throw new OverflowException(); + + if (_positionOffset < _chunks[_positionChunk].Length) + return; + + _positionOffset -= _chunks[_positionChunk].Length; + _positionChunk++; + } + } + } + + public override int Read(byte[] buffer, int offset, int count) + { + int result = 0; + while ((count != 0) && (_positionChunk != _chunks.Count)) + { + int fromChunk = Math.Min(count, _chunks[_positionChunk].Length - _positionOffset); + if (fromChunk != 0) + { + Array.Copy(_chunks[_positionChunk], _positionOffset, buffer, offset, fromChunk); + offset += fromChunk; + count -= fromChunk; + result += fromChunk; + _position += fromChunk; + } + + _positionOffset = 0; + _positionChunk++; + } + + return result; + } + + public override long Seek(long offset, SeekOrigin origin) + { + long newPos = 0; + + switch (origin) + { + case SeekOrigin.Begin: + newPos = offset; + break; + case SeekOrigin.Current: + newPos = Position + offset; + break; + case SeekOrigin.End: + newPos = Length - offset; + break; + } + + Position = Math.Max(0, Math.Min(newPos, Length)); + return newPos; + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + while ((count != 0) && (_positionChunk != _chunks.Count)) + { + int toChunk = Math.Min(count, _chunks[_positionChunk].Length - _positionOffset); + if (toChunk != 0) + { + Array.Copy(buffer, offset, _chunks[_positionChunk], _positionOffset, toChunk); + offset += toChunk; + count -= toChunk; + _position += toChunk; + } + + _positionOffset = 0; + _positionChunk++; + } + + if (count != 0) + { + byte[] chunk = new byte[count]; + Array.Copy(buffer, offset, chunk, 0, count); + _chunks.Add(chunk); + _positionChunk = _chunks.Count; + _position += count; + } + } + + public byte[] ToArray() + { + using (MemoryStream ms = new MemoryStream()) + { + foreach (var bytes in _chunks) + { + ms.Write(bytes, 0, bytes.Length); + } + return ms.ToArray(); + } + } + } +} \ No newline at end of file diff --git a/tests/Transports.Subscriptions.WebSockets.Tests/WebSocketWriterPipelineTests.cs b/tests/Transports.Subscriptions.WebSockets.Tests/WebSocketWriterPipelineTests.cs new file mode 100644 index 00000000..0f576378 --- /dev/null +++ b/tests/Transports.Subscriptions.WebSockets.Tests/WebSocketWriterPipelineTests.cs @@ -0,0 +1,193 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using GraphQL.Http; +using GraphQL.Server.Transports.Subscriptions.Abstractions; +using Newtonsoft.Json; +using Newtonsoft.Json.Serialization; +using Xunit; + +namespace GraphQL.Server.Transports.WebSockets.Tests +{ + public class WebSocketWriterPipelineFacts + { + private readonly WebSocketWriterPipeline _webSocketWriterPipeline; + private readonly TestWebSocket _testWebSocket; + + public WebSocketWriterPipelineFacts() + { + _testWebSocket = new TestWebSocket(); + _webSocketWriterPipeline = new WebSocketWriterPipeline(_testWebSocket, new DocumentWriter(Formatting.None, + new JsonSerializerSettings + { + ContractResolver = new CamelCasePropertyNamesContractResolver(), + NullValueHandling = NullValueHandling.Ignore + })); + } + + public static IEnumerable TestData => + new List + { + new object[] + { + new OperationMessage + { + Payload = new ExecutionResult + { + Data = new TestMessage + { + Content = "Hello world", + SentAt = new DateTimeOffset(2018, 12, 12, 10, 0,0, TimeSpan.Zero) + } + } + }, + 83 + }, + new object[] + { + new OperationMessage + { + Payload = new ExecutionResult + { + Data = Enumerable.Repeat(new TestMessage + { + Content = "Hello world", + SentAt = new DateTimeOffset(2018, 12, 12, 10, 0,0, TimeSpan.Zero) + }, 10) + } + }, + 652 + }, + new object[] + { + new OperationMessage + { + Payload = new ExecutionResult + { + Data = Enumerable.Repeat(new TestMessage + { + Content = "Hello world", + SentAt = new DateTimeOffset(2018, 12, 12, 10, 0,0, TimeSpan.Zero) + }, 16_000) + } + }, + // About 1 megabyte + 1_008_022 + }, + new object[] + { + new OperationMessage + { + Payload = new ExecutionResult + { + Data = Enumerable.Repeat(new TestMessage + { + Content = "Hello world", + SentAt = new DateTimeOffset(2018, 12, 12, 10, 0,0, TimeSpan.Zero) + }, 160_000) + } + }, + // About 10 megabytes + 10_080_022 + }, + new object[] + { + new OperationMessage + { + Payload = new ExecutionResult + { + Data = Enumerable.Repeat(new TestMessage + { + Content = "Hello world", + SentAt = new DateTimeOffset(2018, 12, 12, 10, 0,0, TimeSpan.Zero) + }, 1_600_000) + } + }, + // About 100 megabytes + 100_800_022 + }, + }; + + [Fact] + public async Task should_post_single_message() + { + var message = new OperationMessage + { + Payload = new ExecutionResult + { + Data = new TestMessage + { + Content = "Hello world", + SentAt = new DateTimeOffset(2018, 12, 12, 10, 0, 0, TimeSpan.Zero) + } + } + }; + Assert.True(_webSocketWriterPipeline.Post(message)); + await _webSocketWriterPipeline.Complete(); + await _webSocketWriterPipeline.Completion; + Assert.Single(_testWebSocket.Messages); + + var resultingJson = Encoding.UTF8.GetString(_testWebSocket.Messages.First().ToArray()); + Assert.Equal( + "{\"payload\":{\"data\":{\"content\":\"Hello world\",\"sentAt\":\"2018-12-12T10:00:00+00:00\"}}}", + resultingJson); + } + + [Fact] + public async Task should_post_array_of_10_messages() + { + var message = new OperationMessage + { + Payload = new ExecutionResult + { + Data = Enumerable.Repeat(new TestMessage + { + Content = "Hello world", + SentAt = new DateTimeOffset(2018, 12, 12, 10, 0, 0, TimeSpan.Zero) + }, 10) + } + }; + Assert.True(_webSocketWriterPipeline.Post(message)); + await _webSocketWriterPipeline.Complete(); + await _webSocketWriterPipeline.Completion; + Assert.Single(_testWebSocket.Messages); + + var resultingJson = Encoding.UTF8.GetString(_testWebSocket.Messages.First().ToArray()); + Assert.Equal("{\"payload\":{\"data\":[{\"content\":\"Hello world\",\"sentAt\":\"2018-12-12T10:00:00+00:00\"}," + + "{\"content\":\"Hello world\",\"sentAt\":\"2018-12-12T10:00:00+00:00\"}," + + "{\"content\":\"Hello world\",\"sentAt\":\"2018-12-12T10:00:00+00:00\"}," + + "{\"content\":\"Hello world\",\"sentAt\":\"2018-12-12T10:00:00+00:00\"}," + + "{\"content\":\"Hello world\",\"sentAt\":\"2018-12-12T10:00:00+00:00\"}," + + "{\"content\":\"Hello world\",\"sentAt\":\"2018-12-12T10:00:00+00:00\"}," + + "{\"content\":\"Hello world\",\"sentAt\":\"2018-12-12T10:00:00+00:00\"}," + + "{\"content\":\"Hello world\",\"sentAt\":\"2018-12-12T10:00:00+00:00\"}," + + "{\"content\":\"Hello world\",\"sentAt\":\"2018-12-12T10:00:00+00:00\"}," + + "{\"content\":\"Hello world\",\"sentAt\":\"2018-12-12T10:00:00+00:00\"}]}}", + resultingJson); + } + + [Theory] + [MemberData(nameof(TestData))] + public async Task should_post_for_any_message_length(OperationMessage message, long expectedLength) + { + Assert.True(_webSocketWriterPipeline.Post(message)); + await _webSocketWriterPipeline.Complete(); + await _webSocketWriterPipeline.Completion; + Assert.Single(_testWebSocket.Messages); + Assert.Equal(expectedLength, _testWebSocket.Messages.First().Length); + } + + [Theory] + [MemberData(nameof(TestData))] + public async Task should_send_for_any_message_length(OperationMessage message, long expectedLength) + { + await _webSocketWriterPipeline.SendAsync(message); + await _webSocketWriterPipeline.Complete(); + await _webSocketWriterPipeline.Completion; + Assert.Single(_testWebSocket.Messages); + Assert.Equal(expectedLength, _testWebSocket.Messages.First().Length); + } + } +} \ No newline at end of file