Skip to content

Commit ad95dfc

Browse files
authored
Merge pull request #152 from graphql-dotnet/fix-connection-init-and-tls-credentials
Fix connection init and TLS credentials
2 parents 26e9532 + fe6d899 commit ad95dfc

File tree

5 files changed

+115
-86
lines changed

5 files changed

+115
-86
lines changed

src/GraphQL.Client.Http/GraphQLHttpClient.cs

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@ public class GraphQLHttpClient : IGraphQLClient {
1212

1313
private readonly GraphQLHttpWebSocket graphQlHttpWebSocket;
1414
private readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource();
15-
private readonly HttpClient httpClient;
1615
private readonly ConcurrentDictionary<Tuple<GraphQLRequest, Type>, object> subscriptionStreams = new ConcurrentDictionary<Tuple<GraphQLRequest, Type>, object>();
1716

17+
/// <summary>
18+
/// the instance of <see cref="HttpClient"/> which is used internally
19+
/// </summary>
20+
public HttpClient HttpClient { get; }
21+
1822
/// <summary>
1923
/// The Options to be used
2024
/// </summary>
@@ -30,33 +34,33 @@ public GraphQLHttpClient(Uri endPoint) : this(o => o.EndPoint = endPoint) { }
3034
public GraphQLHttpClient(Action<GraphQLHttpClientOptions> configure) {
3135
Options = new GraphQLHttpClientOptions();
3236
configure(Options);
33-
this.httpClient = new HttpClient();
37+
this.HttpClient = new HttpClient(Options.HttpMessageHandler);
3438
this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), Options);
3539
}
3640

3741
public GraphQLHttpClient(GraphQLHttpClientOptions options) {
3842
Options = options;
39-
this.httpClient = new HttpClient();
43+
this.HttpClient = new HttpClient(Options.HttpMessageHandler);
4044
this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), Options);
4145
}
4246

4347
public GraphQLHttpClient(GraphQLHttpClientOptions options, HttpClient httpClient) {
4448
Options = options;
45-
this.httpClient = httpClient;
49+
this.HttpClient = httpClient;
4650
this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), Options);
4751
}
4852

53+
/// <inheritdoc />
4954
public Task<GraphQLResponse<TResponse>> SendQueryAsync<TResponse>(GraphQLRequest request, CancellationToken cancellationToken = default) {
5055
return Options.UseWebSocketForQueriesAndMutations
51-
? this.graphQlHttpWebSocket.Request<TResponse>(request, Options, cancellationToken)
56+
? this.graphQlHttpWebSocket.SendRequest<TResponse>(request, this, cancellationToken)
5257
: this.SendHttpPostRequestAsync<TResponse>(request, cancellationToken);
5358
}
5459

55-
public Task<GraphQLResponse<TResponse>> SendMutationAsync<TResponse>(GraphQLRequest request, CancellationToken cancellationToken = default) {
56-
return Options.UseWebSocketForQueriesAndMutations
57-
? this.graphQlHttpWebSocket.Request<TResponse>(request, Options, cancellationToken)
58-
: this.SendHttpPostRequestAsync<TResponse>(request, cancellationToken);
59-
}
60+
/// <inheritdoc />
61+
public Task<GraphQLResponse<TResponse>> SendMutationAsync<TResponse>(GraphQLRequest request,
62+
CancellationToken cancellationToken = default)
63+
=> SendQueryAsync<TResponse>(request, cancellationToken);
6064

6165
/// <inheritdoc />
6266
public IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream<TResponse>(GraphQLRequest request) {
@@ -68,7 +72,7 @@ public IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream<TRespons
6872
if (subscriptionStreams.ContainsKey(key))
6973
return (IObservable<GraphQLResponse<TResponse>>)subscriptionStreams[key];
7074

71-
var observable = graphQlHttpWebSocket.CreateSubscriptionStream<TResponse>(request, Options, cancellationToken: cancellationTokenSource.Token);
75+
var observable = graphQlHttpWebSocket.CreateSubscriptionStream<TResponse>(request, this, cancellationToken: cancellationTokenSource.Token);
7276

7377
subscriptionStreams.TryAdd(key, observable);
7478
return observable;
@@ -84,7 +88,7 @@ public IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream<TRespons
8488
if (subscriptionStreams.ContainsKey(key))
8589
return (IObservable<GraphQLResponse<TResponse>>)subscriptionStreams[key];
8690

87-
var observable = graphQlHttpWebSocket.CreateSubscriptionStream<TResponse>(request, Options, exceptionHandler, cancellationTokenSource.Token);
91+
var observable = graphQlHttpWebSocket.CreateSubscriptionStream<TResponse>(request, this, exceptionHandler, cancellationTokenSource.Token);
8892
subscriptionStreams.TryAdd(key, observable);
8993
return observable;
9094
}
@@ -98,8 +102,9 @@ public IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream<TRespons
98102
#region Private Methods
99103

100104
private async Task<GraphQLResponse<TResponse>> SendHttpPostRequestAsync<TResponse>(GraphQLRequest request, CancellationToken cancellationToken = default) {
101-
using var httpRequestMessage = this.GenerateHttpRequestMessage(request.SerializeToJson(Options));
102-
using var httpResponseMessage = await this.httpClient.SendAsync(httpRequestMessage, cancellationToken);
105+
var preprocessedRequest = await Options.PreprocessRequest(request, this);
106+
using var httpRequestMessage = this.GenerateHttpRequestMessage(preprocessedRequest.SerializeToJson(Options));
107+
using var httpResponseMessage = await this.HttpClient.SendAsync(httpRequestMessage, cancellationToken);
103108
if (!httpResponseMessage.IsSuccessStatusCode) {
104109
throw new GraphQLHttpException(httpResponseMessage);
105110
}
@@ -140,7 +145,7 @@ public void Dispose() {
140145

141146
private void _dispose() {
142147
disposed = true;
143-
this.httpClient.Dispose();
148+
this.HttpClient.Dispose();
144149
this.graphQlHttpWebSocket.Dispose();
145150
cancellationTokenSource.Cancel();
146151
cancellationTokenSource.Dispose();

src/GraphQL.Client.Http/GraphQLHttpClientOptions.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Net.Http;
33
using System.Net.Http.Headers;
44
using System.Text.Json;
5+
using System.Threading.Tasks;
56
using Dahomey.Json;
67

78
namespace GraphQL.Client.Http {
@@ -46,6 +47,10 @@ public class GraphQLHttpClientOptions {
4647
/// If <see langword="true"/>, the websocket connection is also used for regular queries and mutations
4748
/// </summary>
4849
public bool UseWebSocketForQueriesAndMutations { get; set; } = false;
49-
}
5050

51+
/// <summary>
52+
/// Request preprocessing function. Can be used i.e. to inject authorization info into a GraphQL request payload.
53+
/// </summary>
54+
public Func<GraphQLRequest, GraphQLHttpClient, Task<GraphQLRequest>> PreprocessRequest { get; set; } = (request, client) => Task.FromResult(request);
55+
}
5156
}

src/GraphQL.Client.Http/Websocket/GraphQLHttpWebSocket.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Diagnostics;
33
using System.IO;
4+
using System.Net.Http;
45
using System.Net.WebSockets;
56
using System.Reactive.Disposables;
67
using System.Reactive.Linq;
@@ -108,9 +109,13 @@ public Task InitializeWebSocket() {
108109
switch (clientWebSocket) {
109110
case ClientWebSocket nativeWebSocket:
110111
nativeWebSocket.Options.AddSubProtocol("graphql-ws");
112+
nativeWebSocket.Options.ClientCertificates = ((HttpClientHandler)_options.HttpMessageHandler).ClientCertificates;
113+
nativeWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)_options.HttpMessageHandler).UseDefaultCredentials;
111114
break;
112115
case System.Net.WebSockets.Managed.ClientWebSocket managedWebSocket:
113116
managedWebSocket.Options.AddSubProtocol("graphql-ws");
117+
managedWebSocket.Options.ClientCertificates = ((HttpClientHandler)_options.HttpMessageHandler).ClientCertificates;
118+
managedWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)_options.HttpMessageHandler).UseDefaultCredentials;
114119
break;
115120
default:
116121
throw new NotSupportedException($"unknown websocket type {clientWebSocket.GetType().Name}");

src/GraphQL.Client.Http/Websocket/GraphQLHttpWebsocketHelpers.cs

Lines changed: 82 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ public static class GraphQLHttpWebsocketHelpers {
1313
internal static IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream<TResponse>(
1414
this GraphQLHttpWebSocket graphQlHttpWebSocket,
1515
GraphQLRequest request,
16-
GraphQLHttpClientOptions options,
16+
GraphQLHttpClient client,
1717
Action<Exception> exceptionHandler = null,
1818
CancellationToken cancellationToken = default) {
1919
return Observable.Defer(() =>
2020
Observable.Create<GraphQLResponse<TResponse>>(async observer => {
21+
await client.Options.PreprocessRequest(request, client);
2122
var startRequest = new GraphQLWebSocketRequest {
2223
Id = Guid.NewGuid().ToString("N"),
2324
Type = GraphQLWebSocketMessageType.GQL_START,
@@ -27,34 +28,38 @@ internal static IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream
2728
Id = startRequest.Id,
2829
Type = GraphQLWebSocketMessageType.GQL_STOP
2930
};
31+
var initRequest = new GraphQLWebSocketRequest {
32+
Id = startRequest.Id,
33+
Type = GraphQLWebSocketMessageType.GQL_CONNECTION_INIT,
34+
};
3035

3136
var observable = Observable.Create<GraphQLResponse<TResponse>>(o =>
3237
graphQlHttpWebSocket.ResponseStream
3338
// ignore null values and messages for other requests
3439
.Where(response => response != null && response.Id == startRequest.Id)
3540
.Subscribe(response => {
36-
// terminate the sequence when a 'complete' message is received
37-
if (response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE) {
38-
Debug.WriteLine($"received 'complete' message on subscription {startRequest.Id}");
39-
o.OnCompleted();
40-
return;
41-
}
42-
43-
// post the GraphQLResponse to the stream (even if a GraphQL error occurred)
44-
Debug.WriteLine($"received payload on subscription {startRequest.Id}");
45-
var typedResponse =
46-
JsonSerializer.Deserialize<GraphQLWebSocketResponse<TResponse>>(response.MessageBytes,
47-
options.JsonSerializerOptions);
48-
o.OnNext(typedResponse.Payload);
49-
50-
// in case of a GraphQL error, terminate the sequence after the response has been posted
51-
if (response.Type == GraphQLWebSocketMessageType.GQL_ERROR) {
52-
Debug.WriteLine($"terminating subscription {startRequest.Id} because of a GraphQL error");
53-
o.OnCompleted();
54-
}
55-
},
56-
o.OnError,
57-
o.OnCompleted)
41+
// terminate the sequence when a 'complete' message is received
42+
if (response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE) {
43+
Debug.WriteLine($"received 'complete' message on subscription {startRequest.Id}");
44+
o.OnCompleted();
45+
return;
46+
}
47+
48+
// post the GraphQLResponse to the stream (even if a GraphQL error occurred)
49+
Debug.WriteLine($"received payload on subscription {startRequest.Id}");
50+
var typedResponse =
51+
JsonSerializer.Deserialize<GraphQLWebSocketResponse<TResponse>>(response.MessageBytes,
52+
client.Options.JsonSerializerOptions);
53+
o.OnNext(typedResponse.Payload);
54+
55+
// in case of a GraphQL error, terminate the sequence after the response has been posted
56+
if (response.Type == GraphQLWebSocketMessageType.GQL_ERROR) {
57+
Debug.WriteLine($"terminating subscription {startRequest.Id} because of a GraphQL error");
58+
o.OnCompleted();
59+
}
60+
},
61+
o.OnError,
62+
o.OnCompleted)
5863
);
5964

6065
try {
@@ -81,6 +86,16 @@ internal static IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream
8186
})
8287
);
8388

89+
// send connection init
90+
Debug.WriteLine($"sending connection init on subscription {startRequest.Id}");
91+
try {
92+
await graphQlHttpWebSocket.SendWebSocketRequest(initRequest).ConfigureAwait(false);
93+
}
94+
catch (Exception e) {
95+
Console.WriteLine(e);
96+
throw;
97+
}
98+
8499
Debug.WriteLine($"sending initial message on subscription {startRequest.Id}");
85100
// send subscription request
86101
try {
@@ -137,53 +152,54 @@ internal static IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream
137152
.Publish().RefCount();
138153
}
139154

140-
internal static Task<GraphQLResponse<TResponse>> Request<TResponse>(
155+
internal static Task<GraphQLResponse<TResponse>> SendRequest<TResponse>(
141156
this GraphQLHttpWebSocket graphQlHttpWebSocket,
142157
GraphQLRequest request,
143-
GraphQLHttpClientOptions options,
158+
GraphQLHttpClient client,
144159
CancellationToken cancellationToken = default) {
145160
return Observable.Create<GraphQLResponse<TResponse>>(async observer => {
146-
var websocketRequest = new GraphQLWebSocketRequest {
147-
Id = Guid.NewGuid().ToString("N"),
148-
Type = GraphQLWebSocketMessageType.GQL_START,
149-
Payload = request
150-
};
151-
var observable = graphQlHttpWebSocket.ResponseStream
152-
.Where(response => response != null && response.Id == websocketRequest.Id)
153-
.TakeUntil(response => response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE)
154-
.Select(response => {
155-
Debug.WriteLine($"received response for request {websocketRequest.Id}");
156-
var typedResponse =
157-
JsonSerializer.Deserialize<GraphQLWebSocketResponse<TResponse>>(response.MessageBytes,
158-
options.JsonSerializerOptions);
159-
return typedResponse.Payload;
160-
});
161-
162-
try {
163-
// intialize websocket (completes immediately if socket is already open)
164-
await graphQlHttpWebSocket.InitializeWebSocket().ConfigureAwait(false);
165-
}
166-
catch (Exception e) {
167-
// subscribe observer to failed observable
168-
return Observable.Throw<GraphQLResponse<TResponse>>(e).Subscribe(observer);
169-
}
170-
171-
var disposable = new CompositeDisposable(
172-
observable.Subscribe(observer)
173-
);
174-
175-
Debug.WriteLine($"submitting request {websocketRequest.Id}");
176-
// send request
177-
try {
178-
await graphQlHttpWebSocket.SendWebSocketRequest(websocketRequest).ConfigureAwait(false);
179-
}
180-
catch (Exception e) {
181-
Console.WriteLine(e);
182-
throw;
183-
}
184-
185-
return disposable;
186-
})
161+
await client.Options.PreprocessRequest(request, client);
162+
var websocketRequest = new GraphQLWebSocketRequest {
163+
Id = Guid.NewGuid().ToString("N"),
164+
Type = GraphQLWebSocketMessageType.GQL_START,
165+
Payload = request
166+
};
167+
var observable = graphQlHttpWebSocket.ResponseStream
168+
.Where(response => response != null && response.Id == websocketRequest.Id)
169+
.TakeUntil(response => response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE)
170+
.Select(response => {
171+
Debug.WriteLine($"received response for request {websocketRequest.Id}");
172+
var typedResponse =
173+
JsonSerializer.Deserialize<GraphQLWebSocketResponse<TResponse>>(response.MessageBytes,
174+
client.Options.JsonSerializerOptions);
175+
return typedResponse.Payload;
176+
});
177+
178+
try {
179+
// intialize websocket (completes immediately if socket is already open)
180+
await graphQlHttpWebSocket.InitializeWebSocket().ConfigureAwait(false);
181+
}
182+
catch (Exception e) {
183+
// subscribe observer to failed observable
184+
return Observable.Throw<GraphQLResponse<TResponse>>(e).Subscribe(observer);
185+
}
186+
187+
var disposable = new CompositeDisposable(
188+
observable.Subscribe(observer)
189+
);
190+
191+
Debug.WriteLine($"submitting request {websocketRequest.Id}");
192+
// send request
193+
try {
194+
await graphQlHttpWebSocket.SendWebSocketRequest(websocketRequest).ConfigureAwait(false);
195+
}
196+
catch (Exception e) {
197+
Console.WriteLine(e);
198+
throw;
199+
}
200+
201+
return disposable;
202+
})
187203
// complete sequence on OperationCanceledException, this is triggered by the cancellation token
188204
.Catch<GraphQLResponse<TResponse>, OperationCanceledException>(exception =>
189205
Observable.Empty<GraphQLResponse<TResponse>>())

tests/GraphQL.Integration.Tests/SubscriptionsTest.cs renamed to tests/GraphQL.Integration.Tests/WebsocketTest.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@
1212
using Xunit.Abstractions;
1313

1414
namespace GraphQL.Integration.Tests {
15-
public class SubscriptionsTest {
15+
public class WebsocketTest {
1616
private readonly ITestOutputHelper output;
1717

1818
private static IWebHost CreateServer(int port) => WebHostHelpers.CreateServer<StartupChat>(port);
1919

20-
private static TimeSpan WaitForConnectionDelay = TimeSpan.FromMilliseconds(200);
21-
22-
public SubscriptionsTest(ITestOutputHelper output) {
20+
public WebsocketTest(ITestOutputHelper output) {
2321
this.output = output;
2422
}
2523

0 commit comments

Comments
 (0)