Skip to content

Commit 0a02c24

Browse files
committed
Spike a connection factory for the sockets transport
- Made the unix domain sockets test use it
1 parent 1c0014c commit 0a02c24

File tree

6 files changed

+178
-21
lines changed

6 files changed

+178
-21
lines changed

src/Servers/Kestrel/Transport.Sockets/ref/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.netcoreapp3.0.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ public static partial class WebHostBuilderSocketExtensions
1111
}
1212
namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets
1313
{
14+
public partial class SocketConnectionFactory : Microsoft.AspNetCore.Connections.IConnectionFactory
15+
{
16+
public SocketConnectionFactory(Microsoft.Extensions.Logging.ILoggerFactory loggerFactory) { }
17+
[System.Diagnostics.DebuggerStepThroughAttribute]
18+
public System.Threading.Tasks.ValueTask<Microsoft.AspNetCore.Connections.ConnectionContext> ConnectAsync(System.Net.EndPoint endPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
19+
}
1420
public sealed partial class SocketTransportFactory : Microsoft.AspNetCore.Connections.IConnectionListenerFactory
1521
{
1622
public SocketTransportFactory(Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions> options, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory) { }
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO.Pipelines;
4+
using System.Net;
5+
using System.Net.Sockets;
6+
using System.Text;
7+
using System.Threading;
8+
using System.Threading.Tasks;
9+
using Microsoft.AspNetCore.Connections;
10+
using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal;
11+
using Microsoft.Extensions.Logging;
12+
13+
namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets
14+
{
15+
/// <summary>
16+
/// Defines a class for creating socket connections based on the specified endpoint.
17+
/// </summary>
18+
public class SocketConnectionFactory : IConnectionFactory
19+
{
20+
private readonly ILogger _logger;
21+
22+
/// <summary>
23+
/// Creates the <see cref="SocketConnectionFactory"/>.
24+
/// </summary>
25+
/// <param name="loggerFactory">The logger factory</param>
26+
public SocketConnectionFactory(ILoggerFactory loggerFactory)
27+
{
28+
_logger = loggerFactory.CreateLogger<SocketConnectionFactory>();
29+
}
30+
31+
/// <summary>
32+
/// Creates a new socket connection to the specified endpoint.
33+
/// </summary>
34+
/// <param name="endPoint">The <see cref="EndPoint"/> to connect to.</param>
35+
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None" />.</param>
36+
/// <returns>
37+
/// A <see cref="ValueTask{TResult}" /> that represents the asynchronous connect, yielding the <see cref="ConnectionContext" /> for the new connection when completed.
38+
/// </returns>
39+
public async ValueTask<ConnectionContext> ConnectAsync(EndPoint endPoint, CancellationToken cancellationToken = default)
40+
{
41+
var protocolType = endPoint is UnixDomainSocketEndPoint ? ProtocolType.Unspecified : ProtocolType.Tcp;
42+
43+
var socket = new Socket(endPoint.AddressFamily, SocketType.Stream, protocolType);
44+
45+
await socket.ConnectAsync(endPoint);
46+
47+
var connection = new SocketConnection(socket, memoryPool: null, PipeScheduler.ThreadPool, new SocketsTrace(_logger));
48+
connection.Start();
49+
50+
return connection;
51+
}
52+
}
53+
}

src/Servers/Kestrel/test/FunctionalTests/UnixDomainSocketsTests.cs

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Buffers;
66
using System.IO;
7+
using System.IO.Pipelines;
78
using System.Linq;
89
using System.Net.Sockets;
910
using System.Text;
@@ -12,6 +13,7 @@
1213
using Microsoft.AspNetCore.Connections;
1314
using Microsoft.AspNetCore.Connections.Features;
1415
using Microsoft.AspNetCore.Hosting;
16+
using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets;
1517
using Microsoft.AspNetCore.Testing;
1618
using Microsoft.AspNetCore.Testing.xunit;
1719
using Microsoft.Extensions.Logging;
@@ -41,14 +43,11 @@ public async Task TestUnixDomainSocket()
4143

4244
async Task EchoServer(ConnectionContext connection)
4345
{
44-
// For graceful shutdown
45-
var notificationFeature = connection.Features.Get<IConnectionLifetimeNotificationFeature>();
46-
4746
try
4847
{
4948
while (true)
5049
{
51-
var result = await connection.Transport.Input.ReadAsync(notificationFeature.ConnectionClosedRequested);
50+
var result = await connection.Transport.Input.ReadAsync();
5251

5352
if (result.IsCompleted)
5453
{
@@ -61,10 +60,6 @@ async Task EchoServer(ConnectionContext connection)
6160
connection.Transport.Input.AdvanceTo(result.Buffer.End);
6261
}
6362
}
64-
catch (OperationCanceledException)
65-
{
66-
Logger.LogDebug("Graceful shutdown triggered for {connectionId}.", connection.ConnectionId);
67-
}
6863
finally
6964
{
7065
serverConnectionCompletedTcs.TrySetResult(null);
@@ -86,20 +81,14 @@ async Task EchoServer(ConnectionContext connection)
8681
{
8782
await host.StartAsync();
8883

89-
using (var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
84+
var factory = new SocketConnectionFactory(LoggerFactory);
85+
var endPoint = new UnixDomainSocketEndPoint(path);
86+
await using (var connection = await factory.ConnectAsync(endPoint))
9087
{
91-
await socket.ConnectAsync(new UnixDomainSocketEndPoint(path));
92-
9388
var data = Encoding.ASCII.GetBytes("Hello World");
94-
await socket.SendAsync(data, SocketFlags.None);
95-
96-
var buffer = new byte[data.Length];
97-
var read = 0;
98-
while (read < data.Length)
99-
{
100-
read += await socket.ReceiveAsync(buffer.AsMemory(read, buffer.Length - read), SocketFlags.None);
101-
}
89+
await connection.Transport.Output.WriteAsync(data);
10290

91+
var buffer = await connection.Transport.Input.ReadAsync(data.Length);
10392
Assert.Equal(data, buffer);
10493
}
10594

src/Servers/Kestrel/test/Libuv.FunctionalTests/Libuv.FunctionalTests.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<Project Sdk="Microsoft.NET.Sdk">
1+
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
44
<TargetFramework>netcoreapp3.0</TargetFramework>
@@ -11,6 +11,7 @@
1111

1212
<ItemGroup>
1313
<Compile Include="..\FunctionalTests\**\*.cs" />
14+
<Compile Include="$(SharedSourceRoot)\Pipelines\*.cs" LinkBase="Pipelines" />
1415
<Compile Include="$(SharedSourceRoot)NullScope.cs" />
1516
<Compile Include="$(SharedSourceRoot)test\SkipOnHelixAttribute.cs" />
1617
<Compile Include="$(KestrelSharedSourceRoot)test\*.cs" LinkBase="shared" />

src/Servers/Kestrel/test/Sockets.FunctionalTests/Sockets.FunctionalTests.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<Project Sdk="Microsoft.NET.Sdk">
1+
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
44
<TargetFramework>netcoreapp3.0</TargetFramework>
@@ -9,6 +9,7 @@
99

1010
<ItemGroup>
1111
<Compile Include="..\FunctionalTests\**\*.cs" />
12+
<Compile Include="$(SharedSourceRoot)\Pipelines\*.cs" LinkBase="Pipelines" />
1213
<Compile Include="$(SharedSourceRoot)NullScope.cs" />
1314
<Compile Include="$(SharedSourceRoot)test\SkipOnHelixAttribute.cs" />
1415
<Compile Include="$(KestrelSharedSourceRoot)test\*.cs" LinkBase="shared" />
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
using System;
2+
using System.Buffers;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
7+
namespace System.IO.Pipelines
8+
{
9+
public static class PipeReaderExtensions
10+
{
11+
public static async Task<bool> WaitToReadAsync(this PipeReader pipeReader)
12+
{
13+
while (true)
14+
{
15+
var result = await pipeReader.ReadAsync();
16+
17+
try
18+
{
19+
if (!result.Buffer.IsEmpty)
20+
{
21+
return true;
22+
}
23+
24+
if (result.IsCompleted)
25+
{
26+
return false;
27+
}
28+
}
29+
finally
30+
{
31+
// Don't consume or advance
32+
pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.Start);
33+
}
34+
}
35+
}
36+
37+
public static async Task<byte[]> ReadSingleAsync(this PipeReader pipeReader)
38+
{
39+
while (true)
40+
{
41+
var result = await pipeReader.ReadAsync();
42+
43+
try
44+
{
45+
return result.Buffer.ToArray();
46+
}
47+
finally
48+
{
49+
pipeReader.AdvanceTo(result.Buffer.End);
50+
}
51+
}
52+
}
53+
54+
public static async Task ConsumeAsync(this PipeReader pipeReader, int numBytes)
55+
{
56+
while (true)
57+
{
58+
var result = await pipeReader.ReadAsync();
59+
if (result.Buffer.Length < numBytes)
60+
{
61+
pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End);
62+
continue;
63+
}
64+
65+
pipeReader.AdvanceTo(result.Buffer.GetPosition(numBytes));
66+
break;
67+
}
68+
}
69+
70+
public static async Task<byte[]> ReadAllAsync(this PipeReader pipeReader)
71+
{
72+
while (true)
73+
{
74+
var result = await pipeReader.ReadAsync();
75+
76+
if (result.IsCompleted)
77+
{
78+
return result.Buffer.ToArray();
79+
}
80+
81+
// Consume nothing, just wait for everything
82+
pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End);
83+
}
84+
}
85+
86+
public static async Task<byte[]> ReadAsync(this PipeReader pipeReader, int numBytes)
87+
{
88+
while (true)
89+
{
90+
var result = await pipeReader.ReadAsync();
91+
if (result.Buffer.Length < numBytes)
92+
{
93+
pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End);
94+
continue;
95+
}
96+
97+
var buffer = result.Buffer.Slice(0, numBytes);
98+
99+
var bytes = buffer.ToArray();
100+
101+
pipeReader.AdvanceTo(buffer.End);
102+
103+
return bytes;
104+
}
105+
}
106+
}
107+
}

0 commit comments

Comments
 (0)