diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs index 15cc579688698f..d7b7a97d724950 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs @@ -3,6 +3,7 @@ // ------------------------------------------------------------------------------ // Changes to this file must follow the https://aka.ms/api-review process. // ------------------------------------------------------------------------------ + namespace System.Net.WebSockets { public sealed partial class ClientWebSocket : System.Net.WebSockets.WebSocket @@ -10,6 +11,8 @@ public sealed partial class ClientWebSocket : System.Net.WebSockets.WebSocket public ClientWebSocket() { } public override System.Net.WebSockets.WebSocketCloseStatus? CloseStatus { get { throw null; } } public override string? CloseStatusDescription { get { throw null; } } + public System.Net.HttpStatusCode HttpStatusCode { get { throw null; } } + public System.Collections.Generic.IReadOnlyDictionary>? HttpResponseHeaders { get { throw null; } set { } } public System.Net.WebSockets.ClientWebSocketOptions Options { get { throw null; } } public override System.Net.WebSockets.WebSocketState State { get { throw null; } } public override string? SubProtocol { get { throw null; } } @@ -32,6 +35,8 @@ internal ClientWebSocketOptions() { } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.CookieContainer? Cookies { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] + public bool CollectHttpResponseDetails { get { throw null; } set { } } + [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.ICredentials? Credentials { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } diff --git a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj index f15169b1fdd47f..b7da9d3e90becb 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj +++ b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj @@ -13,6 +13,7 @@ + diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs index e01b4fcf46a876..59096fc864d3a9 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs @@ -82,6 +82,13 @@ public System.Net.CookieContainer Cookies set => throw new PlatformNotSupportedException(); } + [UnsupportedOSPlatform("browser")] + public bool CollectHttpResponseDetails + { + get => throw new PlatformNotSupportedException(); + set => throw new PlatformNotSupportedException(); + } + #endregion HTTP Settings #region WebSocket Settings diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs index 2a32dee3b1ee9f..e8a652e0796b0d 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; using System.Diagnostics; using System.Net.Http; using System.Threading; @@ -51,6 +52,21 @@ public override WebSocketState State } } + public System.Net.HttpStatusCode HttpStatusCode => _innerWebSocket?.HttpStatusCode ?? 0; + + // setter to clean up when not needed anymore + public IReadOnlyDictionary>? HttpResponseHeaders + { + get => _innerWebSocket?.HttpResponseHeaders; + set + { + if (_innerWebSocket != null) + { + _innerWebSocket.HttpResponseHeaders = value; + } + } + } + public Task ConnectAsync(Uri uri, CancellationToken cancellationToken) { return ConnectAsync(uri, null, cancellationToken); diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index 5f8027abda7bb2..463ccf01c4c4d8 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -28,6 +28,7 @@ public sealed class ClientWebSocketOptions internal List? _requestedSubProtocols; private Version _version = Net.HttpVersion.Version11; private HttpVersionPolicy _versionPolicy = HttpVersionPolicy.RequestVersionOrLower; + private bool _collectHttpResponseDetails; internal ClientWebSocketOptions() { } // prevent external instantiation @@ -232,6 +233,17 @@ public void SetBuffer(int receiveBufferSize, int sendBufferSize, ArraySegment _collectHttpResponseDetails; + set + { + ThrowIfReadOnly(); + _collectHttpResponseDetails = value; + } + } + #endregion WebSocket settings #region Helpers diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/HttpResponseHeadersReadOnlyCollection.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/HttpResponseHeadersReadOnlyCollection.cs new file mode 100644 index 00000000000000..4e874a2bc2335c --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/HttpResponseHeadersReadOnlyCollection.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Net.Http.Headers; + +namespace System.Net.WebSockets +{ + internal sealed class HttpResponseHeadersReadOnlyCollection : IReadOnlyDictionary> + { + private readonly HttpHeadersNonValidated _headers; + + public HttpResponseHeadersReadOnlyCollection(HttpResponseHeaders headers) => _headers = headers.NonValidated; + + public IEnumerable this[string key] => _headers[key]; + + public IEnumerable Keys + { + get + { + foreach (KeyValuePair header in _headers) + { + yield return header.Key; + } + } + } + + public IEnumerable> Values + { + get + { + foreach (KeyValuePair header in _headers) + { + yield return header.Value; + } + } + } + + public int Count => _headers.Count; + + public bool ContainsKey(string key) => _headers.Contains(key); + + public IEnumerator>> GetEnumerator() + { + foreach (KeyValuePair header in _headers) + { + yield return new KeyValuePair>(header.Key, header.Value); + } + } + + public bool TryGetValue(string key, [MaybeNullWhen(false)] out IEnumerable value) + { + if (_headers.TryGetValues(key, out HeaderStringValues values)) + { + value = values; + return true; + } + + value = null; + return false; + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs index 2addc85ea5aed1..f90eb22fa1e4f5 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; using System.Net.Http; using System.Threading; using System.Threading.Tasks; @@ -10,6 +11,11 @@ namespace System.Net.WebSockets internal sealed class WebSocketHandle { private WebSocketState _state = WebSocketState.Connecting; +#pragma warning disable CA1822 // Mark members as static + public HttpStatusCode HttpStatusCode => (HttpStatusCode)0; +#pragma warning restore CA1822 // Mark members as static + + public IReadOnlyDictionary>? HttpResponseHeaders { get; set; } public WebSocket? WebSocket { get; private set; } public WebSocketState State => WebSocket?.State ?? _state; diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index 480ea91ce1e3e3..d9dd9c1b1b1771 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -27,6 +27,9 @@ internal sealed class WebSocketHandle public WebSocket? WebSocket { get; private set; } public WebSocketState State => WebSocket?.State ?? _state; + public HttpStatusCode HttpStatusCode { get; private set; } + + public IReadOnlyDictionary>? HttpResponseHeaders { get; set; } public static ClientWebSocketOptions CreateDefaultOptions() => new ClientWebSocketOptions() { Proxy = DefaultWebProxy.Instance }; @@ -48,6 +51,8 @@ public async Task ConnectAsync(Uri uri, HttpMessageInvoker? invoker, Cancellatio invoker ??= new HttpMessageInvoker(SetupHandler(options, out disposeHandler)); HttpResponseMessage? response = null; + bool disposeResponse = false; + bool tryDowngrade = false; try { @@ -187,7 +192,7 @@ public async Task ConnectAsync(Uri uri, HttpMessageInvoker? invoker, Cancellatio } Abort(); - response?.Dispose(); + disposeResponse = true; if (exc is WebSocketException || (exc is OperationCanceledException && cancellationToken.IsCancellationRequested)) @@ -199,6 +204,20 @@ public async Task ConnectAsync(Uri uri, HttpMessageInvoker? invoker, Cancellatio } finally { + if (response is not null) + { + if (options.CollectHttpResponseDetails) + { + HttpStatusCode = response.StatusCode; + HttpResponseHeaders = new HttpResponseHeadersReadOnlyCollection(response.Headers); + } + + if (disposeResponse) + { + response.Dispose(); + } + } + // Disposing the handler will not affect any active stream wrapped in the WebSocket. if (disposeHandler) { diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs index 31c375a9cd499a..416fd020da995f 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; +using System.Linq; using System.Net.Test.Common; using System.Threading; using System.Threading.Tasks; @@ -313,5 +314,71 @@ await server.AcceptConnectionAsync(async connection => catch (IOException) { } }, new LoopbackServer.Options { WebSocketEndpoint = true }); } + + [ConditionalFact(nameof(WebSocketsSupported))] + [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] + [SkipOnPlatform(TestPlatforms.Browser, "CollectHttpResponseDetails not supported on Browser")] + public async Task ConnectAsync_HttpResponseDetailsCollectedOnFailure() + { + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using (var clientWebSocket = new ClientWebSocket()) + using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) + { + clientWebSocket.Options.CollectHttpResponseDetails = true; + Task t = clientWebSocket.ConnectAsync(uri, cts.Token); + await Assert.ThrowsAnyAsync(() => t); + + Assert.Equal(HttpStatusCode.Unauthorized, clientWebSocket.HttpStatusCode); + Assert.NotEmpty(clientWebSocket.HttpResponseHeaders); + } + }, server => server.AcceptConnectionSendResponseAndCloseAsync(HttpStatusCode.Unauthorized), new LoopbackServer.Options { WebSocketEndpoint = true }); + } + + [ConditionalFact(nameof(WebSocketsSupported))] + [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] + [SkipOnPlatform(TestPlatforms.Browser, "CollectHttpResponseDetails not supported on Browser")] + public async Task ConnectAsync_HttpResponseDetailsCollectedOnFailure_CustomHeader() + { + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using (var clientWebSocket = new ClientWebSocket()) + using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) + { + clientWebSocket.Options.CollectHttpResponseDetails = true; + Task t = clientWebSocket.ConnectAsync(uri, cts.Token); + await Assert.ThrowsAnyAsync(() => t); + + Assert.Equal(HttpStatusCode.Unauthorized, clientWebSocket.HttpStatusCode); + Assert.NotEmpty(clientWebSocket.HttpResponseHeaders); + Assert.Contains("X-CustomHeader1", clientWebSocket.HttpResponseHeaders); + Assert.Contains("X-CustomHeader2", clientWebSocket.HttpResponseHeaders); + Assert.NotNull(clientWebSocket.HttpResponseHeaders.Values); + } + }, server => server.AcceptConnectionSendResponseAndCloseAsync(HttpStatusCode.Unauthorized, "X-CustomHeader1: Value1\r\nX-CustomHeader2: Value2\r\n"), new LoopbackServer.Options { WebSocketEndpoint = true }); + } + + [ConditionalFact(nameof(WebSocketsSupported))] + [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] + [SkipOnPlatform(TestPlatforms.Browser, "CollectHttpResponseDetails not supported on Browser")] + public async Task ConnectAsync_HttpResponseDetailsCollectedOnSuccess_Extentions() + { + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using (var clientWebSocket = new ClientWebSocket()) + using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) + { + clientWebSocket.Options.CollectHttpResponseDetails = true; + await clientWebSocket.ConnectAsync(uri, cts.Token); + + Assert.Equal(HttpStatusCode.SwitchingProtocols, clientWebSocket.HttpStatusCode); + Assert.NotEmpty(clientWebSocket.HttpResponseHeaders); + Assert.Contains("Sec-WebSocket-Extensions", clientWebSocket.HttpResponseHeaders); + } + }, server => server.AcceptConnectionAsync(async connection => + { + Dictionary headers = await LoopbackHelper.WebSocketHandshakeAsync(connection, "X-CustomHeader1"); + }), new LoopbackServer.Options { WebSocketEndpoint = true }); + } } }