Skip to content

Avoid client result invocation ID collisions #43716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/SignalR/common/Shared/ClientResultsManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ public void AddInvocation(string invocationId, (Type Type, string ConnectionId,
{
var result = _pendingInvocations.TryAdd(invocationId, invocationInfo);
Debug.Assert(result);
// Should have a 50% chance of happening once every 2.71 quintillion invocations (see UUID in Wikipedia)
Copy link
Member

@davidfowl davidfowl Sep 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤣 . Are we taking bets on if we ever see this error message filed in an issue?

if (!result)
{
invocationInfo.Complete(invocationInfo.Tcs, CompletionMessage.WithError(invocationId, "ID collision occurred when using client results. This is likely a bug in SignalR."));
}
}

public void TryCompleteResult(string connectionId, CompletionMessage message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,25 +588,22 @@ public async Task InvocationsFromDifferentServersUseUniqueIDs()
var manager1 = CreateNewHubLifetimeManager(backplane);
var manager2 = CreateNewHubLifetimeManager(backplane);

using (var client1 = new TestClient())
using (var client2 = new TestClient())
using (var client = new TestClient())
{
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
var connection = HubConnectionContextUtils.Create(client.Connection);

await manager1.OnConnectedAsync(connection1).DefaultTimeout();
await manager2.OnConnectedAsync(connection2).DefaultTimeout();
await manager1.OnConnectedAsync(connection).DefaultTimeout();

var invoke1 = manager1.InvokeConnectionAsync<int>(connection2.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default);
var invocation2 = Assert.IsType<InvocationMessage>(await client2.ReadAsync().DefaultTimeout());
var invoke1 = manager1.InvokeConnectionAsync<int>(connection.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default);
var invocation2 = Assert.IsType<InvocationMessage>(await client.ReadAsync().DefaultTimeout());

var invoke2 = manager2.InvokeConnectionAsync<int>(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default);
var invocation1 = Assert.IsType<InvocationMessage>(await client1.ReadAsync().DefaultTimeout());
var invoke2 = manager2.InvokeConnectionAsync<int>(connection.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default);
var invocation1 = Assert.IsType<InvocationMessage>(await client.ReadAsync().DefaultTimeout());

Assert.NotEqual(invocation1.InvocationId, invocation2.InvocationId);

await manager1.SetConnectionResultAsync(connection2.ConnectionId, CompletionMessage.WithResult(invocation2.InvocationId, 2)).DefaultTimeout();
await manager2.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 5)).DefaultTimeout();
await manager1.SetConnectionResultAsync(connection.ConnectionId, CompletionMessage.WithResult(invocation2.InvocationId, 2)).DefaultTimeout();
await manager2.SetConnectionResultAsync(connection.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 5)).DefaultTimeout();

var res = await invoke1.DefaultTimeout();
Assert.Equal(2, res);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ internal sealed class RedisChannels
/// </summary>
public string GroupManagement { get; }

public RedisChannels(string prefix)
/// <summary>
/// Gets the name of the internal channel for receiving client results.
/// </summary>
public string ReturnResults { get; }

public RedisChannels(string prefix, string serverName)
{
_prefix = prefix;

All = prefix + ":all";
GroupManagement = prefix + ":internal:groups";
ReturnResults = _prefix + ":internal:return:" + serverName;
}

/// <summary>
Expand Down Expand Up @@ -71,15 +77,4 @@ public string Ack(string serverName)
{
return _prefix + ":internal:ack:" + serverName;
}

/// <summary>
/// Gets the name of the client return results channel for the specified server.
/// </summary>
/// <param name="serverName">The name of the server to get the client return results channel for.</param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public string ReturnResults(string serverName)
{
return _prefix + ":internal:return:" + serverName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab

private readonly AckHandler _ackHandler;
private int _internalAckId;
private ulong _lastInvocationId;

/// <summary>
/// Constructs the <see cref="RedisHubLifetimeManager{THub}"/> with types from Dependency Injection.
Expand Down Expand Up @@ -72,7 +71,7 @@ public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
_logger = logger;
_options = options.Value;
_ackHandler = new AckHandler();
_channels = new RedisChannels(typeof(THub).FullName!);
_channels = new RedisChannels(typeof(THub).FullName!, _serverName);
if (globalHubOptions != null && hubOptions != null)
{
_protocol = new RedisProtocol(new DefaultHubMessageSerializer(hubProtocolResolver, globalHubOptions.Value.SupportedProtocols, hubOptions.Value.SupportedProtocols));
Expand Down Expand Up @@ -416,8 +415,8 @@ public override async Task<T> InvokeConnectionAsync<T>(string connectionId, stri

var connection = _connections[connectionId];

// Needs to be unique across servers, easiest way to do that is prefix with connection ID.
var invocationId = $"{connectionId}{Interlocked.Increment(ref _lastInvocationId)}";
// ID needs to be unique for each invocation and across servers, we generate a GUID every time, that should provide enough uniqueness guarantees.
var invocationId = GenerateInvocationId();

using var _ = CancellationTokenUtils.CreateLinkedToken(cancellationToken,
connection?.ConnectionAborted ?? default, out var linkedToken);
Expand All @@ -428,7 +427,7 @@ public override async Task<T> InvokeConnectionAsync<T>(string connectionId, stri
if (connection == null)
{
// TODO: Need to handle other server going away while waiting for connection result
var messageBytes = _protocol.WriteInvocation(methodName, args, invocationId, returnChannel: _channels.ReturnResults(_serverName));
var messageBytes = _protocol.WriteInvocation(methodName, args, invocationId, returnChannel: _channels.ReturnResults);
var received = await PublishAsync(_channels.Connection(connectionId), messageBytes);
if (received < 1)
{
Expand Down Expand Up @@ -674,7 +673,7 @@ private async Task SubscribeToGroupAsync(string groupChannel, HubConnectionStore

private async Task SubscribeToReturnResultsAsync()
{
var channel = await _bus!.SubscribeAsync(_channels.ReturnResults(_serverName));
var channel = await _bus!.SubscribeAsync(_channels.ReturnResults);
channel.OnMessage((channelMessage) =>
{
var completion = RedisProtocol.ReadCompletion(channelMessage.Message);
Expand All @@ -700,6 +699,7 @@ private async Task SubscribeToReturnResultsAsync()
Debug.Assert(parseSuccess);

var invocationInfo = _clientResultsManager.RemoveInvocation(((CompletionMessage)hubMessage!).InvocationId!);

invocationInfo?.Completion(invocationInfo?.Tcs!, (CompletionMessage)hubMessage!);
});
}
Expand Down Expand Up @@ -784,6 +784,21 @@ private static string GenerateServerName()
return $"{Environment.MachineName}_{Guid.NewGuid():N}";
}

private static string GenerateInvocationId()
{
Span<byte> buffer = stackalloc byte[16];
var success = Guid.NewGuid().TryWriteBytes(buffer);
Debug.Assert(success);
// 16 * 4/3 = 21.333 which means base64 encoding will use 22 characters of actual data and 2 characters of padding ('=')
Span<char> base64 = stackalloc char[24];
success = Convert.TryToBase64Chars(buffer, base64, out var written);
Debug.Assert(success);
Debug.Assert(written == 24);
// Trim the two '=='
Debug.Assert(base64.EndsWith("=="));
return new string(base64[..^2]);
}

private sealed class LoggerTextWriter : TextWriter
{
private readonly ILogger _logger;
Expand Down