diff --git a/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs index 14affae6bd39e5..063ee71169d17e 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs @@ -12,6 +12,7 @@ using Xunit; using Xunit.Abstractions; +using Microsoft.DotNet.RemoteExecutor; namespace System.Net.WebSockets.Client.Tests { @@ -523,5 +524,42 @@ await Assert.ThrowsAnyAsync(async () => }), new LoopbackServer.Options { WebSocketEndpoint = true }); } + + // Regression test for https://github.com/dotnet/runtime/issues/80116. + [OuterLoop("Uses Task.Delay")] + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public async Task CloseHandshake_ExceptionsAreObserved() + { + await RemoteExecutor.Invoke(static (typeName) => + { + CloseTest test = (CloseTest)Activator.CreateInstance(typeof(CloseTest).Assembly.GetType(typeName), new object[] { null }); + using CancellationTokenSource timeoutCts = new CancellationTokenSource(TimeOutMilliseconds); + + Exception unobserved = null; + TaskScheduler.UnobservedTaskException += (obj, args) => + { + unobserved = args.Exception; + }; + + TaskCompletionSource clientCompleted = new TaskCompletionSource(); + + return LoopbackWebSocketServer.RunAsync(async (clientWs, ct) => + { + await clientWs.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", ct); + await clientWs.ReceiveAsync(new byte[16], ct); + await Task.Delay(1500); + GC.Collect(2); + GC.WaitForPendingFinalizers(); + clientCompleted.SetResult(); + Assert.Null(unobserved); + }, + async (serverWs, ct) => + { + await serverWs.ReceiveAsync(new byte[16], ct); + await serverWs.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", ct); + await clientCompleted.Task; + }, new LoopbackWebSocketServer.Options(HttpVersion.Version11, true, test.GetInvoker()), timeoutCts.Token); + }, GetType().FullName).DisposeAsync(); + } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj index 98f61386b3eaa8..d5e28ff2f552cb 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj +++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj @@ -3,6 +3,7 @@ ../src/Resources/Strings.resx + true $(NetCoreAppCurrent);$(NetCoreAppCurrent)-browser $(DefineConstants);NETSTANDARD @@ -46,6 +47,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs index 4e38c4fc10042f..1583872c881031 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs @@ -29,7 +29,7 @@ private void UnsolicitedPongHeartBeat() { if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); - Observe( + LogExceptions( TrySendKeepAliveFrameAsync(MessageOpcode.Pong)); } @@ -98,7 +98,7 @@ private void KeepAlivePingHeartBeat() if (shouldSendPing) { - Observe( + LogExceptions( SendPingAsync(pingPayload)); } } @@ -122,52 +122,6 @@ private async ValueTask SendPingAsync(long pingPayload) if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload); } - // "Observe" either a ValueTask result, or any exception, ignoring it - // to prevent the unobserved exception event from being raised. - private void Observe(ValueTask t) - { - if (t.IsCompletedSuccessfully) - { - t.GetAwaiter().GetResult(); - } - else - { - Observe(t.AsTask()); - } - } - - // "Observe" any exception, ignoring it to prevent the unobserved task - // exception event from being raised. - private void Observe(Task t) - { - if (t.IsCompleted) - { - if (t.IsFaulted) - { - LogFaulted(t, this); - } - } - else - { - t.ContinueWith( - LogFaulted, - this, - CancellationToken.None, - TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, - TaskScheduler.Default); - } - - static void LogFaulted(Task task, object? thisObj) - { - Debug.Assert(task.IsFaulted); - - // accessing exception to observe it regardless of whether the tracing is enabled - Exception e = task.Exception!.InnerException!; - - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(thisObj, e); - } - } - private sealed class KeepAlivePingState { internal const int PingPayloadSize = sizeof(long); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index f40da1cb06e8c5..594c4bae2b96db 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -1144,6 +1144,7 @@ private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken ca // additional data, but at this point we're about to close the connection and we're just stalling // to try to get the server to close first. ValueTask finalReadTask = _stream.ReadAsync(_receiveBuffer, cancellationToken); + if (finalReadTask.IsCompletedSuccessfully) { finalReadTask.GetAwaiter().GetResult(); @@ -1151,16 +1152,19 @@ private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken ca else { const int WaitForCloseTimeoutMs = 1_000; // arbitrary amount of time to give the server (same duration as .NET Framework) + Task task = finalReadTask.AsTask(); + try { #pragma warning disable CA2016 // Token was already provided to the ReadAsync - await finalReadTask.AsTask().WaitAsync(TimeSpan.FromMilliseconds(WaitForCloseTimeoutMs)).ConfigureAwait(false); + await task.WaitAsync(TimeSpan.FromMilliseconds(WaitForCloseTimeoutMs)).ConfigureAwait(false); #pragma warning restore CA2016 } catch { + // Eat any resulting exceptions. We were going to close the connection, anyway. + LogExceptions(task); Abort(); - // Eat any resulting exceptions. We were going to close the connection, anyway. } } } @@ -1851,6 +1855,52 @@ private static bool TryValidateUtf8(ReadOnlySpan span, bool endOfMessage, return !endOfMessage || !state.SequenceInProgress; } + // "Observe" either a ValueTask result, or any exception, logging and ignoring it + // to prevent the unobserved exception event from being raised. + private void LogExceptions(ValueTask t) + { + if (t.IsCompletedSuccessfully) + { + t.GetAwaiter().GetResult(); + } + else + { + LogExceptions(t.AsTask()); + } + } + + // "Observe" and log any exception, ignoring it to prevent the unobserved task + // exception event from being raised. + private void LogExceptions(Task t) + { + if (t.IsCompleted) + { + if (t.IsFaulted) + { + LogFaulted(t, this); + } + } + else + { + t.ContinueWith( + LogFaulted, + this, + CancellationToken.None, + TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); + } + + static void LogFaulted(Task task, object? thisObj) + { + Debug.Assert(task.IsFaulted); + + // accessing exception to observe it regardless of whether the tracing is enabled + Exception e = task.Exception!.InnerException!; + + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(thisObj, e); + } + } + private sealed class Utf8MessageState { internal bool SequenceInProgress;