Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,11 @@ public async Task ServerLogsErrorIfClientInvokeCannotBeSerialized(string protoco
};

var protocol = HubProtocols[protocolName];
await using (var server = await StartServer<Startup>(write => write.EventId.Name == "FailedWritingMessage"))
await using (var server = await StartServer<Startup>(write =>
{
return write.EventId.Name == "FailedWritingMessage" || write.EventId.Name == "ReceivedCloseWithError"
|| write.EventId.Name == "ShutdownWithError";
}))
{
var connection = CreateHubConnection(server.Url, "/default", HttpTransportType.WebSockets, protocol, LoggerFactory);
var closedTcs = new TaskCompletionSource<Exception>(TaskCreationOptions.RunContinuationsAsynchronously);
Expand All @@ -1361,9 +1365,12 @@ public async Task ServerLogsErrorIfClientInvokeCannotBeSerialized(string protoco
var result = connection.InvokeAsync<string>(nameof(TestHub.CallWithUnserializableObject));

// The connection should close.
Assert.Null(await closedTcs.Task.OrTimeout());
var exception = await closedTcs.Task.OrTimeout();
Assert.Contains("Connection closed with an error.", exception.Message);

await Assert.ThrowsAsync<TaskCanceledException>(() => result).OrTimeout();
var hubException = await Assert.ThrowsAsync<HubException>(() => result).OrTimeout();
Assert.Contains("Connection closed with an error.", hubException.Message);
Assert.Contains(exceptionSubstring, hubException.Message);
}
catch (Exception ex)
{
Expand Down Expand Up @@ -1396,7 +1403,11 @@ public async Task ServerLogsErrorIfReturnValueCannotBeSerialized(string protocol
};

var protocol = HubProtocols[protocolName];
await using (var server = await StartServer<Startup>(write => write.EventId.Name == "FailedWritingMessage"))
await using (var server = await StartServer<Startup>(write =>
{
return write.EventId.Name == "FailedWritingMessage" || write.EventId.Name == "ReceivedCloseWithError"
|| write.EventId.Name == "ShutdownWithError";
}))
{
var connection = CreateHubConnection(server.Url, "/default", HttpTransportType.LongPolling, protocol, LoggerFactory);
var closedTcs = new TaskCompletionSource<Exception>(TaskCreationOptions.RunContinuationsAsynchronously);
Expand All @@ -1408,9 +1419,12 @@ public async Task ServerLogsErrorIfReturnValueCannotBeSerialized(string protocol
var result = connection.InvokeAsync<string>(nameof(TestHub.GetUnserializableObject)).OrTimeout();

// The connection should close.
Assert.Null(await closedTcs.Task.OrTimeout());
var exception = await closedTcs.Task.OrTimeout();
Assert.Contains("Connection closed with an error.", exception.Message);

await Assert.ThrowsAsync<TaskCanceledException>(() => result).OrTimeout();
var hubException = await Assert.ThrowsAsync<HubException>(() => result).OrTimeout();
Assert.Contains("Connection closed with an error.", hubException.Message);
Assert.Contains(exceptionSubstring, hubException.Message);
}
catch (Exception ex)
{
Expand Down
46 changes: 46 additions & 0 deletions src/SignalR/server/Core/src/HubConnectionContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,52 @@ internal StreamTracker StreamTracker
// Currently used only for streaming methods
internal ConcurrentDictionary<string, CancellationTokenSource> ActiveRequestCancellationSources { get; } = new ConcurrentDictionary<string, CancellationTokenSource>(StringComparer.Ordinal);

// Only used when writing CloseMessage, it ignores _connectionAborted to be the final message sent to the client
internal ValueTask WriteCloseAsync(CloseMessage message, CancellationToken cancellationToken = default)
{
// Try to grab the lock synchronously, if we fail, go to the slower path
if (!_writeLock.Wait(0))
{
return new ValueTask(WriteCloseSlowAsync(message, cancellationToken));
}

// This method should never throw synchronously
var task = WriteCore(message, cancellationToken);

// The write didn't complete synchronously so await completion
if (!task.IsCompletedSuccessfully)
{
return new ValueTask(CompleteWriteAsync(task));
}

// Otherwise, release the lock acquired when entering WriteAsync
_writeLock.Release();

return default;
}

private async Task WriteCloseSlowAsync(CloseMessage message, CancellationToken cancellationToken)
{
// Failed to get the lock immediately when entering WriteAsync so await until it is available
await _writeLock.WaitAsync(cancellationToken);

try
{

await WriteCore(message, cancellationToken);
}
catch (Exception ex)
{
CloseException = ex;
Log.FailedWritingMessage(_logger, ex);
AbortAllowReconnect();
}
finally
{
_writeLock.Release();
}
}

public virtual ValueTask WriteAsync(HubMessage message, CancellationToken cancellationToken = default)
{
// Try to grab the lock synchronously, if we fail, go to the slower path
Expand Down
2 changes: 1 addition & 1 deletion src/SignalR/server/Core/src/HubConnectionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ private async Task SendCloseAsync(HubConnectionContext connection, Exception? ex

try
{
await connection.WriteAsync(closeMessage);
await connection.WriteCloseAsync(closeMessage);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer adding an ignoreAbort = false parameter to WriteAsync. Where it checks if (_connectionAborted) today, we'd change that to if (_connectionAborted && !ignoreAbort). That way if we make a change to WriteAsync in the future, we get it for free without having to remember to also update WriteCloseAsync.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and add a branch to every single write!? 😮

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the branch is already there 😆

}
catch (Exception ex)
{
Expand Down
23 changes: 12 additions & 11 deletions src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ public async Task AbortFromHubMethodForcesClientDisconnect()

await client.SendInvocationAsync(nameof(AbortHub.Kill)).OrTimeout();

var close = Assert.IsType<CloseMessage>(await client.ReadAsync().OrTimeout());
Assert.False(close.AllowReconnect);

await connectionHandlerTask.OrTimeout();

Assert.Null(client.TryRead());
Expand Down Expand Up @@ -955,15 +958,18 @@ public async Task HubMethodListeningToConnectionAbortedClosesOnConnectionContext
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);

var invokeTask = client.InvokeAsync(nameof(MethodHub.BlockingMethod));
await client.SendInvocationAsync(nameof(MethodHub.BlockingMethod)).OrTimeout();

client.Connection.Abort();

var closeMessage = Assert.IsType<CloseMessage>(await client.ReadAsync().OrTimeout());
Assert.False(closeMessage.AllowReconnect);

// If this completes then the server has completed the connection
await connectionHandlerTask.OrTimeout();

// Nothing written to connection because it was closed
Assert.False(invokeTask.IsCompleted);
Assert.Null(client.TryRead());
}
}
}
Expand Down Expand Up @@ -1019,16 +1025,11 @@ public async Task HubMethodDoesNotSendResultWhenInvocationIsNonBlocking()
// kill the connection
client.Dispose();

var message = Assert.IsType<CloseMessage>(client.TryRead());
Assert.True(message.AllowReconnect);

// Ensure the client channel is empty
var message = client.TryRead();
switch (message)
{
case CloseMessage close:
break;
default:
Assert.Null(message);
break;
}
Assert.Null(client.TryRead());

await connectionHandlerTask.OrTimeout();
}
Expand Down