Skip to content

WebSocket: observe exceptions in WaitForServerToCloseConnectionAsync #114689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

using Xunit;
using Xunit.Abstractions;
using Microsoft.DotNet.RemoteExecutor;

namespace System.Net.WebSockets.Client.Tests
{
Expand Down Expand Up @@ -523,5 +524,42 @@ await Assert.ThrowsAnyAsync<OperationCanceledException>(async () =>

}), new LoopbackServer.Options { WebSocketEndpoint = true });
}

// Regression test for https://github.com/dotnet/runtime/issues/80116.
[OuterLoop("Uses Task.Delay")]
[ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
public async Task CloseHandshake_ExceptionsAreObserved()
{
await RemoteExecutor.Invoke(static (typeName) =>
{
CloseTest test = (CloseTest)Activator.CreateInstance(typeof(CloseTest).Assembly.GetType(typeName), new object[] { null });
using CancellationTokenSource timeoutCts = new CancellationTokenSource(TimeOutMilliseconds);

Exception unobserved = null;
TaskScheduler.UnobservedTaskException += (obj, args) =>
{
unobserved = args.Exception;
};

TaskCompletionSource clientCompleted = new TaskCompletionSource();

return LoopbackWebSocketServer.RunAsync(async (clientWs, ct) =>
{
await clientWs.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", ct);
await clientWs.ReceiveAsync(new byte[16], ct);
await Task.Delay(1500);
GC.Collect(2);
GC.WaitForPendingFinalizers();
clientCompleted.SetResult();
Assert.Null(unobserved);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was that enough to consistently get an unobserved task exception before the PR change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's failing without the change.

},
async (serverWs, ct) =>
{
await serverWs.ReceiveAsync(new byte[16], ct);
await serverWs.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", ct);
await clientCompleted.Task;
}, new LoopbackWebSocketServer.Options(HttpVersion.Version11, true, test.GetInvoker()), timeoutCts.Token);
}, GetType().FullName).DisposeAsync();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

<PropertyGroup>
<StringResourcesPath>../src/Resources/Strings.resx</StringResourcesPath>
<IncludeRemoteExecutor>true</IncludeRemoteExecutor>
<TargetFrameworks>$(NetCoreAppCurrent);$(NetCoreAppCurrent)-browser</TargetFrameworks>
<DefineConstants>$(DefineConstants);NETSTANDARD</DefineConstants>
</PropertyGroup>
Expand Down Expand Up @@ -46,6 +47,7 @@
<Compile Include="$(CommonTestPath)System\Net\Http\HuffmanEncoder.cs" Link="Common\System\Net\Http\HuffmanEncoder.cs" />
<Compile Include="$(CommonTestPath)System\Net\Http\HPackEncoder.cs" Link="Common\System\Net\Http\HPackEncoder.cs" />
<Compile Include="$(CommonTestPath)System\Net\Http\GenericLoopbackServer.cs" Link="Common\System\Net\Http\GenericLoopbackServer.cs" />
<Compile Include="$(CommonTestPath)System\Net\RemoteExecutorExtensions.cs" Link="Common\System\Net\RemoteExecutorExtensions.cs" />
<Compile Include="$(CommonTestPath)System\Threading\Tasks\TaskTimeoutExtensions.cs" Link="Common\System\Threading\Tasks\TaskTimeoutExtensions.cs" />
<Compile Include="AbortTest.cs" />
<Compile Include="AbortTest.Loopback.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ private void UnsolicitedPongHeartBeat()
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);

Observe(
LogExceptions(
TrySendKeepAliveFrameAsync(MessageOpcode.Pong));
}

Expand Down Expand Up @@ -98,7 +98,7 @@ private void KeepAlivePingHeartBeat()

if (shouldSendPing)
{
Observe(
LogExceptions(
SendPingAsync(pingPayload));
}
}
Expand All @@ -122,52 +122,6 @@ private async ValueTask SendPingAsync(long pingPayload)
if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload);
}

// "Observe" either a ValueTask result, or any exception, ignoring it
// to prevent the unobserved exception event from being raised.
private void Observe(ValueTask t)
{
if (t.IsCompletedSuccessfully)
{
t.GetAwaiter().GetResult();
}
else
{
Observe(t.AsTask());
}
}

// "Observe" any exception, ignoring it to prevent the unobserved task
// exception event from being raised.
private void Observe(Task t)
{
if (t.IsCompleted)
{
if (t.IsFaulted)
{
LogFaulted(t, this);
}
}
else
{
t.ContinueWith(
LogFaulted,
this,
CancellationToken.None,
TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously,
TaskScheduler.Default);
}

static void LogFaulted(Task task, object? thisObj)
{
Debug.Assert(task.IsFaulted);

// accessing exception to observe it regardless of whether the tracing is enabled
Exception e = task.Exception!.InnerException!;

if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(thisObj, e);
}
}

private sealed class KeepAlivePingState
{
internal const int PingPayloadSize = sizeof(long);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1144,23 +1144,27 @@ private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken ca
// additional data, but at this point we're about to close the connection and we're just stalling
// to try to get the server to close first.
ValueTask<int> finalReadTask = _stream.ReadAsync(_receiveBuffer, cancellationToken);

if (finalReadTask.IsCompletedSuccessfully)
{
finalReadTask.GetAwaiter().GetResult();
}
else
{
const int WaitForCloseTimeoutMs = 1_000; // arbitrary amount of time to give the server (same duration as .NET Framework)
Task task = finalReadTask.AsTask();

try
{
#pragma warning disable CA2016 // Token was already provided to the ReadAsync
await finalReadTask.AsTask().WaitAsync(TimeSpan.FromMilliseconds(WaitForCloseTimeoutMs)).ConfigureAwait(false);
await task.WaitAsync(TimeSpan.FromMilliseconds(WaitForCloseTimeoutMs)).ConfigureAwait(false);
#pragma warning restore CA2016
}
catch
{
// Eat any resulting exceptions. We were going to close the connection, anyway.
LogExceptions(task);
Abort();
// Eat any resulting exceptions. We were going to close the connection, anyway.
}
}
}
Expand Down Expand Up @@ -1851,6 +1855,52 @@ private static bool TryValidateUtf8(ReadOnlySpan<byte> span, bool endOfMessage,
return !endOfMessage || !state.SequenceInProgress;
}

// "Observe" either a ValueTask result, or any exception, logging and ignoring it
// to prevent the unobserved exception event from being raised.
private void LogExceptions(ValueTask t)
{
if (t.IsCompletedSuccessfully)
{
t.GetAwaiter().GetResult();
}
else
{
LogExceptions(t.AsTask());
}
}

// "Observe" and log any exception, ignoring it to prevent the unobserved task
// exception event from being raised.
private void LogExceptions(Task t)
{
if (t.IsCompleted)
{
if (t.IsFaulted)
{
LogFaulted(t, this);
}
}
else
{
t.ContinueWith(
LogFaulted,
this,
CancellationToken.None,
TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously,
TaskScheduler.Default);
}

static void LogFaulted(Task task, object? thisObj)
{
Debug.Assert(task.IsFaulted);

// accessing exception to observe it regardless of whether the tracing is enabled
Exception e = task.Exception!.InnerException!;

if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(thisObj, e);
}
}

private sealed class Utf8MessageState
{
internal bool SequenceInProgress;
Expand Down
Loading