Skip to content

Commit 31dfe91

Browse files
veloekdavidfowl
authored andcommitted
Support async access token factory (#1911)
1 parent 6bc2ebb commit 31dfe91

File tree

11 files changed

+74
-34
lines changed

11 files changed

+74
-34
lines changed

clients/ts/FunctionalTests/ts/HubConnectionTests.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,33 @@ describe("hubConnection", () => {
532532
}
533533
});
534534

535+
it("can connect to hub with authorization using async token factory", async (done) => {
536+
const message = "你好,世界!";
537+
538+
try {
539+
const hubConnection = new HubConnection("/authorizedhub", {
540+
accessTokenFactory: () => getJwtToken("http://" + document.location.host + "/generateJwtToken"),
541+
...commonOptions,
542+
transport: transportType,
543+
});
544+
hubConnection.onclose((error) => {
545+
expect(error).toBe(undefined);
546+
done();
547+
});
548+
await hubConnection.start();
549+
const response = await hubConnection.invoke("Echo", message);
550+
551+
expect(response).toEqual(message);
552+
553+
await hubConnection.stop();
554+
555+
done();
556+
} catch (err) {
557+
fail(err);
558+
done();
559+
}
560+
});
561+
535562
if (transportType !== TransportType.LongPolling) {
536563
it("terminates if no messages received within timeout interval", (done) => {
537564
const hubConnection = new HubConnection(TESTHUBENDPOINT_URL, {

clients/ts/signalr/src/HttpConnection.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export interface IHttpConnectionOptions {
1313
httpClient?: HttpClient;
1414
transport?: TransportType | ITransport;
1515
logger?: ILogger | LogLevel;
16-
accessTokenFactory?: () => string;
16+
accessTokenFactory?: () => string | Promise<string>;
1717
logMessageContent?: boolean;
1818
}
1919

@@ -87,7 +87,7 @@ export class HttpConnection implements IConnection {
8787
// No fallback or negotiate in this case.
8888
await this.transport.connect(this.url, transferFormat, this);
8989
} else {
90-
const token = this.options.accessTokenFactory();
90+
const token = await this.options.accessTokenFactory();
9191
let headers;
9292
if (token) {
9393
headers = {

clients/ts/signalr/src/Transports.ts

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,17 @@ export interface ITransport {
3030

3131
export class WebSocketTransport implements ITransport {
3232
private readonly logger: ILogger;
33-
private readonly accessTokenFactory: () => string;
33+
private readonly accessTokenFactory: () => string | Promise<string>;
3434
private readonly logMessageContent: boolean;
3535
private webSocket: WebSocket;
3636

37-
constructor(accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) {
37+
constructor(accessTokenFactory: () => string | Promise<string>, logger: ILogger, logMessageContent: boolean) {
3838
this.logger = logger;
3939
this.accessTokenFactory = accessTokenFactory || (() => null);
4040
this.logMessageContent = logMessageContent;
4141
}
4242

43-
public connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
43+
public async connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
4444
Arg.isRequired(url, "url");
4545
Arg.isRequired(transferFormat, "transferFormat");
4646
Arg.isIn(transferFormat, TransferFormat, "transferFormat");
@@ -52,9 +52,9 @@ export class WebSocketTransport implements ITransport {
5252

5353
this.logger.log(LogLevel.Trace, "(WebSockets transport) Connecting");
5454

55+
const token = await this.accessTokenFactory();
5556
return new Promise<void>((resolve, reject) => {
5657
url = url.replace(/^http/, "ws");
57-
const token = this.accessTokenFactory();
5858
if (token) {
5959
url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`;
6060
}
@@ -118,20 +118,20 @@ export class WebSocketTransport implements ITransport {
118118

119119
export class ServerSentEventsTransport implements ITransport {
120120
private readonly httpClient: HttpClient;
121-
private readonly accessTokenFactory: () => string;
121+
private readonly accessTokenFactory: () => string | Promise<string>;
122122
private readonly logger: ILogger;
123123
private readonly logMessageContent: boolean;
124124
private eventSource: EventSource;
125125
private url: string;
126126

127-
constructor(httpClient: HttpClient, accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) {
127+
constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise<string>, logger: ILogger, logMessageContent: boolean) {
128128
this.httpClient = httpClient;
129129
this.accessTokenFactory = accessTokenFactory || (() => null);
130130
this.logger = logger;
131131
this.logMessageContent = logMessageContent;
132132
}
133133

134-
public connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
134+
public async connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
135135
Arg.isRequired(url, "url");
136136
Arg.isRequired(transferFormat, "transferFormat");
137137
Arg.isIn(transferFormat, TransferFormat, "transferFormat");
@@ -144,12 +144,12 @@ export class ServerSentEventsTransport implements ITransport {
144144
this.logger.log(LogLevel.Trace, "(SSE transport) Connecting");
145145

146146
this.url = url;
147+
const token = await this.accessTokenFactory();
147148
return new Promise<void>((resolve, reject) => {
148149
if (transferFormat !== TransferFormat.Text) {
149150
reject(new Error("The Server-Sent Events transport only supports the 'Text' transfer format"));
150151
}
151152

152-
const token = this.accessTokenFactory();
153153
if (token) {
154154
url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`;
155155
}
@@ -210,15 +210,15 @@ export class ServerSentEventsTransport implements ITransport {
210210

211211
export class LongPollingTransport implements ITransport {
212212
private readonly httpClient: HttpClient;
213-
private readonly accessTokenFactory: () => string;
213+
private readonly accessTokenFactory: () => string | Promise<string>;
214214
private readonly logger: ILogger;
215215
private readonly logMessageContent: boolean;
216216

217217
private url: string;
218218
private pollXhr: XMLHttpRequest;
219219
private pollAbort: AbortController;
220220

221-
constructor(httpClient: HttpClient, accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) {
221+
constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise<string>, logger: ILogger, logMessageContent: boolean) {
222222
this.httpClient = httpClient;
223223
this.accessTokenFactory = accessTokenFactory || (() => null);
224224
this.logger = logger;
@@ -259,7 +259,7 @@ export class LongPollingTransport implements ITransport {
259259
pollOptions.responseType = "arraybuffer";
260260
}
261261

262-
const token = this.accessTokenFactory();
262+
const token = await this.accessTokenFactory();
263263
if (token) {
264264
// tslint:disable-next-line:no-string-literal
265265
pollOptions.headers["Authorization"] = `Bearer ${token}`;
@@ -356,12 +356,12 @@ function formatArrayBuffer(data: ArrayBuffer): string {
356356
return str.substr(0, str.length - 1);
357357
}
358358

359-
async function send(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: () => string, content: string | ArrayBuffer, logMessageContent: boolean): Promise<void> {
359+
async function send(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: () => string | Promise<string>, content: string | ArrayBuffer, logMessageContent: boolean): Promise<void> {
360360
let headers;
361-
const token = accessTokenFactory();
361+
const token = await accessTokenFactory();
362362
if (token) {
363363
headers = {
364-
["Authorization"]: `Bearer ${accessTokenFactory()}`,
364+
["Authorization"]: `Bearer ${token}`,
365365
};
366366
}
367367

samples/JwtClientSample/Program.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ await Task.WhenAll(
2323

2424
private const string ServerUrl = "http://localhost:54543";
2525

26-
private readonly ConcurrentDictionary<string, string> _tokens = new ConcurrentDictionary<string, string>(StringComparer.Ordinal);
26+
private readonly ConcurrentDictionary<string, Task<string>> _tokens = new ConcurrentDictionary<string, Task<string>>(StringComparer.Ordinal);
2727
private readonly Random _random = new Random();
2828

2929
private async Task RunConnection(HttpTransportType transportType)
3030
{
3131
var userId = "C#" + transportType;
32-
_tokens[userId] = await GetJwtToken(userId);
32+
_tokens[userId] = GetJwtToken(userId);
3333

3434
var hubConnection = new HubConnectionBuilder()
3535
.WithUrl(ServerUrl + "/broadcast", options =>
@@ -60,7 +60,7 @@ private async Task RunConnection(HttpTransportType transportType)
6060
// no need to refresh the token for websockets
6161
if (transportType != HttpTransportType.WebSockets)
6262
{
63-
_tokens[userId] = await GetJwtToken(userId);
63+
_tokens[userId] = GetJwtToken(userId);
6464
Console.WriteLine($"[{userId}] Token refreshed");
6565
}
6666
}

src/Microsoft.AspNetCore.Http.Connections.Client/HttpOptions.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Net.Http;
88
using System.Net.WebSockets;
99
using System.Security.Cryptography.X509Certificates;
10+
using System.Threading.Tasks;
1011

1112
namespace Microsoft.AspNetCore.Http.Connections.Client
1213
{
@@ -19,7 +20,7 @@ public class HttpOptions
1920
public Func<HttpMessageHandler, HttpMessageHandler> HttpMessageHandlerFactory { get; set; }
2021

2122
public IReadOnlyCollection<KeyValuePair<string, string>> Headers { get; set; }
22-
public Func<string> AccessTokenFactory { get; set; }
23+
public Func<Task<string>> AccessTokenFactory { get; set; }
2324
public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5);
2425
public ICredentials Credentials { get; set; }
2526
public X509CertificateCollection ClientCertificates { get; set; } = new X509CertificateCollection();

src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,19 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
1111
{
1212
public class AccessTokenHttpMessageHandler : DelegatingHandler
1313
{
14-
private readonly Func<string> _accessTokenFactory;
14+
private readonly Func<Task<string>> _accessTokenFactory;
1515

16-
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func<string> accessTokenFactory) : base(inner)
16+
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func<Task<string>> accessTokenFactory) : base(inner)
1717
{
1818
_accessTokenFactory = accessTokenFactory ?? throw new ArgumentNullException(nameof(accessTokenFactory));
1919
}
2020

21-
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
21+
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
2222
{
23-
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessTokenFactory());
23+
var accessToken = await _accessTokenFactory();
24+
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken);
2425

25-
return base.SendAsync(request, cancellationToken);
26+
return await base.SendAsync(request, cancellationToken);
2627
}
2728
}
2829
}

src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
1717
public partial class WebSocketsTransport : ITransport
1818
{
1919
private readonly ClientWebSocket _webSocket;
20+
private readonly Func<Task<string>> _accessTokenFactory;
2021
private IDuplexPipe _application;
2122
private WebSocketMessageType _webSocketMessageType;
2223
private readonly ILogger _logger;
@@ -80,7 +81,7 @@ public WebSocketsTransport(HttpOptions httpOptions, ILoggerFactory loggerFactory
8081

8182
if (httpOptions.AccessTokenFactory != null)
8283
{
83-
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {httpOptions.AccessTokenFactory()}");
84+
_accessTokenFactory = httpOptions.AccessTokenFactory;
8485
}
8586

8687
httpOptions.WebSocketOptions?.Invoke(_webSocket.Options);
@@ -115,6 +116,12 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat)
115116

116117
Log.StartTransport(_logger, transferFormat, resolvedUrl);
117118

119+
if (_accessTokenFactory != null)
120+
{
121+
var accessToken = await _accessTokenFactory();
122+
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}");
123+
}
124+
118125
await _webSocket.ConnectAsync(resolvedUrl, CancellationToken.None);
119126

120127
// Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa)

src/Microsoft.AspNetCore.SignalR.Client/HttpConnectionOptions.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Net.Http;
88
using System.Net.WebSockets;
99
using System.Security.Cryptography.X509Certificates;
10+
using System.Threading.Tasks;
1011
using Microsoft.AspNetCore.Http.Connections;
1112
using Microsoft.AspNetCore.Http.Connections.Internal;
1213

@@ -29,7 +30,7 @@ public HttpConnectionOptions()
2930
public bool? UseDefaultCredentials { get; set; }
3031
public ICredentials Credentials { get; set; }
3132
public IWebProxy Proxy { get; set; }
32-
public Func<string> AccessTokenFactory { get; set; }
33+
public Func<Task<string>> AccessTokenFactory { get; set; }
3334
public Action<ClientWebSocketOptions> WebSocketOptions { get; set; }
3435

3536
public X509CertificateCollection ClientCertificates

test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -706,15 +706,18 @@ public async Task ClientCanUseJwtBearerTokenForAuthentication(HttpTransportType
706706
{
707707
using (StartLog(out var loggerFactory, $"{nameof(ClientCanUseJwtBearerTokenForAuthentication)}_{transportType}"))
708708
{
709-
var httpResponse = await new HttpClient().GetAsync(_serverFixture.Url + "/generateJwtToken");
710-
httpResponse.EnsureSuccessStatusCode();
711-
var token = await httpResponse.Content.ReadAsStringAsync();
709+
async Task<string> AccessTokenFactory()
710+
{
711+
var httpResponse = await new HttpClient().GetAsync(_serverFixture.Url + "/generateJwtToken");
712+
httpResponse.EnsureSuccessStatusCode();
713+
return await httpResponse.Content.ReadAsStringAsync();
714+
};
712715

713716
var hubConnection = new HubConnectionBuilder()
714717
.WithLoggerFactory(loggerFactory)
715718
.WithUrl(_serverFixture.Url + "/authorizedhub", transportType, options =>
716719
{
717-
options.AccessTokenFactory = () => token;
720+
options.AccessTokenFactory = AccessTokenFactory;
718721
})
719722
.Build();
720723
try

test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ private static HttpConnection CreateConnection(
2121
ITransport transport = null,
2222
ITransportFactory transportFactory = null,
2323
HttpTransportType transportType = HttpTransportType.LongPolling,
24-
Func<string> accessTokenFactory = null)
24+
Func<Task<string>> accessTokenFactory = null)
2525
{
2626
var httpOptions = new HttpOptions
2727
{

test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ public async Task HttpConnectionSetsAccessTokenOnAllRequests(HttpTransportType t
5050
return await next();
5151
});
5252

53-
string AccessTokenFactory()
53+
Task<string> AccessTokenFactory()
5454
{
5555
callCount++;
56-
return callCount.ToString();
56+
return Task.FromResult(callCount.ToString());
5757
}
5858

5959
await WithConnectionAsync(

0 commit comments

Comments
 (0)