Skip to content

Commit baff28b

Browse files
Avoid client result invocation ID collisions (#43716)
1 parent 9952c6b commit baff28b

File tree

4 files changed

+42
-30
lines changed

4 files changed

+42
-30
lines changed

src/SignalR/common/Shared/ClientResultsManager.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ public void AddInvocation(string invocationId, (Type Type, string ConnectionId,
4242
{
4343
var result = _pendingInvocations.TryAdd(invocationId, invocationInfo);
4444
Debug.Assert(result);
45+
// Should have a 50% chance of happening once every 2.71 quintillion invocations (see UUID in Wikipedia)
46+
if (!result)
47+
{
48+
invocationInfo.Complete(invocationInfo.Tcs, CompletionMessage.WithError(invocationId, "ID collision occurred when using client results. This is likely a bug in SignalR."));
49+
}
4550
}
4651

4752
public void TryCompleteResult(string connectionId, CompletionMessage message)

src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -588,25 +588,22 @@ public async Task InvocationsFromDifferentServersUseUniqueIDs()
588588
var manager1 = CreateNewHubLifetimeManager(backplane);
589589
var manager2 = CreateNewHubLifetimeManager(backplane);
590590

591-
using (var client1 = new TestClient())
592-
using (var client2 = new TestClient())
591+
using (var client = new TestClient())
593592
{
594-
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
595-
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
593+
var connection = HubConnectionContextUtils.Create(client.Connection);
596594

597-
await manager1.OnConnectedAsync(connection1).DefaultTimeout();
598-
await manager2.OnConnectedAsync(connection2).DefaultTimeout();
595+
await manager1.OnConnectedAsync(connection).DefaultTimeout();
599596

600-
var invoke1 = manager1.InvokeConnectionAsync<int>(connection2.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default);
601-
var invocation2 = Assert.IsType<InvocationMessage>(await client2.ReadAsync().DefaultTimeout());
597+
var invoke1 = manager1.InvokeConnectionAsync<int>(connection.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default);
598+
var invocation2 = Assert.IsType<InvocationMessage>(await client.ReadAsync().DefaultTimeout());
602599

603-
var invoke2 = manager2.InvokeConnectionAsync<int>(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default);
604-
var invocation1 = Assert.IsType<InvocationMessage>(await client1.ReadAsync().DefaultTimeout());
600+
var invoke2 = manager2.InvokeConnectionAsync<int>(connection.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default);
601+
var invocation1 = Assert.IsType<InvocationMessage>(await client.ReadAsync().DefaultTimeout());
605602

606603
Assert.NotEqual(invocation1.InvocationId, invocation2.InvocationId);
607604

608-
await manager1.SetConnectionResultAsync(connection2.ConnectionId, CompletionMessage.WithResult(invocation2.InvocationId, 2)).DefaultTimeout();
609-
await manager2.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 5)).DefaultTimeout();
605+
await manager1.SetConnectionResultAsync(connection.ConnectionId, CompletionMessage.WithResult(invocation2.InvocationId, 2)).DefaultTimeout();
606+
await manager2.SetConnectionResultAsync(connection.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 5)).DefaultTimeout();
610607

611608
var res = await invoke1.DefaultTimeout();
612609
Assert.Equal(2, res);

src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,18 @@ internal sealed class RedisChannels
2323
/// </summary>
2424
public string GroupManagement { get; }
2525

26-
public RedisChannels(string prefix)
26+
/// <summary>
27+
/// Gets the name of the internal channel for receiving client results.
28+
/// </summary>
29+
public string ReturnResults { get; }
30+
31+
public RedisChannels(string prefix, string serverName)
2732
{
2833
_prefix = prefix;
2934

3035
All = prefix + ":all";
3136
GroupManagement = prefix + ":internal:groups";
37+
ReturnResults = _prefix + ":internal:return:" + serverName;
3238
}
3339

3440
/// <summary>
@@ -71,15 +77,4 @@ public string Ack(string serverName)
7177
{
7278
return _prefix + ":internal:ack:" + serverName;
7379
}
74-
75-
/// <summary>
76-
/// Gets the name of the client return results channel for the specified server.
77-
/// </summary>
78-
/// <param name="serverName">The name of the server to get the client return results channel for.</param>
79-
/// <returns></returns>
80-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
81-
public string ReturnResults(string serverName)
82-
{
83-
return _prefix + ":internal:return:" + serverName;
84-
}
8580
}

src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab
3939

4040
private readonly AckHandler _ackHandler;
4141
private int _internalAckId;
42-
private ulong _lastInvocationId;
4342

4443
/// <summary>
4544
/// Constructs the <see cref="RedisHubLifetimeManager{THub}"/> with types from Dependency Injection.
@@ -72,7 +71,7 @@ public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
7271
_logger = logger;
7372
_options = options.Value;
7473
_ackHandler = new AckHandler();
75-
_channels = new RedisChannels(typeof(THub).FullName!);
74+
_channels = new RedisChannels(typeof(THub).FullName!, _serverName);
7675
if (globalHubOptions != null && hubOptions != null)
7776
{
7877
_protocol = new RedisProtocol(new DefaultHubMessageSerializer(hubProtocolResolver, globalHubOptions.Value.SupportedProtocols, hubOptions.Value.SupportedProtocols));
@@ -416,8 +415,8 @@ public override async Task<T> InvokeConnectionAsync<T>(string connectionId, stri
416415

417416
var connection = _connections[connectionId];
418417

419-
// Needs to be unique across servers, easiest way to do that is prefix with connection ID.
420-
var invocationId = $"{connectionId}{Interlocked.Increment(ref _lastInvocationId)}";
418+
// ID needs to be unique for each invocation and across servers, we generate a GUID every time, that should provide enough uniqueness guarantees.
419+
var invocationId = GenerateInvocationId();
421420

422421
using var _ = CancellationTokenUtils.CreateLinkedToken(cancellationToken,
423422
connection?.ConnectionAborted ?? default, out var linkedToken);
@@ -428,7 +427,7 @@ public override async Task<T> InvokeConnectionAsync<T>(string connectionId, stri
428427
if (connection == null)
429428
{
430429
// TODO: Need to handle other server going away while waiting for connection result
431-
var messageBytes = _protocol.WriteInvocation(methodName, args, invocationId, returnChannel: _channels.ReturnResults(_serverName));
430+
var messageBytes = _protocol.WriteInvocation(methodName, args, invocationId, returnChannel: _channels.ReturnResults);
432431
var received = await PublishAsync(_channels.Connection(connectionId), messageBytes);
433432
if (received < 1)
434433
{
@@ -674,7 +673,7 @@ private async Task SubscribeToGroupAsync(string groupChannel, HubConnectionStore
674673

675674
private async Task SubscribeToReturnResultsAsync()
676675
{
677-
var channel = await _bus!.SubscribeAsync(_channels.ReturnResults(_serverName));
676+
var channel = await _bus!.SubscribeAsync(_channels.ReturnResults);
678677
channel.OnMessage((channelMessage) =>
679678
{
680679
var completion = RedisProtocol.ReadCompletion(channelMessage.Message);
@@ -700,6 +699,7 @@ private async Task SubscribeToReturnResultsAsync()
700699
Debug.Assert(parseSuccess);
701700

702701
var invocationInfo = _clientResultsManager.RemoveInvocation(((CompletionMessage)hubMessage!).InvocationId!);
702+
703703
invocationInfo?.Completion(invocationInfo?.Tcs!, (CompletionMessage)hubMessage!);
704704
});
705705
}
@@ -784,6 +784,21 @@ private static string GenerateServerName()
784784
return $"{Environment.MachineName}_{Guid.NewGuid():N}";
785785
}
786786

787+
private static string GenerateInvocationId()
788+
{
789+
Span<byte> buffer = stackalloc byte[16];
790+
var success = Guid.NewGuid().TryWriteBytes(buffer);
791+
Debug.Assert(success);
792+
// 16 * 4/3 = 21.333 which means base64 encoding will use 22 characters of actual data and 2 characters of padding ('=')
793+
Span<char> base64 = stackalloc char[24];
794+
success = Convert.TryToBase64Chars(buffer, base64, out var written);
795+
Debug.Assert(success);
796+
Debug.Assert(written == 24);
797+
// Trim the two '=='
798+
Debug.Assert(base64.EndsWith("=="));
799+
return new string(base64[..^2]);
800+
}
801+
787802
private sealed class LoggerTextWriter : TextWriter
788803
{
789804
private readonly ILogger _logger;

0 commit comments

Comments
 (0)