From 705c1715ada76786cb79a073a0c175ac9cf13d8d Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 4 Mar 2021 21:04:08 +0200 Subject: [PATCH 01/47] Prototyped how deflate/inflate should be plugged in. --- .../System/Net/WebSockets/ManagedWebSocket.cs | 252 ++++++++++++++---- 1 file changed, 194 insertions(+), 58 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 9a0142f9c73b36..8ef34f119ac4ce 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -3,6 +3,7 @@ using System.Buffers; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Numerics; using System.Runtime.CompilerServices; @@ -151,6 +152,33 @@ public static ManagedWebSocket CreateFromConnectedStream( /// private object ReceiveAsyncLock => _utf8TextState; // some object, as we're simply lock'ing on it + /// + /// Indicates whether compression is enabled for the receiving part of the websocket. + /// + private readonly bool _inflateEnabled; + private byte[]? _inflateBuffer; + + /// + /// The position for the next unconsumed byte in the inflate buffer. + /// + private int _inflateBufferPosition; + + /// + /// How many unconsumed bytes are left in the inflate buffer. + /// + private int _inflateBufferAvailable; + + /// + /// Because of how the underlying zlib inflater works, we may have consumed the entire payload, but still + /// might have data left in the native component, we need to track if inflating has really finished. + /// + private bool _inflateFinished = true; + + /// + /// Indicates whether compression is enabled for the sending part of the websocket. + /// + private readonly bool _deflateEnabled; + /// Initializes the websocket. /// The connected Stream. /// true if this is the server-side of the connection; false if this is the client-side of the connection. @@ -170,6 +198,8 @@ private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Time _stream = stream; _isServer = isServer; _subprotocol = subprotocol; + _inflateEnabled = false; + _deflateEnabled = false; // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and // control payloads (at most 125 bytes). Message payloads are read directly into the buffer @@ -530,42 +560,100 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM /// Writes a frame into the send buffer, which can then be sent over the network. private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, ReadOnlySpan payloadBuffer) { - // Ensure we have a _sendBuffer. - AllocateSendBuffer(payloadBuffer.Length + MaxMessageHeaderLength); - Debug.Assert(_sendBuffer != null); - - // Write the message header data to the buffer. - int headerLength; - int? maskOffset = null; - if (_isServer) + try { - // The server doesn't send a mask, so the mask offset returned by WriteHeader - // is actually the end of the header. - headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false); + if (_deflateEnabled) + { + payloadBuffer = Deflate(payloadBuffer); + } + + // Ensure we have a _sendBuffer. + AllocateSendBuffer(payloadBuffer.Length + MaxMessageHeaderLength); + Debug.Assert(_sendBuffer != null); + + // Write the message header data to the buffer. + int headerLength; + int? maskOffset = null; + if (_isServer) + { + // The server doesn't send a mask, so the mask offset returned by WriteHeader + // is actually the end of the header. + headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false); + } + else + { + // We need to know where the mask starts so that we can use the mask to manipulate the payload data, + // and we need to know the total length for sending it on the wire. + maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true); + headerLength = maskOffset.GetValueOrDefault() + MaskLength; + } + + // Write the payload + if (payloadBuffer.Length > 0) + { + payloadBuffer.CopyTo(new Span(_sendBuffer, headerLength, payloadBuffer.Length)); + + // If we added a mask to the header, XOR the payload with the mask. We do the manipulation in the send buffer so as to avoid + // changing the data in the caller-supplied payload buffer. + if (maskOffset.HasValue) + { + ApplyMask(new Span(_sendBuffer, headerLength, payloadBuffer.Length), _sendBuffer, maskOffset.Value, 0); + } + } + + // Return the number of bytes in the send buffer + return headerLength + payloadBuffer.Length; } - else + finally { - // We need to know where the mask starts so that we can use the mask to manipulate the payload data, - // and we need to know the total length for sending it on the wire. - maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true); - headerLength = maskOffset.GetValueOrDefault() + MaskLength; + if (_deflateEnabled) + { + ReleaseDeflateBuffer(); + } } + } - // Write the payload - if (payloadBuffer.Length > 0) - { - payloadBuffer.CopyTo(new Span(_sendBuffer, headerLength, payloadBuffer.Length)); + private ReadOnlySpan Deflate(ReadOnlySpan payload) + { + // This function assumes that we're going to use a single buffer. + throw new NotImplementedException(); + } - // If we added a mask to the header, XOR the payload with the mask. We do the manipulation in the send buffer so as to avoid - // changing the data in the caller-supplied payload buffer. - if (maskOffset.HasValue) - { - ApplyMask(new Span(_sendBuffer, headerLength, payloadBuffer.Length), _sendBuffer, maskOffset.Value, 0); - } + private void ReleaseDeflateBuffer() + { + throw new NotImplementedException(); + } + + private bool Inflate(Span output, out int bytesWritten) + { + int consumed = 42; + + _inflateBufferPosition += consumed; + _inflateBufferAvailable -= consumed; + + if (_inflateBufferAvailable == 0) + { + ReleaseInflateBuffer(); } - // Return the number of bytes in the send buffer - return headerLength + payloadBuffer.Length; + throw new NotImplementedException(); + } + + [MemberNotNull(nameof(_inflateBuffer))] + private void RentInflateBuffer(long payloadLength) + { + _inflateBufferPosition = 0; + _inflateBuffer = ArrayPool.Shared.Rent((int)Math.Min(payloadLength, 1_000_000)); + + throw new NotImplementedException(); + } + + private void ReleaseInflateBuffer() + { + if (_inflateBuffer is not null) + { + ArrayPool.Shared.Return(_inflateBuffer); + } } private void SendKeepAliveFrameAsync() @@ -707,7 +795,7 @@ private async ValueTask ReceiveAsyncPrivate ReceiveAsyncPrivate ReceiveAsyncPrivate 0) - { - int receiveBufferBytesToCopy = Math.Min(payloadBuffer.Length, (int)Math.Min(header.PayloadLength, _receiveBufferCount)); - Debug.Assert(receiveBufferBytesToCopy > 0); - _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(payloadBuffer.Span); - ConsumeFromBuffer(receiveBufferBytesToCopy); - totalBytesReceived += receiveBufferBytesToCopy; - Debug.Assert( - _receiveBufferCount == 0 || - totalBytesReceived == payloadBuffer.Length || - totalBytesReceived == header.PayloadLength); - } - // Then read directly into the payload buffer until we've hit a limit. - while (totalBytesReceived < payloadBuffer.Length && - totalBytesReceived < header.PayloadLength) + // Only start a new receive when we've consumed everything from the inflate buffer. When + // there is no compression, this will always be 0. + if (_inflateBufferAvailable == 0) { - int numBytesRead = await _stream.ReadAsync(payloadBuffer.Slice( - totalBytesReceived, - (int)Math.Min(payloadBuffer.Length, header.PayloadLength) - totalBytesReceived), cancellationToken).ConfigureAwait(false); - if (numBytesRead <= 0) + if (_receiveBufferCount > 0) { - ThrowIfEOFUnexpected(throwOnPrematureClosure: true); - break; + int receiveBufferBytesToCopy = Math.Min(payloadBuffer.Length, (int)Math.Min(header.PayloadLength, _receiveBufferCount)); + Debug.Assert(receiveBufferBytesToCopy > 0); + + if (header.Compressed) + { + Debug.Assert(_inflateBufferAvailable == 0); + RentInflateBuffer(header.PayloadLength); + + _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(_inflateBuffer); + _inflateBufferAvailable += receiveBufferBytesToCopy; + ConsumeFromBuffer(receiveBufferBytesToCopy); + totalBytesReceived += receiveBufferBytesToCopy; + Debug.Assert(_receiveBufferCount == 0 || totalBytesReceived == header.PayloadLength); + } + else + { + _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(payloadBuffer.Span); + ConsumeFromBuffer(receiveBufferBytesToCopy); + totalBytesReceived += receiveBufferBytesToCopy; + Debug.Assert( + _receiveBufferCount == 0 || + totalBytesReceived == payloadBuffer.Length || + totalBytesReceived == header.PayloadLength); + } } - totalBytesReceived += numBytesRead; + + // Then read directly into the payload buffer until we've hit a limit. + while (totalBytesReceived < payloadBuffer.Length && + totalBytesReceived < header.PayloadLength) + { + int numBytesRead = await _stream.ReadAsync( + header.Compressed ? + _inflateBuffer.AsMemory(totalBytesReceived, (int)Math.Min(_inflateBuffer!.Length, header.PayloadLength) - totalBytesReceived) : + payloadBuffer.Slice(totalBytesReceived, (int)Math.Min(payloadBuffer.Length, header.PayloadLength) - totalBytesReceived), + cancellationToken).ConfigureAwait(false); + if (numBytesRead <= 0) + { + ThrowIfEOFUnexpected(throwOnPrematureClosure: true); + break; + } + totalBytesReceived += numBytesRead; + } + + if (_isServer) + { + _receivedMaskOffsetOffset = ApplyMask(header.Compressed ? + _inflateBuffer.AsSpan(0, totalBytesReceived) : + payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset); + } + header.PayloadLength -= totalBytesReceived; } - if (_isServer) + if (header.Compressed) { - _receivedMaskOffsetOffset = ApplyMask(payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset); + // In case of compression totalBytesReceived should actually represent how much we've + // inflated, rather than how much we've read from the stream. + _inflateFinished = Inflate(payloadBuffer.Span, out totalBytesReceived); } - header.PayloadLength -= totalBytesReceived; // If this a text message, validate that it contains valid UTF8. if (header.Opcode == MessageOpcode.Text && @@ -828,8 +949,15 @@ private async ValueTask ReceiveAsyncPrivate receiveBufferSpan = _receiveBuffer.Span; header.Fin = (receiveBufferSpan[_receiveBufferOffset] & 0x80) != 0; - bool reservedSet = (receiveBufferSpan[_receiveBufferOffset] & 0x70) != 0; + bool reservedSet = (receiveBufferSpan[_receiveBufferOffset] & 0b_0011_0000) != 0; header.Opcode = (MessageOpcode)(receiveBufferSpan[_receiveBufferOffset] & 0xF); + header.Compressed = (receiveBufferSpan[_receiveBufferOffset] & 0b_0100_0000) != 0; bool masked = (receiveBufferSpan[_receiveBufferOffset + 1] & 0x80) != 0; header.PayloadLength = receiveBufferSpan[_receiveBufferOffset + 1] & 0x7F; @@ -1083,6 +1212,12 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( return SR.net_Websockets_ReservedBitsSet; } + if (header.Compressed && !_inflateEnabled) + { + resultHeader = default; + return "TODO"; + } + if (masked) { if (!_isServer) @@ -1580,6 +1715,7 @@ private struct MessageHeader internal MessageOpcode Opcode; internal bool Fin; internal long PayloadLength; + internal bool Compressed; internal int Mask; } From 740798cc8cba1e08801f7e81551fa7163cfdc9e1 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 4 Mar 2021 23:46:37 +0200 Subject: [PATCH 02/47] Replaced WebSocketReceiveResultGetter abstraction for the receive with a generic method which can be inlined by the jit. --- .../System/Net/WebSockets/ManagedWebSocket.cs | 67 +++++++------------ 1 file changed, 25 insertions(+), 42 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 8ef34f119ac4ce..749cd9dad0f034 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -329,7 +329,7 @@ public override Task ReceiveAsync(ArraySegment buf lock (ReceiveAsyncLock) // synchronize with receives in CloseAsync { ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted); - Task t = ReceiveAsyncPrivate(buffer, cancellationToken).AsTask(); + Task t = ReceiveAsyncPrivate(buffer, cancellationToken).AsTask(); _lastReceiveAsync = t; return t; } @@ -401,7 +401,7 @@ public override ValueTask ReceiveAsync(Memory { ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted); - ValueTask receiveValueTask = ReceiveAsyncPrivate(buffer, cancellationToken); + ValueTask receiveValueTask = ReceiveAsyncPrivate(buffer, cancellationToken); if (receiveValueTask.IsCompletedSuccessfully) { _lastReceiveAsync = receiveValueTask.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask; @@ -430,7 +430,7 @@ private Task ValidateAndReceiveAsync(Task receiveTask, byte[] buffer, Cancellati !(receiveTask is Task wsrr && wsrr.Result.MessageType == WebSocketMessageType.Close) && !(receiveTask is Task vwsrr && vwsrr.Result.MessageType == WebSocketMessageType.Close))) { - ValueTask vt = ReceiveAsyncPrivate(buffer, cancellationToken); + ValueTask vt = ReceiveAsyncPrivate(buffer, cancellationToken); receiveTask = vt.IsCompletedSuccessfully ? (vt.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask) : vt.AsTask(); @@ -439,13 +439,6 @@ private Task ValidateAndReceiveAsync(Task receiveTask, byte[] buffer, Cancellati return receiveTask; } - /// implementation for . - private readonly struct ValueWebSocketReceiveResultGetter : IWebSocketReceiveResultGetter - { - public ValueWebSocketReceiveResult GetResult(int count, WebSocketMessageType messageType, bool endOfMessage, WebSocketCloseStatus? closeStatus, string? closeDescription) => - new ValueWebSocketReceiveResult(count, messageType, endOfMessage); // closeStatus/closeDescription are ignored - } - /// Sends a websocket frame to the network. /// The opcode for the message. /// The value of the FIN bit for the message. @@ -768,13 +761,8 @@ private static void WriteRandomMask(byte[] buffer, int offset) => /// /// The buffer into which payload data should be written. /// The CancellationToken used to cancel the websocket. - /// Used to get the result. Allows the same method to be used with both WebSocketReceiveResult and ValueWebSocketReceiveResult. /// Information about the received message. - private async ValueTask ReceiveAsyncPrivate( - Memory payloadBuffer, - CancellationToken cancellationToken, - TWebSocketReceiveResultGetter resultGetter = default) - where TWebSocketReceiveResultGetter : struct, IWebSocketReceiveResultGetter // constrained to avoid boxing and enable inlining + private async ValueTask ReceiveAsyncPrivate(Memory payloadBuffer, CancellationToken cancellationToken) { // This is a long method. While splitting it up into pieces would arguably help with readability, doing so would // also result in more allocations, as each async method that yields ends up with multiple allocations. The impact @@ -837,7 +825,7 @@ private async ValueTask ReceiveAsyncPrivate(0, WebSocketMessageType.Close, true); } // If this is a continuation, replace the opcode with the one of the message it's continuing @@ -854,11 +842,10 @@ private async ValueTask ReceiveAsyncPrivate( + count: 0, + messageType: header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, + endOfMessage: header.Fin && header.PayloadLength == 0 && _inflateFinished); } // Otherwise, read as much of the payload as we can efficiently, and update the header to reflect how much data @@ -942,11 +929,10 @@ private async ValueTask ReceiveAsyncPrivate( totalBytesReceived, header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - header.Fin && header.PayloadLength == 0, - null, null); + header.Fin && header.PayloadLength == 0 && _inflateFinished); } } catch (Exception exc) @@ -977,6 +963,20 @@ private async ValueTask ReceiveAsyncPrivate + /// Returns either or . + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private TResult GetReceiveResult(int count, WebSocketMessageType messageType, bool endOfMessage) + { + if (typeof(TResult) == typeof(ValueWebSocketReceiveResult)) + { + return (TResult)(object)new ValueWebSocketReceiveResult(count, messageType, endOfMessage); + } + + return (TResult)(object)new WebSocketReceiveResult(count, messageType, endOfMessage, _closeStatus, _closeStatusDescription); + } + /// Processes a received close message. /// The message header. /// The CancellationToken used to cancel the websocket operation. @@ -1718,22 +1718,5 @@ private struct MessageHeader internal bool Compressed; internal int Mask; } - - /// - /// Interface used by to enable it to return - /// different result types in an efficient manner. - /// - /// The type of the result - private interface IWebSocketReceiveResultGetter - { - TResult GetResult(int count, WebSocketMessageType messageType, bool endOfMessage, WebSocketCloseStatus? closeStatus, string? closeDescription); - } - - /// implementation for . - private readonly struct WebSocketReceiveResultGetter : IWebSocketReceiveResultGetter - { - public WebSocketReceiveResult GetResult(int count, WebSocketMessageType messageType, bool endOfMessage, WebSocketCloseStatus? closeStatus, string? closeDescription) => - new WebSocketReceiveResult(count, messageType, endOfMessage, closeStatus, closeDescription); - } } } From 72a34fd611679a71506b863e4eb86718f9ddc7e5 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 4 Mar 2021 23:54:35 +0200 Subject: [PATCH 03/47] Moved ZLibNative and ZLibNative.ZStream to Common so they can be used in other projects as well. --- .../src/System/IO/Compression}/ZLibNative.ZStream.cs | 0 .../src/System/IO/Compression}/ZLibNative.cs | 0 .../src/System.IO.Compression.csproj | 8 +++++--- 3 files changed, 5 insertions(+), 3 deletions(-) rename src/libraries/{System.IO.Compression/src/System/IO/Compression/DeflateZLib => Common/src/System/IO/Compression}/ZLibNative.ZStream.cs (100%) rename src/libraries/{System.IO.Compression/src/System/IO/Compression/DeflateZLib => Common/src/System/IO/Compression}/ZLibNative.cs (100%) diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.ZStream.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.ZStream.cs similarity index 100% rename from src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.ZStream.cs rename to src/libraries/Common/src/System/IO/Compression/ZLibNative.ZStream.cs diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs similarity index 100% rename from src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs rename to src/libraries/Common/src/System/IO/Compression/ZLibNative.cs diff --git a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj index e2a7adee12f579..60ee6f229fc034 100644 --- a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj +++ b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj @@ -1,4 +1,4 @@ - + true $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser @@ -25,8 +25,10 @@ - - + + From 5128c9f645b26e02dfeaaa6b5f52e0de9fc9ecfb Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 5 Mar 2021 01:09:17 +0200 Subject: [PATCH 04/47] Added compression related implementation. --- .../ref/System.Net.WebSockets.cs | 15 ++ .../src/Resources/Strings.resx | 88 ++++++- .../src/System.Net.WebSockets.csproj | 24 +- .../Compression/WebSocketDeflater.cs | 179 ++++++++++++++ .../Compression/WebSocketInflater.cs | 230 ++++++++++++++++++ .../System/Net/WebSockets/ManagedWebSocket.cs | 128 ++++++---- .../src/System/Net/WebSockets/WebSocket.cs | 31 ++- .../WebSockets/WebSocketCreationOptions.cs | 57 +++++ .../Net/WebSockets/WebSocketDeflateOptions.cs | 78 ++++++ 9 files changed, 780 insertions(+), 50 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs diff --git a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs index e4ff945bb5b647..e93d55aa5fea42 100644 --- a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs +++ b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs @@ -29,6 +29,7 @@ protected WebSocket() { } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] public static System.Net.WebSockets.WebSocket CreateClientWebSocket(System.IO.Stream innerStream, string? subProtocol, int receiveBufferSize, int sendBufferSize, System.TimeSpan keepAliveInterval, bool useZeroMaskingKey, System.ArraySegment internalBuffer) { throw null; } public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, bool isServer, string? subProtocol, System.TimeSpan keepAliveInterval) { throw null; } + public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, System.Net.WebSockets.WebSocketCreationOptions options) { throw null; } public static System.ArraySegment CreateServerBuffer(int receiveBufferSize) { throw null; } public abstract void Dispose(); [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] @@ -131,4 +132,18 @@ public enum WebSocketState Closed = 5, Aborted = 6, } + public sealed partial class WebSocketCreationOptions + { + public bool IsServer { get { throw null; } set { } } + public string? SubProtocol { get { throw null; } set { } } + public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } + public System.Net.WebSockets.WebSocketDeflateOptions? DeflateOptions { get { throw null; } set { } } + } + public sealed partial class WebSocketDeflateOptions + { + public int ClientMaxWindowBits { get { throw null; } set { } } + public bool ClientContextTakeover { get { throw null; } set { } } + public int ServerMaxWindowBits { get { throw null; } set { } } + public bool ServerContextTakeover { get { throw null; } set { } } + } } diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index a4f630ea24c039..d8963299bd87c6 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -1,4 +1,64 @@ - + + + @@ -138,4 +198,28 @@ The base stream is not writeable. - + + The argument must be a value between {0} and {1}. + + + The WebSocket received a continuation frame with Per-Message Compressed flag set. + + + The WebSocket received compressed frame when compression is not enabled. + + + The underlying compression routine could not be loaded correctly. + + + The stream state of the underlying compression routine is inconsistent. + + + The underlying compression routine could not reserve sufficient memory. + + + The underlying compression routine returned an unexpected error code {0}. + + + The message was compressed using an unsupported compression method. + + \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index d65e6c55737af0..577adb80ecfe95 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -1,14 +1,18 @@ - + True - $(NetCoreAppCurrent) + $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser enable + + + + @@ -17,6 +21,22 @@ + + + + + + + + + + + diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs new file mode 100644 index 00000000000000..aa053c4a2cccf8 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -0,0 +1,179 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using static System.IO.Compression.ZLibNative; + +namespace System.Net.WebSockets.Compression +{ + /// + /// Provides a wrapper around the ZLib compression API. + /// + internal sealed class WebSocketDeflater : IDisposable + { + private ZLibStreamHandle? _stream; + private readonly int _windowBits; + private readonly bool _persisted; + + internal WebSocketDeflater(int windowBits, bool persisted) + { + Debug.Assert(windowBits >= 9 && windowBits <= 15); + + // We use negative window bits in order to produce raw deflate data + _windowBits = -windowBits; + _persisted = persisted; + } + + public void Dispose() => _stream?.Dispose(); + + public void Deflate( ReadOnlySpan payload, Span output, bool continuation, bool endOfMessage, + out int consumed, out int written, out bool needsMoreOutput) + { + Debug.Assert(!continuation || _stream is not null, "Invalid state. The stream should not be null in continuations."); + + if (_stream is null) + { + Initialize(); + } + + Deflate(payload, output, out consumed, out written, out needsMoreOutput); + if (needsMoreOutput) + { + return; + } + + // See comment by Mark Adler https://github.com/madler/zlib/issues/149#issuecomment-225237457 + // At that point there will be at most a few bits left to write. + // Then call deflate() with Z_FULL_FLUSH and no more input and at least six bytes of available output. + written += Flush(output.Slice(written), out needsMoreOutput); + + if (!needsMoreOutput) + { + Debug.Assert(output.Slice(written - WebSocketInflater.FlushMarkerLength, WebSocketInflater.FlushMarkerLength) + .EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker."); + + if (endOfMessage) + { + // As per RFC we need to remove the flush markers + written -= WebSocketInflater.FlushMarkerLength; + } + + if (endOfMessage && !_persisted) + { + _stream.Dispose(); + _stream = null; + } + } + } + + private unsafe void Deflate(ReadOnlySpan input, Span output, out int consumed, out int written, out bool needsMoreBuffer) + { + Debug.Assert(_stream is not null); + + fixed (byte* fixedInput = input) + fixed (byte* fixedOutput = output) + { + _stream.NextIn = (IntPtr)fixedInput; + _stream.AvailIn = (uint)input.Length; + + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; + + // If flush is set to Z_BLOCK, a deflate block is completed + // and emitted, as for Z_SYNC_FLUSH, but the output + // is not aligned on a byte boundary, and up to seven bits + // of the current block are held to be written as the next byte after + // the next deflate block is completed. + var errorCode = Deflate(_stream, (FlushCode)5/*Z_BLOCK*/); + + consumed = input.Length - (int)_stream.AvailIn; + written = output.Length - (int)_stream.AvailOut; + + needsMoreBuffer = errorCode == ErrorCode.BufError; + } + } + + private unsafe int Flush(Span output, out bool needsMoreBuffer) + { + Debug.Assert(_stream is not null); + Debug.Assert(_stream.AvailIn == 0); + Debug.Assert(output.Length >= 6); + + fixed (byte* fixedOutput = output) + { + _stream.NextIn = IntPtr.Zero; + _stream.AvailIn = 0; + + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; + + ErrorCode errorCode = Deflate(_stream, (FlushCode)3/*Z_FULL_FLUSH*/); + int writtenBytes = output.Length - (int)_stream.AvailOut; + + needsMoreBuffer = errorCode == ErrorCode.BufError; + return writtenBytes; + } + } + + private static ErrorCode Deflate(ZLibStreamHandle stream, FlushCode flushCode) + { + ErrorCode errorCode; + try + { + errorCode = stream.Deflate(flushCode); + } + catch (Exception cause) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + + switch (errorCode) + { + case ErrorCode.Ok: + case ErrorCode.StreamEnd: + return errorCode; + + case ErrorCode.BufError: + return errorCode; // This is a recoverable error + + case ErrorCode.StreamError: + throw new WebSocketException(SR.ZLibErrorInconsistentStream); + + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); + } + } + + [MemberNotNull(nameof(_stream))] + private void Initialize() + { + Debug.Assert(_stream is null); + + var compressionLevel = CompressionLevel.DefaultCompression; + var memLevel = Deflate_DefaultMemLevel; + var strategy = CompressionStrategy.DefaultStrategy; + + ErrorCode errorCode; + try + { + errorCode = CreateZLibStreamForDeflate(out _stream, compressionLevel, _windowBits, memLevel, strategy); + } + catch (Exception cause) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + + switch (errorCode) + { + case ErrorCode.Ok: + return; + case ErrorCode.MemError: + throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs new file mode 100644 index 00000000000000..0b4f18ce69940f --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -0,0 +1,230 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; +using static System.IO.Compression.ZLibNative; + +namespace System.Net.WebSockets.Compression +{ + /// + /// Provides a wrapper around the ZLib decompression API. + /// + internal sealed class WebSocketInflater : IDisposable + { + internal const int FlushMarkerLength = 4; + internal static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; + + private ZLibStreamHandle? _stream; + private readonly int _windowBits; + private readonly bool _persisted; + + /// + /// There is no way of knowing, when decoding data, if the underlying deflater + /// has flushed all outstanding data to consumer other than to provide a buffer + /// and see whether any bytes are written. There are cases when the consumers + /// provide a buffer exactly the size of the uncompressed data and in this case + /// to avoid requiring another read we will use this field. + /// + private byte? _remainingByte; + + /// + /// When the inflater is persisted we need to manually append the flush marker + /// before finishing the decoding. + /// + private bool _needsFlushMarker; + + internal WebSocketInflater(int windowBits, bool persisted) + { + Debug.Assert(windowBits >= 9 && windowBits <= 15); + + // We use negative window bits to instruct deflater to expect raw deflate data + _windowBits = -windowBits; + _persisted = persisted; + } + + public void Dispose() => _stream?.Dispose(); + + public unsafe void Inflate(ReadOnlySpan input, Span output, out int consumed, out int written) + { + if (_stream is null) + { + Initialize(); + } + fixed (byte* fixedInput = &MemoryMarshal.GetReference(input)) + fixed (byte* fixedOutput = &MemoryMarshal.GetReference(output)) + { + _stream.NextIn = (IntPtr)fixedInput; + _stream.AvailIn = (uint)input.Length; + + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; + + Inflate(_stream); + + consumed = input.Length - (int)_stream.AvailIn; + written = output.Length - (int)_stream.AvailOut; + } + + _needsFlushMarker = _persisted; + } + + /// + /// Finishes the decoding by writing any outstanding data to the output. + /// + /// true if the finish completed, false to indicate that there is more outstanding data. + public bool Finish(Span output, out int written) + { + Debug.Assert(_stream is not null); + + if (_needsFlushMarker) + { + Inflate(FlushMarker, output, out var _, out written); + _needsFlushMarker = false; + + if ( written < output.Length || IsFinished(_stream, out _remainingByte) ) + { + OnFinished(); + return true; + } + } + + written = 0; + + if (output.IsEmpty) + { + if (_remainingByte is not null) + { + return false; + } + if (IsFinished(_stream, out _remainingByte)) + { + OnFinished(); + return true; + } + } + else + { + if (_remainingByte is not null) + { + output[0] = _remainingByte.GetValueOrDefault(); + written = 1; + _remainingByte = null; + } + + written += Inflate(_stream, output[written..]); + + if (written < output.Length || IsFinished(_stream, out _remainingByte)) + { + OnFinished(); + return true; + } + } + + return false; + } + + private void OnFinished() + { + Debug.Assert(_stream is not null); + + if (!_persisted) + { + _stream.Dispose(); + _stream = null; + } + } + + private static unsafe bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte) + { + if (stream.AvailIn > 0) + { + remainingByte = null; + return false; + } + + // There is no other way to make sure that we'e consumed all data + // but to try to inflate again with at least one byte of output buffer. + byte b; + if (Inflate(stream, new Span(&b, 1)) == 0) + { + remainingByte = null; + return true; + } + + remainingByte = b; + return false; + } + + private static unsafe int Inflate(ZLibStreamHandle stream, Span destination) + { + fixed (byte* bufPtr = &MemoryMarshal.GetReference(destination)) + { + stream.NextOut = (IntPtr)bufPtr; + stream.AvailOut = (uint)destination.Length; + + Inflate(stream); + return destination.Length - (int)stream.AvailOut; + } + } + + private static void Inflate(ZLibStreamHandle stream) + { + ErrorCode errorCode; + try + { + errorCode = stream.Inflate(FlushCode.NoFlush); + } + catch (Exception cause) // could not load the Zlib DLL correctly + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + switch (errorCode) + { + case ErrorCode.Ok: // progress has been made inflating + case ErrorCode.StreamEnd: // The end of the input stream has been reached + case ErrorCode.BufError: // No room in the output buffer - inflate() can be called again with more space to continue + break; + + case ErrorCode.MemError: // Not enough memory to complete the operation + throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + + case ErrorCode.DataError: // The input data was corrupted (input stream not conforming to the zlib format or incorrect check value) + throw new WebSocketException(SR.ZLibUnsupportedCompression); + + case ErrorCode.StreamError: //the stream structure was inconsistent (for example if next_in or next_out was NULL), + throw new WebSocketException(SR.ZLibErrorInconsistentStream); + + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); + } + } + + [MemberNotNull(nameof(_stream))] + private void Initialize() + { + Debug.Assert(_stream is null); + + ErrorCode error; + try + { + error = CreateZLibStreamForInflate(out _stream, _windowBits); + } + catch (Exception exception) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); + } + + switch (error) + { + case ErrorCode.Ok: + return; + case ErrorCode.MemError: + throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)error)); + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 749cd9dad0f034..c290a5cf336362 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; +using System.Net.WebSockets.Compression; using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -28,14 +29,11 @@ internal sealed partial class ManagedWebSocket : WebSocket { /// Creates a from a connected to a websocket endpoint. /// The connected Stream. - /// true if this is the server-side of the connection; false if this is the client-side of the connection. - /// The agreed upon subprotocol for the connection. - /// The interval to use for keep-alive pings. + /// The options with which the websocket must be created. /// The created instance. - public static ManagedWebSocket CreateFromConnectedStream( - Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) + public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocketCreationOptions options) { - return new ManagedWebSocket(stream, isServer, subprotocol, keepAliveInterval); + return new ManagedWebSocket(stream, options); } /// Thread-safe random number generator used to generate masks for each send. @@ -152,10 +150,7 @@ public static ManagedWebSocket CreateFromConnectedStream( /// private object ReceiveAsyncLock => _utf8TextState; // some object, as we're simply lock'ing on it - /// - /// Indicates whether compression is enabled for the receiving part of the websocket. - /// - private readonly bool _inflateEnabled; + private readonly WebSocketInflater? _inflater; private byte[]? _inflateBuffer; /// @@ -174,17 +169,11 @@ public static ManagedWebSocket CreateFromConnectedStream( /// private bool _inflateFinished = true; - /// - /// Indicates whether compression is enabled for the sending part of the websocket. - /// - private readonly bool _deflateEnabled; + private readonly WebSocketDeflater? _deflater; + private byte[]? _deflateBuffer; + private int _deflateBufferPosition; - /// Initializes the websocket. - /// The connected Stream. - /// true if this is the server-side of the connection; false if this is the client-side of the connection. - /// The agreed upon subprotocol for the connection. - /// The interval to use for keep-alive pings. - private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) + private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) { Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null"); Debug.Assert(ReceiveAsyncLock != null, $"Expected {nameof(ReceiveAsyncLock)} to be non-null"); @@ -193,13 +182,23 @@ private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Time Debug.Assert(stream != null, $"Expected non-null stream"); Debug.Assert(stream.CanRead, $"Expected readable stream"); Debug.Assert(stream.CanWrite, $"Expected writeable stream"); - Debug.Assert(keepAliveInterval == Timeout.InfiniteTimeSpan || keepAliveInterval >= TimeSpan.Zero, $"Invalid keepalive interval: {keepAliveInterval}"); _stream = stream; - _isServer = isServer; - _subprotocol = subprotocol; - _inflateEnabled = false; - _deflateEnabled = false; + _isServer = options.IsServer; + _subprotocol = options.SubProtocol; + + var deflateOptions = options.DeflateOptions; + + if (deflateOptions is not null) + { + _deflater = options.IsServer ? + new WebSocketDeflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover) : + new WebSocketDeflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover); + + _inflater = options.IsServer ? + new WebSocketInflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover) : + new WebSocketInflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover); + } // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and // control payloads (at most 125 bytes). Message payloads are read directly into the buffer @@ -229,7 +228,7 @@ private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Time // Now that we're opened, initiate the keep alive timer to send periodic pings. // We use a weak reference from the timer to the web socket to avoid a cycle // that could keep the web socket rooted in erroneous cases. - if (keepAliveInterval > TimeSpan.Zero) + if (options.KeepAliveInterval > TimeSpan.Zero) { _keepAliveTimer = new Timer(static s => { @@ -238,7 +237,7 @@ private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Time { thisRef.SendKeepAliveFrameAsync(); } - }, new WeakReference(this), keepAliveInterval, keepAliveInterval); + }, new WeakReference(this), options.KeepAliveInterval, options.KeepAliveInterval); } } @@ -555,9 +554,9 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read { try { - if (_deflateEnabled) + if (_deflater is not null && !payloadBuffer.IsEmpty) { - payloadBuffer = Deflate(payloadBuffer); + payloadBuffer = Deflate(payloadBuffer, opcode == MessageOpcode.Continuation, endOfMessage); } // Ensure we have a _sendBuffer. @@ -599,37 +598,74 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read } finally { - if (_deflateEnabled) + if (_deflater is not null) { ReleaseDeflateBuffer(); } } } - private ReadOnlySpan Deflate(ReadOnlySpan payload) + private ReadOnlySpan Deflate(ReadOnlySpan payload, bool continuation, bool endOfMessage) { - // This function assumes that we're going to use a single buffer. - throw new NotImplementedException(); + Debug.Assert(_deflater is not null); + Debug.Assert(_deflateBuffer is null); + + _deflateBuffer = ArrayPool.Shared.Rent(Math.Min(payload.Length, 1_000_000)); + _deflateBufferPosition = 0; + + while (true) + { + _deflater.Deflate(payload, _deflateBuffer.AsSpan(_deflateBufferPosition), continuation, endOfMessage, + out int consumed, out int written, out bool needsMoreOutput); + _deflateBufferPosition += written; + + if (!needsMoreOutput) + { + break; + } + + payload = payload.Slice(consumed); + + // Rent a 30% bigger buffer + byte[] newBuffer = ArrayPool.Shared.Rent((int)(_deflateBuffer.Length * 1.3)); + _deflateBuffer.AsSpan(0, _deflateBufferPosition).CopyTo(newBuffer); + ArrayPool.Shared.Return(_deflateBuffer); + _deflateBuffer = newBuffer; + } + + return new ReadOnlySpan(_deflateBuffer, 0, _deflateBufferPosition); } private void ReleaseDeflateBuffer() { - throw new NotImplementedException(); + if (_deflateBuffer is not null) + { + ArrayPool.Shared.Return(_deflateBuffer); + _deflateBuffer = null; + } } - private bool Inflate(Span output, out int bytesWritten) + private void Inflate(Span output, bool finish, out int bytesWritten) { - int consumed = 42; + Debug.Assert(_inflater is not null); + + _inflater.Inflate(new ReadOnlySpan(_inflateBuffer, _inflateBufferPosition, _inflateBufferAvailable), output, + out int consumed, out bytesWritten); _inflateBufferPosition += consumed; _inflateBufferAvailable -= consumed; + _inflateFinished = false; if (_inflateBufferAvailable == 0) { ReleaseInflateBuffer(); - } - throw new NotImplementedException(); + if (finish) + { + _inflateFinished = _inflater.Finish(output.Slice(bytesWritten), out int byteCount); + bytesWritten += byteCount; + } + } } [MemberNotNull(nameof(_inflateBuffer))] @@ -637,8 +673,6 @@ private void RentInflateBuffer(long payloadLength) { _inflateBufferPosition = 0; _inflateBuffer = ArrayPool.Shared.Rent((int)Math.Min(payloadLength, 1_000_000)); - - throw new NotImplementedException(); } private void ReleaseInflateBuffer() @@ -646,6 +680,7 @@ private void ReleaseInflateBuffer() if (_inflateBuffer is not null) { ArrayPool.Shared.Return(_inflateBuffer); + _inflateBuffer = null; } } @@ -918,7 +953,8 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo { // In case of compression totalBytesReceived should actually represent how much we've // inflated, rather than how much we've read from the stream. - _inflateFinished = Inflate(payloadBuffer.Span, out totalBytesReceived); + Inflate(payloadBuffer.Span, + finish: header.PayloadLength == 0, out totalBytesReceived); } // If this a text message, validate that it contains valid UTF8. @@ -1212,10 +1248,10 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( return SR.net_Websockets_ReservedBitsSet; } - if (header.Compressed && !_inflateEnabled) + if (header.Compressed && _inflater is null) { resultHeader = default; - return "TODO"; + return SR.net_Websockets_PerMessageCompressedFlagWhenNotEnabled; } if (masked) @@ -1241,6 +1277,12 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( resultHeader = default; return SR.net_Websockets_ContinuationFromFinalFrame; } + if (header.Compressed) + { + // Must not mark continuations as compressed + resultHeader = default; + return SR.net_Websockets_PerMessageCompressedFlagInContinuation; + } break; case MessageOpcode.Binary: diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index 3bd6835a16f1d4..0cdb4d3ab67616 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -157,7 +157,29 @@ public static WebSocket CreateFromStream(Stream stream, bool isServer, string? s 0)); } - return ManagedWebSocket.CreateFromConnectedStream(stream, isServer, subProtocol, keepAliveInterval); + return ManagedWebSocket.CreateFromConnectedStream(stream, new WebSocketCreationOptions + { + IsServer = isServer, + SubProtocol = subProtocol, + KeepAliveInterval = keepAliveInterval + }); + } + + /// Creates a that operates on a representing a web socket connection. + /// The for the connection. + /// The options with which the websocket must be created. + public static WebSocket CreateFromStream(Stream stream, WebSocketCreationOptions options) + { + if (stream is null) + throw new ArgumentNullException(nameof(stream)); + + if (options is null) + throw new ArgumentNullException(nameof(options)); + + if (!stream.CanRead || !stream.CanWrite) + throw new ArgumentException(!stream.CanRead ? SR.NotReadableStream : SR.NotWriteableStream, nameof(stream)); + + return ManagedWebSocket.CreateFromConnectedStream(stream, options); } [EditorBrowsable(EditorBrowsableState.Never)] @@ -209,8 +231,11 @@ public static WebSocket CreateClientWebSocket(Stream innerStream, // Ignore useZeroMaskingKey. ManagedWebSocket doesn't currently support that debugging option. // Ignore internalBuffer. ManagedWebSocket uses its own small buffer for headers/control messages. - - return ManagedWebSocket.CreateFromConnectedStream(innerStream, false, subProtocol, keepAliveInterval); + return ManagedWebSocket.CreateFromConnectedStream(innerStream, new WebSocketCreationOptions + { + SubProtocol = subProtocol, + KeepAliveInterval = keepAliveInterval + }); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs new file mode 100644 index 00000000000000..cfea9eece572db --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; + +namespace System.Net.WebSockets +{ + public sealed class WebSocketCreationOptions + { + private string? _subProtocol; + private TimeSpan _keepAliveInterval; + + /// + /// Defines if this websocket is the server-side of the connection. The default value is false. + /// + public bool IsServer { get; set; } + + /// + /// The agreed upon sub-protocol that was used when creating the connection. + /// + public string? SubProtocol + { + get => _subProtocol; + set + { + if (value is not null) + { + WebSocketValidate.ValidateSubprotocol(value); + } + _subProtocol = value; + } + } + + /// + /// The keep-alive interval to use, or or to disable keep-alives. + /// The default is . + /// + public TimeSpan KeepAliveInterval + { + get => _keepAliveInterval; + set + { + if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(KeepAliveInterval), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0)); + } + _keepAliveInterval = value; + } + } + + /// + /// The agreed upon options for per message deflate. + /// + public WebSocketDeflateOptions? DeflateOptions { get; set; } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs new file mode 100644 index 00000000000000..6ddb82c1b0c7ed --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.WebSockets +{ + /// + /// Options to enable per-message deflate compression for . + /// + /// + /// Although the WebSocket spec allows window bits from 8 to 15, the current implementation doesn't support 8 bits. + /// For more information refer to the zlib manual https://zlib.net/manual.html. + /// + public sealed class WebSocketDeflateOptions + { + private int _clientMaxWindowBits = 15; + private int _serverMaxWindowBits = 15; + + /// + /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the client context. + /// Must be a value between 9 and 15. The default is 15. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.2.2 + public int ClientMaxWindowBits + { + get => _clientMaxWindowBits; + set + { + // The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 + // and https://zlib.net/manual.html). Quote from the manual "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". + // We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream + // and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. + if (value < 9 || value > 15) + { + throw new ArgumentOutOfRangeException(nameof(ClientMaxWindowBits), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 9, 15)); + } + _clientMaxWindowBits = value; + } + } + + /// + /// When true the client-side of the connection indicates that it will persist the deflate context accross messages. + /// The default is true. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.1.2 + public bool ClientContextTakeover { get; set; } = true; + + /// + /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the server context. + /// Must be a value between 9 and 15. The default is 15. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.2.1 + public int ServerMaxWindowBits + { + get => _serverMaxWindowBits; + set + { + // The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 + // and https://zlib.net/manual.html). Quote from the manual "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". + // We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream + // and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. + if (value < 9 || value > 15) + { + throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 9, 15)); + } + _serverMaxWindowBits = value; + } + } + + /// + /// When true the server-side of the connection indicates that it will persist the deflate context accross messages. + /// The default is true. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.1.1 + public bool ServerContextTakeover { get; set; } = true; + } +} From 28259731db1eb832abb28c857e8155670651a3f8 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 5 Mar 2021 10:24:54 +0200 Subject: [PATCH 05/47] Added api changes and implementation to ClientWebSocket to light up deflate compression. --- .../ref/System.Net.WebSockets.Client.cs | 2 + .../src/Resources/Strings.resx | 66 +++++---- .../src/System.Net.WebSockets.Client.csproj | 2 + .../ClientWebSocketOptions.cs | 7 + .../ClientWebSocketDeflateConstants.cs | 16 ++ .../Net/WebSockets/ClientWebSocketOptions.cs | 3 + .../Net/WebSockets/WebSocketHandle.Managed.cs | 138 +++++++++++++++++- .../tests/DeflateTests.cs | 99 +++++++++++++ .../tests/LoopbackHelper.cs | 3 +- .../System.Net.WebSockets.Client.Tests.csproj | 1 + 10 files changed, 300 insertions(+), 37 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs index cee3a5170b8625..660a2c5fbe7485 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs @@ -36,6 +36,8 @@ internal ClientWebSocketOptions() { } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] + public System.Net.WebSockets.WebSocketDeflateOptions? DeflateOptions { get { throw null; } set { } } + [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.IWebProxy? Proxy { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.Security.RemoteCertificateValidationCallback? RemoteCertificateValidationCallback { get { throw null; } set { } } diff --git a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx index 3259b86c99fcba..5649bc52e7653b 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx @@ -1,16 +1,17 @@ - - @@ -193,8 +194,11 @@ Connection was aborted. - + WebSocket binary type '{0}' not supported. - - + + + The WebSocket failed to negotiate max {0} window bits. The client requested {1} but the server responded with {2}. + + \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj index b74f3d8962be63..e84ea02f895ba8 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj +++ b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj @@ -6,6 +6,7 @@ + @@ -37,6 +38,7 @@ + diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs index 85b0f025b46502..2ed5c527421c9f 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs @@ -100,6 +100,13 @@ public TimeSpan KeepAliveInterval set => throw new PlatformNotSupportedException(); } + [UnsupportedOSPlatform("browser")] + public WebSocketDeflateOptions? DeflateOptions + { + get => throw new PlatformNotSupportedException(); + set => throw new PlatformNotSupportedException(); + } + [UnsupportedOSPlatform("browser")] public void SetBuffer(int receiveBufferSize, int sendBufferSize) { diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs new file mode 100644 index 00000000000000..75a7a3c7ea9161 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.WebSockets +{ + internal static class ClientWebSocketDeflateConstants + { + public const string Extension = "permessage-deflate"; + + public const string ClientMaxWindowBits = "client_max_window_bits"; + public const string ClientNoContextTakeover = "client_no_context_takeover"; + + public const string ServerMaxWindowBits = "server_max_window_bits"; + public const string ServerNoContextTakeover = "server_no_context_takeover"; + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index a7609a0ff09057..573c3eb8325b70 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -148,6 +148,9 @@ public TimeSpan KeepAliveInterval } } + [UnsupportedOSPlatform("browser")] + public WebSocketDeflateOptions? DeflateOptions { get; set; } + internal int ReceiveBufferSize => _receiveBufferSize; internal ArraySegment? Buffer => _buffer; diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index d61f368e7aae8e..f7eca1af856e66 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.IO; using System.Net.Http; using System.Net.Http.Headers; @@ -183,6 +184,26 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } } + // Because deflate options are negotiated we need a new object + WebSocketDeflateOptions? deflateOptions = null; + + if (options.DeflateOptions is not null && response.Headers.TryGetValues(HttpKnownHeaderNames.SecWebSocketExtensions, out IEnumerable? extensions)) + { + foreach (ReadOnlySpan extension in extensions) + { + if (extension.TrimStart().StartsWith(ClientWebSocketDeflateConstants.Extension)) + { + deflateOptions = ParseDeflateOptions(extension, options.DeflateOptions); + break; + } + } + } + + // Store the negotiated deflate options in the original options, because + // otherwise there is now way of clients to actually check whether we are using + // per message deflate or not. + options.DeflateOptions = deflateOptions; + if (response.Content is null) { throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); @@ -192,11 +213,13 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli Stream connectedStream = response.Content.ReadAsStream(); Debug.Assert(connectedStream.CanWrite); Debug.Assert(connectedStream.CanRead); - WebSocket = WebSocket.CreateFromStream( - connectedStream, - isServer: false, - subprotocol, - options.KeepAliveInterval); + WebSocket = WebSocket.CreateFromStream(connectedStream, new WebSocketCreationOptions + { + IsServer = false, + SubProtocol = subprotocol, + KeepAliveInterval = options.KeepAliveInterval, + DeflateOptions = deflateOptions, + }); } catch (Exception exc) { @@ -226,6 +249,72 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } } + private static WebSocketDeflateOptions ParseDeflateOptions(ReadOnlySpan extension, WebSocketDeflateOptions original) + { + var options = new WebSocketDeflateOptions(); + + while (true) + { + int end = extension.IndexOf(';'); + ReadOnlySpan value = (end >= 0 ? extension[..end] : extension).Trim(); + + if (!value.IsEmpty) + { + if (value.Equals(ClientWebSocketDeflateConstants.ClientNoContextTakeover, StringComparison.Ordinal)) + { + options.ClientContextTakeover = false; + } + else if (value.Equals(ClientWebSocketDeflateConstants.ServerNoContextTakeover, StringComparison.Ordinal)) + { + options.ServerContextTakeover = false; + } + else if (value.StartsWith(ClientWebSocketDeflateConstants.ClientMaxWindowBits, StringComparison.Ordinal)) + { + options.ClientMaxWindowBits = ParseWindowBits(value); + } + else if (value.StartsWith(ClientWebSocketDeflateConstants.ServerMaxWindowBits, StringComparison.Ordinal)) + { + options.ServerMaxWindowBits = ParseWindowBits(value); + } + + static int ParseWindowBits(ReadOnlySpan value) + { + var startIndex = value.IndexOf('='); + + if (startIndex < 0 || + !int.TryParse(value.Slice(startIndex + 1), NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) || + windowBits < 9 || + windowBits > 15) + { + throw new WebSocketException(WebSocketError.HeaderError, + SR.Format(SR.net_WebSockets_InvalidResponseHeader, ClientWebSocketDeflateConstants.Extension, value.ToString())); + } + + return windowBits; + } + } + + if (end < 0) + break; + + extension = extension[(end + 1)..]; + } + + if (options.ClientMaxWindowBits > original.ClientMaxWindowBits) + { + throw new WebSocketException(string.Format(SR.net_WebSockets_WindowBitsNegotiationFailure, + "client", original.ClientMaxWindowBits, options.ClientMaxWindowBits)); + } + + if (options.ServerMaxWindowBits > original.ServerMaxWindowBits) + { + throw new WebSocketException(string.Format(SR.net_WebSockets_WindowBitsNegotiationFailure, + "server", original.ServerMaxWindowBits, options.ServerMaxWindowBits)); + } + + return options; + } + /// Adds the necessary headers for the web socket request. /// The request to which the headers should be added. /// The generated security key to send in the Sec-WebSocket-Key header. @@ -240,6 +329,45 @@ private static void AddWebSocketHeaders(HttpRequestMessage request, string secKe { request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketProtocol, string.Join(", ", options.RequestedSubProtocols)); } + if (options.DeflateOptions is not null) + { + request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketExtensions, string.Join("; ", GetDeflateOptions(options.DeflateOptions))); + + static IEnumerable GetDeflateOptions(WebSocketDeflateOptions options) + { + yield return ClientWebSocketDeflateConstants.Extension; + + if (options.ClientMaxWindowBits != 15) + { + yield return $"{ClientWebSocketDeflateConstants.ClientMaxWindowBits}={options.ClientMaxWindowBits}"; + } + else + { + // Advertise that we support this option + yield return ClientWebSocketDeflateConstants.ClientMaxWindowBits; + } + + if (!options.ClientContextTakeover) + { + yield return ClientWebSocketDeflateConstants.ClientNoContextTakeover; + } + + if (options.ServerMaxWindowBits != 15) + { + yield return $"{ClientWebSocketDeflateConstants.ServerMaxWindowBits}={options.ServerMaxWindowBits}"; + } + else + { + // Advertise that we support this option + yield return ClientWebSocketDeflateConstants.ServerMaxWindowBits; + } + + if (!options.ServerContextTakeover) + { + yield return ClientWebSocketDeflateConstants.ServerNoContextTakeover; + } + } + } } /// diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs new file mode 100644 index 00000000000000..a182830426c307 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -0,0 +1,99 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.Net.Test.Common; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.WebSockets.Client.Tests +{ + public class DeflateTests : ClientWebSocketTestBase + { + public DeflateTests(ITestOutputHelper output) : base(output) + { + } + + [ConditionalTheory(nameof(WebSocketsSupported))] + [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/42852", TestPlatforms.Browser)] + [InlineData(15, true, 15, true, "permessage-deflate; client_max_window_bits; server_max_window_bits")] + [InlineData(14, true, 15, true, "permessage-deflate; client_max_window_bits=14; server_max_window_bits")] + [InlineData(15, true, 14, true, "permessage-deflate; client_max_window_bits; server_max_window_bits=14")] + [InlineData(10, true, 11, true, "permessage-deflate; client_max_window_bits=10; server_max_window_bits=11")] + [InlineData(15, false, 15, true, "permessage-deflate; client_max_window_bits; client_no_context_takeover; server_max_window_bits")] + [InlineData(15, true, 15, false, "permessage-deflate; client_max_window_bits; server_max_window_bits; server_no_context_takeover")] + public async Task PerMessageDeflateHeaders(int clientWindowBits, bool clientContextTakeover, + int serverWindowBits, bool serverContextTakover, + string expected) + { + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using var client = new ClientWebSocket(); + using var cancellation = new CancellationTokenSource(TimeOutMilliseconds); + + client.Options.DeflateOptions = new WebSocketDeflateOptions + { + ClientMaxWindowBits = clientWindowBits, + ClientContextTakeover = clientContextTakeover, + ServerMaxWindowBits = serverWindowBits, + ServerContextTakeover = serverContextTakover + }; + + await client.ConnectAsync(uri, cancellation.Token); + + Assert.NotNull(client.Options.DeflateOptions); + Assert.Equal(clientWindowBits - 1, client.Options.DeflateOptions.ClientMaxWindowBits); + Assert.Equal(clientContextTakeover, client.Options.DeflateOptions.ClientContextTakeover); + Assert.Equal(serverWindowBits - 1, client.Options.DeflateOptions.ServerMaxWindowBits); + Assert.Equal(serverContextTakover, client.Options.DeflateOptions.ServerContextTakeover); + }, server => server.AcceptConnectionAsync(async connection => + { + var extensionsReply = CreateDeflateOptionsHeader(new WebSocketDeflateOptions + { + ClientMaxWindowBits = clientWindowBits - 1, + ClientContextTakeover = clientContextTakeover, + ServerMaxWindowBits = serverWindowBits - 1, + ServerContextTakeover = serverContextTakover + }); + Dictionary headers = await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply); + Assert.NotNull(headers); + Assert.True(headers.TryGetValue("Sec-WebSocket-Extensions", out string extensions)); + Assert.Equal(expected, extensions); + }), new LoopbackServer.Options { WebSocketEndpoint = true }); + } + + private static string CreateDeflateOptionsHeader(WebSocketDeflateOptions options) + { + var builder = new StringBuilder(); + builder.Append("permessage-deflate"); + + if (options.ClientMaxWindowBits != 15) + { + builder.Append("; client_max_window_bits=").Append(options.ClientMaxWindowBits); + } + + if (!options.ClientContextTakeover) + { + builder.Append("; client_no_context_takeover"); + } + + if (options.ServerMaxWindowBits != 15) + { + builder.Append("; server_max_window_bits=").Append(options.ServerMaxWindowBits); + } + + if (!options.ServerContextTakeover) + { + builder.Append("; server_no_context_takeover"); + } + + return builder.ToString(); + } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs index 5726326c6ab8fa..48d167b072f781 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs @@ -11,7 +11,7 @@ namespace System.Net.WebSockets.Client.Tests { public static class LoopbackHelper { - public static async Task> WebSocketHandshakeAsync(LoopbackServer.Connection connection) + public static async Task> WebSocketHandshakeAsync(LoopbackServer.Connection connection, string? extensions = null) { string serverResponse = null; List headers = await connection.ReadRequestHeaderAsync().ConfigureAwait(false); @@ -34,6 +34,7 @@ public static async Task> WebSocketHandshakeAsync(Loo "Content-Length: 0\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + + (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") + "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n"; } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj index a1323fa83db1ed..21ba2a12dd5247 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj +++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj @@ -46,6 +46,7 @@ + From 67a5d7204ce69de08fba41540f183e234e1d44e4 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 5 Mar 2021 14:56:30 +0200 Subject: [PATCH 06/47] Added tests and fixed a few deflate bugs. --- .../Compression/WebSocketDeflater.cs | 41 ++- .../System/Net/WebSockets/ManagedWebSocket.cs | 126 +++++-- .../tests/System.Net.WebSockets.Tests.csproj | 5 +- .../tests/WebSocketDeflateOptionsTests.cs | 48 +++ .../tests/WebSocketDeflateTests.cs | 341 ++++++++++++++++++ .../tests/WebSocketTestStream.cs | 222 ++++++++++++ 6 files changed, 728 insertions(+), 55 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs create mode 100644 src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs create mode 100644 src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index aa053c4a2cccf8..b993e0e52777ca 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -28,7 +28,7 @@ internal WebSocketDeflater(int windowBits, bool persisted) public void Dispose() => _stream?.Dispose(); - public void Deflate( ReadOnlySpan payload, Span output, bool continuation, bool endOfMessage, + public void Deflate(ReadOnlySpan payload, Span output, bool continuation, bool endOfMessage, out int consumed, out int written, out bool needsMoreOutput) { Debug.Assert(!continuation || _stream is not null, "Invalid state. The stream should not be null in continuations."); @@ -47,24 +47,25 @@ public void Deflate( ReadOnlySpan payload, Span output, bool continu // See comment by Mark Adler https://github.com/madler/zlib/issues/149#issuecomment-225237457 // At that point there will be at most a few bits left to write. // Then call deflate() with Z_FULL_FLUSH and no more input and at least six bytes of available output. - written += Flush(output.Slice(written), out needsMoreOutput); + if (output.Length - written < 6) + { + needsMoreOutput = true; + return; + } + written += Flush(output.Slice(written)); + Debug.Assert(output.Slice(written - WebSocketInflater.FlushMarkerLength, WebSocketInflater.FlushMarkerLength) + .EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker."); - if (!needsMoreOutput) + if (endOfMessage) { - Debug.Assert(output.Slice(written - WebSocketInflater.FlushMarkerLength, WebSocketInflater.FlushMarkerLength) - .EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker."); - - if (endOfMessage) - { - // As per RFC we need to remove the flush markers - written -= WebSocketInflater.FlushMarkerLength; - } - - if (endOfMessage && !_persisted) - { - _stream.Dispose(); - _stream = null; - } + // As per RFC we need to remove the flush markers + written -= WebSocketInflater.FlushMarkerLength; + } + + if (endOfMessage && !_persisted) + { + _stream.Dispose(); + _stream = null; } } @@ -95,11 +96,11 @@ private unsafe void Deflate(ReadOnlySpan input, Span output, out int } } - private unsafe int Flush(Span output, out bool needsMoreBuffer) + private unsafe int Flush(Span output) { Debug.Assert(_stream is not null); Debug.Assert(_stream.AvailIn == 0); - Debug.Assert(output.Length >= 6); + Debug.Assert(output.Length >= 6, "We neede at least 6 bytes guarantee the completion of the deflate block."); fixed (byte* fixedOutput = output) { @@ -111,8 +112,8 @@ private unsafe int Flush(Span output, out bool needsMoreBuffer) ErrorCode errorCode = Deflate(_stream, (FlushCode)3/*Z_FULL_FLUSH*/); int writtenBytes = output.Length - (int)_stream.AvailOut; + Debug.Assert(errorCode == ErrorCode.Ok); - needsMoreBuffer = errorCode == ErrorCode.BufError; return writtenBytes; } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index c290a5cf336362..ae26651b1aa666 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -112,7 +112,7 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke /// remaining to be received for that header. As a result, between fragments, the payload /// length in this header should be 0. /// - private MessageHeader _lastReceiveHeader = new MessageHeader { Opcode = MessageOpcode.Text, Fin = true }; + private MessageHeader _lastReceiveHeader = new MessageHeader { Opcode = MessageOpcode.Text, Fin = true, Processed = true }; /// The offset of the next available byte in the _receiveBuffer. private int _receiveBufferOffset; /// The number of bytes available in the _receiveBuffer. @@ -163,12 +163,6 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke /// private int _inflateBufferAvailable; - /// - /// Because of how the underlying zlib inflater works, we may have consumed the entire payload, but still - /// might have data left in the native component, we need to track if inflating has really finished. - /// - private bool _inflateFinished = true; - private readonly WebSocketDeflater? _deflater; private byte[]? _deflateBuffer; private int _deflateBufferPosition; @@ -570,13 +564,13 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read { // The server doesn't send a mask, so the mask offset returned by WriteHeader // is actually the end of the header. - headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false); + headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false, compressed: _inflater is not null); } else { // We need to know where the mask starts so that we can use the mask to manipulate the payload data, // and we need to know the total length for sending it on the wire. - maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true); + maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true, compressed: _inflater is not null); headerLength = maskOffset.GetValueOrDefault() + MaskLength; } @@ -610,7 +604,15 @@ private ReadOnlySpan Deflate(ReadOnlySpan payload, bool continuation Debug.Assert(_deflater is not null); Debug.Assert(_deflateBuffer is null); - _deflateBuffer = ArrayPool.Shared.Rent(Math.Min(payload.Length, 1_000_000)); + // Do not try to rent more than 1MB initially, because it will actually allocate + // instead of renting. Be optimistic that what we're sending is actually going to fit. + const int MaxInitialBufferLength = 1024 * 1024; + + // For small payloads there might actually be overhead in the compression and the resulting + // output might be larger than the payload. This is why we rent at least 4KB initially. + const int MinInitialBufferLength = 4 * 1024; + + _deflateBuffer = ArrayPool.Shared.Rent(Math.Min(Math.Max(payload.Length, MinInitialBufferLength), MaxInitialBufferLength)); _deflateBufferPosition = 0; while (true) @@ -645,32 +647,56 @@ private void ReleaseDeflateBuffer() } } - private void Inflate(Span output, bool finish, out int bytesWritten) + /// + /// Inflates the last receive payload into the provided buffer. + /// + /// true if inflate operation finished and no more data needs to be written + private bool Inflate(Span output, bool finish, out int written) { Debug.Assert(_inflater is not null); - _inflater.Inflate(new ReadOnlySpan(_inflateBuffer, _inflateBufferPosition, _inflateBufferAvailable), output, - out int consumed, out bytesWritten); + if (_inflateBufferAvailable > 0) + { + _inflater.Inflate(input: new ReadOnlySpan(_inflateBuffer, _inflateBufferPosition, _inflateBufferAvailable), + output, out int consumed, out written); - _inflateBufferPosition += consumed; - _inflateBufferAvailable -= consumed; - _inflateFinished = false; + _inflateBufferPosition += consumed; + _inflateBufferAvailable -= consumed; + } + else + { + written = 0; + } - if (_inflateBufferAvailable == 0) + if (_inflateBufferAvailable <= 0) { ReleaseInflateBuffer(); if (finish) { - _inflateFinished = _inflater.Finish(output.Slice(bytesWritten), out int byteCount); - bytesWritten += byteCount; + if (_inflater.Finish(output.Slice(written), out int byteCount)) + { + _inflateBufferAvailable = 0; + } + else + { + // Setting this to -1 instructs the receive operation to not try and + // read any more data from the stream. + _inflateBufferAvailable = -1; + } + + written += byteCount; } } + + return _inflateBufferAvailable == 0; } [MemberNotNull(nameof(_inflateBuffer))] private void RentInflateBuffer(long payloadLength) { + Debug.Assert(_inflateBuffer is null); + _inflateBufferPosition = 0; _inflateBuffer = ArrayPool.Shared.Rent((int)Math.Min(payloadLength, 1_000_000)); } @@ -715,7 +741,7 @@ private void SendKeepAliveFrameAsync() } } - private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask) + private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask, bool compressed) { // Client header format: // 1 bit - FIN - 1 if this is the final fragment in the message (it could be the only fragment), otherwise 0 @@ -745,6 +771,11 @@ private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnly { sendBuffer[0] |= 0x80; // 1 bit for FIN } + if (compressed && opcode != MessageOpcode.Continuation) + { + // Per-Message Deflate flag needs to be set only in the first frame + sendBuffer[0] |= 0b_0100_0000; + } // Store the payload length. int maskOffset; @@ -818,7 +849,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo // with it. If instead its payload length is zero, then we've completed the processing of // thta message, and we should read the next header. MessageHeader header = _lastReceiveHeader; - if (header.PayloadLength == 0 && _inflateFinished) + if (header.Processed) { if (_receiveBufferCount < (_isServer ? MaxMessageHeaderLength : (MaxMessageHeaderLength - MaskLength))) { @@ -874,13 +905,13 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo Debug.Assert(header.Opcode == MessageOpcode.Binary || header.Opcode == MessageOpcode.Text, $"Unexpected opcode {header.Opcode}"); // If there's no data to read, return an appropriate result. - if ((header.PayloadLength == 0 && _inflateFinished) || payloadBuffer.Length == 0) + if (header.Processed || payloadBuffer.Length == 0) { _lastReceiveHeader = header; return GetReceiveResult( count: 0, messageType: header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - endOfMessage: header.Fin && header.PayloadLength == 0 && _inflateFinished); + endOfMessage: header.EndOfMessage); } // Otherwise, read as much of the payload as we can efficiently, and update the header to reflect how much data @@ -897,7 +928,9 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo { if (_receiveBufferCount > 0) { - int receiveBufferBytesToCopy = Math.Min(payloadBuffer.Length, (int)Math.Min(header.PayloadLength, _receiveBufferCount)); + int receiveBufferBytesToCopy = header.Compressed ? + (int)Math.Min(header.PayloadLength, _receiveBufferCount) : + Math.Min(payloadBuffer.Length, (int)Math.Min(header.PayloadLength, _receiveBufferCount)); Debug.Assert(receiveBufferBytesToCopy > 0); if (header.Compressed) @@ -906,7 +939,6 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo RentInflateBuffer(header.PayloadLength); _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(_inflateBuffer); - _inflateBufferAvailable += receiveBufferBytesToCopy; ConsumeFromBuffer(receiveBufferBytesToCopy); totalBytesReceived += receiveBufferBytesToCopy; Debug.Assert(_receiveBufferCount == 0 || totalBytesReceived == header.PayloadLength); @@ -922,15 +954,19 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo totalBytesReceived == header.PayloadLength); } } + else if (header.Compressed) + { + RentInflateBuffer(header.PayloadLength); + } - // Then read directly into the payload buffer until we've hit a limit. - while (totalBytesReceived < payloadBuffer.Length && - totalBytesReceived < header.PayloadLength) + // Then read directly into the appropriate buffer until we've hit a limit. + int limit = (int)Math.Min(header.Compressed ? _inflateBuffer!.Length : payloadBuffer.Length, header.PayloadLength); + while (totalBytesReceived < limit) { int numBytesRead = await _stream.ReadAsync( header.Compressed ? - _inflateBuffer.AsMemory(totalBytesReceived, (int)Math.Min(_inflateBuffer!.Length, header.PayloadLength) - totalBytesReceived) : - payloadBuffer.Slice(totalBytesReceived, (int)Math.Min(payloadBuffer.Length, header.PayloadLength) - totalBytesReceived), + _inflateBuffer.AsMemory(totalBytesReceived, limit - totalBytesReceived) : + payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived), cancellationToken).ConfigureAwait(false); if (numBytesRead <= 0) { @@ -946,20 +982,31 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo _inflateBuffer.AsSpan(0, totalBytesReceived) : payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset); } + header.PayloadLength -= totalBytesReceived; } + else + { + totalBytesReceived = _inflateBufferAvailable; + } if (header.Compressed) { // In case of compression totalBytesReceived should actually represent how much we've // inflated, rather than how much we've read from the stream. - Inflate(payloadBuffer.Span, - finish: header.PayloadLength == 0, out totalBytesReceived); + _inflateBufferAvailable = totalBytesReceived; + header.Processed = Inflate(payloadBuffer.Span, + finish: header.Fin && header.PayloadLength == 0, out totalBytesReceived); + } + else + { + // Without compression the frame is processed as soon as we've received everything + header.Processed = header.PayloadLength == 0; } // If this a text message, validate that it contains valid UTF8. if (header.Opcode == MessageOpcode.Text && - !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.Fin && header.PayloadLength == 0, _utf8TextState)) + !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.EndOfMessage, _utf8TextState)) { await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false); } @@ -968,7 +1015,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo return GetReceiveResult( totalBytesReceived, header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - header.Fin && header.PayloadLength == 0 && _inflateFinished); + header.EndOfMessage); } } catch (Exception exc) @@ -1314,6 +1361,7 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( // Return the read header resultHeader = header; + resultHeader.Processed = header.PayloadLength == 0; return null; } @@ -1759,6 +1807,16 @@ private struct MessageHeader internal long PayloadLength; internal bool Compressed; internal int Mask; + + /// + /// Returns if frame has been received and processed. + /// + internal bool Processed { get; set; } + + /// + /// Returns if message has been received and processed. + /// + internal bool EndOfMessage => Fin && Processed && PayloadLength == 0; } } } diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index 7cf0328df31ca8..0b606a12e8c446 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -1,4 +1,4 @@ - + $(NetCoreAppCurrent) @@ -7,6 +7,9 @@ + + + (() => options.ClientMaxWindowBits = 8); + Assert.Throws(() => options.ClientMaxWindowBits = 16); + + options.ClientMaxWindowBits = 14; + Assert.Equal(14, options.ClientMaxWindowBits); + } + + [Fact] + public void ServerMaxWindowBits() + { + WebSocketDeflateOptions options = new(); + Assert.Equal(15, options.ServerMaxWindowBits); + + Assert.Throws(() => options.ServerMaxWindowBits = 8); + Assert.Throws(() => options.ServerMaxWindowBits = 16); + + options.ServerMaxWindowBits = 14; + Assert.Equal(14, options.ServerMaxWindowBits); + } + + [Fact] + public void ContextTakeover() + { + WebSocketDeflateOptions options = new(); + + Assert.True(options.ClientContextTakeover); + Assert.True(options.ServerContextTakeover); + + options.ClientContextTakeover = false; + Assert.False(options.ClientContextTakeover); + + options.ServerContextTakeover = false; + Assert.False(options.ServerContextTakeover); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs new file mode 100644 index 00000000000000..fa83a1043cf9ab --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -0,0 +1,341 @@ +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.WebSockets.Tests +{ + [PlatformSpecific(~TestPlatforms.Browser)] + public class WebSocketDeflateTests + { + private readonly CancellationTokenSource? _cancellation; + + public WebSocketDeflateTests() + { + if (!Debugger.IsAttached) + { + _cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + } + } + + public CancellationToken CancellationToken => _cancellation?.Token ?? default; + + public static IEnumerable SupportedWindowBits + { + get + { + for (var i = 9; i <= 15; ++i) + { + yield return new object[] { i }; + } + } + } + + [Fact] + public async Task HelloWithContextTakeover() + { + WebSocketTestStream stream = new(); + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + using WebSocket websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DeflateOptions = new() + }); + + Memory buffer = new byte[5]; + ValueWebSocketReceiveResult result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.Span)); + + // Because context takeover is set by default if we try to send + // the same message it would take fewer bytes. + stream.Enqueue(0xc1, 0x05, 0xf2, 0x00, 0x11, 0x00, 0x00); + + buffer.Span.Clear(); + result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.Span)); + } + + [Fact] + public async Task HelloWithoutContextTakeover() + { + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DeflateOptions = new() + { + ClientContextTakeover = false + } + }); + + Memory buffer = new byte[5]; + + for (var i = 0; i < 100; ++i) + { + // Without context takeover the message should look the same every time + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + buffer.Span.Clear(); + + ValueWebSocketReceiveResult result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.Span)); + } + } + + [Fact] + public async Task TwoDeflateBlocksInOneMessage() + { + // Two or more DEFLATE blocks may be used in one message. + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DeflateOptions = new() + }); + // The first 3 octets(0xf2 0x48 0x05) and the least significant two + // bits of the 4th octet(0x00) constitute one DEFLATE block with + // "BFINAL" set to 0 and "BTYPE" set to 01 containing "He". The rest of + // the 4th octet contains the header bits with "BFINAL" set to 0 and + // "BTYPE" set to 00, and the 3 padding bits of 0. Together with the + // following 4 octets(0x00 0x00 0xff 0xff), the header bits constitute + // an empty DEFLATE block with no compression. A DEFLATE block + // containing "llo" follows the empty DEFLATE block. + stream.Enqueue(0x41, 0x08, 0xf2, 0x48, 0x05, 0x00, 0x00, 0x00, 0xff, 0xff); + stream.Enqueue(0x80, 0x05, 0xca, 0xc9, 0xc9, 0x07, 0x00); + + Memory buffer = new byte[5]; + ValueWebSocketReceiveResult result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.Equal(2, result.Count); + Assert.False(result.EndOfMessage); + + result = await websocket.ReceiveAsync(buffer.Slice(result.Count), CancellationToken); + + Assert.Equal(3, result.Count); + Assert.True(result.EndOfMessage); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.Span)); + } + + [Theory] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + [InlineData(true, false)] + public async Task Duplex(bool clientContextTakover, bool serverContextTakover) + { + WebSocketTestStream stream = new(); + using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DeflateOptions = new WebSocketDeflateOptions + { + ClientContextTakeover = clientContextTakover, + ServerContextTakeover = serverContextTakover + } + }); + using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DeflateOptions = new WebSocketDeflateOptions + { + ClientContextTakeover = clientContextTakover, + ServerContextTakeover = serverContextTakover + } + }); + + var buffer = new byte[1024]; + + for (var i = 0; i < 10; ++i) + { + string message = $"Sending number {i} from server."; + await SendTextAsync(message, server); + + ValueWebSocketReceiveResult result = await client.ReceiveAsync(buffer.AsMemory(), CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + Assert.Equal(message, Encoding.UTF8.GetString(buffer.AsSpan(0, result.Count))); + } + + for (var i = 0; i < 10; ++i) + { + string message = $"Sending number {i} from client."; + await SendTextAsync(message, client); + + ValueWebSocketReceiveResult result = await server.ReceiveAsync(buffer.AsMemory(), CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + Assert.Equal(message, Encoding.UTF8.GetString(buffer.AsSpan(0, result.Count))); + } + } + + [Theory] + [MemberData(nameof(SupportedWindowBits))] + public async Task LargeMessageSplitInMultipleFrames(int windowBits) + { + WebSocketTestStream stream = new(); + using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DeflateOptions = new() + { + ClientMaxWindowBits = windowBits + } + }); + using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DeflateOptions = new() + { + ClientMaxWindowBits = windowBits + } + }); + + Memory testData = new byte[ushort.MaxValue]; + Memory receivedData = new byte[testData.Length]; + + // Make the data incompressible to make sure that the output is larger than the input + var rng = new Random(0); + rng.NextBytes(testData.Span); + + // Test it a few times with different frame sizes + for (var i = 0; i < 10; ++i) + { + var frameSize = rng.Next(1024, 2048); + var position = 0; + + while (position < testData.Length) + { + var currentFrameSize = Math.Min(frameSize, testData.Length - position); + var eof = position + currentFrameSize == testData.Length; + + await server.SendAsync(testData.Slice(position, currentFrameSize), WebSocketMessageType.Binary, eof, CancellationToken); + position += currentFrameSize; + } + + Assert.True(testData.Length < stream.Remote.Available, "The compressed data should be bigger."); + Assert.Equal(testData.Length, position); + + // Receive the data from the client side + receivedData.Span.Clear(); + position = 0; + + // Intentionally receive with a frame size that is less than what the sender used + frameSize /= 3; + + while (true) + { + int currentFrameSize = Math.Min(frameSize, testData.Length - position); + ValueWebSocketReceiveResult result = await client.ReceiveAsync(receivedData.Slice(position, currentFrameSize), CancellationToken); + + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + position += result.Count; + + if (result.EndOfMessage) + break; + } + + Assert.Equal(0, stream.Remote.Available); + Assert.Equal(testData.Length, position); + Assert.True(testData.Span.SequenceEqual(receivedData.Span)); + } + } + + [Fact] + public async Task WebSocketWithoutDeflateShouldThrowOnCompressedMessage() + { + WebSocketTestStream stream = new(); + + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + using WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions()); + + var exception = await Assert.ThrowsAsync(() => + client.ReceiveAsync(Memory.Empty, CancellationToken).AsTask()); + + Assert.Equal("The WebSocket received compressed frame when compression is not enabled.", exception.Message); + } + + [Fact] + public async Task ReceiveUncompressedMessageWhenCompressionEnabled() + { + // We should be able to handle the situation where even if we have + // deflate compression enabled, uncompressed messages are OK + WebSocketTestStream stream = new(); + WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DeflateOptions = null + }); + WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DeflateOptions = new WebSocketDeflateOptions() + }); + + // Server sends uncompressed + await SendTextAsync("Hello", server); + + // Although client has deflate options, it should still be able + // to handle uncompressed messages. + Assert.Equal("Hello", await ReceiveTextAsync(client)); + + // Client sends compressed, but server compression is disabled and should throw on receive + await SendTextAsync("Hello back", client); + var exception = await Assert.ThrowsAsync(() => ReceiveTextAsync(server)); + Assert.Equal("The WebSocket received compressed frame when compression is not enabled.", exception.Message); + Assert.Equal(WebSocketState.Aborted, server.State); + + // The client should close if we try to receive + ValueWebSocketReceiveResult result = await client.ReceiveAsync(Memory.Empty, CancellationToken); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.ProtocolError, client.CloseStatus); + Assert.Equal(WebSocketState.CloseReceived, client.State); + } + + [Fact] + public async Task ReceiveInvalidCompressedData() + { + WebSocketTestStream stream = new(); + WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DeflateOptions = new WebSocketDeflateOptions() + }); + + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + Assert.Equal("Hello", await ReceiveTextAsync(client)); + + stream.Enqueue(0xc1, 0x07, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00); + var exception = await Assert.ThrowsAsync(() => ReceiveTextAsync(client)); + + Assert.Equal("The message was compressed using an unsupported compression method.", exception.Message); + Assert.Equal(WebSocketState.Aborted, client.State); + } + + private ValueTask SendTextAsync(string text, WebSocket websocket) + { + byte[] bytes = Encoding.UTF8.GetBytes(text); + return websocket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, CancellationToken); + } + + private async Task ReceiveTextAsync(WebSocket websocket) + { + using IMemoryOwner buffer = MemoryPool.Shared.Rent(1024 * 32); + ValueWebSocketReceiveResult result = await websocket.ReceiveAsync(buffer.Memory, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + return Encoding.UTF8.GetString(buffer.Memory.Span.Slice(0, result.Count)); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs new file mode 100644 index 00000000000000..f6ec18bfba7abf --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs @@ -0,0 +1,222 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets.Tests +{ + /// + /// A helper stream class that can be used simulate sending / receiving (duplex) data in a websocket. + /// + public class WebSocketTestStream : Stream + { + private readonly SemaphoreSlim _inputLock = new(initialCount: 0); + private readonly Queue _inputQueue = new(); + private readonly CancellationTokenSource _disposed = new(); + + public WebSocketTestStream() + { + GC.SuppressFinalize(this); + Remote = new WebSocketTestStream(this); + } + + private WebSocketTestStream(WebSocketTestStream remote) + { + GC.SuppressFinalize(this); + Remote = remote; + } + + public WebSocketTestStream Remote { get; } + + /// + /// Returns the number of unread bytes. + /// + public int Available + { + get + { + var available = 0; + + lock (_inputQueue) + { + foreach (var x in _inputQueue) + { + available += x.AvailableLength; + } + } + + return available; + } + } + + public Span NextAvailableBytes + { + get + { + lock (_inputQueue) + { + var block = _inputQueue.Peek(); + + if (block is null) + { + return default; + } + return block.Available; + } + } + } + + /// + /// If set, would cause the next send operation to be delayed + /// and complete asynchronously. Can be used to test cancellation tokens + /// and async code branches. + /// + public TimeSpan DelayForNextSend { get; set; } + + public override bool CanRead => true; + + public override bool CanSeek => false; + + public override bool CanWrite => true; + + public override long Length => -1; + + public override long Position { get => -1; set => throw new NotSupportedException(); } + + protected override void Dispose(bool disposing) + { + if (!_disposed.IsCancellationRequested) + { + _disposed.Cancel(); + + lock (Remote._inputQueue) + { + Remote._inputLock.Release(); + Remote._inputQueue.Enqueue(Block.ConnectionClosed); + } + } + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) + { + using (var cancellation = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposed.Token)) + { + try + { + await _inputLock.WaitAsync(cancellation.Token).ConfigureAwait(false); + } + catch (TaskCanceledException) when (cancellationToken.IsCancellationRequested) + { + throw new OperationCanceledException(cancellationToken); + } + catch (OperationCanceledException) when (_disposed.IsCancellationRequested) + { + return 0; + } + } + + lock (_inputQueue) + { + var block = _inputQueue.Peek(); + if (block == Block.ConnectionClosed) + { + return 0; + } + var count = Math.Min(block.AvailableLength, buffer.Length); + + block.Available.Slice(0, count).CopyTo(buffer.Span); + block.Advance(count); + + if (block.AvailableLength == 0) + { + _inputQueue.Dequeue(); + } + else + { + // Because we haven't fully consumed the buffer + // we should release once the input lock so we can acquire + // it again on consequent receive. + _inputLock.Release(); + } + + return count; + } + } + + /// + /// Enqueues the provided data for receive by the WebSocket. + /// + public void Enqueue(params byte[] data) + { + lock (_inputQueue) + { + _inputLock.Release(); + _inputQueue.Enqueue(new Block(data)); + } + } + + /// + /// Enqueues the provided data for receive by the WebSocket. + /// + public void Enqueue(ReadOnlySpan data) + { + lock (_inputQueue) + { + _inputLock.Release(); + _inputQueue.Enqueue(new Block(data.ToArray())); + } + } + + public override void Write(ReadOnlySpan buffer) + { + lock (Remote._inputQueue) + { + Remote._inputLock.Release(); + Remote._inputQueue.Enqueue(new Block(buffer.ToArray())); + } + } + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + if (DelayForNextSend > TimeSpan.Zero) + { + await Task.Delay(DelayForNextSend, cancellationToken); + DelayForNextSend = TimeSpan.Zero; + } + + Write(buffer.Span); + } + + public override void Flush() => throw new NotSupportedException(); + + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + private sealed class Block + { + public static readonly Block ConnectionClosed = new(Array.Empty()); + + private readonly byte[] _data; + private int _position; + + public Block(byte[] data) + { + _data = data; + } + + public Span Available => _data.AsSpan(_position); + + public int AvailableLength => _data.Length - _position; + + public void Advance(int count) => _position += count; + } + } +} From 9cb024e84fbf33a855e12abde6a7e3a1dd3f0730 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sat, 6 Mar 2021 09:01:56 +0200 Subject: [PATCH 07/47] Forgot to dispose inflater & deflater. --- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index ae26651b1aa666..3756e8cac6d3f8 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -250,7 +250,10 @@ private void DisposeCore() { _disposed = true; _keepAliveTimer?.Dispose(); - _stream?.Dispose(); + _stream.Dispose(); + _inflater?.Dispose(); + _deflater?.Dispose(); + if (_state < WebSocketState.Aborted) { _state = WebSocketState.Closed; From 92cf8f49f3c5d45d5b65dc46d70ced2a2e3b0691 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 9 Mar 2021 10:07:50 +0200 Subject: [PATCH 08/47] Addressing pr feedback. --- .../src/System/Net/WebSockets/WebSocketHandle.Managed.cs | 4 ++-- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 2 +- .../tests/WebSocketDeflateOptionsTests.cs | 5 ++++- .../System.Net.WebSockets/tests/WebSocketDeflateTests.cs | 6 ++++-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index f7eca1af856e66..fb7f1f89541332 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -200,7 +200,7 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } // Store the negotiated deflate options in the original options, because - // otherwise there is now way of clients to actually check whether we are using + // otherwise there is no way of clients to actually check whether we are using // per message deflate or not. options.DeflateOptions = deflateOptions; @@ -218,7 +218,7 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli IsServer = false, SubProtocol = subprotocol, KeepAliveInterval = options.KeepAliveInterval, - DeflateOptions = deflateOptions, + DeflateOptions = deflateOptions }); } catch (Exception exc) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 3756e8cac6d3f8..02b67184982c6a 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -1050,7 +1050,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo } /// - /// Returns either or . + /// Returns either or . /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private TResult GetReceiveResult(int count, WebSocketMessageType messageType, bool endOfMessage) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs index 19858fb1b253c0..c698a29d5a922d 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs @@ -1,4 +1,7 @@ -using Xunit; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; namespace System.Net.WebSockets.Tests { diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index fa83a1043cf9ab..edc9b2741ed969 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -1,4 +1,7 @@ -using System.Buffers; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.Text; @@ -8,7 +11,6 @@ namespace System.Net.WebSockets.Tests { - [PlatformSpecific(~TestPlatforms.Browser)] public class WebSocketDeflateTests { private readonly CancellationTokenSource? _cancellation; From 4045305102fa1d2a969193b1bbee5af319e57ae8 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 9 Mar 2021 10:41:57 +0200 Subject: [PATCH 09/47] Moved deflate state into deflater. --- .../Compression/WebSocketDeflater.cs | 59 +++++++++++++++++-- .../System/Net/WebSockets/ManagedWebSocket.cs | 54 +---------------- 2 files changed, 56 insertions(+), 57 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index b993e0e52777ca..3eef219e712213 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -17,6 +17,8 @@ internal sealed class WebSocketDeflater : IDisposable private readonly int _windowBits; private readonly bool _persisted; + private byte[]? _buffer; + internal WebSocketDeflater(int windowBits, bool persisted) { Debug.Assert(windowBits >= 9 && windowBits <= 15); @@ -28,7 +30,54 @@ internal WebSocketDeflater(int windowBits, bool persisted) public void Dispose() => _stream?.Dispose(); - public void Deflate(ReadOnlySpan payload, Span output, bool continuation, bool endOfMessage, + public void ReleaseBuffer() + { + if (_buffer is not null) + { + ArrayPool.Shared.Return(_buffer); + _buffer = null; + } + } + + public ReadOnlySpan Deflate(ReadOnlySpan payload, bool continuation, bool endOfMessage) + { + Debug.Assert(_buffer is null, "Invalid state, ReleaseBuffer not called."); + + // Do not try to rent more than 1MB initially, because it will actually allocate + // instead of renting. Be optimistic that what we're sending is actually going to fit. + const int MaxInitialBufferLength = 1024 * 1024; + + // For small payloads there might actually be overhead in the compression and the resulting + // output might be larger than the payload. This is why we rent at least 4KB initially. + const int MinInitialBufferLength = 4 * 1024; + + _buffer = ArrayPool.Shared.Rent(Math.Min(Math.Max(payload.Length, MinInitialBufferLength), MaxInitialBufferLength)); + int position = 0; + + while (true) + { + DeflatePrivate(payload, _buffer.AsSpan(position), continuation, endOfMessage, + out int consumed, out int written, out bool needsMoreOutput); + position += written; + + if (!needsMoreOutput) + { + break; + } + + payload = payload.Slice(consumed); + + // Rent a 30% bigger buffer + byte[] newBuffer = ArrayPool.Shared.Rent((int)(_buffer.Length * 1.3)); + _buffer.AsSpan(0, position).CopyTo(newBuffer); + ArrayPool.Shared.Return(_buffer); + _buffer = newBuffer; + } + + return new ReadOnlySpan(_buffer, 0, position); + } + + private void DeflatePrivate(ReadOnlySpan payload, Span output, bool continuation, bool endOfMessage, out int consumed, out int written, out bool needsMoreOutput) { Debug.Assert(!continuation || _stream is not null, "Invalid state. The stream should not be null in continuations."); @@ -38,7 +87,7 @@ public void Deflate(ReadOnlySpan payload, Span output, bool continua Initialize(); } - Deflate(payload, output, out consumed, out written, out needsMoreOutput); + UnsafeDeflate(payload, output, out consumed, out written, out needsMoreOutput); if (needsMoreOutput) { return; @@ -52,7 +101,7 @@ public void Deflate(ReadOnlySpan payload, Span output, bool continua needsMoreOutput = true; return; } - written += Flush(output.Slice(written)); + written += UnsafeFlush(output.Slice(written)); Debug.Assert(output.Slice(written - WebSocketInflater.FlushMarkerLength, WebSocketInflater.FlushMarkerLength) .EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker."); @@ -69,7 +118,7 @@ public void Deflate(ReadOnlySpan payload, Span output, bool continua } } - private unsafe void Deflate(ReadOnlySpan input, Span output, out int consumed, out int written, out bool needsMoreBuffer) + private unsafe void UnsafeDeflate(ReadOnlySpan input, Span output, out int consumed, out int written, out bool needsMoreBuffer) { Debug.Assert(_stream is not null); @@ -96,7 +145,7 @@ private unsafe void Deflate(ReadOnlySpan input, Span output, out int } } - private unsafe int Flush(Span output) + private unsafe int UnsafeFlush(Span output) { Debug.Assert(_stream is not null); Debug.Assert(_stream.AvailIn == 0); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 02b67184982c6a..73c9649cf20860 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -164,8 +164,6 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke private int _inflateBufferAvailable; private readonly WebSocketDeflater? _deflater; - private byte[]? _deflateBuffer; - private int _deflateBufferPosition; private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) { @@ -553,7 +551,7 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read { if (_deflater is not null && !payloadBuffer.IsEmpty) { - payloadBuffer = Deflate(payloadBuffer, opcode == MessageOpcode.Continuation, endOfMessage); + payloadBuffer = _deflater.Deflate(payloadBuffer, opcode == MessageOpcode.Continuation, endOfMessage); } // Ensure we have a _sendBuffer. @@ -597,59 +595,11 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read { if (_deflater is not null) { - ReleaseDeflateBuffer(); + _deflater.ReleaseBuffer(); } } } - private ReadOnlySpan Deflate(ReadOnlySpan payload, bool continuation, bool endOfMessage) - { - Debug.Assert(_deflater is not null); - Debug.Assert(_deflateBuffer is null); - - // Do not try to rent more than 1MB initially, because it will actually allocate - // instead of renting. Be optimistic that what we're sending is actually going to fit. - const int MaxInitialBufferLength = 1024 * 1024; - - // For small payloads there might actually be overhead in the compression and the resulting - // output might be larger than the payload. This is why we rent at least 4KB initially. - const int MinInitialBufferLength = 4 * 1024; - - _deflateBuffer = ArrayPool.Shared.Rent(Math.Min(Math.Max(payload.Length, MinInitialBufferLength), MaxInitialBufferLength)); - _deflateBufferPosition = 0; - - while (true) - { - _deflater.Deflate(payload, _deflateBuffer.AsSpan(_deflateBufferPosition), continuation, endOfMessage, - out int consumed, out int written, out bool needsMoreOutput); - _deflateBufferPosition += written; - - if (!needsMoreOutput) - { - break; - } - - payload = payload.Slice(consumed); - - // Rent a 30% bigger buffer - byte[] newBuffer = ArrayPool.Shared.Rent((int)(_deflateBuffer.Length * 1.3)); - _deflateBuffer.AsSpan(0, _deflateBufferPosition).CopyTo(newBuffer); - ArrayPool.Shared.Return(_deflateBuffer); - _deflateBuffer = newBuffer; - } - - return new ReadOnlySpan(_deflateBuffer, 0, _deflateBufferPosition); - } - - private void ReleaseDeflateBuffer() - { - if (_deflateBuffer is not null) - { - ArrayPool.Shared.Return(_deflateBuffer); - _deflateBuffer = null; - } - } - /// /// Inflates the last receive payload into the provided buffer. /// From 862e163a905186e866d708acf150b012082a1a91 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 9 Mar 2021 11:28:33 +0200 Subject: [PATCH 10/47] Moved inflater buffer into the inflater. --- .../Compression/WebSocketInflater.cs | 103 +++++++++++++++-- .../System/Net/WebSockets/ManagedWebSocket.cs | 109 +++--------------- 2 files changed, 109 insertions(+), 103 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index 0b4f18ce69940f..a315ef12b5917c 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Runtime.InteropServices; @@ -35,6 +36,18 @@ internal sealed class WebSocketInflater : IDisposable /// private bool _needsFlushMarker; + private byte[]? _buffer; + + /// + /// The position for the next unconsumed byte in the inflate buffer. + /// + private int _position; + + /// + /// How many unconsumed bytes are left in the inflate buffer. + /// + private int _available; + internal WebSocketInflater(int windowBits, bool persisted) { Debug.Assert(windowBits >= 9 && windowBits <= 15); @@ -44,9 +57,74 @@ internal WebSocketInflater(int windowBits, bool persisted) _persisted = persisted; } - public void Dispose() => _stream?.Dispose(); + public bool Finished { get; private set; } = true; + + public Memory Memory => _buffer; + + public Span Span => _buffer; + + public void Dispose() + { + _stream?.Dispose(); + ReleaseBuffer(); + } + + public void Initialize(long payloadLength) + { + Debug.Assert(_available == 0); + Debug.Assert(_buffer is null); + + // Do not try to rent anythin above 1MB because the array pool + // will not pool the buffer but allocate it. + _buffer = ArrayPool.Shared.Rent((int)Math.Min(payloadLength, 1_000_000)); + _position = 0; + } - public unsafe void Inflate(ReadOnlySpan input, Span output, out int consumed, out int written) + /// + /// Inflates the last receive payload into the provided buffer. + /// + public void Inflate(int totalBytesReceived, Span output, bool flush, out int written) + { + if (totalBytesReceived > 0) + { + Debug.Assert(_buffer is not null, "Initialize must be called."); + _available = totalBytesReceived; + } + + if (_available > 0) + { + UnsafeInflate(input: new ReadOnlySpan(_buffer, _position, _available), + output, out int consumed, out written); + + _position += consumed; + _available -= consumed; + } + else + { + written = 0; + } + + if (_available == 0) + { + ReleaseBuffer(); + + if (flush) + { + Finished = Flush(output.Slice(written), out int byteCount); + written += byteCount; + } + else + { + Finished = true; + } + } + else + { + Finished = false; + } + } + + private unsafe void UnsafeInflate(ReadOnlySpan input, Span output, out int consumed, out int written) { if (_stream is null) { @@ -71,19 +149,20 @@ public unsafe void Inflate(ReadOnlySpan input, Span output, out int } /// - /// Finishes the decoding by writing any outstanding data to the output. + /// Finishes the decoding by flushing any outstanding data to the output. /// - /// true if the finish completed, false to indicate that there is more outstanding data. - public bool Finish(Span output, out int written) + /// true if the flush completed, false to indicate that there is more outstanding data. + private bool Flush(Span output, out int written) { Debug.Assert(_stream is not null); + Debug.Assert(_available == 0); if (_needsFlushMarker) { - Inflate(FlushMarker, output, out var _, out written); + UnsafeInflate(FlushMarker, output, out var _, out written); _needsFlushMarker = false; - if ( written < output.Length || IsFinished(_stream, out _remainingByte) ) + if (written < output.Length || IsFinished(_stream, out _remainingByte)) { OnFinished(); return true; @@ -136,6 +215,16 @@ private void OnFinished() } } + private void ReleaseBuffer() + { + if (_buffer is not null) + { + ArrayPool.Shared.Return(_buffer); + _buffer = null; + _available = 0; + } + } + private static unsafe bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte) { if (stream.AvailIn > 0) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 73c9649cf20860..4ba09a07bb6110 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -3,7 +3,6 @@ using System.Buffers; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.IO; using System.Net.WebSockets.Compression; using System.Numerics; @@ -151,18 +150,6 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke private object ReceiveAsyncLock => _utf8TextState; // some object, as we're simply lock'ing on it private readonly WebSocketInflater? _inflater; - private byte[]? _inflateBuffer; - - /// - /// The position for the next unconsumed byte in the inflate buffer. - /// - private int _inflateBufferPosition; - - /// - /// How many unconsumed bytes are left in the inflate buffer. - /// - private int _inflateBufferAvailable; - private readonly WebSocketDeflater? _deflater; private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) @@ -600,69 +587,6 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read } } - /// - /// Inflates the last receive payload into the provided buffer. - /// - /// true if inflate operation finished and no more data needs to be written - private bool Inflate(Span output, bool finish, out int written) - { - Debug.Assert(_inflater is not null); - - if (_inflateBufferAvailable > 0) - { - _inflater.Inflate(input: new ReadOnlySpan(_inflateBuffer, _inflateBufferPosition, _inflateBufferAvailable), - output, out int consumed, out written); - - _inflateBufferPosition += consumed; - _inflateBufferAvailable -= consumed; - } - else - { - written = 0; - } - - if (_inflateBufferAvailable <= 0) - { - ReleaseInflateBuffer(); - - if (finish) - { - if (_inflater.Finish(output.Slice(written), out int byteCount)) - { - _inflateBufferAvailable = 0; - } - else - { - // Setting this to -1 instructs the receive operation to not try and - // read any more data from the stream. - _inflateBufferAvailable = -1; - } - - written += byteCount; - } - } - - return _inflateBufferAvailable == 0; - } - - [MemberNotNull(nameof(_inflateBuffer))] - private void RentInflateBuffer(long payloadLength) - { - Debug.Assert(_inflateBuffer is null); - - _inflateBufferPosition = 0; - _inflateBuffer = ArrayPool.Shared.Rent((int)Math.Min(payloadLength, 1_000_000)); - } - - private void ReleaseInflateBuffer() - { - if (_inflateBuffer is not null) - { - ArrayPool.Shared.Return(_inflateBuffer); - _inflateBuffer = null; - } - } - private void SendKeepAliveFrameAsync() { #pragma warning disable CA1416 // Validate platform compatibility, will not wait because timeout equals 0 @@ -875,9 +799,8 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo // First copy any data lingering in the receive buffer. int totalBytesReceived = 0; - // Only start a new receive when we've consumed everything from the inflate buffer. When - // there is no compression, this will always be 0. - if (_inflateBufferAvailable == 0) + // Only start a new receive when we've consumed everything from the inflater, if present. + if (_inflater is null || _inflater.Finished) { if (_receiveBufferCount > 0) { @@ -888,10 +811,9 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo if (header.Compressed) { - Debug.Assert(_inflateBufferAvailable == 0); - RentInflateBuffer(header.PayloadLength); - - _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(_inflateBuffer); + Debug.Assert(_inflater is not null); + _inflater.Initialize(header.PayloadLength); + _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(_inflater.Span); ConsumeFromBuffer(receiveBufferBytesToCopy); totalBytesReceived += receiveBufferBytesToCopy; Debug.Assert(_receiveBufferCount == 0 || totalBytesReceived == header.PayloadLength); @@ -909,16 +831,17 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo } else if (header.Compressed) { - RentInflateBuffer(header.PayloadLength); + Debug.Assert(_inflater is not null); + _inflater.Initialize(header.PayloadLength); } // Then read directly into the appropriate buffer until we've hit a limit. - int limit = (int)Math.Min(header.Compressed ? _inflateBuffer!.Length : payloadBuffer.Length, header.PayloadLength); + int limit = (int)Math.Min(header.Compressed ? _inflater!.Memory.Length : payloadBuffer.Length, header.PayloadLength); while (totalBytesReceived < limit) { int numBytesRead = await _stream.ReadAsync( header.Compressed ? - _inflateBuffer.AsMemory(totalBytesReceived, limit - totalBytesReceived) : + _inflater!.Memory.Slice(totalBytesReceived, limit - totalBytesReceived) : payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived), cancellationToken).ConfigureAwait(false); if (numBytesRead <= 0) @@ -932,24 +855,20 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo if (_isServer) { _receivedMaskOffsetOffset = ApplyMask(header.Compressed ? - _inflateBuffer.AsSpan(0, totalBytesReceived) : + _inflater!.Span.Slice(0, totalBytesReceived) : payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset); } header.PayloadLength -= totalBytesReceived; } - else - { - totalBytesReceived = _inflateBufferAvailable; - } if (header.Compressed) { // In case of compression totalBytesReceived should actually represent how much we've // inflated, rather than how much we've read from the stream. - _inflateBufferAvailable = totalBytesReceived; - header.Processed = Inflate(payloadBuffer.Span, - finish: header.Fin && header.PayloadLength == 0, out totalBytesReceived); + _inflater!.Inflate(totalBytesReceived, payloadBuffer.Span, + flush: header.Fin && header.PayloadLength == 0, out totalBytesReceived); + header.Processed = _inflater.Finished && header.PayloadLength == 0; } else { @@ -973,8 +892,6 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo } catch (Exception exc) { - ReleaseInflateBuffer(); - if (exc is OperationCanceledException) { throw; From 92520b66f941f138423c8837abe3e821718cdb46 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sun, 14 Mar 2021 10:05:05 +0200 Subject: [PATCH 11/47] Using Math.Clamp instead of Math.Min/Max. --- .../src/System/Net/WebSockets/Compression/WebSocketDeflater.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 3eef219e712213..bd67dd58bca94b 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -51,7 +51,7 @@ public ReadOnlySpan Deflate(ReadOnlySpan payload, bool continuation, // output might be larger than the payload. This is why we rent at least 4KB initially. const int MinInitialBufferLength = 4 * 1024; - _buffer = ArrayPool.Shared.Rent(Math.Min(Math.Max(payload.Length, MinInitialBufferLength), MaxInitialBufferLength)); + _buffer = ArrayPool.Shared.Rent(Math.Clamp(payload.Length, MinInitialBufferLength, MaxInitialBufferLength)); int position = 0; while (true) From 4d024facd3abe7875f660f3e997bc99680290bf7 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sun, 14 Mar 2021 10:42:50 +0200 Subject: [PATCH 12/47] Releasing deflater buffer as soon as we're done using it. --- .../System/Net/WebSockets/ManagedWebSocket.cs | 78 +++++++++---------- 1 file changed, 36 insertions(+), 42 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 4ba09a07bb6110..51598105d6273b 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -534,57 +534,51 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM /// Writes a frame into the send buffer, which can then be sent over the network. private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, ReadOnlySpan payloadBuffer) { - try + if (_deflater is not null && !payloadBuffer.IsEmpty) { - if (_deflater is not null && !payloadBuffer.IsEmpty) - { - payloadBuffer = _deflater.Deflate(payloadBuffer, opcode == MessageOpcode.Continuation, endOfMessage); - } + payloadBuffer = _deflater.Deflate(payloadBuffer, opcode == MessageOpcode.Continuation, endOfMessage); + } + int payloadLength = payloadBuffer.Length; - // Ensure we have a _sendBuffer. - AllocateSendBuffer(payloadBuffer.Length + MaxMessageHeaderLength); - Debug.Assert(_sendBuffer != null); + // Ensure we have a _sendBuffer + AllocateSendBuffer(payloadLength + MaxMessageHeaderLength); + Debug.Assert(_sendBuffer != null); - // Write the message header data to the buffer. - int headerLength; - int? maskOffset = null; - if (_isServer) - { - // The server doesn't send a mask, so the mask offset returned by WriteHeader - // is actually the end of the header. - headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false, compressed: _inflater is not null); - } - else - { - // We need to know where the mask starts so that we can use the mask to manipulate the payload data, - // and we need to know the total length for sending it on the wire. - maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true, compressed: _inflater is not null); - headerLength = maskOffset.GetValueOrDefault() + MaskLength; - } + // Write the message header data to the buffer. + int headerLength; + int? maskOffset = null; + if (_isServer) + { + // The server doesn't send a mask, so the mask offset returned by WriteHeader + // is actually the end of the header. + headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false, compressed: _inflater is not null); + } + else + { + // We need to know where the mask starts so that we can use the mask to manipulate the payload data, + // and we need to know the total length for sending it on the wire. + maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true, compressed: _inflater is not null); + headerLength = maskOffset.GetValueOrDefault() + MaskLength; + } - // Write the payload - if (payloadBuffer.Length > 0) - { - payloadBuffer.CopyTo(new Span(_sendBuffer, headerLength, payloadBuffer.Length)); + // Write the payload + if (!payloadBuffer.IsEmpty) + { + payloadBuffer.CopyTo(new Span(_sendBuffer, headerLength, payloadLength)); - // If we added a mask to the header, XOR the payload with the mask. We do the manipulation in the send buffer so as to avoid - // changing the data in the caller-supplied payload buffer. - if (maskOffset.HasValue) - { - ApplyMask(new Span(_sendBuffer, headerLength, payloadBuffer.Length), _sendBuffer, maskOffset.Value, 0); - } - } + // Release the deflater buffer if any, we're not going to need the payloadBuffer anymore. + _deflater?.ReleaseBuffer(); - // Return the number of bytes in the send buffer - return headerLength + payloadBuffer.Length; - } - finally - { - if (_deflater is not null) + // If we added a mask to the header, XOR the payload with the mask. We do the manipulation in the send buffer so as to avoid + // changing the data in the caller-supplied payload buffer. + if (maskOffset.HasValue) { - _deflater.ReleaseBuffer(); + ApplyMask(new Span(_sendBuffer, headerLength, payloadLength), _sendBuffer, maskOffset.Value, 0); } } + + // Return the number of bytes in the send buffer + return headerLength + payloadLength; } private void SendKeepAliveFrameAsync() From b8bdfaee498e3812ea723da1631ed9f58d41a0fc Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sun, 14 Mar 2021 19:02:20 +0200 Subject: [PATCH 13/47] Using 2 constructors in ManagedWebSocket to avoid allocations for WebSocketCreationOptions in the existing API. --- .../System/Net/WebSockets/ManagedWebSocket.cs | 57 ++++++++++--------- .../src/System/Net/WebSockets/WebSocket.cs | 15 +---- 2 files changed, 33 insertions(+), 39 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 51598105d6273b..d6a5dd3e2fb2b8 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -26,15 +26,6 @@ namespace System.Net.WebSockets /// internal sealed partial class ManagedWebSocket : WebSocket { - /// Creates a from a connected to a websocket endpoint. - /// The connected Stream. - /// The options with which the websocket must be created. - /// The created instance. - public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocketCreationOptions options) - { - return new ManagedWebSocket(stream, options); - } - /// Thread-safe random number generator used to generate masks for each send. private static readonly RandomNumberGenerator s_random = RandomNumberGenerator.Create(); /// Encoding for the payload of text messages: UTF8 encoding that throws if invalid bytes are discovered, per the RFC. @@ -152,7 +143,12 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke private readonly WebSocketInflater? _inflater; private readonly WebSocketDeflater? _deflater; - private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) + /// Initializes the websocket. + /// The connected Stream. + /// true if this is the server-side of the connection; false if this is the client-side of the connection. + /// The agreed upon subprotocol for the connection. + /// The interval to use for keep-alive pings. + internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) { Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null"); Debug.Assert(ReceiveAsyncLock != null, $"Expected {nameof(ReceiveAsyncLock)} to be non-null"); @@ -163,21 +159,8 @@ private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) Debug.Assert(stream.CanWrite, $"Expected writeable stream"); _stream = stream; - _isServer = options.IsServer; - _subprotocol = options.SubProtocol; - - var deflateOptions = options.DeflateOptions; - - if (deflateOptions is not null) - { - _deflater = options.IsServer ? - new WebSocketDeflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover) : - new WebSocketDeflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover); - - _inflater = options.IsServer ? - new WebSocketInflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover) : - new WebSocketInflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover); - } + _isServer = isServer; + _subprotocol = subprotocol; // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and // control payloads (at most 125 bytes). Message payloads are read directly into the buffer @@ -207,7 +190,7 @@ private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) // Now that we're opened, initiate the keep alive timer to send periodic pings. // We use a weak reference from the timer to the web socket to avoid a cycle // that could keep the web socket rooted in erroneous cases. - if (options.KeepAliveInterval > TimeSpan.Zero) + if (keepAliveInterval > TimeSpan.Zero) { _keepAliveTimer = new Timer(static s => { @@ -216,7 +199,27 @@ private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) { thisRef.SendKeepAliveFrameAsync(); } - }, new WeakReference(this), options.KeepAliveInterval, options.KeepAliveInterval); + }, new WeakReference(this), keepAliveInterval, keepAliveInterval); + } + } + + /// Initializes the websocket. + /// The connected Stream. + /// The options with which the websocket must be created. + internal ManagedWebSocket(Stream stream, WebSocketCreationOptions options) + : this(stream, options.IsServer, options.SubProtocol, options.KeepAliveInterval) + { + var deflateOptions = options.DeflateOptions; + + if (deflateOptions is not null) + { + _deflater = options.IsServer ? + new WebSocketDeflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover) : + new WebSocketDeflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover); + + _inflater = options.IsServer ? + new WebSocketInflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover) : + new WebSocketInflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover); } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index 0cdb4d3ab67616..0666fd4edd44f3 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -157,12 +157,7 @@ public static WebSocket CreateFromStream(Stream stream, bool isServer, string? s 0)); } - return ManagedWebSocket.CreateFromConnectedStream(stream, new WebSocketCreationOptions - { - IsServer = isServer, - SubProtocol = subProtocol, - KeepAliveInterval = keepAliveInterval - }); + return new ManagedWebSocket(stream, isServer, subProtocol, keepAliveInterval); } /// Creates a that operates on a representing a web socket connection. @@ -179,7 +174,7 @@ public static WebSocket CreateFromStream(Stream stream, WebSocketCreationOptions if (!stream.CanRead || !stream.CanWrite) throw new ArgumentException(!stream.CanRead ? SR.NotReadableStream : SR.NotWriteableStream, nameof(stream)); - return ManagedWebSocket.CreateFromConnectedStream(stream, options); + return new ManagedWebSocket(stream, options); } [EditorBrowsable(EditorBrowsableState.Never)] @@ -231,11 +226,7 @@ public static WebSocket CreateClientWebSocket(Stream innerStream, // Ignore useZeroMaskingKey. ManagedWebSocket doesn't currently support that debugging option. // Ignore internalBuffer. ManagedWebSocket uses its own small buffer for headers/control messages. - return ManagedWebSocket.CreateFromConnectedStream(innerStream, new WebSocketCreationOptions - { - SubProtocol = subProtocol, - KeepAliveInterval = keepAliveInterval - }); + return new ManagedWebSocket(innerStream, false, subProtocol, keepAliveInterval); } } } From 7e8d3c4815490bf6e6ea185835fbfb35b7c5f9f6 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sun, 14 Mar 2021 19:04:57 +0200 Subject: [PATCH 14/47] Added missing assert. --- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index d6a5dd3e2fb2b8..7de980b30244c4 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -157,6 +157,7 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim Debug.Assert(stream != null, $"Expected non-null stream"); Debug.Assert(stream.CanRead, $"Expected readable stream"); Debug.Assert(stream.CanWrite, $"Expected writeable stream"); + Debug.Assert(keepAliveInterval == Timeout.InfiniteTimeSpan || keepAliveInterval >= TimeSpan.Zero, $"Invalid keepalive interval: {keepAliveInterval}"); _stream = stream; _isServer = isServer; From 38693673d179f91399c1ea7dbe5b3bb1a9871488 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sun, 14 Mar 2021 19:09:39 +0200 Subject: [PATCH 15/47] Removed unnecessary comments that visual studio inserted from resx files. --- .../src/Resources/Strings.resx | 62 +------------------ .../src/Resources/Strings.resx | 62 +------------------ 2 files changed, 2 insertions(+), 122 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx index 5649bc52e7653b..bf4aea90ecc1e9 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx @@ -1,64 +1,4 @@ - - - + diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index d8963299bd87c6..c838d4d188ad45 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -1,64 +1,4 @@ - - - + From b7f58b16e5828fe917d7956c2c03a78e1c294481 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 15 Mar 2021 12:27:52 +0200 Subject: [PATCH 16/47] Better buffer handling in inflater. --- .../Compression/WebSocketInflater.cs | 19 +++++++++++++++---- .../System/Net/WebSockets/ManagedWebSocket.cs | 4 ++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index a315ef12b5917c..92fdfa93318ee6 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -57,6 +57,11 @@ internal WebSocketInflater(int windowBits, bool persisted) _persisted = persisted; } + /// + /// Indicates that there is nothing left for inflating. If there is + /// more data left to be received for the message then call + /// and receive directly into . + /// public bool Finished { get; private set; } = true; public Memory Memory => _buffer; @@ -69,14 +74,20 @@ public void Dispose() ReleaseBuffer(); } - public void Initialize(long payloadLength) + /// + /// Initializes the inflater by allocating a buffer so the websocket can receive directly onto it. + /// + /// the length of the message payload + /// the length of the buffer where the payload will be inflated + public void Initialize(long payloadLength, int clientBufferLength) { Debug.Assert(_available == 0); Debug.Assert(_buffer is null); - // Do not try to rent anythin above 1MB because the array pool - // will not pool the buffer but allocate it. - _buffer = ArrayPool.Shared.Rent((int)Math.Min(payloadLength, 1_000_000)); + // Rent a buffer as close to the size of the client buffer as possible, + // but not try to rent anything above 1MB because the array pool will allocate. + // If the payload is smaller than the client buffer, rent only as much as we need. + _buffer = ArrayPool.Shared.Rent(Math.Min(clientBufferLength, (int)Math.Min(payloadLength, 1_000_000))); _position = 0; } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 7de980b30244c4..d631656cd3c986 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -810,7 +810,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo if (header.Compressed) { Debug.Assert(_inflater is not null); - _inflater.Initialize(header.PayloadLength); + _inflater.Initialize(header.PayloadLength, payloadBuffer.Length); _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(_inflater.Span); ConsumeFromBuffer(receiveBufferBytesToCopy); totalBytesReceived += receiveBufferBytesToCopy; @@ -830,7 +830,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo else if (header.Compressed) { Debug.Assert(_inflater is not null); - _inflater.Initialize(header.PayloadLength); + _inflater.Initialize(header.PayloadLength, payloadBuffer.Length); } // Then read directly into the appropriate buffer until we've hit a limit. From c6c33264cdbc04f065c75d3c3117a769a5fe762e Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 15 Mar 2021 12:45:43 +0200 Subject: [PATCH 17/47] Renaming clientBufferLength to userBufferLength. --- .../Net/WebSockets/Compression/WebSocketInflater.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index 92fdfa93318ee6..e53e0b42a971d0 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -78,16 +78,16 @@ public void Dispose() /// Initializes the inflater by allocating a buffer so the websocket can receive directly onto it. /// /// the length of the message payload - /// the length of the buffer where the payload will be inflated - public void Initialize(long payloadLength, int clientBufferLength) + /// the length of the buffer where the payload will be inflated + public void Initialize(long payloadLength, int userBufferLength) { Debug.Assert(_available == 0); Debug.Assert(_buffer is null); - // Rent a buffer as close to the size of the client buffer as possible, + // Rent a buffer as close to the size of the user buffer as possible, // but not try to rent anything above 1MB because the array pool will allocate. - // If the payload is smaller than the client buffer, rent only as much as we need. - _buffer = ArrayPool.Shared.Rent(Math.Min(clientBufferLength, (int)Math.Min(payloadLength, 1_000_000))); + // If the payload is smaller than the user buffer, rent only as much as we need. + _buffer = ArrayPool.Shared.Rent(Math.Min(userBufferLength, (int)Math.Min(payloadLength, 1_000_000))); _position = 0; } From 7ec31514a04d720f07001fa49d0cd7f472c340cf Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 26 Mar 2021 11:53:11 +0200 Subject: [PATCH 18/47] Fixed an issue in inflater / deflater where with context takeover, the sliding window was reset with every message. --- .../Compression/WebSocketDeflater.cs | 56 ++++++---- .../Compression/WebSocketInflater.cs | 102 ++++++++++-------- .../System/Net/WebSockets/ManagedWebSocket.cs | 50 +++------ .../tests/WebSocketDeflateTests.cs | 50 ++++++++- .../tests/WebSocketTestStream.cs | 28 ++++- 5 files changed, 179 insertions(+), 107 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index bd67dd58bca94b..626240df7a5430 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -87,21 +87,29 @@ private void DeflatePrivate(ReadOnlySpan payload, Span output, bool Initialize(); } - UnsafeDeflate(payload, output, out consumed, out written, out needsMoreOutput); - if (needsMoreOutput) + if (payload.IsEmpty) { - return; + consumed = 0; + written = 0; + } + else + { + UnsafeDeflate(payload, output, out consumed, out written, out needsMoreOutput); + + if (needsMoreOutput) + { + Debug.Assert(written == output.Length); + return; + } } - // See comment by Mark Adler https://github.com/madler/zlib/issues/149#issuecomment-225237457 - // At that point there will be at most a few bits left to write. - // Then call deflate() with Z_FULL_FLUSH and no more input and at least six bytes of available output. - if (output.Length - written < 6) + written += UnsafeFlush(output.Slice(written), out needsMoreOutput); + + if (needsMoreOutput) { - needsMoreOutput = true; + Debug.Assert(written == output.Length); return; } - written += UnsafeFlush(output.Slice(written)); Debug.Assert(output.Slice(written - WebSocketInflater.FlushMarkerLength, WebSocketInflater.FlushMarkerLength) .EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker."); @@ -131,12 +139,10 @@ private unsafe void UnsafeDeflate(ReadOnlySpan input, Span output, o _stream.NextOut = (IntPtr)fixedOutput; _stream.AvailOut = (uint)output.Length; - // If flush is set to Z_BLOCK, a deflate block is completed - // and emitted, as for Z_SYNC_FLUSH, but the output - // is not aligned on a byte boundary, and up to seven bits - // of the current block are held to be written as the next byte after - // the next deflate block is completed. - var errorCode = Deflate(_stream, (FlushCode)5/*Z_BLOCK*/); + // The flush is set to Z_NO_FLUSH, which allows deflate to decide + // how much data to accumulate before producing output, + // in order to maximize compression. + var errorCode = Deflate(_stream, FlushCode.NoFlush); consumed = input.Length - (int)_stream.AvailIn; written = output.Length - (int)_stream.AvailOut; @@ -145,7 +151,7 @@ private unsafe void UnsafeDeflate(ReadOnlySpan input, Span output, o } } - private unsafe int UnsafeFlush(Span output) + private unsafe int UnsafeFlush(Span output, out bool needsMoreBuffer) { Debug.Assert(_stream is not null); Debug.Assert(_stream.AvailIn == 0); @@ -159,11 +165,17 @@ private unsafe int UnsafeFlush(Span output) _stream.NextOut = (IntPtr)fixedOutput; _stream.AvailOut = (uint)output.Length; - ErrorCode errorCode = Deflate(_stream, (FlushCode)3/*Z_FULL_FLUSH*/); - int writtenBytes = output.Length - (int)_stream.AvailOut; - Debug.Assert(errorCode == ErrorCode.Ok); + // The flush is set to Z_SYNC_FLUSH, all pending output is flushed + // to the output buffer and the output is aligned on a byte boundary, + // so that the decompressor can get all input data available so far. + // This completes the current deflate block and follows it with an empty + // stored block that is three bits plus filler bits to the next byte, + // followed by four bytes (00 00 ff ff). + ErrorCode errorCode = Deflate(_stream, FlushCode.SyncFlush); + Debug.Assert(errorCode is ErrorCode.Ok or ErrorCode.BufError); - return writtenBytes; + needsMoreBuffer = errorCode == ErrorCode.BufError; + return output.Length - (int)_stream.AvailOut; } } @@ -183,10 +195,8 @@ private static ErrorCode Deflate(ZLibStreamHandle stream, FlushCode flushCode) { case ErrorCode.Ok: case ErrorCode.StreamEnd: - return errorCode; - case ErrorCode.BufError: - return errorCode; // This is a recoverable error + return errorCode; case ErrorCode.StreamError: throw new WebSocketException(SR.ZLibErrorInconsistentStream); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index e53e0b42a971d0..7cae72a0aafdf5 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -58,15 +58,13 @@ internal WebSocketInflater(int windowBits, bool persisted) } /// - /// Indicates that there is nothing left for inflating. If there is - /// more data left to be received for the message then call - /// and receive directly into . + /// Indicates that there is nothing left for inflating. /// public bool Finished { get; private set; } = true; - public Memory Memory => _buffer; + public Memory Memory => _buffer.AsMemory(_position + _available); - public Span Span => _buffer; + public Span Span => _buffer.AsSpan(_position + _available); public void Dispose() { @@ -79,36 +77,62 @@ public void Dispose() /// /// the length of the message payload /// the length of the buffer where the payload will be inflated - public void Initialize(long payloadLength, int userBufferLength) + public void Prepare(long payloadLength, int userBufferLength) { - Debug.Assert(_available == 0); - Debug.Assert(_buffer is null); + if (_buffer is not null) + { + Debug.Assert(_available > 0); - // Rent a buffer as close to the size of the user buffer as possible, - // but not try to rent anything above 1MB because the array pool will allocate. - // If the payload is smaller than the user buffer, rent only as much as we need. - _buffer = ArrayPool.Shared.Rent(Math.Min(userBufferLength, (int)Math.Min(payloadLength, 1_000_000))); - _position = 0; + _buffer.AsSpan(_position, _available).CopyTo(_buffer); + _position = 0; + } + else + { + // Rent a buffer as close to the size of the user buffer as possible, + // but not try to rent anything above 1MB because the array pool will allocate. + // If the payload is smaller than the user buffer, rent only as much as we need. + _buffer = ArrayPool.Shared.Rent(Math.Min(userBufferLength, (int)Math.Min(payloadLength, 1_000_000))); + } } /// /// Inflates the last receive payload into the provided buffer. /// - public void Inflate(int totalBytesReceived, Span output, bool flush, out int written) + public unsafe void Inflate(int totalBytesReceived, Span output, bool flush, out int written) { if (totalBytesReceived > 0) { - Debug.Assert(_buffer is not null, "Initialize must be called."); - _available = totalBytesReceived; + Debug.Assert(_buffer is not null, "Prepare must be called."); + _available += totalBytesReceived; } if (_available > 0) { - UnsafeInflate(input: new ReadOnlySpan(_buffer, _position, _available), - output, out int consumed, out written); + if (_stream is null) + { + Initialize(); + } + + int consumed; + + fixed (byte* fixedInput = _buffer) + fixed (byte* fixedOutput = &MemoryMarshal.GetReference(output)) + { + _stream.NextIn = (IntPtr)(fixedInput + _position); + _stream.AvailIn = (uint)_available; + + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; + + Inflate(_stream); + + consumed = _available - (int)_stream.AvailIn; + written = output.Length - (int)_stream.AvailOut; + } _position += consumed; _available -= consumed; + _needsFlushMarker = _persisted; } else { @@ -135,42 +159,31 @@ public void Inflate(int totalBytesReceived, Span output, bool flush, out i } } - private unsafe void UnsafeInflate(ReadOnlySpan input, Span output, out int consumed, out int written) - { - if (_stream is null) - { - Initialize(); - } - fixed (byte* fixedInput = &MemoryMarshal.GetReference(input)) - fixed (byte* fixedOutput = &MemoryMarshal.GetReference(output)) - { - _stream.NextIn = (IntPtr)fixedInput; - _stream.AvailIn = (uint)input.Length; - - _stream.NextOut = (IntPtr)fixedOutput; - _stream.AvailOut = (uint)output.Length; - - Inflate(_stream); - - consumed = input.Length - (int)_stream.AvailIn; - written = output.Length - (int)_stream.AvailOut; - } - - _needsFlushMarker = _persisted; - } - /// /// Finishes the decoding by flushing any outstanding data to the output. /// /// true if the flush completed, false to indicate that there is more outstanding data. - private bool Flush(Span output, out int written) + private unsafe bool Flush(Span output, out int written) { Debug.Assert(_stream is not null); Debug.Assert(_available == 0); if (_needsFlushMarker) { - UnsafeInflate(FlushMarker, output, out var _, out written); + fixed (byte* fixedInput = &MemoryMarshal.GetReference(FlushMarker)) + fixed (byte* fixedOutput = &MemoryMarshal.GetReference(output)) + { + _stream.NextIn = (IntPtr)fixedInput; + _stream.AvailIn = (uint)FlushMarkerLength; + + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; + + Inflate(_stream); + + written = output.Length - (int)_stream.AvailOut; + } + _needsFlushMarker = false; if (written < output.Length || IsFinished(_stream, out _remainingByte)) @@ -233,6 +246,7 @@ private void ReleaseBuffer() ArrayPool.Shared.Return(_buffer); _buffer = null; _available = 0; + _position = 0; } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index d631656cd3c986..114c99ac767fd6 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -797,48 +797,32 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo // First copy any data lingering in the receive buffer. int totalBytesReceived = 0; - // Only start a new receive when we've consumed everything from the inflater, if present. - if (_inflater is null || _inflater.Finished) + // Only start a new receive if we haven't received the entire frame. + if (header.PayloadLength > 0) { + if (header.Compressed) + { + Debug.Assert(_inflater is not null); + _inflater.Prepare(header.PayloadLength, payloadBuffer.Length); + } + + // Read directly into the appropriate buffer until we've hit a limit. + int limit = (int)Math.Min(header.Compressed ? _inflater!.Span.Length : payloadBuffer.Length, header.PayloadLength); + if (_receiveBufferCount > 0) { - int receiveBufferBytesToCopy = header.Compressed ? - (int)Math.Min(header.PayloadLength, _receiveBufferCount) : - Math.Min(payloadBuffer.Length, (int)Math.Min(header.PayloadLength, _receiveBufferCount)); + int receiveBufferBytesToCopy = Math.Min(limit, _receiveBufferCount); Debug.Assert(receiveBufferBytesToCopy > 0); - if (header.Compressed) - { - Debug.Assert(_inflater is not null); - _inflater.Initialize(header.PayloadLength, payloadBuffer.Length); - _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(_inflater.Span); - ConsumeFromBuffer(receiveBufferBytesToCopy); - totalBytesReceived += receiveBufferBytesToCopy; - Debug.Assert(_receiveBufferCount == 0 || totalBytesReceived == header.PayloadLength); - } - else - { - _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(payloadBuffer.Span); - ConsumeFromBuffer(receiveBufferBytesToCopy); - totalBytesReceived += receiveBufferBytesToCopy; - Debug.Assert( - _receiveBufferCount == 0 || - totalBytesReceived == payloadBuffer.Length || - totalBytesReceived == header.PayloadLength); - } - } - else if (header.Compressed) - { - Debug.Assert(_inflater is not null); - _inflater.Initialize(header.PayloadLength, payloadBuffer.Length); + _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo( + header.Compressed ? _inflater!.Span : payloadBuffer.Span); + ConsumeFromBuffer(receiveBufferBytesToCopy); + totalBytesReceived += receiveBufferBytesToCopy; } - // Then read directly into the appropriate buffer until we've hit a limit. - int limit = (int)Math.Min(header.Compressed ? _inflater!.Memory.Length : payloadBuffer.Length, header.PayloadLength); while (totalBytesReceived < limit) { - int numBytesRead = await _stream.ReadAsync( - header.Compressed ? + int numBytesRead = await _stream.ReadAsync( header.Compressed ? _inflater!.Memory.Slice(totalBytesReceived, limit - totalBytesReceived) : payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived), cancellationToken).ConfigureAwait(false); diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index edc9b2741ed969..6a63ab4c541c2a 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -37,7 +37,7 @@ public static IEnumerable SupportedWindowBits } [Fact] - public async Task HelloWithContextTakeover() + public async Task ReceiveHelloWithContextTakeover() { WebSocketTestStream stream = new(); stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); @@ -67,7 +67,28 @@ public async Task HelloWithContextTakeover() } [Fact] - public async Task HelloWithoutContextTakeover() + public async Task SendHelloWithContextTakeover() + { + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + IsServer = true, + DeflateOptions = new() + }); + + await websocket.SendAsync(Encoding.UTF8.GetBytes("Hello"), WebSocketMessageType.Text, true, CancellationToken); + Assert.Equal("C107F248CDC9C90700", Convert.ToHexString(stream.NextAvailableBytes)); + + stream.Clear(); + await websocket.SendAsync(Encoding.UTF8.GetBytes("Hello"), WebSocketMessageType.Text, true, CancellationToken); + + // Because context takeover is set by default if we try to send + // the same message it should result in fewer bytes. + Assert.Equal("C105F200110000", Convert.ToHexString(stream.NextAvailableBytes)); + } + + [Fact] + public async Task ReceiveHelloWithoutContextTakeover() { WebSocketTestStream stream = new(); using WebSocket websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions @@ -95,6 +116,31 @@ public async Task HelloWithoutContextTakeover() } } + [Fact] + public async Task SendHelloWithoutContextTakeover() + { + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + IsServer = true, + DeflateOptions = new() + { + ClientContextTakeover = false + } + }); + + Memory buffer = new byte[5]; + + for (var i = 0; i < 100; ++i) + { + await websocket.SendAsync(Encoding.UTF8.GetBytes("Hello"), WebSocketMessageType.Text, true, CancellationToken); + + // Without context takeover the message should look the same every time + Assert.Equal("C107F248CDC9C90700", Convert.ToHexString(stream.NextAvailableBytes)); + stream.Clear(); + } + } + [Fact] public async Task TwoDeflateBlocksInOneMessage() { diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs index f6ec18bfba7abf..0b3e7b443f239f 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs @@ -58,13 +58,11 @@ public Span NextAvailableBytes { lock (_inputQueue) { - var block = _inputQueue.Peek(); - - if (block is null) + if (_inputQueue.TryPeek(out Block block)) { - return default; + return block.Available; } - return block.Available; + return default; } } } @@ -170,6 +168,26 @@ public void Enqueue(ReadOnlySpan data) } } + public void Clear() + { + lock (_inputQueue) + { + while (_inputQueue.Count > 0) + { + if (_inputQueue.Peek() == Block.ConnectionClosed) + { + break; + } + _inputQueue.Dequeue(); + } + + while (_inputLock.CurrentCount > _inputQueue.Count) + { + _inputLock.Wait(0); + } + } + } + public override void Write(ReadOnlySpan buffer) { lock (Remote._inputQueue) From 271c9cdde08c0d504f30a413718bc0d49d26a141 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 26 Mar 2021 12:02:06 +0200 Subject: [PATCH 19/47] Revert removal of BOM. --- .../System.IO.Compression/src/System.IO.Compression.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj index 60ee6f229fc034..0ffa0044e2a167 100644 --- a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj +++ b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj @@ -1,4 +1,4 @@ - + true $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser From 61c7e5b48320fcc32403502ba8eabd9756fb89bc Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 26 Mar 2021 12:48:31 +0200 Subject: [PATCH 20/47] PR feedback for ClientWebSocket. --- .../src/Resources/Strings.resx | 7 ++- .../ClientWebSocketDeflateConstants.cs | 7 +++ .../Net/WebSockets/ClientWebSocketOptions.cs | 6 +++ .../Net/WebSockets/WebSocketHandle.Managed.cs | 45 +++++++++++++------ 4 files changed, 50 insertions(+), 15 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx index bf4aea90ecc1e9..bda2641a36378c 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx @@ -138,7 +138,10 @@ WebSocket binary type '{0}' not supported. - - The WebSocket failed to negotiate max {0} window bits. The client requested {1} but the server responded with {2}. + + The WebSocket failed to negotiate max server window bits. The client requested {1} but the server responded with {2}. + + + The WebSocket failed to negotiate max client window bits. The client requested {1} but the server responded with {2}. \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs index 75a7a3c7ea9161..3faa886d5c3068 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs @@ -5,6 +5,13 @@ namespace System.Net.WebSockets { internal static class ClientWebSocketDeflateConstants { + /// + /// The maximum length that this extension can have, assuming that we're not abusing white space. + /// + /// "permessage-deflate; client_max_window_bits=15; client_no_context_takeover; server_max_window_bits=15; server_no_context_takeover" + /// + public const int MaxExtensionLength = 128; + public const string Extension = "permessage-deflate"; public const string ClientMaxWindowBits = "client_max_window_bits"; diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index 573c3eb8325b70..00427ea9e2eb10 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -148,6 +148,12 @@ public TimeSpan KeepAliveInterval } } + /// + /// Gets or sets the options for the per-message-deflate extension. + /// When present, the options are sent to the server during the handshake phase. If the server + /// supports per-message-deflate and the options are accepted, the instance + /// will be created with compression enabled by default for all messages. + /// [UnsupportedOSPlatform("browser")] public WebSocketDeflateOptions? DeflateOptions { get; set; } diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index fb7f1f89541332..2b5424a492f985 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -302,14 +302,14 @@ static int ParseWindowBits(ReadOnlySpan value) if (options.ClientMaxWindowBits > original.ClientMaxWindowBits) { - throw new WebSocketException(string.Format(SR.net_WebSockets_WindowBitsNegotiationFailure, - "client", original.ClientMaxWindowBits, options.ClientMaxWindowBits)); + throw new WebSocketException(string.Format(SR.net_WebSockets_ClientWindowBitsNegotiationFailure, + original.ClientMaxWindowBits, options.ClientMaxWindowBits)); } if (options.ServerMaxWindowBits > original.ServerMaxWindowBits) { - throw new WebSocketException(string.Format(SR.net_WebSockets_WindowBitsNegotiationFailure, - "server", original.ServerMaxWindowBits, options.ServerMaxWindowBits)); + throw new WebSocketException(string.Format(SR.net_WebSockets_ServerWindowBitsNegotiationFailure, + original.ServerMaxWindowBits, options.ServerMaxWindowBits)); } return options; @@ -331,41 +331,60 @@ private static void AddWebSocketHeaders(HttpRequestMessage request, string secKe } if (options.DeflateOptions is not null) { - request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketExtensions, string.Join("; ", GetDeflateOptions(options.DeflateOptions))); + request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketExtensions, GetDeflateOptions(options.DeflateOptions)); - static IEnumerable GetDeflateOptions(WebSocketDeflateOptions options) + static string GetWindowBitsString(int value) => value switch { - yield return ClientWebSocketDeflateConstants.Extension; + 9 => "9", + 10 => "10", + 11 => "11", + 12 => "12", + 13 => "13", + 14 => "14", + 15 => "15", + _ => value.ToString(CultureInfo.InvariantCulture) + }; + static string GetDeflateOptions(WebSocketDeflateOptions options) + { + var builder = new StringBuilder(ClientWebSocketDeflateConstants.MaxExtensionLength); + builder.Append(ClientWebSocketDeflateConstants.Extension).Append("; "); if (options.ClientMaxWindowBits != 15) { - yield return $"{ClientWebSocketDeflateConstants.ClientMaxWindowBits}={options.ClientMaxWindowBits}"; + builder.Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits).Append('=') + .Append(GetWindowBitsString(options.ClientMaxWindowBits)); } else { // Advertise that we support this option - yield return ClientWebSocketDeflateConstants.ClientMaxWindowBits; + builder.Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits); } if (!options.ClientContextTakeover) { - yield return ClientWebSocketDeflateConstants.ClientNoContextTakeover; + builder.Append("; ").Append(ClientWebSocketDeflateConstants.ClientNoContextTakeover); } + builder.Append("; "); + if (options.ServerMaxWindowBits != 15) { - yield return $"{ClientWebSocketDeflateConstants.ServerMaxWindowBits}={options.ServerMaxWindowBits}"; + builder.Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits).Append('=') + .Append(GetWindowBitsString(options.ServerMaxWindowBits)); } else { // Advertise that we support this option - yield return ClientWebSocketDeflateConstants.ServerMaxWindowBits; + builder.Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits); } if (!options.ServerContextTakeover) { - yield return ClientWebSocketDeflateConstants.ServerNoContextTakeover; + builder.Append("; ").Append(ClientWebSocketDeflateConstants.ServerNoContextTakeover); } + + Debug.Assert(builder.Length <= ClientWebSocketDeflateConstants.MaxExtensionLength); + return builder.ToString(); } } } From 4fa65f4e39b5639b1a0f382e3d0e663081ddcda0 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 26 Mar 2021 15:19:00 +0200 Subject: [PATCH 21/47] PR feedback for WebSocket. --- .../Compression/WebSocketDeflater.cs | 51 ++---- .../Compression/WebSocketInflater.cs | 166 ++++++------------ .../System/Net/WebSockets/ManagedWebSocket.cs | 9 +- .../WebSockets/WebSocketCreationOptions.cs | 3 + 4 files changed, 75 insertions(+), 154 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 626240df7a5430..7de66c47b0ae43 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -181,58 +181,43 @@ private unsafe int UnsafeFlush(Span output, out bool needsMoreBuffer) private static ErrorCode Deflate(ZLibStreamHandle stream, FlushCode flushCode) { - ErrorCode errorCode; - try - { - errorCode = stream.Deflate(flushCode); - } - catch (Exception cause) - { - throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); - } + ErrorCode errorCode = stream.Deflate(flushCode); - switch (errorCode) + if (errorCode is ErrorCode.Ok or ErrorCode.StreamEnd or ErrorCode.BufError) { - case ErrorCode.Ok: - case ErrorCode.StreamEnd: - case ErrorCode.BufError: - return errorCode; - - case ErrorCode.StreamError: - throw new WebSocketException(SR.ZLibErrorInconsistentStream); - - default: - throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); + return errorCode; } + + string message = errorCode == ErrorCode.StreamError + ? SR.ZLibErrorInconsistentStream + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); } [MemberNotNull(nameof(_stream))] private void Initialize() { Debug.Assert(_stream is null); - - var compressionLevel = CompressionLevel.DefaultCompression; - var memLevel = Deflate_DefaultMemLevel; - var strategy = CompressionStrategy.DefaultStrategy; - ErrorCode errorCode; try { - errorCode = CreateZLibStreamForDeflate(out _stream, compressionLevel, _windowBits, memLevel, strategy); + errorCode = CreateZLibStreamForDeflate(out _stream, + level: CompressionLevel.DefaultCompression, + windowBits: _windowBits, + memLevel: Deflate_DefaultMemLevel, + strategy: CompressionStrategy.DefaultStrategy); } catch (Exception cause) { throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); } - switch (errorCode) + if (errorCode != ErrorCode.Ok) { - case ErrorCode.Ok: - return; - case ErrorCode.MemError: - throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); - default: - throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); + string message = errorCode == ErrorCode.MemError + ? SR.ZLibErrorNotEnoughMemory + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index 7cae72a0aafdf5..8255a69d106a67 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -3,8 +3,6 @@ using System.Buffers; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Runtime.InteropServices; using static System.IO.Compression.ZLibNative; namespace System.Net.WebSockets.Compression @@ -106,28 +104,19 @@ public unsafe void Inflate(int totalBytesReceived, Span output, bool flush _available += totalBytesReceived; } - if (_available > 0) - { - if (_stream is null) - { - Initialize(); - } + _stream ??= Initialize(_windowBits); + if (_available > 0 && output.Length > 0) + { int consumed; - fixed (byte* fixedInput = _buffer) - fixed (byte* fixedOutput = &MemoryMarshal.GetReference(output)) + fixed (byte* bufferPtr = _buffer) { - _stream.NextIn = (IntPtr)(fixedInput + _position); + _stream.NextIn = (IntPtr)(bufferPtr + _position); _stream.AvailIn = (uint)_available; - _stream.NextOut = (IntPtr)fixedOutput; - _stream.AvailOut = (uint)output.Length; - - Inflate(_stream); - + written = Inflate(_stream, output); consumed = _available - (int)_stream.AvailIn; - written = output.Length - (int)_stream.AvailOut; } _position += consumed; @@ -142,16 +131,7 @@ public unsafe void Inflate(int totalBytesReceived, Span output, bool flush if (_available == 0) { ReleaseBuffer(); - - if (flush) - { - Finished = Flush(output.Slice(written), out int byteCount); - written += byteCount; - } - else - { - Finished = true; - } + Finished = flush ? Flush(output, ref written) : true; } else { @@ -163,82 +143,55 @@ public unsafe void Inflate(int totalBytesReceived, Span output, bool flush /// Finishes the decoding by flushing any outstanding data to the output. /// /// true if the flush completed, false to indicate that there is more outstanding data. - private unsafe bool Flush(Span output, out int written) + private unsafe bool Flush(Span output, ref int written) { Debug.Assert(_stream is not null); Debug.Assert(_available == 0); if (_needsFlushMarker) { - fixed (byte* fixedInput = &MemoryMarshal.GetReference(FlushMarker)) - fixed (byte* fixedOutput = &MemoryMarshal.GetReference(output)) - { - _stream.NextIn = (IntPtr)fixedInput; - _stream.AvailIn = (uint)FlushMarkerLength; - - _stream.NextOut = (IntPtr)fixedOutput; - _stream.AvailOut = (uint)output.Length; - - Inflate(_stream); - - written = output.Length - (int)_stream.AvailOut; - } - _needsFlushMarker = false; - if (written < output.Length || IsFinished(_stream, out _remainingByte)) + // It's OK to use the flush marker like this, because it's pointer is unmovable. + fixed (byte* flushMarkerPtr = FlushMarker) { - OnFinished(); - return true; + _stream.NextIn = (IntPtr)flushMarkerPtr; + _stream.AvailIn = FlushMarkerLength; } } - written = 0; - - if (output.IsEmpty) + if (_remainingByte is not null) { - if (_remainingByte is not null) + if (output.Length == written) { return false; } - if (IsFinished(_stream, out _remainingByte)) - { - OnFinished(); - return true; - } + output[written] = _remainingByte.GetValueOrDefault(); + _remainingByte = null; + written += 1; } - else - { - if (_remainingByte is not null) - { - output[0] = _remainingByte.GetValueOrDefault(); - written = 1; - _remainingByte = null; - } + // If we have more space in the output, try to inflate + if (output.Length > written) + { written += Inflate(_stream, output[written..]); + } - if (written < output.Length || IsFinished(_stream, out _remainingByte)) + // After inflate, if we have more space in the output then it means that we + // have finished. Otherwise we need to manually check for more data. + if (written < output.Length || IsFinished(_stream, out _remainingByte)) + { + if (!_persisted) { - OnFinished(); - return true; + _stream.Dispose(); + _stream = null; } + return true; } return false; } - private void OnFinished() - { - Debug.Assert(_stream is not null); - - if (!_persisted) - { - _stream.Dispose(); - _stream = null; - } - } - private void ReleaseBuffer() { if (_buffer is not null) @@ -252,12 +205,6 @@ private void ReleaseBuffer() private static unsafe bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte) { - if (stream.AvailIn > 0) - { - remainingByte = null; - return false; - } - // There is no other way to make sure that we'e consumed all data // but to try to inflate again with at least one byte of output buffer. byte b; @@ -273,34 +220,24 @@ private static unsafe bool IsFinished(ZLibStreamHandle stream, out byte? remaini private static unsafe int Inflate(ZLibStreamHandle stream, Span destination) { - fixed (byte* bufPtr = &MemoryMarshal.GetReference(destination)) + Debug.Assert(destination.Length > 0); + ErrorCode errorCode; + + fixed (byte* bufPtr = destination) { stream.NextOut = (IntPtr)bufPtr; stream.AvailOut = (uint)destination.Length; - Inflate(stream); - return destination.Length - (int)stream.AvailOut; - } - } - - private static void Inflate(ZLibStreamHandle stream) - { - ErrorCode errorCode; - try - { errorCode = stream.Inflate(FlushCode.NoFlush); + + if (errorCode is ErrorCode.Ok or ErrorCode.StreamEnd or ErrorCode.BufError) + { + return destination.Length - (int)stream.AvailOut; + } } - catch (Exception cause) // could not load the Zlib DLL correctly - { - throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); - } + switch (errorCode) { - case ErrorCode.Ok: // progress has been made inflating - case ErrorCode.StreamEnd: // The end of the input stream has been reached - case ErrorCode.BufError: // No room in the output buffer - inflate() can be called again with more space to continue - break; - case ErrorCode.MemError: // Not enough memory to complete the operation throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); @@ -315,30 +252,31 @@ private static void Inflate(ZLibStreamHandle stream) } } - [MemberNotNull(nameof(_stream))] - private void Initialize() + private static ZLibStreamHandle Initialize(int windowBits) { - Debug.Assert(_stream is null); + ZLibStreamHandle stream; + ErrorCode errorCode; - ErrorCode error; try { - error = CreateZLibStreamForInflate(out _stream, _windowBits); + errorCode = CreateZLibStreamForInflate(out stream, windowBits); } catch (Exception exception) { throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); } - switch (error) + if (errorCode == ErrorCode.Ok) { - case ErrorCode.Ok: - return; - case ErrorCode.MemError: - throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); - default: - throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)error)); + return stream; } + + stream.Dispose(); + + string message = errorCode == ErrorCode.MemError + ? SR.ZLibErrorNotEnoughMemory + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 114c99ac767fd6..529d4a54b95bc4 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -822,7 +822,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo while (totalBytesReceived < limit) { - int numBytesRead = await _stream.ReadAsync( header.Compressed ? + int numBytesRead = await _stream.ReadAsync(header.Compressed ? _inflater!.Memory.Slice(totalBytesReceived, limit - totalBytesReceived) : payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived), cancellationToken).ConfigureAwait(false); @@ -872,13 +872,8 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo header.EndOfMessage); } } - catch (Exception exc) + catch (Exception exc) when (exc is not OperationCanceledException) { - if (exc is OperationCanceledException) - { - throw; - } - if (_state == WebSocketState.Aborted) { throw new OperationCanceledException(nameof(WebSocketState.Aborted), exc); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs index cfea9eece572db..7a0c94a9eba0e7 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs @@ -5,6 +5,9 @@ namespace System.Net.WebSockets { + /// + /// Options that control how a is created. + /// public sealed class WebSocketCreationOptions { private string? _subProtocol; From 9031be326897be1a598478a8605e230563cda17a Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 26 Mar 2021 15:35:50 +0200 Subject: [PATCH 22/47] Prefer Length over IsEmpty for spans. --- .../System/Net/WebSockets/Compression/WebSocketDeflater.cs | 2 +- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 7de66c47b0ae43..39ac2b2dfdcb43 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -87,7 +87,7 @@ private void DeflatePrivate(ReadOnlySpan payload, Span output, bool Initialize(); } - if (payload.IsEmpty) + if (payload.Length == 0) { consumed = 0; written = 0; diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 529d4a54b95bc4..b7022cce49ce09 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -538,7 +538,7 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM /// Writes a frame into the send buffer, which can then be sent over the network. private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, ReadOnlySpan payloadBuffer) { - if (_deflater is not null && !payloadBuffer.IsEmpty) + if (_deflater is not null && payloadBuffer.Length > 0) { payloadBuffer = _deflater.Deflate(payloadBuffer, opcode == MessageOpcode.Continuation, endOfMessage); } @@ -566,7 +566,7 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read } // Write the payload - if (!payloadBuffer.IsEmpty) + if (payloadBuffer.Length > 0) { payloadBuffer.CopyTo(new Span(_sendBuffer, headerLength, payloadLength)); From 4e70073cd91933c8f14f19dfc64822c46ac05d06 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 29 Mar 2021 15:26:36 +0300 Subject: [PATCH 23/47] Fixed wrong check. --- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index b7022cce49ce09..9fe48acff9cc53 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -555,13 +555,13 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read { // The server doesn't send a mask, so the mask offset returned by WriteHeader // is actually the end of the header. - headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false, compressed: _inflater is not null); + headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false, compressed: _deflater is not null); } else { // We need to know where the mask starts so that we can use the mask to manipulate the payload data, // and we need to know the total length for sending it on the wire. - maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true, compressed: _inflater is not null); + maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true, compressed: _deflater is not null); headerLength = maskOffset.GetValueOrDefault() + MaskLength; } From b7743ef41f21c29381091b3895b97fc862509888 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Wed, 31 Mar 2021 13:55:31 +0300 Subject: [PATCH 24/47] New API that allows disabling compression per message basis and renaming DeflateOptions properties to DangerousDeflateOptions. Also addressed some PR feedback and added new tests. --- .../ref/System.Net.WebSockets.Client.cs | 2 +- .../ClientWebSocketOptions.cs | 2 +- .../Net/WebSockets/ClientWebSocketOptions.cs | 2 +- .../Net/WebSockets/WebSocketHandle.Managed.cs | 32 ++----- .../tests/DeflateTests.cs | 14 ++-- .../ref/System.Net.WebSockets.cs | 10 ++- .../src/System.Net.WebSockets.csproj | 3 +- .../Compression/WebSocketDeflater.cs | 2 - .../System/Net/WebSockets/ManagedWebSocket.cs | 70 +++++++++++----- .../src/System/Net/WebSockets/WebSocket.cs | 5 ++ .../WebSockets/WebSocketCreationOptions.cs | 2 +- .../Net/WebSockets/WebSocketDeflateOptions.cs | 31 ++++--- .../Net/WebSockets/WebSocketMessageFlags.cs | 27 ++++++ .../tests/WebSocketDeflateTests.cs | 84 +++++++++++++++---- 14 files changed, 196 insertions(+), 90 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketMessageFlags.cs diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs index 660a2c5fbe7485..96cecd9e30f471 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs @@ -36,7 +36,7 @@ internal ClientWebSocketOptions() { } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] - public System.Net.WebSockets.WebSocketDeflateOptions? DeflateOptions { get { throw null; } set { } } + public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.IWebProxy? Proxy { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs index 2ed5c527421c9f..79dd04229b9c33 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs @@ -101,7 +101,7 @@ public TimeSpan KeepAliveInterval } [UnsupportedOSPlatform("browser")] - public WebSocketDeflateOptions? DeflateOptions + public WebSocketDeflateOptions? DangerousDeflateOptions { get => throw new PlatformNotSupportedException(); set => throw new PlatformNotSupportedException(); diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index 00427ea9e2eb10..b24abcd368b9fb 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -155,7 +155,7 @@ public TimeSpan KeepAliveInterval /// will be created with compression enabled by default for all messages. /// [UnsupportedOSPlatform("browser")] - public WebSocketDeflateOptions? DeflateOptions { get; set; } + public WebSocketDeflateOptions? DangerousDeflateOptions { get; set; } internal int ReceiveBufferSize => _receiveBufferSize; internal ArraySegment? Buffer => _buffer; diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index 2b5424a492f985..b6e16077706481 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -185,25 +185,20 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } // Because deflate options are negotiated we need a new object - WebSocketDeflateOptions? deflateOptions = null; + WebSocketDeflateOptions? negotiatedDeflateOptions = null; - if (options.DeflateOptions is not null && response.Headers.TryGetValues(HttpKnownHeaderNames.SecWebSocketExtensions, out IEnumerable? extensions)) + if (options.DangerousDeflateOptions is not null && response.Headers.TryGetValues(HttpKnownHeaderNames.SecWebSocketExtensions, out IEnumerable? extensions)) { foreach (ReadOnlySpan extension in extensions) { if (extension.TrimStart().StartsWith(ClientWebSocketDeflateConstants.Extension)) { - deflateOptions = ParseDeflateOptions(extension, options.DeflateOptions); + negotiatedDeflateOptions = ParseDeflateOptions(extension, options.DangerousDeflateOptions); break; } } } - // Store the negotiated deflate options in the original options, because - // otherwise there is no way of clients to actually check whether we are using - // per message deflate or not. - options.DeflateOptions = deflateOptions; - if (response.Content is null) { throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); @@ -218,7 +213,7 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli IsServer = false, SubProtocol = subprotocol, KeepAliveInterval = options.KeepAliveInterval, - DeflateOptions = deflateOptions + DangerousDeflateOptions = negotiatedDeflateOptions }); } catch (Exception exc) @@ -329,21 +324,10 @@ private static void AddWebSocketHeaders(HttpRequestMessage request, string secKe { request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketProtocol, string.Join(", ", options.RequestedSubProtocols)); } - if (options.DeflateOptions is not null) + if (options.DangerousDeflateOptions is not null) { - request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketExtensions, GetDeflateOptions(options.DeflateOptions)); + request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketExtensions, GetDeflateOptions(options.DangerousDeflateOptions)); - static string GetWindowBitsString(int value) => value switch - { - 9 => "9", - 10 => "10", - 11 => "11", - 12 => "12", - 13 => "13", - 14 => "14", - 15 => "15", - _ => value.ToString(CultureInfo.InvariantCulture) - }; static string GetDeflateOptions(WebSocketDeflateOptions options) { var builder = new StringBuilder(ClientWebSocketDeflateConstants.MaxExtensionLength); @@ -352,7 +336,7 @@ static string GetDeflateOptions(WebSocketDeflateOptions options) if (options.ClientMaxWindowBits != 15) { builder.Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits).Append('=') - .Append(GetWindowBitsString(options.ClientMaxWindowBits)); + .Append(options.ClientMaxWindowBits.ToString(CultureInfo.InvariantCulture)); } else { @@ -370,7 +354,7 @@ static string GetDeflateOptions(WebSocketDeflateOptions options) if (options.ServerMaxWindowBits != 15) { builder.Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits).Append('=') - .Append(GetWindowBitsString(options.ServerMaxWindowBits)); + .Append(options.ServerMaxWindowBits.ToString(CultureInfo.InvariantCulture)); } else { diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs index a182830426c307..af7827aec25f4f 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.IO; using System.Net.Test.Common; using System.Text; using System.Threading; @@ -37,7 +36,7 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => using var client = new ClientWebSocket(); using var cancellation = new CancellationTokenSource(TimeOutMilliseconds); - client.Options.DeflateOptions = new WebSocketDeflateOptions + client.Options.DangerousDeflateOptions = new WebSocketDeflateOptions { ClientMaxWindowBits = clientWindowBits, ClientContextTakeover = clientContextTakeover, @@ -47,11 +46,12 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => await client.ConnectAsync(uri, cancellation.Token); - Assert.NotNull(client.Options.DeflateOptions); - Assert.Equal(clientWindowBits - 1, client.Options.DeflateOptions.ClientMaxWindowBits); - Assert.Equal(clientContextTakeover, client.Options.DeflateOptions.ClientContextTakeover); - Assert.Equal(serverWindowBits - 1, client.Options.DeflateOptions.ServerMaxWindowBits); - Assert.Equal(serverContextTakover, client.Options.DeflateOptions.ServerContextTakeover); + // Uncomment this if we expose DangerousDeflateOptions directly in the websocket to represent the + // negotiated settings. Otherise we can't verify if compression is negotiated successfully. + // Assert.Equal(clientWindowBits - 1, client.DangerousDeflateOptions.ClientMaxWindowBits); + // Assert.Equal(clientContextTakeover, client.DangerousDeflateOptions.ClientContextTakeover); + // Assert.Equal(serverWindowBits - 1, client.DangerousDeflateOptions.ServerMaxWindowBits); + // Assert.Equal(serverContextTakover, client.DangerousDeflateOptions.ServerContextTakeover); }, server => server.AcceptConnectionAsync(async connection => { var extensionsReply = CreateDeflateOptionsHeader(new WebSocketDeflateOptions diff --git a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs index e93d55aa5fea42..32ebf5eb1e804e 100644 --- a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs +++ b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs @@ -42,6 +42,7 @@ protected WebSocket() { } public static void RegisterPrefixes() { } public abstract System.Threading.Tasks.Task SendAsync(System.ArraySegment buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken); public virtual System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken) { throw null; } + public virtual System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory buffer, System.Net.WebSockets.WebSocketMessageType messageType, System.Net.WebSockets.WebSocketMessageFlags messageFlags, System.Threading.CancellationToken cancellationToken) { throw null; } protected static void ThrowOnInvalidState(System.Net.WebSockets.WebSocketState state, params System.Net.WebSockets.WebSocketState[] validStates) { } } public enum WebSocketCloseStatus @@ -137,7 +138,7 @@ public sealed partial class WebSocketCreationOptions public bool IsServer { get { throw null; } set { } } public string? SubProtocol { get { throw null; } set { } } public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } - public System.Net.WebSockets.WebSocketDeflateOptions? DeflateOptions { get { throw null; } set { } } + public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } } } public sealed partial class WebSocketDeflateOptions { @@ -146,4 +147,11 @@ public sealed partial class WebSocketDeflateOptions public int ServerMaxWindowBits { get { throw null; } set { } } public bool ServerContextTakeover { get { throw null; } set { } } } + [Flags] + public enum WebSocketMessageFlags + { + None = 0, + EndOfMessage = 1, + DisableCompression = 2 + } } diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index 577adb80ecfe95..d8df225180aefc 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -1,4 +1,4 @@ - + True $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser @@ -16,6 +16,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 39ac2b2dfdcb43..ae5c2490a0530e 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -80,8 +80,6 @@ public ReadOnlySpan Deflate(ReadOnlySpan payload, bool continuation, private void DeflatePrivate(ReadOnlySpan payload, Span output, bool continuation, bool endOfMessage, out int consumed, out int written, out bool needsMoreOutput) { - Debug.Assert(!continuation || _stream is not null, "Invalid state. The stream should not be null in continuations."); - if (_stream is null) { Initialize(); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 9fe48acff9cc53..d2262f15edebd3 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -126,6 +126,10 @@ internal sealed partial class ManagedWebSocket : WebSocket /// private bool _lastSendWasFragment; /// + /// Whether the last SendAsync had flag set. + /// + private bool _lastSendHadDisableCompression; + /// /// The task returned from the last ReceiveAsync(ArraySegment, ...) operation to not complete synchronously. /// If this is not null and not completed when a subsequent ReceiveAsync is issued, an exception occurs. /// @@ -210,7 +214,7 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim internal ManagedWebSocket(Stream stream, WebSocketCreationOptions options) : this(stream, options.IsServer, options.SubProtocol, options.KeepAliveInterval) { - var deflateOptions = options.DeflateOptions; + var deflateOptions = options.DangerousDeflateOptions; if (deflateOptions is not null) { @@ -270,10 +274,10 @@ public override Task SendAsync(ArraySegment buffer, WebSocketMessageType m WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); - return SendPrivateAsync(buffer, messageType, endOfMessage, cancellationToken).AsTask(); + return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken).AsTask(); } - private ValueTask SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + private ValueTask SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) { if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) { @@ -292,13 +296,25 @@ private ValueTask SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessage return new ValueTask(Task.FromException(exc)); } - MessageOpcode opcode = - _lastSendWasFragment ? MessageOpcode.Continuation : - messageType == WebSocketMessageType.Binary ? MessageOpcode.Binary : - MessageOpcode.Text; + bool endOfMessage = messageFlags.HasFlag(WebSocketMessageFlags.EndOfMessage); + bool disableCompression; + MessageOpcode opcode; - ValueTask t = SendFrameAsync(opcode, endOfMessage, buffer, cancellationToken); + if (_lastSendWasFragment) + { + disableCompression = _lastSendHadDisableCompression; + opcode = MessageOpcode.Continuation; + } + else + { + opcode = messageType == WebSocketMessageType.Binary ? MessageOpcode.Binary : MessageOpcode.Text; + disableCompression = messageFlags.HasFlag(WebSocketMessageFlags.DisableCompression); + } + + ValueTask t = SendFrameAsync(opcode, endOfMessage, disableCompression, buffer, cancellationToken); _lastSendWasFragment = !endOfMessage; + _lastSendHadDisableCompression = disableCompression; + return t; } @@ -372,7 +388,12 @@ public override void Abort() public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { - return SendPrivateAsync(buffer, messageType, endOfMessage, cancellationToken); + return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken); + } + + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) + { + return SendPrivateAsync(buffer, messageType, messageFlags, cancellationToken); } public override ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) @@ -427,9 +448,10 @@ private Task ValidateAndReceiveAsync(Task receiveTask, byte[] buffer, Cancellati /// Sends a websocket frame to the network. /// The opcode for the message. /// The value of the FIN bit for the message. - /// The buffer containing the payload data fro the message. + /// Disables compression for the message. + /// The buffer containing the payload data from the message. /// The CancellationToken to use to cancel the websocket. - private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) + private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) { // If a cancelable cancellation token was provided, that would require registering with it, which means more state we have to // pass around (the CancellationTokenRegistration), so if it is cancelable, just immediately go to the fallback path. @@ -438,15 +460,16 @@ private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, ReadOn #pragma warning disable CA1416 // Validate platform compatibility, will not wait because timeout equals 0 return cancellationToken.CanBeCanceled || !_sendFrameAsyncLock.Wait(0, default) ? #pragma warning restore CA1416 - SendFrameFallbackAsync(opcode, endOfMessage, payloadBuffer, cancellationToken) : - SendFrameLockAcquiredNonCancelableAsync(opcode, endOfMessage, payloadBuffer); + SendFrameFallbackAsync(opcode, endOfMessage, disableCompression, payloadBuffer, cancellationToken) : + SendFrameLockAcquiredNonCancelableAsync(opcode, endOfMessage, disableCompression, payloadBuffer); } /// Sends a websocket frame to the network. The caller must hold the sending lock. /// The opcode for the message. /// The value of the FIN bit for the message. + /// Disables compression for the message. /// The buffer containing the payload data fro the message. - private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer) + private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer) { Debug.Assert(_sendFrameAsyncLock.CurrentCount == 0, "Caller should hold the _sendFrameAsyncLock"); @@ -457,7 +480,7 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, try { // Write the payload synchronously to the buffer, then write that buffer out to the network. - int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span); + int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, disableCompression, payloadBuffer.Span); writeTask = _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes)); // If the operation happens to complete synchronously (or, more specifically, by @@ -511,12 +534,12 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask) } } - private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) + private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) { await _sendFrameAsyncLock.WaitAsync(cancellationToken).ConfigureAwait(false); try { - int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span); + int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, disableCompression, payloadBuffer.Span); using (cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this)) { await _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes), cancellationToken).ConfigureAwait(false); @@ -536,9 +559,9 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM } /// Writes a frame into the send buffer, which can then be sent over the network. - private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, ReadOnlySpan payloadBuffer) + private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlySpan payloadBuffer) { - if (_deflater is not null && payloadBuffer.Length > 0) + if (_deflater is not null && payloadBuffer.Length > 0 && !disableCompression) { payloadBuffer = _deflater.Deflate(payloadBuffer, opcode == MessageOpcode.Continuation, endOfMessage); } @@ -555,13 +578,13 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read { // The server doesn't send a mask, so the mask offset returned by WriteHeader // is actually the end of the header. - headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false, compressed: _deflater is not null); + headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false, compressed: _deflater is not null && !disableCompression); } else { // We need to know where the mask starts so that we can use the mask to manipulate the payload data, // and we need to know the total length for sending it on the wire. - maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true, compressed: _deflater is not null); + maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true, compressed: _deflater is not null && !disableCompression); headerLength = maskOffset.GetValueOrDefault() + MaskLength; } @@ -595,7 +618,7 @@ private void SendKeepAliveFrameAsync() // This exists purely to keep the connection alive; don't wait for the result, and ignore any failures. // The call will handle releasing the lock. We send a pong rather than ping, since it's allowed by // the RFC as a unidirectional heartbeat and we're not interested in waiting for a response. - ValueTask t = SendFrameLockAcquiredNonCancelableAsync(MessageOpcode.Pong, true, ReadOnlyMemory.Empty); + ValueTask t = SendFrameLockAcquiredNonCancelableAsync(MessageOpcode.Pong, endOfMessage: true, disableCompression: true, ReadOnlyMemory.Empty); if (t.IsCompletedSuccessfully) { t.GetAwaiter().GetResult(); @@ -1025,6 +1048,7 @@ private async ValueTask HandleReceivedPingPongAsync(MessageHeader header, Cancel await SendFrameAsync( MessageOpcode.Pong, endOfMessage: true, + disableCompression: true, _receiveBuffer.Slice(_receiveBufferOffset, (int)header.PayloadLength), cancellationToken).ConfigureAwait(false); } @@ -1320,7 +1344,7 @@ private async ValueTask SendCloseFrameAsync(WebSocketCloseStatus closeStatus, st buffer[0] = (byte)(closeStatusValue >> 8); buffer[1] = (byte)(closeStatusValue & 0xFF); - await SendFrameAsync(MessageOpcode.Close, true, new Memory(buffer, 0, count), cancellationToken).ConfigureAwait(false); + await SendFrameAsync(MessageOpcode.Close, endOfMessage: true, disableCompression: true, new Memory(buffer, 0, count), cancellationToken).ConfigureAwait(false); } finally { diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index 0666fd4edd44f3..bd1e5c5186ab0c 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -58,6 +58,11 @@ public virtual ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessage new ValueTask(SendAsync(arraySegment, messageType, endOfMessage, cancellationToken)) : SendWithArrayPoolAsync(buffer, messageType, endOfMessage, cancellationToken); + public virtual ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags = WebSocketMessageFlags.EndOfMessage, CancellationToken cancellationToken = default) + { + return SendAsync(buffer, messageType, messageFlags.HasFlag(WebSocketMessageFlags.EndOfMessage), cancellationToken); + } + private async ValueTask SendWithArrayPoolAsync( ReadOnlyMemory buffer, WebSocketMessageType messageType, diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs index 7a0c94a9eba0e7..4f566cfaf71f56 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs @@ -55,6 +55,6 @@ public TimeSpan KeepAliveInterval /// /// The agreed upon options for per message deflate. /// - public WebSocketDeflateOptions? DeflateOptions { get; set; } + public WebSocketDeflateOptions? DangerousDeflateOptions { get; set; } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs index 6ddb82c1b0c7ed..a843f9ceb21f9e 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -12,6 +12,21 @@ namespace System.Net.WebSockets /// public sealed class WebSocketDeflateOptions { + /// + /// The minimum value for window bits that the websocket can support. + /// The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 + /// and https://zlib.net/manual.html). Quote from the manual: + /// "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". + /// We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream + /// and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. + /// + private const int MinWindowBits = 9; + + /// + /// The maximum value for window bits that the websocket can support. + /// + private const int MaxWindowBits = 15; + private int _clientMaxWindowBits = 15; private int _serverMaxWindowBits = 15; @@ -25,14 +40,10 @@ public int ClientMaxWindowBits get => _clientMaxWindowBits; set { - // The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 - // and https://zlib.net/manual.html). Quote from the manual "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". - // We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream - // and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. - if (value < 9 || value > 15) + if (value < MinWindowBits || value > MaxWindowBits) { throw new ArgumentOutOfRangeException(nameof(ClientMaxWindowBits), value, - SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 9, 15)); + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, MinWindowBits, MaxWindowBits)); } _clientMaxWindowBits = value; } @@ -55,14 +66,10 @@ public int ServerMaxWindowBits get => _serverMaxWindowBits; set { - // The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 - // and https://zlib.net/manual.html). Quote from the manual "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". - // We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream - // and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. - if (value < 9 || value > 15) + if (value < MinWindowBits || value > MaxWindowBits) { throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits), value, - SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 9, 15)); + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, MinWindowBits, MaxWindowBits)); } _serverMaxWindowBits = value; } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketMessageFlags.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketMessageFlags.cs new file mode 100644 index 00000000000000..9ce165d8de8433 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketMessageFlags.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.WebSockets +{ + /// + /// Flags for controlling how the should send a message. + /// + [Flags] + public enum WebSocketMessageFlags + { + /// + /// None + /// + None = 0, + + /// + /// Indicates that the data in "buffer" is the last part of a message. + /// + EndOfMessage = 1, + + /// + /// Disables compression for the message if compression has been enabled for the instance. + /// + DisableCompression = 2 + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index 6a63ab4c541c2a..6bfdcc96c6ffaa 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -43,7 +43,7 @@ public async Task ReceiveHelloWithContextTakeover() stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); using WebSocket websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { - DeflateOptions = new() + DangerousDeflateOptions = new() }); Memory buffer = new byte[5]; @@ -73,7 +73,7 @@ public async Task SendHelloWithContextTakeover() using WebSocket websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions { IsServer = true, - DeflateOptions = new() + DangerousDeflateOptions = new() }); await websocket.SendAsync(Encoding.UTF8.GetBytes("Hello"), WebSocketMessageType.Text, true, CancellationToken); @@ -87,13 +87,60 @@ public async Task SendHelloWithContextTakeover() Assert.Equal("C105F200110000", Convert.ToHexString(stream.NextAvailableBytes)); } + [Fact] + public async Task SendHelloWithDisableCompression() + { + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = new() + }); + + byte[] bytes = Encoding.UTF8.GetBytes("Hello"); + WebSocketMessageFlags flags = WebSocketMessageFlags.DisableCompression | WebSocketMessageFlags.EndOfMessage; + await websocket.SendAsync(bytes, WebSocketMessageType.Text, flags, CancellationToken); + + Assert.Equal(bytes.Length + 2, stream.Available); + Assert.True(stream.NextAvailableBytes.EndsWith(bytes)); + } + + [Fact] + public async Task SendHelloWithEmptyFrame() + { + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = new() + }); + + byte[] bytes = Encoding.UTF8.GetBytes("Hello"); + await websocket.SendAsync(Memory.Empty, WebSocketMessageType.Text, endOfMessage: false, CancellationToken); + await websocket.SendAsync(bytes, WebSocketMessageType.Text, endOfMessage: true, CancellationToken); + + using WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = false, + DangerousDeflateOptions = new() + }); + + ValueWebSocketReceiveResult result = await client.ReceiveAsync(bytes.AsMemory(), CancellationToken); + Assert.False(result.EndOfMessage); + Assert.Equal(0, result.Count); + + result = await client.ReceiveAsync(bytes.AsMemory(), CancellationToken); + Assert.True(result.EndOfMessage); + Assert.Equal(5, result.Count); + } + [Fact] public async Task ReceiveHelloWithoutContextTakeover() { WebSocketTestStream stream = new(); using WebSocket websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { - DeflateOptions = new() + DangerousDeflateOptions = new() { ClientContextTakeover = false } @@ -123,7 +170,7 @@ public async Task SendHelloWithoutContextTakeover() using WebSocket websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions { IsServer = true, - DeflateOptions = new() + DangerousDeflateOptions = new() { ClientContextTakeover = false } @@ -148,7 +195,7 @@ public async Task TwoDeflateBlocksInOneMessage() WebSocketTestStream stream = new(); using WebSocket websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { - DeflateOptions = new() + DangerousDeflateOptions = new() }); // The first 3 octets(0xf2 0x48 0x05) and the least significant two // bits of the 4th octet(0x00) constitute one DEFLATE block with @@ -185,7 +232,7 @@ public async Task Duplex(bool clientContextTakover, bool serverContextTakover) using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { IsServer = true, - DeflateOptions = new WebSocketDeflateOptions + DangerousDeflateOptions = new WebSocketDeflateOptions { ClientContextTakeover = clientContextTakover, ServerContextTakeover = serverContextTakover @@ -193,7 +240,7 @@ public async Task Duplex(bool clientContextTakover, bool serverContextTakover) }); using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions { - DeflateOptions = new WebSocketDeflateOptions + DangerousDeflateOptions = new WebSocketDeflateOptions { ClientContextTakeover = clientContextTakover, ServerContextTakeover = serverContextTakover @@ -205,7 +252,7 @@ public async Task Duplex(bool clientContextTakover, bool serverContextTakover) for (var i = 0; i < 10; ++i) { string message = $"Sending number {i} from server."; - await SendTextAsync(message, server); + await SendTextAsync(message, server, disableCompression: i % 2 == 0); ValueWebSocketReceiveResult result = await client.ReceiveAsync(buffer.AsMemory(), CancellationToken); @@ -218,7 +265,7 @@ public async Task Duplex(bool clientContextTakover, bool serverContextTakover) for (var i = 0; i < 10; ++i) { string message = $"Sending number {i} from client."; - await SendTextAsync(message, client); + await SendTextAsync(message, client, disableCompression: i % 2 == 0); ValueWebSocketReceiveResult result = await server.ReceiveAsync(buffer.AsMemory(), CancellationToken); @@ -237,14 +284,14 @@ public async Task LargeMessageSplitInMultipleFrames(int windowBits) using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { IsServer = true, - DeflateOptions = new() + DangerousDeflateOptions = new() { ClientMaxWindowBits = windowBits } }); using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions { - DeflateOptions = new() + DangerousDeflateOptions = new() { ClientMaxWindowBits = windowBits } @@ -323,11 +370,11 @@ public async Task ReceiveUncompressedMessageWhenCompressionEnabled() WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { IsServer = true, - DeflateOptions = null + DangerousDeflateOptions = null }); WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions { - DeflateOptions = new WebSocketDeflateOptions() + DangerousDeflateOptions = new WebSocketDeflateOptions() }); // Server sends uncompressed @@ -356,7 +403,7 @@ public async Task ReceiveInvalidCompressedData() WebSocketTestStream stream = new(); WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { - DeflateOptions = new WebSocketDeflateOptions() + DangerousDeflateOptions = new WebSocketDeflateOptions() }); stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); @@ -369,10 +416,15 @@ public async Task ReceiveInvalidCompressedData() Assert.Equal(WebSocketState.Aborted, client.State); } - private ValueTask SendTextAsync(string text, WebSocket websocket) + private ValueTask SendTextAsync(string text, WebSocket websocket, bool disableCompression = false) { + WebSocketMessageFlags flags = WebSocketMessageFlags.EndOfMessage; + if (disableCompression) + { + flags |= WebSocketMessageFlags.DisableCompression; + } byte[] bytes = Encoding.UTF8.GetBytes(text); - return websocket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, CancellationToken); + return websocket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, flags, CancellationToken); } private async Task ReceiveTextAsync(WebSocket websocket) From 49d00eb725b4e915959fef42683b254e9e54aa97 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Wed, 31 Mar 2021 14:38:55 +0300 Subject: [PATCH 25/47] Added comment explaining why GetReceiveResult will not result in boxing of a struct. --- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index d2262f15edebd3..62a01640fcedf6 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -924,6 +924,9 @@ private TResult GetReceiveResult(int count, WebSocketMessageType messag { if (typeof(TResult) == typeof(ValueWebSocketReceiveResult)) { + // Although it might seem that this will incur boxing of the struct, + // the JIT is smart enough to figure out it is unncessessary and will emit + // bytecode that returns the ValueWebSocketReceiveResult directly. return (TResult)(object)new ValueWebSocketReceiveResult(count, messageType, endOfMessage); } From 58a333877803588b2006b9ff0bfcd330cc981315 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Wed, 31 Mar 2021 16:28:34 +0300 Subject: [PATCH 26/47] Exposed DeflateReset and InflateReset methods from zlib. --- .../Common/src/Interop/Interop.zlib.cs | 6 +++++ .../src/System/IO/Compression/ZLibNative.cs | 16 ++++++++++++++ .../entrypoints.c | 2 ++ src/libraries/Native/AnyOS/zlib/pal_zlib.c | 22 +++++++++++++++++++ src/libraries/Native/AnyOS/zlib/pal_zlib.h | 16 ++++++++++++++ ...stem.IO.Compression.Native_unixexports.src | 2 ++ .../System.IO.Compression.Native.def | 2 ++ 7 files changed, 66 insertions(+) diff --git a/src/libraries/Common/src/Interop/Interop.zlib.cs b/src/libraries/Common/src/Interop/Interop.zlib.cs index 280c5558667eb9..ad517da4079ca2 100644 --- a/src/libraries/Common/src/Interop/Interop.zlib.cs +++ b/src/libraries/Common/src/Interop/Interop.zlib.cs @@ -20,6 +20,9 @@ internal static extern ZLibNative.ErrorCode DeflateInit2_( [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_Deflate")] internal static extern ZLibNative.ErrorCode Deflate(ref ZLibNative.ZStream stream, ZLibNative.FlushCode flush); + [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_DeflateReset")] + internal static extern ZLibNative.ErrorCode DeflateReset(ref ZLibNative.ZStream stream); + [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_DeflateEnd")] internal static extern ZLibNative.ErrorCode DeflateEnd(ref ZLibNative.ZStream stream); @@ -29,6 +32,9 @@ internal static extern ZLibNative.ErrorCode DeflateInit2_( [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_Inflate")] internal static extern ZLibNative.ErrorCode Inflate(ref ZLibNative.ZStream stream, ZLibNative.FlushCode flush); + [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_InflateReset")] + internal static extern ZLibNative.ErrorCode InflateReset(ref ZLibNative.ZStream stream); + [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_InflateEnd")] internal static extern ZLibNative.ErrorCode InflateEnd(ref ZLibNative.ZStream stream); diff --git a/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs index 8118aeba0ecb82..9daf208692faea 100644 --- a/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs +++ b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs @@ -281,6 +281,14 @@ public ErrorCode Deflate(FlushCode flush) } + public ErrorCode DeflateReset() + { + EnsureNotDisposed(); + EnsureState(State.InitializedForDeflate); + return Interop.zlib.DeflateReset(ref _zStream); + } + + public ErrorCode DeflateEnd() { EnsureNotDisposed(); @@ -313,6 +321,14 @@ public ErrorCode Inflate(FlushCode flush) } + public ErrorCode InflateReset() + { + EnsureNotDisposed(); + EnsureState(State.InitializedForInflate); + return Interop.zlib.InflateReset(ref _zStream); + } + + public ErrorCode InflateEnd() { EnsureNotDisposed(); diff --git a/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c b/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c index b194b978debe23..f363a91eb1add3 100644 --- a/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c +++ b/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c @@ -28,9 +28,11 @@ static const Entry s_compressionNative[] = DllImportEntry(CompressionNative_Crc32) DllImportEntry(CompressionNative_Deflate) DllImportEntry(CompressionNative_DeflateEnd) + DllImportEntry(CompressionNative_DeflateReset) DllImportEntry(CompressionNative_DeflateInit2_) DllImportEntry(CompressionNative_Inflate) DllImportEntry(CompressionNative_InflateEnd) + DllImportEntry(CompressionNative_InflateReset) DllImportEntry(CompressionNative_InflateInit2_) }; diff --git a/src/libraries/Native/AnyOS/zlib/pal_zlib.c b/src/libraries/Native/AnyOS/zlib/pal_zlib.c index 2c399639d0fa92..aa4dcdca8a29e8 100644 --- a/src/libraries/Native/AnyOS/zlib/pal_zlib.c +++ b/src/libraries/Native/AnyOS/zlib/pal_zlib.c @@ -135,6 +135,17 @@ int32_t CompressionNative_Deflate(PAL_ZStream* stream, int32_t flush) return result; } +int32_t CompressionNative_DeflateReset(PAL_ZStream* stream) +{ + assert(stream != NULL); + + z_stream* zStream = GetCurrentZStream(stream); + int32_t result = deflateReset(zStream); + TransferStateToPalZStream(zStream, stream); + + return result; +} + int32_t CompressionNative_DeflateEnd(PAL_ZStream* stream) { assert(stream != NULL); @@ -172,6 +183,17 @@ int32_t CompressionNative_Inflate(PAL_ZStream* stream, int32_t flush) return result; } +int32_t CompressionNative_InflateReset(PAL_ZStream* stream) +{ + assert(stream != NULL); + + z_stream* zStream = GetCurrentZStream(stream); + int32_t result = inflateReset(zStream); + TransferStateToPalZStream(zStream, stream); + + return result; +} + int32_t CompressionNative_InflateEnd(PAL_ZStream* stream) { assert(stream != NULL); diff --git a/src/libraries/Native/AnyOS/zlib/pal_zlib.h b/src/libraries/Native/AnyOS/zlib/pal_zlib.h index b317091b843f62..1eb1baa6b3846b 100644 --- a/src/libraries/Native/AnyOS/zlib/pal_zlib.h +++ b/src/libraries/Native/AnyOS/zlib/pal_zlib.h @@ -95,6 +95,14 @@ Returns a PAL_ErrorCode indicating success or an error number on failure. */ FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_Deflate(PAL_ZStream* stream, int32_t flush); +/* +This function is equivalent to DeflateEnd followed by DeflateInit, but does not free and reallocate +the internal compression state. The stream will leave the compression level and any other attributes that may have been set unchanged. + +Returns a PAL_ErrorCode indicating success or an error number on failure. +*/ +FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_DeflateReset(PAL_ZStream* stream); + /* All dynamically allocated data structures for this stream are freed. @@ -117,6 +125,14 @@ Returns a PAL_ErrorCode indicating success or an error number on failure. */ FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_Inflate(PAL_ZStream* stream, int32_t flush); +/* +This function is equivalent to InflateEnd followed by InflateInit, but does not free and reallocate +the internal decompression state. The The stream will keep attributes that may have been set by InflateInit. + +Returns a PAL_ErrorCode indicating success or an error number on failure. +*/ +FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_InflateReset(PAL_ZStream* stream); + /* All dynamically allocated data structures for this stream are freed. diff --git a/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src b/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src index 08dd1700a52f21..2ac827035f271b 100644 --- a/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src +++ b/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src @@ -15,7 +15,9 @@ BrotliEncoderSetParameter CompressionNative_Crc32 CompressionNative_Deflate CompressionNative_DeflateEnd +CompressionNative_DeflateReset CompressionNative_DeflateInit2_ CompressionNative_Inflate CompressionNative_InflateEnd +CompressionNative_InflateReset CompressionNative_InflateInit2_ diff --git a/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def b/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def index 6821d0e538f51f..aecd0dd974618a 100644 --- a/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def +++ b/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def @@ -15,7 +15,9 @@ EXPORTS CompressionNative_Crc32 CompressionNative_Deflate CompressionNative_DeflateEnd + CompressionNative_DeflateReset CompressionNative_DeflateInit2_ CompressionNative_Inflate CompressionNative_InflateEnd + CompressionNative_InflateReset CompressionNative_InflateInit2_ From 56fd0ef52073fe45da6a3768b102326278d7dc24 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Wed, 31 Mar 2021 19:08:15 +0300 Subject: [PATCH 27/47] Pooling zlib streams. --- .../src/System.Net.WebSockets.csproj | 2 + .../Compression/WebSocketDeflater.cs | 50 +--- .../Compression/WebSocketInflater.cs | 46 +-- .../WebSockets/Compression/ZLibStreamPool.cs | 283 ++++++++++++++++++ .../Net/WebSockets/WebSocketDeflateOptions.cs | 4 +- 5 files changed, 309 insertions(+), 76 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index d8df225180aefc..3419171de4cc71 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -7,6 +7,7 @@ + @@ -42,6 +43,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index ae5c2490a0530e..2e6c51aab82ef3 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -13,22 +13,26 @@ namespace System.Net.WebSockets.Compression /// internal sealed class WebSocketDeflater : IDisposable { + private readonly ZLibStreamPool _streamPool; private ZLibStreamHandle? _stream; - private readonly int _windowBits; private readonly bool _persisted; private byte[]? _buffer; internal WebSocketDeflater(int windowBits, bool persisted) { - Debug.Assert(windowBits >= 9 && windowBits <= 15); - - // We use negative window bits in order to produce raw deflate data - _windowBits = -windowBits; + _streamPool = ZLibStreamPool.GetOrCreate(windowBits); _persisted = persisted; } - public void Dispose() => _stream?.Dispose(); + public void Dispose() + { + if (_stream is not null) + { + _streamPool.ReturnDeflater(_stream); + _stream = null; + } + } public void ReleaseBuffer() { @@ -80,10 +84,7 @@ public ReadOnlySpan Deflate(ReadOnlySpan payload, bool continuation, private void DeflatePrivate(ReadOnlySpan payload, Span output, bool continuation, bool endOfMessage, out int consumed, out int written, out bool needsMoreOutput) { - if (_stream is null) - { - Initialize(); - } + _stream ??= _streamPool.GetDeflater(); if (payload.Length == 0) { @@ -119,7 +120,7 @@ private void DeflatePrivate(ReadOnlySpan payload, Span output, bool if (endOfMessage && !_persisted) { - _stream.Dispose(); + _streamPool.ReturnDeflater(_stream); _stream = null; } } @@ -191,32 +192,5 @@ private static ErrorCode Deflate(ZLibStreamHandle stream, FlushCode flushCode) : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); throw new WebSocketException(message); } - - [MemberNotNull(nameof(_stream))] - private void Initialize() - { - Debug.Assert(_stream is null); - ErrorCode errorCode; - try - { - errorCode = CreateZLibStreamForDeflate(out _stream, - level: CompressionLevel.DefaultCompression, - windowBits: _windowBits, - memLevel: Deflate_DefaultMemLevel, - strategy: CompressionStrategy.DefaultStrategy); - } - catch (Exception cause) - { - throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); - } - - if (errorCode != ErrorCode.Ok) - { - string message = errorCode == ErrorCode.MemError - ? SR.ZLibErrorNotEnoughMemory - : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); - throw new WebSocketException(message); - } - } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index 8255a69d106a67..dc5e7dd80568b6 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -15,12 +15,12 @@ internal sealed class WebSocketInflater : IDisposable internal const int FlushMarkerLength = 4; internal static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; + private readonly ZLibStreamPool _streamPool; private ZLibStreamHandle? _stream; - private readonly int _windowBits; private readonly bool _persisted; /// - /// There is no way of knowing, when decoding data, if the underlying deflater + /// There is no way of knowing, when decoding data, if the underlying inflater /// has flushed all outstanding data to consumer other than to provide a buffer /// and see whether any bytes are written. There are cases when the consumers /// provide a buffer exactly the size of the uncompressed data and in this case @@ -48,10 +48,7 @@ internal sealed class WebSocketInflater : IDisposable internal WebSocketInflater(int windowBits, bool persisted) { - Debug.Assert(windowBits >= 9 && windowBits <= 15); - - // We use negative window bits to instruct deflater to expect raw deflate data - _windowBits = -windowBits; + _streamPool = ZLibStreamPool.GetOrCreate(windowBits); _persisted = persisted; } @@ -66,7 +63,11 @@ internal WebSocketInflater(int windowBits, bool persisted) public void Dispose() { - _stream?.Dispose(); + if (_stream is not null) + { + _streamPool.ReturnInflater(_stream); + _stream = null; + } ReleaseBuffer(); } @@ -104,7 +105,7 @@ public unsafe void Inflate(int totalBytesReceived, Span output, bool flush _available += totalBytesReceived; } - _stream ??= Initialize(_windowBits); + _stream ??= _streamPool.GetInflater(); if (_available > 0 && output.Length > 0) { @@ -183,7 +184,7 @@ private unsafe bool Flush(Span output, ref int written) { if (!_persisted) { - _stream.Dispose(); + _streamPool.ReturnInflater(_stream); _stream = null; } return true; @@ -251,32 +252,5 @@ private static unsafe int Inflate(ZLibStreamHandle stream, Span destinatio throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); } } - - private static ZLibStreamHandle Initialize(int windowBits) - { - ZLibStreamHandle stream; - ErrorCode errorCode; - - try - { - errorCode = CreateZLibStreamForInflate(out stream, windowBits); - } - catch (Exception exception) - { - throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); - } - - if (errorCode == ErrorCode.Ok) - { - return stream; - } - - stream.Dispose(); - - string message = errorCode == ErrorCode.MemError - ? SR.ZLibErrorNotEnoughMemory - : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); - throw new WebSocketException(message); - } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs new file mode 100644 index 00000000000000..9e5a8bc466da08 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs @@ -0,0 +1,283 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using static System.IO.Compression.ZLibNative; + +namespace System.Net.WebSockets.Compression +{ + internal sealed class ZLibStreamPool + { + private static readonly ZLibStreamPool?[] s_pools + = new ZLibStreamPool[WebSocketDeflateOptions.MaxWindowBits - WebSocketDeflateOptions.MinWindowBits + 1]; + + /// + /// The maximum number of cached items. + /// + private static readonly int MaximumRetained = Environment.ProcessorCount * 2; + + /// + /// The amount of time after which a cached item will be removed. + /// + private static readonly TimeSpan CacheTimeout = TimeSpan.FromMinutes(1); + + /// + /// The in ticks. + /// + private static readonly long CacheTimeoutRawTicks = (long)(CacheTimeout.Ticks / (double)TimeSpan.TicksPerSecond * Stopwatch.Frequency); + + private readonly int _windowBits; + private readonly List _inflaters = new(MaximumRetained); + private readonly List _deflaters = new(MaximumRetained); + private readonly Timer _cleaningTimer; + + /// + /// The number of cached inflaters and deflaters. + /// + private int _activeCount; + + private ZLibStreamPool(int windowBits) + { + // Use negative window bits to for raw deflate data + _windowBits = -windowBits; + + bool restoreFlow = false; + try + { + if (!ExecutionContext.IsFlowSuppressed()) + { + ExecutionContext.SuppressFlow(); + restoreFlow = true; + } + + // There is no need to use weak references here, because these pools are kept + // for the entire lifetime of the application. Also we reset the timer on each tick, + // which prevents the object being rooted forever. + _cleaningTimer = new Timer(x => ((ZLibStreamPool)x!).RemoveStaleItems(), + state: this, Timeout.Infinite, Timeout.Infinite); + } + finally + { + if (restoreFlow) + { + ExecutionContext.RestoreFlow(); + } + } + } + + public static ZLibStreamPool GetOrCreate(int windowBits) + { + Debug.Assert(windowBits >= WebSocketDeflateOptions.MinWindowBits + && windowBits <= WebSocketDeflateOptions.MaxWindowBits); + + int index = windowBits - WebSocketDeflateOptions.MinWindowBits; + ref ZLibStreamPool? pool = ref s_pools[index]; + + return Volatile.Read(ref pool) ?? EnsureInitialized(windowBits, ref pool); + + static ZLibStreamPool EnsureInitialized(int windowBits, ref ZLibStreamPool? target) + { + Interlocked.CompareExchange(ref target, new ZLibStreamPool(windowBits), null); + + Debug.Assert(target != null); + return target; + } + } + + public ZLibStreamHandle GetInflater() + { + if (TryGet(_inflaters, out ZLibStreamHandle? stream)) + { + return stream; + } + + return CreateInflater(); + } + + public void ReturnInflater(ZLibStreamHandle stream) + { + if (stream.InflateReset() != ErrorCode.Ok) + { + stream.Dispose(); + return; + } + + Return(stream, _inflaters); + } + + public ZLibStreamHandle GetDeflater() + { + if (TryGet(_deflaters, out ZLibStreamHandle? stream)) + { + return stream; + } + + return CreateDeflater(); + } + + public void ReturnDeflater(ZLibStreamHandle stream) + { + if (stream.DeflateReset() != ErrorCode.Ok) + { + stream.Dispose(); + return; + } + + Return(stream, _deflaters); + } + + private void Return(ZLibStreamHandle stream, List cache) + { + lock (cache) + { + if (cache.Count < MaximumRetained) + { + cache.Add(new CacheItem(stream)); + + if (Interlocked.Increment(ref _activeCount) == 1) + { + _cleaningTimer.Change(CacheTimeout, Timeout.InfiniteTimeSpan); + } + return; + } + } + + // If we've reached the maximum retained capacity, we will destroy the stream. + // It is important that we do this outside of the cache lock. + stream.Dispose(); + } + + private bool TryGet(List cache, [NotNullWhen(true)] out ZLibStreamHandle? stream) + { + lock (cache) + { + int count = cache.Count; + + if (count > 0) + { + CacheItem item = cache[count - 1]; + cache.RemoveAt(count - 1); + Interlocked.Decrement(ref _activeCount); + + stream = item.Stream; + return true; + } + } + + stream = null; + return false; + } + + private void RemoveStaleItems() + { + RemoveStaleItems(_inflaters); + RemoveStaleItems(_deflaters); + + // There is a race condition here, were _activeCount could be decremented + // by a rent operation, but it's not big deal to schedule a timer tick that + // would eventually do nothing. + if (_activeCount > 0) + { + _cleaningTimer.Change(CacheTimeout, Timeout.InfiniteTimeSpan); + } + } + + private void RemoveStaleItems(List cache) + { + long currentTimestamp = Stopwatch.GetTimestamp(); + + lock (cache) + { + for (int index = cache.Count; index >= 0;) + { + CacheItem item = cache[index]; + + if (currentTimestamp - item.Timestamp > CacheTimeoutRawTicks) + { + item.Stream.Dispose(); + cache.RemoveAt(index); + Interlocked.Decrement(ref _activeCount); + } + else + { + --index; + } + } + } + } + + private ZLibStreamHandle CreateDeflater() + { + ZLibStreamHandle stream; + ErrorCode errorCode; + try + { + errorCode = CreateZLibStreamForDeflate(out stream, + level: CompressionLevel.DefaultCompression, + windowBits: _windowBits, + memLevel: Deflate_DefaultMemLevel, + strategy: CompressionStrategy.DefaultStrategy); + } + catch (Exception cause) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + + if (errorCode != ErrorCode.Ok) + { + string message = errorCode == ErrorCode.MemError + ? SR.ZLibErrorNotEnoughMemory + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); + } + + return stream; + } + + private ZLibStreamHandle CreateInflater() + { + ZLibStreamHandle stream; + ErrorCode errorCode; + + try + { + errorCode = CreateZLibStreamForInflate(out stream, _windowBits); + } + catch (Exception exception) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); + } + + if (errorCode == ErrorCode.Ok) + { + return stream; + } + + stream.Dispose(); + + string message = errorCode == ErrorCode.MemError + ? SR.ZLibErrorNotEnoughMemory + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); + } + + private readonly struct CacheItem + { + public CacheItem(ZLibStreamHandle stream) + { + Stream = stream; + Timestamp = Stopwatch.GetTimestamp(); + } + + public ZLibStreamHandle Stream { get; } + + /// + /// The time when this item was returned to cache. + /// + public long Timestamp { get; } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs index a843f9ceb21f9e..43299f9ce98a57 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -20,12 +20,12 @@ public sealed class WebSocketDeflateOptions /// We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream /// and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. /// - private const int MinWindowBits = 9; + internal const int MinWindowBits = 9; /// /// The maximum value for window bits that the websocket can support. /// - private const int MaxWindowBits = 15; + internal const int MaxWindowBits = 15; private int _clientMaxWindowBits = 15; private int _serverMaxWindowBits = 15; From fdf41d677a999cd1f249d37e985025e0181f006c Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 8 Apr 2021 09:08:02 +0300 Subject: [PATCH 28/47] Moved min/max deflate window bits consants to WebSocketValidate so they can be used in System.Net.WebSockets.Client. --- .../Net/WebSockets/WebSocketValidate.cs | 15 ++++++++++++ .../Net/WebSockets/WebSocketHandle.Managed.cs | 8 +++---- .../WebSockets/Compression/ZLibStreamPool.cs | 8 +++---- .../Net/WebSockets/WebSocketDeflateOptions.cs | 23 ++++--------------- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs index a092c966483896..a60b159f0db530 100644 --- a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs +++ b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs @@ -9,6 +9,21 @@ namespace System.Net.WebSockets { internal static partial class WebSocketValidate { + /// + /// The minimum value for window bits that the websocket per-message-deflate extension can support. + /// The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 + /// and https://zlib.net/manual.html). Quote from the manual: + /// "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". + /// We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream + /// and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. + /// + internal const int MinDeflateWindowBits = 9; + + /// + /// The maximum value for window bits that the websocket per-message-deflate extension can support. + /// + internal const int MaxDeflateWindowBits = 15; + internal const int MaxControlFramePayloadLength = 123; private const int CloseStatusCodeAbort = 1006; private const int CloseStatusCodeFailedTLSHandshake = 1015; diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index b6e16077706481..a9e8d327638a7b 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -278,8 +278,8 @@ static int ParseWindowBits(ReadOnlySpan value) if (startIndex < 0 || !int.TryParse(value.Slice(startIndex + 1), NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) || - windowBits < 9 || - windowBits > 15) + windowBits < WebSocketValidate.MinDeflateWindowBits || + windowBits > WebSocketValidate.MaxDeflateWindowBits) { throw new WebSocketException(WebSocketError.HeaderError, SR.Format(SR.net_WebSockets_InvalidResponseHeader, ClientWebSocketDeflateConstants.Extension, value.ToString())); @@ -333,7 +333,7 @@ static string GetDeflateOptions(WebSocketDeflateOptions options) var builder = new StringBuilder(ClientWebSocketDeflateConstants.MaxExtensionLength); builder.Append(ClientWebSocketDeflateConstants.Extension).Append("; "); - if (options.ClientMaxWindowBits != 15) + if (options.ClientMaxWindowBits != WebSocketValidate.MaxDeflateWindowBits) { builder.Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits).Append('=') .Append(options.ClientMaxWindowBits.ToString(CultureInfo.InvariantCulture)); @@ -351,7 +351,7 @@ static string GetDeflateOptions(WebSocketDeflateOptions options) builder.Append("; "); - if (options.ServerMaxWindowBits != 15) + if (options.ServerMaxWindowBits != WebSocketValidate.MaxDeflateWindowBits) { builder.Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits).Append('=') .Append(options.ServerMaxWindowBits.ToString(CultureInfo.InvariantCulture)); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs index 9e5a8bc466da08..6477d7d5d23345 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs @@ -12,7 +12,7 @@ namespace System.Net.WebSockets.Compression internal sealed class ZLibStreamPool { private static readonly ZLibStreamPool?[] s_pools - = new ZLibStreamPool[WebSocketDeflateOptions.MaxWindowBits - WebSocketDeflateOptions.MinWindowBits + 1]; + = new ZLibStreamPool[WebSocketValidate.MaxDeflateWindowBits - WebSocketValidate.MinDeflateWindowBits + 1]; /// /// The maximum number of cached items. @@ -70,10 +70,10 @@ private ZLibStreamPool(int windowBits) public static ZLibStreamPool GetOrCreate(int windowBits) { - Debug.Assert(windowBits >= WebSocketDeflateOptions.MinWindowBits - && windowBits <= WebSocketDeflateOptions.MaxWindowBits); + Debug.Assert(windowBits >= WebSocketValidate.MinDeflateWindowBits + && windowBits <= WebSocketValidate.MaxDeflateWindowBits); - int index = windowBits - WebSocketDeflateOptions.MinWindowBits; + int index = windowBits - WebSocketValidate.MinDeflateWindowBits; ref ZLibStreamPool? pool = ref s_pools[index]; return Volatile.Read(ref pool) ?? EnsureInitialized(windowBits, ref pool); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs index 43299f9ce98a57..584ec1603234ed 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -12,21 +12,6 @@ namespace System.Net.WebSockets /// public sealed class WebSocketDeflateOptions { - /// - /// The minimum value for window bits that the websocket can support. - /// The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 - /// and https://zlib.net/manual.html). Quote from the manual: - /// "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". - /// We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream - /// and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. - /// - internal const int MinWindowBits = 9; - - /// - /// The maximum value for window bits that the websocket can support. - /// - internal const int MaxWindowBits = 15; - private int _clientMaxWindowBits = 15; private int _serverMaxWindowBits = 15; @@ -40,10 +25,10 @@ public int ClientMaxWindowBits get => _clientMaxWindowBits; set { - if (value < MinWindowBits || value > MaxWindowBits) + if (value < WebSocketValidate.MinDeflateWindowBits || value > WebSocketValidate.MaxDeflateWindowBits) { throw new ArgumentOutOfRangeException(nameof(ClientMaxWindowBits), value, - SR.Format(SR.net_WebSockets_ArgumentOutOfRange, MinWindowBits, MaxWindowBits)); + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, WebSocketValidate.MinDeflateWindowBits, WebSocketValidate.MaxDeflateWindowBits)); } _clientMaxWindowBits = value; } @@ -66,10 +51,10 @@ public int ServerMaxWindowBits get => _serverMaxWindowBits; set { - if (value < MinWindowBits || value > MaxWindowBits) + if (value < WebSocketValidate.MinDeflateWindowBits || value > WebSocketValidate.MaxDeflateWindowBits) { throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits), value, - SR.Format(SR.net_WebSockets_ArgumentOutOfRange, MinWindowBits, MaxWindowBits)); + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, WebSocketValidate.MinDeflateWindowBits, WebSocketValidate.MaxDeflateWindowBits)); } _serverMaxWindowBits = value; } From 4fd18f68bd338a2282b8de770ea495641261854f Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 8 Apr 2021 23:32:02 +0300 Subject: [PATCH 29/47] Testing negotiation of deflate options with reflection. --- .../Net/WebSockets/WebSocketHandle.Managed.cs | 2 ++ .../tests/DeflateTests.cs | 16 ++++++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index a9e8d327638a7b..a5198825b9072b 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -23,6 +23,7 @@ internal sealed class WebSocketHandle private readonly CancellationTokenSource _abortSource = new CancellationTokenSource(); private WebSocketState _state = WebSocketState.Connecting; + private WebSocketDeflateOptions? _negotiatedDeflateOptions; public WebSocket? WebSocket { get; private set; } public WebSocketState State => WebSocket?.State ?? _state; @@ -215,6 +216,7 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli KeepAliveInterval = options.KeepAliveInterval, DangerousDeflateOptions = negotiatedDeflateOptions }); + _negotiatedDeflateOptions = negotiatedDeflateOptions; } catch (Exception exc) { diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs index af7827aec25f4f..66707c1c4bf6cb 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Net.Test.Common; +using System.Reflection; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -46,12 +47,15 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => await client.ConnectAsync(uri, cancellation.Token); - // Uncomment this if we expose DangerousDeflateOptions directly in the websocket to represent the - // negotiated settings. Otherise we can't verify if compression is negotiated successfully. - // Assert.Equal(clientWindowBits - 1, client.DangerousDeflateOptions.ClientMaxWindowBits); - // Assert.Equal(clientContextTakeover, client.DangerousDeflateOptions.ClientContextTakeover); - // Assert.Equal(serverWindowBits - 1, client.DangerousDeflateOptions.ServerMaxWindowBits); - // Assert.Equal(serverContextTakover, client.DangerousDeflateOptions.ServerContextTakeover); + object webSocketHandle = client.GetType().GetField("_innerWebSocket", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(client); + WebSocketDeflateOptions negotiatedDeflateOptions = (WebSocketDeflateOptions)webSocketHandle.GetType() + .GetField("_negotiatedDeflateOptions", BindingFlags.NonPublic | BindingFlags.Instance) + .GetValue(webSocketHandle); + + Assert.Equal(clientWindowBits - 1, negotiatedDeflateOptions.ClientMaxWindowBits); + Assert.Equal(clientContextTakeover, negotiatedDeflateOptions.ClientContextTakeover); + Assert.Equal(serverWindowBits - 1, negotiatedDeflateOptions.ServerMaxWindowBits); + Assert.Equal(serverContextTakover, negotiatedDeflateOptions.ServerContextTakeover); }, server => server.AcceptConnectionAsync(async connection => { var extensionsReply = CreateDeflateOptionsHeader(new WebSocketDeflateOptions From 29c539a9d3c37e84b64cf3e7e3ae55992e1ca908 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 8 Apr 2021 23:42:25 +0300 Subject: [PATCH 30/47] Removed the default value for messageFlags parameter. --- .../src/System/Net/WebSockets/WebSocket.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index bd1e5c5186ab0c..044c7b95536bef 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -58,7 +58,7 @@ public virtual ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessage new ValueTask(SendAsync(arraySegment, messageType, endOfMessage, cancellationToken)) : SendWithArrayPoolAsync(buffer, messageType, endOfMessage, cancellationToken); - public virtual ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags = WebSocketMessageFlags.EndOfMessage, CancellationToken cancellationToken = default) + public virtual ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken = default) { return SendAsync(buffer, messageType, messageFlags.HasFlag(WebSocketMessageFlags.EndOfMessage), cancellationToken); } From 0739640ffb06a3efa06fb789035cc4fc2abd19a1 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 9 Apr 2021 00:36:01 +0300 Subject: [PATCH 31/47] pr feedback --- .../tests/DeflateTests.cs | 1 - .../Compression/WebSocketDeflater.cs | 3 +- .../Compression/WebSocketInflater.cs | 22 +++++-------- .../tests/WebSocketTestStream.cs | 32 +++++++++---------- 4 files changed, 25 insertions(+), 33 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs index 66707c1c4bf6cb..ed307e2b93530f 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -21,7 +21,6 @@ public DeflateTests(ITestOutputHelper output) : base(output) [ConditionalTheory(nameof(WebSocketsSupported))] [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] - [ActiveIssue("https://github.com/dotnet/runtime/issues/42852", TestPlatforms.Browser)] [InlineData(15, true, 15, true, "permessage-deflate; client_max_window_bits; server_max_window_bits")] [InlineData(14, true, 15, true, "permessage-deflate; client_max_window_bits=14; server_max_window_bits")] [InlineData(15, true, 14, true, "permessage-deflate; client_max_window_bits; server_max_window_bits=14")] diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 2e6c51aab82ef3..3ec57296b1cd4d 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -66,6 +66,7 @@ public ReadOnlySpan Deflate(ReadOnlySpan payload, bool continuation, if (!needsMoreOutput) { + Debug.Assert(consumed == payload.Length); break; } @@ -154,7 +155,7 @@ private unsafe int UnsafeFlush(Span output, out bool needsMoreBuffer) { Debug.Assert(_stream is not null); Debug.Assert(_stream.AvailIn == 0); - Debug.Assert(output.Length >= 6, "We neede at least 6 bytes guarantee the completion of the deflate block."); + Debug.Assert(output.Length >= 6, "We need at least 6 bytes to guarantee the completion of the deflate block."); fixed (byte* fixedOutput = output) { diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index dc5e7dd80568b6..5629cb380b3600 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -237,20 +237,14 @@ private static unsafe int Inflate(ZLibStreamHandle stream, Span destinatio } } - switch (errorCode) - { - case ErrorCode.MemError: // Not enough memory to complete the operation - throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); - - case ErrorCode.DataError: // The input data was corrupted (input stream not conforming to the zlib format or incorrect check value) - throw new WebSocketException(SR.ZLibUnsupportedCompression); - - case ErrorCode.StreamError: //the stream structure was inconsistent (for example if next_in or next_out was NULL), - throw new WebSocketException(SR.ZLibErrorInconsistentStream); - - default: - throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); - } + string message = errorCode switch + { + ErrorCode.MemError => SR.ZLibErrorNotEnoughMemory, + ErrorCode.DataError => SR.ZLibUnsupportedCompression, + ErrorCode.StreamError => SR.ZLibErrorInconsistentStream, + _ => string.Format(SR.ZLibErrorUnexpected, (int)errorCode) + }; + throw new WebSocketException(message); } } } diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs index 0b3e7b443f239f..b7dfb3ea7f26da 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs @@ -38,11 +38,11 @@ public int Available { get { - var available = 0; + int available = 0; lock (_inputQueue) { - foreach (var x in _inputQueue) + foreach (Block x in _inputQueue) { available += x.AvailableLength; } @@ -100,30 +100,28 @@ protected override void Dispose(bool disposing) public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) { - using (var cancellation = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposed.Token)) + using CancellationTokenSource cancellation = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposed.Token); + try { - try - { - await _inputLock.WaitAsync(cancellation.Token).ConfigureAwait(false); - } - catch (TaskCanceledException) when (cancellationToken.IsCancellationRequested) - { - throw new OperationCanceledException(cancellationToken); - } - catch (OperationCanceledException) when (_disposed.IsCancellationRequested) - { - return 0; - } + await _inputLock.WaitAsync(cancellation.Token).ConfigureAwait(false); + } + catch (TaskCanceledException) when (cancellationToken.IsCancellationRequested) + { + throw new OperationCanceledException(cancellationToken); + } + catch (OperationCanceledException) when (_disposed.IsCancellationRequested) + { + return 0; } lock (_inputQueue) { - var block = _inputQueue.Peek(); + Block block = _inputQueue.Peek(); if (block == Block.ConnectionClosed) { return 0; } - var count = Math.Min(block.AvailableLength, buffer.Length); + int count = Math.Min(block.AvailableLength, buffer.Length); block.Available.Slice(0, count).CopyTo(buffer.Span); block.Advance(count); From 224f7ab2a58910e07d4be9955ea6f35b8d5f7fc4 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 9 Apr 2021 00:37:35 +0300 Subject: [PATCH 32/47] Removed cache capacity limit for deflate streams. Now the cache is only controlled by the inactivity timeout for the streams. --- .../WebSockets/Compression/ZLibStreamPool.cs | 64 +++++++++---------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs index 6477d7d5d23345..cdf758c78bd811 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs @@ -14,24 +14,14 @@ internal sealed class ZLibStreamPool private static readonly ZLibStreamPool?[] s_pools = new ZLibStreamPool[WebSocketValidate.MaxDeflateWindowBits - WebSocketValidate.MinDeflateWindowBits + 1]; - /// - /// The maximum number of cached items. - /// - private static readonly int MaximumRetained = Environment.ProcessorCount * 2; - /// /// The amount of time after which a cached item will be removed. /// - private static readonly TimeSpan CacheTimeout = TimeSpan.FromMinutes(1); - - /// - /// The in ticks. - /// - private static readonly long CacheTimeoutRawTicks = (long)(CacheTimeout.Ticks / (double)TimeSpan.TicksPerSecond * Stopwatch.Frequency); + private const int TimeoutMilliseconds = 60_000; private readonly int _windowBits; - private readonly List _inflaters = new(MaximumRetained); - private readonly List _deflaters = new(MaximumRetained); + private readonly List _inflaters = new(); + private readonly List _deflaters = new(); private readonly Timer _cleaningTimer; /// @@ -133,21 +123,13 @@ private void Return(ZLibStreamHandle stream, List cache) { lock (cache) { - if (cache.Count < MaximumRetained) - { - cache.Add(new CacheItem(stream)); + cache.Add(new CacheItem(stream)); - if (Interlocked.Increment(ref _activeCount) == 1) - { - _cleaningTimer.Change(CacheTimeout, Timeout.InfiniteTimeSpan); - } - return; + if (Interlocked.Increment(ref _activeCount) == 1) + { + _cleaningTimer.Change(TimeoutMilliseconds, Timeout.Infinite); } } - - // If we've reached the maximum retained capacity, we will destroy the stream. - // It is important that we do this outside of the cache lock. - stream.Dispose(); } private bool TryGet(List cache, [NotNullWhen(true)] out ZLibStreamHandle? stream) @@ -181,31 +163,47 @@ private void RemoveStaleItems() // would eventually do nothing. if (_activeCount > 0) { - _cleaningTimer.Change(CacheTimeout, Timeout.InfiniteTimeSpan); + _cleaningTimer.Change(TimeoutMilliseconds, Timeout.Infinite); } } private void RemoveStaleItems(List cache) { - long currentTimestamp = Stopwatch.GetTimestamp(); + long currentTimestamp = Environment.TickCount64; + List? removedStreams = null; lock (cache) { - for (int index = cache.Count; index >= 0;) + for (int index = 0; index < cache.Count; ++index) { CacheItem item = cache[index]; - if (currentTimestamp - item.Timestamp > CacheTimeoutRawTicks) + if (currentTimestamp - item.Timestamp > TimeoutMilliseconds) { - item.Stream.Dispose(); - cache.RemoveAt(index); + removedStreams ??= new List(); + removedStreams.Add(item.Stream); Interlocked.Decrement(ref _activeCount); } else { - --index; + // The freshest streams are in the back of the collection. + // If we've reached a stream that is not timed out, all + // other after it will not be as well. + break; } } + + if (removedStreams is null) + { + return; + } + + cache.RemoveRange(0, removedStreams.Count); + } + + foreach (ZLibStreamHandle stream in removedStreams) + { + stream.Dispose(); } } @@ -269,7 +267,7 @@ private readonly struct CacheItem public CacheItem(ZLibStreamHandle stream) { Stream = stream; - Timestamp = Stopwatch.GetTimestamp(); + Timestamp = Environment.TickCount64; } public ZLibStreamHandle Stream { get; } From 4a7feb473de41ec0912d835dd3028011162d5723 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 9 Apr 2021 00:44:04 +0300 Subject: [PATCH 33/47] Created a test with active issue to track a failing case for the deflate compression. --- .../tests/WebSocketDeflateTests.cs | 79 ++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index 6bfdcc96c6ffaa..70b021c587bdff 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -276,6 +276,83 @@ public async Task Duplex(bool clientContextTakover, bool serverContextTakover) } } + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/50235")] + public async Task LargeMessageSplitInMultipleFramesActiveIssue() + { + // This test is exactly the same as LargeMessageSplitInMultipleFrames, but + // for the data seed it uses Random(0) where the other uses Random(10). This is done + // only because it was found that there is a bug in the deflate somewhere and it only appears + // so far when using 10 window bits and data generated using Random(0). Once + // the issue is resolved this test can be deleted and LargeMessageSplitInMultipleFrames should be + // updated to use Random(0). + WebSocketTestStream stream = new(); + using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = new() + { + ClientMaxWindowBits = 10 + } + }); + using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DangerousDeflateOptions = new() + { + ClientMaxWindowBits = 10 + } + }); + + Memory testData = new byte[ushort.MaxValue]; + Memory receivedData = new byte[testData.Length]; + + // Make the data incompressible to make sure that the output is larger than the input + var rng = new Random(0); + rng.NextBytes(testData.Span); + + // Test it a few times with different frame sizes + for (var i = 0; i < 10; ++i) + { + var frameSize = rng.Next(1024, 2048); + var position = 0; + + while (position < testData.Length) + { + var currentFrameSize = Math.Min(frameSize, testData.Length - position); + var eof = position + currentFrameSize == testData.Length; + + await server.SendAsync(testData.Slice(position, currentFrameSize), WebSocketMessageType.Binary, eof, CancellationToken); + position += currentFrameSize; + } + + Assert.True(testData.Length < stream.Remote.Available, "The compressed data should be bigger."); + Assert.Equal(testData.Length, position); + + // Receive the data from the client side + receivedData.Span.Clear(); + position = 0; + + // Intentionally receive with a frame size that is less than what the sender used + frameSize /= 3; + + while (true) + { + int currentFrameSize = Math.Min(frameSize, testData.Length - position); + ValueWebSocketReceiveResult result = await client.ReceiveAsync(receivedData.Slice(position, currentFrameSize), CancellationToken); + + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + position += result.Count; + + if (result.EndOfMessage) + break; + } + + Assert.Equal(0, stream.Remote.Available); + Assert.Equal(testData.Length, position); + Assert.True(testData.Span.SequenceEqual(receivedData.Span)); + } + } + [Theory] [MemberData(nameof(SupportedWindowBits))] public async Task LargeMessageSplitInMultipleFrames(int windowBits) @@ -301,7 +378,7 @@ public async Task LargeMessageSplitInMultipleFrames(int windowBits) Memory receivedData = new byte[testData.Length]; // Make the data incompressible to make sure that the output is larger than the input - var rng = new Random(0); + var rng = new Random(10); rng.NextBytes(testData.Span); // Test it a few times with different frame sizes From 222c782b2cd622daaeb23276cf43989dcdc6ab21 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 9 Apr 2021 08:13:21 +0300 Subject: [PATCH 34/47] Disabled client websocket deflate tests for browser platform. --- src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs index ed307e2b93530f..5759675ed81a39 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -13,6 +13,7 @@ namespace System.Net.WebSockets.Client.Tests { + [PlatformSpecific(~TestPlatforms.Browser)] public class DeflateTests : ClientWebSocketTestBase { public DeflateTests(ITestOutputHelper output) : base(output) From 0b09b5b4b57100d1c54b1634fb1e212add4744de Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 13 Apr 2021 01:22:32 +0300 Subject: [PATCH 35/47] Created tests for ZLibStream pooling. --- .../WebSockets/Compression/ZLibStreamPool.cs | 20 ++-- .../tests/System.Net.WebSockets.Tests.csproj | 3 +- .../tests/ZLibStreamTests.cs | 104 ++++++++++++++++++ 3 files changed, 119 insertions(+), 8 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets/tests/ZLibStreamTests.cs diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs index cdf758c78bd811..75ff4707adcdad 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs @@ -15,24 +15,30 @@ private static readonly ZLibStreamPool?[] s_pools = new ZLibStreamPool[WebSocketValidate.MaxDeflateWindowBits - WebSocketValidate.MinDeflateWindowBits + 1]; /// - /// The amount of time after which a cached item will be removed. + /// The default amount of time after which a cached item will be removed. /// - private const int TimeoutMilliseconds = 60_000; + private const int DefaultTimeoutMilliseconds = 60_000; private readonly int _windowBits; private readonly List _inflaters = new(); private readonly List _deflaters = new(); private readonly Timer _cleaningTimer; + /// + /// The amount of time after which a cached item will be removed. + /// + private readonly int _timeoutMilliseconds; + /// /// The number of cached inflaters and deflaters. /// private int _activeCount; - private ZLibStreamPool(int windowBits) + private ZLibStreamPool(int windowBits, int timeoutMilliseconds) { // Use negative window bits to for raw deflate data _windowBits = -windowBits; + _timeoutMilliseconds = timeoutMilliseconds; bool restoreFlow = false; try @@ -70,7 +76,7 @@ public static ZLibStreamPool GetOrCreate(int windowBits) static ZLibStreamPool EnsureInitialized(int windowBits, ref ZLibStreamPool? target) { - Interlocked.CompareExchange(ref target, new ZLibStreamPool(windowBits), null); + Interlocked.CompareExchange(ref target, new ZLibStreamPool(windowBits, DefaultTimeoutMilliseconds), null); Debug.Assert(target != null); return target; @@ -127,7 +133,7 @@ private void Return(ZLibStreamHandle stream, List cache) if (Interlocked.Increment(ref _activeCount) == 1) { - _cleaningTimer.Change(TimeoutMilliseconds, Timeout.Infinite); + _cleaningTimer.Change(_timeoutMilliseconds, Timeout.Infinite); } } } @@ -163,7 +169,7 @@ private void RemoveStaleItems() // would eventually do nothing. if (_activeCount > 0) { - _cleaningTimer.Change(TimeoutMilliseconds, Timeout.Infinite); + _cleaningTimer.Change(_timeoutMilliseconds, Timeout.Infinite); } } @@ -178,7 +184,7 @@ private void RemoveStaleItems(List cache) { CacheItem item = cache[index]; - if (currentTimestamp - item.Timestamp > TimeoutMilliseconds) + if (currentTimestamp - item.Timestamp > _timeoutMilliseconds) { removedStreams ??= new List(); removedStreams.Add(item.Stream); diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index 0b606a12e8c446..9ce68c7b42f998 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -1,4 +1,4 @@ - + $(NetCoreAppCurrent) @@ -10,6 +10,7 @@ + + { + if (x % 2 == 0) + { + object inflater = pool.GetInflater(); + pool.ReturnInflater(inflater); + } + else + { + object deflater = pool.GetDeflater(); + pool.ReturnDeflater(deflater); + } + }); + + Assert.True(pool.ActiveCount >= 2); + Assert.True(pool.ActiveCount <= parallelOptions.MaxDegreeOfParallelism * 2); + Thread.Sleep(250); + Assert.Equal(0, pool.ActiveCount); + } + + private sealed class Pool + { + private static Type? s_type; + private static ConstructorInfo? s_constructor; + private static FieldInfo? s_activeCount; + private static MethodInfo? s_rentInflater; + private static MethodInfo? s_returnInflater; + private static MethodInfo? s_rentDeflater; + private static MethodInfo? s_returnDeflater; + + private readonly object _instance; + + public Pool(int timeoutMilliseconds) + { + s_type ??= typeof(WebSocket).Assembly.GetType("System.Net.WebSockets.Compression.ZLibStreamPool", throwOnError: true); + s_constructor ??= s_type.GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic)[0]; + + _instance = s_constructor.Invoke(new object[] { /*windowBits*/9, timeoutMilliseconds }); + } + + public int ActiveCount => (int)(s_activeCount ??= s_type.GetField("_activeCount", BindingFlags.Instance | BindingFlags.NonPublic)).GetValue(_instance); + + public object GetInflater() => GetMethod(ref s_rentInflater).Invoke(_instance, null); + + public void ReturnInflater(object inflater) => GetMethod(ref s_returnInflater).Invoke(_instance, new[] { inflater }); + + public object GetDeflater() => GetMethod(ref s_rentDeflater).Invoke(_instance, null); + + public void ReturnDeflater(object deflater) => GetMethod(ref s_returnDeflater).Invoke(_instance, new[] { deflater }); + + private static MethodInfo GetMethod(ref MethodInfo? method, [CallerMemberName] string? name = null) + { + return method ??= s_type.GetMethod(name) + ?? throw new InvalidProgramException($"Method {name} was not found in {s_type}."); + } + } + } +} From 99869bf41ff75048d2b0999d05cda5b905fdc703 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Wed, 14 Apr 2021 09:03:47 +0300 Subject: [PATCH 36/47] Reusing constants in deflate options. --- .../src/System/Net/WebSockets/WebSocketDeflateOptions.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs index 584ec1603234ed..d3d0af5d06c18a 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -12,8 +12,8 @@ namespace System.Net.WebSockets /// public sealed class WebSocketDeflateOptions { - private int _clientMaxWindowBits = 15; - private int _serverMaxWindowBits = 15; + private int _clientMaxWindowBits = WebSocketValidate.MaxDeflateWindowBits; + private int _serverMaxWindowBits = WebSocketValidate.MaxDeflateWindowBits; /// /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the client context. From c46a0f37adfc72295ece9529d7276cba156de87b Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Wed, 14 Apr 2021 09:04:49 +0300 Subject: [PATCH 37/47] Fixing tests for browser environment. --- .../System.Net.WebSockets/tests/ZLibStreamTests.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/tests/ZLibStreamTests.cs b/src/libraries/System.Net.WebSockets/tests/ZLibStreamTests.cs index 9845e48084ee09..eb0ad93e3770e5 100644 --- a/src/libraries/System.Net.WebSockets/tests/ZLibStreamTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/ZLibStreamTests.cs @@ -3,7 +3,6 @@ using System.Reflection; using System.Runtime.CompilerServices; -using System.Threading; using System.Threading.Tasks; using Xunit; @@ -12,7 +11,7 @@ namespace System.Net.WebSockets.Tests public class ZLibStreamTests { [Fact] - public void PoolShouldReuseTheSameInstance() + public async Task PoolShouldReuseTheSameInstance() { var pool = new Pool(timeoutMilliseconds: 100); @@ -30,14 +29,15 @@ public void PoolShouldReuseTheSameInstance() pool.ReturnInflater(inflater); Assert.Equal(1, pool.ActiveCount); - Thread.Sleep(250); + await Task.Delay(250); - // After timeout elapses we should have any active instances + // After timeout elapses we should not have any active instances Assert.Equal(0, pool.ActiveCount); } [Fact] - public void PoolingConcurrently() + [PlatformSpecific(~TestPlatforms.Browser)] // There is no concurrency in browser + public async Task PoolingConcurrently() { var pool = new Pool(timeoutMilliseconds: 100); var parallelOptions = new ParallelOptions @@ -60,7 +60,7 @@ public void PoolingConcurrently() Assert.True(pool.ActiveCount >= 2); Assert.True(pool.ActiveCount <= parallelOptions.MaxDegreeOfParallelism * 2); - Thread.Sleep(250); + await Task.Delay(250); Assert.Equal(0, pool.ActiveCount); } From 3ec748d80e1c294532a549e103e146036d7ef6bf Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 16 Apr 2021 16:10:37 +0300 Subject: [PATCH 38/47] Added Block flush code for zlib, because it's needed for websockets. --- src/libraries/Common/src/System/IO/Compression/ZLibNative.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs index 9daf208692faea..98cc8ad59dbc5a 100644 --- a/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs +++ b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs @@ -23,6 +23,7 @@ public enum FlushCode : int NoFlush = 0, SyncFlush = 2, Finish = 4, + Block = 5 } public enum ErrorCode : int From fb45f2caa1b9ca2fa6371e3c6e72a1166de0b537 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 16 Apr 2021 16:14:58 +0300 Subject: [PATCH 39/47] Bug fixes after running Autobahn WebSocket Testsuite. --- .../src/Resources/Strings.resx | 4 +- .../Net/WebSockets/WebSocketHandle.Managed.cs | 12 +-- .../Compression/WebSocketDeflater.cs | 38 ++++---- .../Compression/WebSocketInflater.cs | 88 ++++++++++--------- .../System/Net/WebSockets/ManagedWebSocket.cs | 42 ++++++--- .../Net/WebSockets/WebSocketDeflateOptions.cs | 6 +- .../tests/WebSocketDeflateTests.cs | 84 +++++++++++++++++- 7 files changed, 193 insertions(+), 81 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx index bda2641a36378c..7b4718b554a151 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx @@ -139,9 +139,9 @@ WebSocket binary type '{0}' not supported. - The WebSocket failed to negotiate max server window bits. The client requested {1} but the server responded with {2}. + The WebSocket failed to negotiate max server window bits. The client requested {0} but the server responded with {1}. - The WebSocket failed to negotiate max client window bits. The client requested {1} but the server responded with {2}. + The WebSocket failed to negotiate max client window bits. The client requested {0} but the server responded with {1}. \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index a5198825b9072b..3faf524d34b67b 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -338,7 +338,7 @@ static string GetDeflateOptions(WebSocketDeflateOptions options) if (options.ClientMaxWindowBits != WebSocketValidate.MaxDeflateWindowBits) { builder.Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits).Append('=') - .Append(options.ClientMaxWindowBits.ToString(CultureInfo.InvariantCulture)); + .Append(options.ClientMaxWindowBits.ToString(CultureInfo.InvariantCulture)); } else { @@ -351,18 +351,12 @@ static string GetDeflateOptions(WebSocketDeflateOptions options) builder.Append("; ").Append(ClientWebSocketDeflateConstants.ClientNoContextTakeover); } - builder.Append("; "); - if (options.ServerMaxWindowBits != WebSocketValidate.MaxDeflateWindowBits) { - builder.Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits).Append('=') + builder.Append("; ") + .Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits).Append('=') .Append(options.ServerMaxWindowBits.ToString(CultureInfo.InvariantCulture)); } - else - { - // Advertise that we support this option - builder.Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits); - } if (!options.ServerContextTakeover) { diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 3ec57296b1cd4d..36a0049d705fbf 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -3,7 +3,6 @@ using System.Buffers; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using static System.IO.Compression.ZLibNative; namespace System.Net.WebSockets.Compression @@ -43,7 +42,7 @@ public void ReleaseBuffer() } } - public ReadOnlySpan Deflate(ReadOnlySpan payload, bool continuation, bool endOfMessage) + public ReadOnlySpan Deflate(ReadOnlySpan payload, bool endOfMessage) { Debug.Assert(_buffer is null, "Invalid state, ReleaseBuffer not called."); @@ -60,7 +59,7 @@ public ReadOnlySpan Deflate(ReadOnlySpan payload, bool continuation, while (true) { - DeflatePrivate(payload, _buffer.AsSpan(position), continuation, endOfMessage, + DeflatePrivate(payload, _buffer.AsSpan(position), endOfMessage, out int consumed, out int written, out bool needsMoreOutput); position += written; @@ -82,7 +81,7 @@ public ReadOnlySpan Deflate(ReadOnlySpan payload, bool continuation, return new ReadOnlySpan(_buffer, 0, position); } - private void DeflatePrivate(ReadOnlySpan payload, Span output, bool continuation, bool endOfMessage, + private void DeflatePrivate(ReadOnlySpan payload, Span output, bool endOfMessage, out int consumed, out int written, out bool needsMoreOutput) { _stream ??= _streamPool.GetDeflater(); @@ -107,7 +106,6 @@ private void DeflatePrivate(ReadOnlySpan payload, Span output, bool if (needsMoreOutput) { - Debug.Assert(written == output.Length); return; } Debug.Assert(output.Slice(written - WebSocketInflater.FlushMarkerLength, WebSocketInflater.FlushMarkerLength) @@ -147,7 +145,7 @@ private unsafe void UnsafeDeflate(ReadOnlySpan input, Span output, o consumed = input.Length - (int)_stream.AvailIn; written = output.Length - (int)_stream.AvailOut; - needsMoreBuffer = errorCode == ErrorCode.BufError; + needsMoreBuffer = errorCode == ErrorCode.BufError || _stream.AvailIn > 0; } } @@ -155,7 +153,6 @@ private unsafe int UnsafeFlush(Span output, out bool needsMoreBuffer) { Debug.Assert(_stream is not null); Debug.Assert(_stream.AvailIn == 0); - Debug.Assert(output.Length >= 6, "We need at least 6 bytes to guarantee the completion of the deflate block."); fixed (byte* fixedOutput = output) { @@ -165,16 +162,27 @@ private unsafe int UnsafeFlush(Span output, out bool needsMoreBuffer) _stream.NextOut = (IntPtr)fixedOutput; _stream.AvailOut = (uint)output.Length; - // The flush is set to Z_SYNC_FLUSH, all pending output is flushed - // to the output buffer and the output is aligned on a byte boundary, - // so that the decompressor can get all input data available so far. - // This completes the current deflate block and follows it with an empty - // stored block that is three bits plus filler bits to the next byte, - // followed by four bytes (00 00 ff ff). - ErrorCode errorCode = Deflate(_stream, FlushCode.SyncFlush); + // We need to use Z_BLOCK_FLUSH to instruct the zlib to flush all outstanding + // data but also not to emit a deflate block boundary. After we know that there is no + // more data, we can safely proceed to instruct the library to emit the boundary markers. + ErrorCode errorCode = Deflate(_stream, FlushCode.Block); Debug.Assert(errorCode is ErrorCode.Ok or ErrorCode.BufError); - needsMoreBuffer = errorCode == ErrorCode.BufError; + // We need at least 6 bytes to guarantee that we can emit a deflate block boundary. + needsMoreBuffer = _stream.AvailOut < 6; + + if (!needsMoreBuffer) + { + // The flush is set to Z_SYNC_FLUSH, all pending output is flushed + // to the output buffer and the output is aligned on a byte boundary, + // so that the decompressor can get all input data available so far. + // This completes the current deflate block and follows it with an empty + // stored block that is three bits plus filler bits to the next byte, + // followed by four bytes (00 00 ff ff). + errorCode = Deflate(_stream, FlushCode.SyncFlush); + Debug.Assert(errorCode == ErrorCode.Ok); + } + return output.Length - (int)_stream.AvailOut; } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index 5629cb380b3600..a6b6432d92004f 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -29,10 +29,10 @@ internal sealed class WebSocketInflater : IDisposable private byte? _remainingByte; /// - /// When the inflater is persisted we need to manually append the flush marker - /// before finishing the decoding. + /// The last added bytes to the inflater were part of the final + /// payload for the message being sent. /// - private bool _needsFlushMarker; + private bool _endOfMessage; private byte[]? _buffer; @@ -52,11 +52,6 @@ internal WebSocketInflater(int windowBits, bool persisted) _persisted = persisted; } - /// - /// Indicates that there is nothing left for inflating. - /// - public bool Finished { get; private set; } = true; - public Memory Memory => _buffer.AsMemory(_position + _available); public Span Span => _buffer.AsSpan(_position + _available); @@ -94,17 +89,45 @@ public void Prepare(long payloadLength, int userBufferLength) } } - /// - /// Inflates the last receive payload into the provided buffer. - /// - public unsafe void Inflate(int totalBytesReceived, Span output, bool flush, out int written) + public void AddBytes(int totalBytesReceived, bool endOfMessage) { - if (totalBytesReceived > 0) + Debug.Assert(totalBytesReceived == 0 || _buffer is not null, "Prepare must be called."); + + _available += totalBytesReceived; + _endOfMessage = endOfMessage; + + if (endOfMessage) { - Debug.Assert(_buffer is not null, "Prepare must be called."); - _available += totalBytesReceived; + if (_buffer is null) + { + Debug.Assert(_available == 0); + + _buffer = ArrayPool.Shared.Rent(FlushMarkerLength); + _available = FlushMarkerLength; + FlushMarker.CopyTo(_buffer); + } + else + { + if (_buffer.Length < _available + FlushMarkerLength) + { + byte[] newBuffer = ArrayPool.Shared.Rent(_available + FlushMarkerLength); + _buffer.AsSpan(0, _available).CopyTo(newBuffer); + ArrayPool.Shared.Return(_buffer); + + _buffer = newBuffer; + } + + FlushMarker.CopyTo(_buffer.AsSpan(_available)); + _available += FlushMarkerLength; + } } + } + /// + /// Inflates the last receive payload into the provided buffer. + /// + public unsafe bool Inflate(Span output, out int written) + { _stream ??= _streamPool.GetInflater(); if (_available > 0 && output.Length > 0) @@ -116,13 +139,12 @@ public unsafe void Inflate(int totalBytesReceived, Span output, bool flush _stream.NextIn = (IntPtr)(bufferPtr + _position); _stream.AvailIn = (uint)_available; - written = Inflate(_stream, output); + written = Inflate(_stream, output, FlushCode.NoFlush); consumed = _available - (int)_stream.AvailIn; } _position += consumed; _available -= consumed; - _needsFlushMarker = _persisted; } else { @@ -132,35 +154,21 @@ public unsafe void Inflate(int totalBytesReceived, Span output, bool flush if (_available == 0) { ReleaseBuffer(); - Finished = flush ? Flush(output, ref written) : true; - } - else - { - Finished = false; + return _endOfMessage ? Finish(output, ref written) : true; } + + return false; } /// /// Finishes the decoding by flushing any outstanding data to the output. /// /// true if the flush completed, false to indicate that there is more outstanding data. - private unsafe bool Flush(Span output, ref int written) + private unsafe bool Finish(Span output, ref int written) { - Debug.Assert(_stream is not null); + Debug.Assert(_stream is not null && _stream.AvailIn == 0); Debug.Assert(_available == 0); - if (_needsFlushMarker) - { - _needsFlushMarker = false; - - // It's OK to use the flush marker like this, because it's pointer is unmovable. - fixed (byte* flushMarkerPtr = FlushMarker) - { - _stream.NextIn = (IntPtr)flushMarkerPtr; - _stream.AvailIn = FlushMarkerLength; - } - } - if (_remainingByte is not null) { if (output.Length == written) @@ -175,7 +183,7 @@ private unsafe bool Flush(Span output, ref int written) // If we have more space in the output, try to inflate if (output.Length > written) { - written += Inflate(_stream, output[written..]); + written += Inflate(_stream, output[written..], FlushCode.SyncFlush); } // After inflate, if we have more space in the output then it means that we @@ -209,7 +217,7 @@ private static unsafe bool IsFinished(ZLibStreamHandle stream, out byte? remaini // There is no other way to make sure that we'e consumed all data // but to try to inflate again with at least one byte of output buffer. byte b; - if (Inflate(stream, new Span(&b, 1)) == 0) + if (Inflate(stream, new Span(&b, 1), FlushCode.SyncFlush) == 0) { remainingByte = null; return true; @@ -219,7 +227,7 @@ private static unsafe bool IsFinished(ZLibStreamHandle stream, out byte? remaini return false; } - private static unsafe int Inflate(ZLibStreamHandle stream, Span destination) + private static unsafe int Inflate(ZLibStreamHandle stream, Span destination, FlushCode flushCode) { Debug.Assert(destination.Length > 0); ErrorCode errorCode; @@ -229,7 +237,7 @@ private static unsafe int Inflate(ZLibStreamHandle stream, Span destinatio stream.NextOut = (IntPtr)bufPtr; stream.AvailOut = (uint)destination.Length; - errorCode = stream.Inflate(FlushCode.NoFlush); + errorCode = stream.Inflate(flushCode); if (errorCode is ErrorCode.Ok or ErrorCode.StreamEnd or ErrorCode.BufError) { diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 62a01640fcedf6..01d633db8ca791 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -218,13 +218,16 @@ internal ManagedWebSocket(Stream stream, WebSocketCreationOptions options) if (deflateOptions is not null) { - _deflater = options.IsServer ? - new WebSocketDeflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover) : - new WebSocketDeflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover); - - _inflater = options.IsServer ? - new WebSocketInflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover) : - new WebSocketInflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover); + if (options.IsServer) + { + _inflater = new WebSocketInflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover); + _deflater = new WebSocketDeflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover); + } + else + { + _inflater = new WebSocketInflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover); + _deflater = new WebSocketDeflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover); + } } } @@ -561,9 +564,9 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM /// Writes a frame into the send buffer, which can then be sent over the network. private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlySpan payloadBuffer) { - if (_deflater is not null && payloadBuffer.Length > 0 && !disableCompression) + if (_deflater is not null && !disableCompression) { - payloadBuffer = _deflater.Deflate(payloadBuffer, opcode == MessageOpcode.Continuation, endOfMessage); + payloadBuffer = _deflater.Deflate(payloadBuffer, endOfMessage); } int payloadLength = payloadBuffer.Length; @@ -776,6 +779,14 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, headerErrorMessage).ConfigureAwait(false); } _receivedMaskOffsetOffset = 0; + + if (header.PayloadLength == 0 && header.Compressed) + { + // In the rare case where we receive a compressed message with no payload + // we need to tell the inflater about it, because the receive code bellow will + // not try to do anything when PayloadLength == 0. + _inflater!.AddBytes(0, endOfMessage: header.Fin); + } } // If the header represents a ping or a pong, it's a control message meant @@ -865,15 +876,18 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo } header.PayloadLength -= totalBytesReceived; + + if (header.Compressed) + { + _inflater!.AddBytes(totalBytesReceived, endOfMessage: header.Fin && header.PayloadLength == 0); + } } if (header.Compressed) { // In case of compression totalBytesReceived should actually represent how much we've // inflated, rather than how much we've read from the stream. - _inflater!.Inflate(totalBytesReceived, payloadBuffer.Span, - flush: header.Fin && header.PayloadLength == 0, out totalBytesReceived); - header.Processed = _inflater.Finished && header.PayloadLength == 0; + header.Processed = _inflater!.Inflate(payloadBuffer.Span, out totalBytesReceived) && header.PayloadLength == 0; } else { @@ -1204,6 +1218,10 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( resultHeader = default; return SR.net_Websockets_PerMessageCompressedFlagInContinuation; } + + // Set the compressed flag from the previous header so the receive procedure can use it + // directly without needing to check the previous header in case of continuations. + header.Compressed = _lastReceiveHeader.Compressed; break; case MessageOpcode.Binary: diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs index d3d0af5d06c18a..cd3f72a0432165 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -16,7 +16,8 @@ public sealed class WebSocketDeflateOptions private int _serverMaxWindowBits = WebSocketValidate.MaxDeflateWindowBits; /// - /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the client context. + /// This parameter indicates the base-2 logarithm for the LZ77 sliding window size used by + /// the client to compress messages and by the server to decompress them. /// Must be a value between 9 and 15. The default is 15. /// /// https://tools.ietf.org/html/rfc7692#section-7.1.2.2 @@ -42,7 +43,8 @@ public int ClientMaxWindowBits public bool ClientContextTakeover { get; set; } = true; /// - /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the server context. + /// This parameter indicates the base-2 logarithm for the LZ77 sliding window size used by + /// the server to compress messages and by the client to decompress them. /// Must be a value between 9 and 15. The default is 15. /// /// https://tools.ietf.org/html/rfc7692#section-7.1.2.1 diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index 70b021c587bdff..74bb00e2ef22a5 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -172,7 +172,7 @@ public async Task SendHelloWithoutContextTakeover() IsServer = true, DangerousDeflateOptions = new() { - ClientContextTakeover = false + ServerContextTakeover = false } }); @@ -493,6 +493,88 @@ public async Task ReceiveInvalidCompressedData() Assert.Equal(WebSocketState.Aborted, client.State); } + [Fact] + public async Task PayloadShouldHaveSimilarSizeWhenSplitIntoSegments() + { + WebSocketTestStream stream = new(); + WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DangerousDeflateOptions = new WebSocketDeflateOptions() + }); + + // We're using a frame size that is close to the sliding window size for the deflate + const int frameSize = 32_000; + + byte[] message = new byte[frameSize * 100]; + Random random = new(0); + + for (int i = 0; i < message.Length; ++i) + { + message[i] = (byte)random.Next(maxValue: 10); + } + + await client.SendAsync(message, WebSocketMessageType.Binary, true, CancellationToken); + + int payloadLength = stream.Remote.Available; + stream.Remote.Clear(); + + for (var i = 0; i < message.Length; i += frameSize) + { + await client.SendAsync(message.AsMemory(i, frameSize), WebSocketMessageType.Binary, i + frameSize == message.Length, CancellationToken); + } + + Assert.Equal(0.999, Math.Round(payloadLength * 1.0 / stream.Remote.Available, 3)); + } + + [Theory] + [InlineData(9, 15)] + [InlineData(15, 9)] + public async Task SendReceiveWithDifferentWindowBits(int clientWindowBits, int serverWindowBits) + { + WebSocketTestStream stream = new(); + WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = new() + { + ClientContextTakeover = false, + ClientMaxWindowBits = clientWindowBits, + ServerContextTakeover = false, + ServerMaxWindowBits = serverWindowBits + } + }); + WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DangerousDeflateOptions = new() + { + ClientContextTakeover = false, + ClientMaxWindowBits = clientWindowBits, + ServerContextTakeover = false, + ServerMaxWindowBits = serverWindowBits + } + }); + + Memory data = new byte[64 * 1024]; + Memory buffer = new byte[data.Length]; + new Random(0).NextBytes(data.Span.Slice(0, data.Length / 2)); + + await server.SendAsync(data, WebSocketMessageType.Binary, true, CancellationToken); + ValueWebSocketReceiveResult result = await client.ReceiveAsync(buffer, CancellationToken); + + Assert.Equal(data.Length, result.Count); + Assert.True(result.EndOfMessage); + Assert.True(data.Span.SequenceEqual(buffer.Span)); + + buffer.Span.Clear(); + + await client.SendAsync(data, WebSocketMessageType.Binary, true, CancellationToken); + result = await server.ReceiveAsync(buffer, CancellationToken); + + Assert.Equal(data.Length, result.Count); + Assert.True(result.EndOfMessage); + Assert.True(data.Span.SequenceEqual(buffer.Span)); + } + private ValueTask SendTextAsync(string text, WebSocket websocket, bool disableCompression = false) { WebSocketMessageFlags flags = WebSocketMessageFlags.EndOfMessage; From 00283e82615c3b119c519f69116d30e570e308e4 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 16 Apr 2021 17:07:50 +0300 Subject: [PATCH 40/47] Addressing PR feedback. --- .../Common/src/System/Net/WebSockets/WebSocketValidate.cs | 6 ++---- .../src/System/Net/WebSockets/ClientWebSocketOptions.cs | 5 ++++- .../src/System/Net/WebSockets/WebSocketHandle.Managed.cs | 2 +- .../System/Net/WebSockets/Compression/WebSocketInflater.cs | 2 +- .../src/System/Net/WebSockets/WebSocketCreationOptions.cs | 5 ++++- .../src/System/Net/WebSockets/WebSocketDeflateOptions.cs | 1 - 6 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs index a60b159f0db530..d074f618bf16db 100644 --- a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs +++ b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs @@ -11,11 +11,9 @@ internal static partial class WebSocketValidate { /// /// The minimum value for window bits that the websocket per-message-deflate extension can support. - /// The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 - /// and https://zlib.net/manual.html). Quote from the manual: - /// "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". + /// For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported. /// We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream - /// and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. + /// and thus it needs to know the window bits in advance. /// internal const int MinDeflateWindowBits = 9; diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index b24abcd368b9fb..5ab2ad51d94eb3 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -152,7 +152,10 @@ public TimeSpan KeepAliveInterval /// Gets or sets the options for the per-message-deflate extension. /// When present, the options are sent to the server during the handshake phase. If the server /// supports per-message-deflate and the options are accepted, the instance - /// will be created with compression enabled by default for all messages. + /// will be created with compression enabled by default for all messages. + /// Be aware that enabling compression makes the application subject to CRIME/BREACH type of attacks. + /// It is strongly advised to turn off compression when sending data containing secrets by + /// specifying flag for such messages. /// [UnsupportedOSPlatform("browser")] public WebSocketDeflateOptions? DangerousDeflateOptions { get; set; } diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index 3faf524d34b67b..ba18af70e8805c 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -255,7 +255,7 @@ private static WebSocketDeflateOptions ParseDeflateOptions(ReadOnlySpan ex int end = extension.IndexOf(';'); ReadOnlySpan value = (end >= 0 ? extension[..end] : extension).Trim(); - if (!value.IsEmpty) + if (value.Length > 0) { if (value.Equals(ClientWebSocketDeflateConstants.ClientNoContextTakeover, StringComparison.Ordinal)) { diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index a6b6432d92004f..169fa604e235f0 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -214,7 +214,7 @@ private void ReleaseBuffer() private static unsafe bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte) { - // There is no other way to make sure that we'e consumed all data + // There is no other way to make sure that we've consumed all data // but to try to inflate again with at least one byte of output buffer. byte b; if (Inflate(stream, new Span(&b, 1), FlushCode.SyncFlush) == 0) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs index 4f566cfaf71f56..d042583da54448 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs @@ -53,7 +53,10 @@ public TimeSpan KeepAliveInterval } /// - /// The agreed upon options for per message deflate. + /// The agreed upon options for per message deflate. + /// Be aware that enabling compression makes the application subject to CRIME/BREACH type of attacks. + /// It is strongly advised to turn off compression when sending data containing secrets by + /// specifying flag for such messages. /// public WebSocketDeflateOptions? DangerousDeflateOptions { get; set; } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs index cd3f72a0432165..e497751db288e4 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -8,7 +8,6 @@ namespace System.Net.WebSockets /// /// /// Although the WebSocket spec allows window bits from 8 to 15, the current implementation doesn't support 8 bits. - /// For more information refer to the zlib manual https://zlib.net/manual.html. /// public sealed class WebSocketDeflateOptions { From 5a0e4358dc01bd73bc720c362d150dc10eeae062 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 16 Apr 2021 18:06:01 +0300 Subject: [PATCH 41/47] Fixed a test. --- .../System.Net.WebSockets.Client/tests/DeflateTests.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs index 5759675ed81a39..e0a0e1e59fd846 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -22,12 +22,12 @@ public DeflateTests(ITestOutputHelper output) : base(output) [ConditionalTheory(nameof(WebSocketsSupported))] [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] - [InlineData(15, true, 15, true, "permessage-deflate; client_max_window_bits; server_max_window_bits")] - [InlineData(14, true, 15, true, "permessage-deflate; client_max_window_bits=14; server_max_window_bits")] + [InlineData(15, true, 15, true, "permessage-deflate; client_max_window_bits")] + [InlineData(14, true, 15, true, "permessage-deflate; client_max_window_bits=14")] [InlineData(15, true, 14, true, "permessage-deflate; client_max_window_bits; server_max_window_bits=14")] [InlineData(10, true, 11, true, "permessage-deflate; client_max_window_bits=10; server_max_window_bits=11")] - [InlineData(15, false, 15, true, "permessage-deflate; client_max_window_bits; client_no_context_takeover; server_max_window_bits")] - [InlineData(15, true, 15, false, "permessage-deflate; client_max_window_bits; server_max_window_bits; server_no_context_takeover")] + [InlineData(15, false, 15, true, "permessage-deflate; client_max_window_bits; client_no_context_takeover")] + [InlineData(15, true, 15, false, "permessage-deflate; client_max_window_bits; server_no_context_takeover")] public async Task PerMessageDeflateHeaders(int clientWindowBits, bool clientContextTakeover, int serverWindowBits, bool serverContextTakover, string expected) From 935969aa9142da6d87d58480d8ebc573827e6607 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 19 Apr 2021 11:20:00 +0300 Subject: [PATCH 42/47] Addressing flakiness of a couple of tests. --- .../tests/WebSocketDeflateTests.cs | 26 ++++++++++++------- .../tests/ZLibStreamTests.cs | 8 +++--- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index 74bb00e2ef22a5..33ac5122516cc5 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Collections.Generic; using System.Diagnostics; +using System.IO; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -493,19 +494,23 @@ public async Task ReceiveInvalidCompressedData() Assert.Equal(WebSocketState.Aborted, client.State); } - [Fact] - public async Task PayloadShouldHaveSimilarSizeWhenSplitIntoSegments() + [Theory] + [MemberData(nameof(SupportedWindowBits))] + public async Task PayloadShouldHaveSimilarSizeWhenSplitIntoSegments(int windowBits) { - WebSocketTestStream stream = new(); + MemoryStream stream = new(); WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { DangerousDeflateOptions = new WebSocketDeflateOptions() + { + ClientMaxWindowBits = windowBits + } }); // We're using a frame size that is close to the sliding window size for the deflate - const int frameSize = 32_000; + int frameSize = 2 << windowBits; - byte[] message = new byte[frameSize * 100]; + byte[] message = new byte[frameSize * 10]; Random random = new(0); for (int i = 0; i < message.Length; ++i) @@ -515,15 +520,18 @@ public async Task PayloadShouldHaveSimilarSizeWhenSplitIntoSegments() await client.SendAsync(message, WebSocketMessageType.Binary, true, CancellationToken); - int payloadLength = stream.Remote.Available; - stream.Remote.Clear(); + long payloadLength = stream.Length; + stream.SetLength(0); - for (var i = 0; i < message.Length; i += frameSize) + for (int i = 0; i < message.Length; i += frameSize) { await client.SendAsync(message.AsMemory(i, frameSize), WebSocketMessageType.Binary, i + frameSize == message.Length, CancellationToken); } - Assert.Equal(0.999, Math.Round(payloadLength * 1.0 / stream.Remote.Available, 3)); + double difference = Math.Round(1 - payloadLength * 1.0 / stream.Length, 3); + + // The difference should not be more than 10% in either direction + Assert.InRange(difference, -0.1, 0.1); } [Theory] diff --git a/src/libraries/System.Net.WebSockets/tests/ZLibStreamTests.cs b/src/libraries/System.Net.WebSockets/tests/ZLibStreamTests.cs index eb0ad93e3770e5..d12bf696f9be9a 100644 --- a/src/libraries/System.Net.WebSockets/tests/ZLibStreamTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/ZLibStreamTests.cs @@ -13,7 +13,7 @@ public class ZLibStreamTests [Fact] public async Task PoolShouldReuseTheSameInstance() { - var pool = new Pool(timeoutMilliseconds: 100); + var pool = new Pool(timeoutMilliseconds: 25); object inflater = pool.GetInflater(); for ( var i = 0; i < 10_000; ++i) @@ -29,7 +29,7 @@ public async Task PoolShouldReuseTheSameInstance() pool.ReturnInflater(inflater); Assert.Equal(1, pool.ActiveCount); - await Task.Delay(250); + await Task.Delay(200); // After timeout elapses we should not have any active instances Assert.Equal(0, pool.ActiveCount); @@ -39,7 +39,7 @@ public async Task PoolShouldReuseTheSameInstance() [PlatformSpecific(~TestPlatforms.Browser)] // There is no concurrency in browser public async Task PoolingConcurrently() { - var pool = new Pool(timeoutMilliseconds: 100); + var pool = new Pool(timeoutMilliseconds: 25); var parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = 16 @@ -60,7 +60,7 @@ public async Task PoolingConcurrently() Assert.True(pool.ActiveCount >= 2); Assert.True(pool.ActiveCount <= parallelOptions.MaxDegreeOfParallelism * 2); - await Task.Delay(250); + await Task.Delay(200); Assert.Equal(0, pool.ActiveCount); } From 437b7e70e47027585f11a38abaa4e8dc7639d383 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 20 Apr 2021 11:18:40 +0300 Subject: [PATCH 43/47] Disallowing the usage of different compression options for continuations. --- .../System.Net.WebSockets/src/Resources/Strings.resx | 3 +++ .../src/System/Net/WebSockets/ManagedWebSocket.cs | 8 +++++--- .../System.Net.WebSockets/tests/WebSocketTests.cs | 11 +++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index c838d4d188ad45..693f8d3863fd7f 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -162,4 +162,7 @@ The message was compressed using an unsupported compression method. + + The compression options for a continuation cannot be different than the options used to send the first fragment of the message. + \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 01d633db8ca791..971c2ceff82be4 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -300,18 +300,20 @@ private ValueTask SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessage } bool endOfMessage = messageFlags.HasFlag(WebSocketMessageFlags.EndOfMessage); - bool disableCompression; + bool disableCompression = messageFlags.HasFlag(WebSocketMessageFlags.DisableCompression); MessageOpcode opcode; if (_lastSendWasFragment) { - disableCompression = _lastSendHadDisableCompression; + if (_lastSendHadDisableCompression != disableCompression) + { + throw new ArgumentException(SR.net_WebSockets_Argument_MessageFlagsHasDifferentCompressionOptions, nameof(messageFlags)); + } opcode = MessageOpcode.Continuation; } else { opcode = messageType == WebSocketMessageType.Binary ? MessageOpcode.Binary : MessageOpcode.Text; - disableCompression = messageFlags.HasFlag(WebSocketMessageFlags.DisableCompression); } ValueTask t = SendFrameAsync(opcode, endOfMessage, disableCompression, buffer, cancellationToken); diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs index ad738a00ec8645..19b38d8d21760e 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; +using System.Threading.Tasks; using Xunit; namespace System.Net.WebSockets.Tests @@ -171,6 +172,16 @@ public void ValueWebSocketReceiveResult_Ctor_ValidArguments_Roundtrip(int count, Assert.Equal(endOfMessage, r.EndOfMessage); } + [Fact] + public async Task ThrowWhenContinuationWithDifferentCompressionFlags() + { + using WebSocket client = CreateFromStream(new MemoryStream(), isServer: false, null, TimeSpan.Zero); + + await client.SendAsync(Memory.Empty, WebSocketMessageType.Text, WebSocketMessageFlags.DisableCompression, default); + Assert.Throws("messageFlags", () => + client.SendAsync(Memory.Empty, WebSocketMessageType.Binary, WebSocketMessageFlags.EndOfMessage, default)); + } + public abstract class ExposeProtectedWebSocket : WebSocket { public static new bool IsStateTerminal(WebSocketState state) => From 5edadd9663d80e8df407c1447c10ba3f20163974 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sat, 24 Apr 2021 17:03:47 +0300 Subject: [PATCH 44/47] Removed [ActiveIssue] from websocket deflate tests. Created a new test that replicates Autobahn Test Case 13.3.1 which causes "invalid distance too far" zlib error. --- .../tests/WebSocketDeflateTests.cs | 186 ++++++++++-------- 1 file changed, 102 insertions(+), 84 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index 33ac5122516cc5..25efbe94b1d5bd 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -277,83 +277,6 @@ public async Task Duplex(bool clientContextTakover, bool serverContextTakover) } } - [Fact] - [ActiveIssue("https://github.com/dotnet/runtime/issues/50235")] - public async Task LargeMessageSplitInMultipleFramesActiveIssue() - { - // This test is exactly the same as LargeMessageSplitInMultipleFrames, but - // for the data seed it uses Random(0) where the other uses Random(10). This is done - // only because it was found that there is a bug in the deflate somewhere and it only appears - // so far when using 10 window bits and data generated using Random(0). Once - // the issue is resolved this test can be deleted and LargeMessageSplitInMultipleFrames should be - // updated to use Random(0). - WebSocketTestStream stream = new(); - using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions - { - IsServer = true, - DangerousDeflateOptions = new() - { - ClientMaxWindowBits = 10 - } - }); - using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions - { - DangerousDeflateOptions = new() - { - ClientMaxWindowBits = 10 - } - }); - - Memory testData = new byte[ushort.MaxValue]; - Memory receivedData = new byte[testData.Length]; - - // Make the data incompressible to make sure that the output is larger than the input - var rng = new Random(0); - rng.NextBytes(testData.Span); - - // Test it a few times with different frame sizes - for (var i = 0; i < 10; ++i) - { - var frameSize = rng.Next(1024, 2048); - var position = 0; - - while (position < testData.Length) - { - var currentFrameSize = Math.Min(frameSize, testData.Length - position); - var eof = position + currentFrameSize == testData.Length; - - await server.SendAsync(testData.Slice(position, currentFrameSize), WebSocketMessageType.Binary, eof, CancellationToken); - position += currentFrameSize; - } - - Assert.True(testData.Length < stream.Remote.Available, "The compressed data should be bigger."); - Assert.Equal(testData.Length, position); - - // Receive the data from the client side - receivedData.Span.Clear(); - position = 0; - - // Intentionally receive with a frame size that is less than what the sender used - frameSize /= 3; - - while (true) - { - int currentFrameSize = Math.Min(frameSize, testData.Length - position); - ValueWebSocketReceiveResult result = await client.ReceiveAsync(receivedData.Slice(position, currentFrameSize), CancellationToken); - - Assert.Equal(WebSocketMessageType.Binary, result.MessageType); - position += result.Count; - - if (result.EndOfMessage) - break; - } - - Assert.Equal(0, stream.Remote.Available); - Assert.Equal(testData.Length, position); - Assert.True(testData.Span.SequenceEqual(receivedData.Span)); - } - } - [Theory] [MemberData(nameof(SupportedWindowBits))] public async Task LargeMessageSplitInMultipleFrames(int windowBits) @@ -379,7 +302,7 @@ public async Task LargeMessageSplitInMultipleFrames(int windowBits) Memory receivedData = new byte[testData.Length]; // Make the data incompressible to make sure that the output is larger than the input - var rng = new Random(10); + var rng = new Random(0); rng.NextBytes(testData.Span); // Test it a few times with different frame sizes @@ -445,12 +368,12 @@ public async Task ReceiveUncompressedMessageWhenCompressionEnabled() // We should be able to handle the situation where even if we have // deflate compression enabled, uncompressed messages are OK WebSocketTestStream stream = new(); - WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { IsServer = true, DangerousDeflateOptions = null }); - WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions { DangerousDeflateOptions = new WebSocketDeflateOptions() }); @@ -479,7 +402,7 @@ public async Task ReceiveUncompressedMessageWhenCompressionEnabled() public async Task ReceiveInvalidCompressedData() { WebSocketTestStream stream = new(); - WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + using WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { DangerousDeflateOptions = new WebSocketDeflateOptions() }); @@ -499,7 +422,7 @@ public async Task ReceiveInvalidCompressedData() public async Task PayloadShouldHaveSimilarSizeWhenSplitIntoSegments(int windowBits) { MemoryStream stream = new(); - WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + using WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { DangerousDeflateOptions = new WebSocketDeflateOptions() { @@ -540,7 +463,7 @@ public async Task PayloadShouldHaveSimilarSizeWhenSplitIntoSegments(int windowBi public async Task SendReceiveWithDifferentWindowBits(int clientWindowBits, int serverWindowBits) { WebSocketTestStream stream = new(); - WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { IsServer = true, DangerousDeflateOptions = new() @@ -551,7 +474,7 @@ public async Task SendReceiveWithDifferentWindowBits(int clientWindowBits, int s ServerMaxWindowBits = serverWindowBits } }); - WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions { DangerousDeflateOptions = new() { @@ -583,6 +506,101 @@ public async Task SendReceiveWithDifferentWindowBits(int clientWindowBits, int s Assert.True(data.Span.SequenceEqual(buffer.Span)); } + [Fact] + public async Task AutobahnTestCase13_3_1() + { + // When running Autobahn Test Suite some tests failed with zlib error "invalid distance too far back". + // Further investigation lead to a bug fix in zlib intel's implementation - https://github.com/dotnet/runtime/issues/50235. + // This test replicates one of the Autobahn tests to make sure this issue doesn't appear again. + byte[][] messages = new[] + { + new byte[] { 0x7B, 0x0A, 0x20, 0x20, 0x20, 0x22, 0x41, 0x75, 0x74, 0x6F, 0x62, 0x61, 0x68, 0x6E, 0x50, 0x79 }, + new byte[] { 0x74, 0x68, 0x6F, 0x6E, 0x2F, 0x30, 0x2E, 0x36, 0x2E, 0x30, 0x22, 0x3A, 0x20, 0x7B, 0x0A, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x31, 0x2E, 0x31, 0x2E, 0x31, 0x22, 0x3A, 0x20, 0x7B, 0x0A }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69 }, + new byte[] { 0x6F, 0x72, 0x22, 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x43, 0x6C, 0x6F }, + new byte[] { 0x73, 0x65, 0x22, 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x22, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6F, 0x6E, 0x22, 0x3A, 0x20 }, + new byte[] { 0x32, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72, 0x65, 0x6D }, + new byte[] { 0x6F, 0x74, 0x65, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x43, 0x6F, 0x64, 0x65, 0x22, 0x3A, 0x20, 0x31 }, + new byte[] { 0x30, 0x30, 0x30, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72 }, + new byte[] { 0x65, 0x70, 0x6F, 0x72, 0x74, 0x66, 0x69, 0x6C, 0x65, 0x22, 0x3A, 0x20, 0x22, 0x61, 0x75, 0x74 }, + new byte[] { 0x6F, 0x62, 0x61, 0x68, 0x6E, 0x70, 0x79, 0x74, 0x68, 0x6F, 0x6E, 0x5F, 0x30, 0x5F, 0x36, 0x5F }, + new byte[] { 0x30, 0x5F, 0x63, 0x61, 0x73, 0x65, 0x5F, 0x31, 0x5F, 0x31, 0x5F, 0x31, 0x2E, 0x6A, 0x73, 0x6F }, + new byte[] { 0x6E, 0x22, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x7D, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x22, 0x31, 0x2E, 0x31, 0x2E, 0x32, 0x22, 0x3A, 0x20, 0x7B, 0x0A, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x22 }, + new byte[] { 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x22 }, + new byte[] { 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x22, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6F, 0x6E, 0x22, 0x3A, 0x20, 0x32, 0x2C, 0x0A }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72, 0x65, 0x6D, 0x6F, 0x74, 0x65 }, + new byte[] { 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x43, 0x6F, 0x64, 0x65, 0x22, 0x3A, 0x20, 0x31, 0x30, 0x30, 0x30 }, + new byte[] { 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72, 0x65, 0x70, 0x6F }, + new byte[] { 0x72, 0x74, 0x66, 0x69, 0x6C, 0x65, 0x22, 0x3A, 0x20, 0x22, 0x61, 0x75, 0x74, 0x6F, 0x62, 0x61 }, + new byte[] { 0x68, 0x6E, 0x70, 0x79, 0x74, 0x68, 0x6F, 0x6E, 0x5F, 0x30, 0x5F, 0x36, 0x5F, 0x30, 0x5F, 0x63 }, + new byte[] { 0x61, 0x73, 0x65, 0x5F, 0x31, 0x5F, 0x31, 0x5F, 0x32, 0x2E, 0x6A, 0x73, 0x6F, 0x6E, 0x22, 0x0A }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x7D, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22 }, + new byte[] { 0x31, 0x2E, 0x31, 0x2E, 0x33, 0x22, 0x3A, 0x20, 0x7B, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x22, 0x3A, 0x20, 0x22 }, + new byte[] { 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x62 }, + new byte[] { 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x22, 0x3A, 0x20, 0x22 }, + new byte[] { 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x64 }, + new byte[] { 0x75, 0x72, 0x61, 0x74, 0x69, 0x6F, 0x6E, 0x22, 0x3A, 0x20, 0x32, 0x2C, 0x0A, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72, 0x65, 0x6D, 0x6F, 0x74, 0x65, 0x43, 0x6C, 0x6F }, + new byte[] { 0x73, 0x65, 0x43, 0x6F, 0x64, 0x65, 0x22, 0x3A, 0x20, 0x31, 0x30, 0x30, 0x30, 0x2C, 0x0A, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72, 0x65, 0x70, 0x6F, 0x72, 0x74, 0x66 }, + new byte[] { 0x69, 0x6C, 0x65, 0x22, 0x3A, 0x20, 0x22, 0x61, 0x75, 0x74, 0x6F, 0x62, 0x61, 0x68, 0x6E, 0x70 }, + new byte[] { 0x79, 0x74, 0x68, 0x6F, 0x6E, 0x5F, 0x30, 0x5F, 0x36, 0x5F, 0x30, 0x5F, 0x63, 0x61, 0x73, 0x65 }, + new byte[] { 0x5F, 0x31, 0x5F, 0x31, 0x5F, 0x33, 0x2E, 0x6A, 0x73, 0x6F, 0x6E, 0x22, 0x0A, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x7D, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x31, 0x2E, 0x31 }, + new byte[] { 0x2E, 0x34, 0x22, 0x3A, 0x20, 0x7B, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x22, 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22 }, + new byte[] { 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x62, 0x65, 0x68, 0x61 }, + new byte[] { 0x76, 0x69, 0x6F, 0x72, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x22, 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22 }, + new byte[] { 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x64, 0x75, 0x72, 0x61 }, + new byte[] { 0x74, 0x69, 0x6F, 0x6E, 0x22, 0x3A, 0x20, 0x32, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20 } + }; + + WebSocketTestStream stream = new(); + using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + KeepAliveInterval = TimeSpan.Zero, + DangerousDeflateOptions = new() + { + ClientMaxWindowBits = 9, + ServerMaxWindowBits = 9 + } + }); + using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + KeepAliveInterval = TimeSpan.Zero, + DangerousDeflateOptions = new() + { + ClientMaxWindowBits = 9, + ServerMaxWindowBits = 9 + } + }); + + foreach (var message in messages) + { + await server.SendAsync(message, WebSocketMessageType.Text, true, CancellationToken); + } + + Memory buffer = new byte[32]; + + for (int i = 0; i < messages.Length; ++i) + { + ValueWebSocketReceiveResult result = await client.ReceiveAsync(buffer, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(messages[i].Length, result.Count); + Assert.True(buffer.Span.Slice(0, result.Count).SequenceEqual(messages[i])); + } + } + private ValueTask SendTextAsync(string text, WebSocket websocket, bool disableCompression = false) { WebSocketMessageFlags flags = WebSocketMessageFlags.EndOfMessage; From f6a4b32572b25e5dc7d0061ab0103685ac00e39e Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Wed, 28 Apr 2021 12:35:01 +0300 Subject: [PATCH 45/47] Addressing PR feedback. --- .../Common/src/System/IO/Compression/ZLibNative.cs | 2 -- .../System/Net/WebSockets/WebSocketHandle.Managed.cs | 11 ++++++----- .../Net/WebSockets/Compression/WebSocketInflater.cs | 2 +- .../Net/WebSockets/Compression/ZLibStreamPool.cs | 6 +++++- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs index 98cc8ad59dbc5a..f0393ebbf35cb5 100644 --- a/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs +++ b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs @@ -289,7 +289,6 @@ public ErrorCode DeflateReset() return Interop.zlib.DeflateReset(ref _zStream); } - public ErrorCode DeflateEnd() { EnsureNotDisposed(); @@ -329,7 +328,6 @@ public ErrorCode InflateReset() return Interop.zlib.InflateReset(ref _zStream); } - public ErrorCode InflateEnd() { EnsureNotDisposed(); diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index ba18af70e8805c..e0c3902a915909 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -257,19 +257,19 @@ private static WebSocketDeflateOptions ParseDeflateOptions(ReadOnlySpan ex if (value.Length > 0) { - if (value.Equals(ClientWebSocketDeflateConstants.ClientNoContextTakeover, StringComparison.Ordinal)) + if (value.SequenceEqual(ClientWebSocketDeflateConstants.ClientNoContextTakeover)) { options.ClientContextTakeover = false; } - else if (value.Equals(ClientWebSocketDeflateConstants.ServerNoContextTakeover, StringComparison.Ordinal)) + else if (value.SequenceEqual(ClientWebSocketDeflateConstants.ServerNoContextTakeover)) { options.ServerContextTakeover = false; } - else if (value.StartsWith(ClientWebSocketDeflateConstants.ClientMaxWindowBits, StringComparison.Ordinal)) + else if (value.StartsWith(ClientWebSocketDeflateConstants.ClientMaxWindowBits)) { options.ClientMaxWindowBits = ParseWindowBits(value); } - else if (value.StartsWith(ClientWebSocketDeflateConstants.ServerMaxWindowBits, StringComparison.Ordinal)) + else if (value.StartsWith(ClientWebSocketDeflateConstants.ServerMaxWindowBits)) { options.ServerMaxWindowBits = ParseWindowBits(value); } @@ -292,8 +292,9 @@ static int ParseWindowBits(ReadOnlySpan value) } if (end < 0) + { break; - + } extension = extension[(end + 1)..]; } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index 169fa604e235f0..a13fadd10f208d 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -85,7 +85,7 @@ public void Prepare(long payloadLength, int userBufferLength) // Rent a buffer as close to the size of the user buffer as possible, // but not try to rent anything above 1MB because the array pool will allocate. // If the payload is smaller than the user buffer, rent only as much as we need. - _buffer = ArrayPool.Shared.Rent(Math.Min(userBufferLength, (int)Math.Min(payloadLength, 1_000_000))); + _buffer = ArrayPool.Shared.Rent(Math.Min(userBufferLength, (int)Math.Min(payloadLength, 1024 * 1024))); } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs index 75ff4707adcdad..12c644bab7d7c5 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs @@ -76,7 +76,11 @@ public static ZLibStreamPool GetOrCreate(int windowBits) static ZLibStreamPool EnsureInitialized(int windowBits, ref ZLibStreamPool? target) { - Interlocked.CompareExchange(ref target, new ZLibStreamPool(windowBits, DefaultTimeoutMilliseconds), null); + ZLibStreamPool newPool = new(windowBits, DefaultTimeoutMilliseconds); + if (Interlocked.CompareExchange(ref target, newPool, null) is not null) + { + newPool._cleaningTimer.Dispose(); + } Debug.Assert(target != null); return target; From f0f09f340491f2cb260f36598506f0d40e565896 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Wed, 28 Apr 2021 16:49:24 +0300 Subject: [PATCH 46/47] Removed custom deflate pool. --- .../src/System.Net.WebSockets.csproj | 1 - .../Compression/WebSocketDeflater.cs | 38 ++- .../Compression/WebSocketInflater.cs | 37 ++- .../WebSockets/Compression/ZLibStreamPool.cs | 291 ------------------ .../tests/System.Net.WebSockets.Tests.csproj | 1 - .../tests/ZLibStreamTests.cs | 104 ------- 6 files changed, 65 insertions(+), 407 deletions(-) delete mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs delete mode 100644 src/libraries/System.Net.WebSockets/tests/ZLibStreamTests.cs diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index 3419171de4cc71..215cf6b4a91649 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -7,7 +7,6 @@ - diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 36a0049d705fbf..a17fe4f8aeb9a2 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -12,7 +12,7 @@ namespace System.Net.WebSockets.Compression /// internal sealed class WebSocketDeflater : IDisposable { - private readonly ZLibStreamPool _streamPool; + private readonly int _windowBits; private ZLibStreamHandle? _stream; private readonly bool _persisted; @@ -20,7 +20,7 @@ internal sealed class WebSocketDeflater : IDisposable internal WebSocketDeflater(int windowBits, bool persisted) { - _streamPool = ZLibStreamPool.GetOrCreate(windowBits); + _windowBits = -windowBits; // Negative for raw deflate _persisted = persisted; } @@ -28,7 +28,7 @@ public void Dispose() { if (_stream is not null) { - _streamPool.ReturnDeflater(_stream); + _stream.Dispose(); _stream = null; } } @@ -84,7 +84,7 @@ public ReadOnlySpan Deflate(ReadOnlySpan payload, bool endOfMessage) private void DeflatePrivate(ReadOnlySpan payload, Span output, bool endOfMessage, out int consumed, out int written, out bool needsMoreOutput) { - _stream ??= _streamPool.GetDeflater(); + _stream ??= CreateDeflater(); if (payload.Length == 0) { @@ -119,7 +119,7 @@ private void DeflatePrivate(ReadOnlySpan payload, Span output, bool if (endOfMessage && !_persisted) { - _streamPool.ReturnDeflater(_stream); + _stream.Dispose(); _stream = null; } } @@ -201,5 +201,33 @@ private static ErrorCode Deflate(ZLibStreamHandle stream, FlushCode flushCode) : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); throw new WebSocketException(message); } + + private ZLibStreamHandle CreateDeflater() + { + ZLibStreamHandle stream; + ErrorCode errorCode; + try + { + errorCode = CreateZLibStreamForDeflate(out stream, + level: CompressionLevel.DefaultCompression, + windowBits: _windowBits, + memLevel: Deflate_DefaultMemLevel, + strategy: CompressionStrategy.DefaultStrategy); + } + catch (Exception cause) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + + if (errorCode != ErrorCode.Ok) + { + string message = errorCode == ErrorCode.MemError + ? SR.ZLibErrorNotEnoughMemory + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); + } + + return stream; + } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index a13fadd10f208d..6ade12d539a440 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -15,7 +15,7 @@ internal sealed class WebSocketInflater : IDisposable internal const int FlushMarkerLength = 4; internal static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; - private readonly ZLibStreamPool _streamPool; + private readonly int _windowBits; private ZLibStreamHandle? _stream; private readonly bool _persisted; @@ -48,7 +48,7 @@ internal sealed class WebSocketInflater : IDisposable internal WebSocketInflater(int windowBits, bool persisted) { - _streamPool = ZLibStreamPool.GetOrCreate(windowBits); + _windowBits = -windowBits; // Negative for raw deflate _persisted = persisted; } @@ -60,7 +60,7 @@ public void Dispose() { if (_stream is not null) { - _streamPool.ReturnInflater(_stream); + _stream.Dispose(); _stream = null; } ReleaseBuffer(); @@ -128,7 +128,7 @@ public void AddBytes(int totalBytesReceived, bool endOfMessage) /// public unsafe bool Inflate(Span output, out int written) { - _stream ??= _streamPool.GetInflater(); + _stream ??= CreateInflater(); if (_available > 0 && output.Length > 0) { @@ -192,7 +192,7 @@ private unsafe bool Finish(Span output, ref int written) { if (!_persisted) { - _streamPool.ReturnInflater(_stream); + _stream.Dispose(); _stream = null; } return true; @@ -254,5 +254,32 @@ private static unsafe int Inflate(ZLibStreamHandle stream, Span destinatio }; throw new WebSocketException(message); } + + private ZLibStreamHandle CreateInflater() + { + ZLibStreamHandle stream; + ErrorCode errorCode; + + try + { + errorCode = CreateZLibStreamForInflate(out stream, _windowBits); + } + catch (Exception exception) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); + } + + if (errorCode == ErrorCode.Ok) + { + return stream; + } + + stream.Dispose(); + + string message = errorCode == ErrorCode.MemError + ? SR.ZLibErrorNotEnoughMemory + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); + } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs deleted file mode 100644 index 12c644bab7d7c5..00000000000000 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs +++ /dev/null @@ -1,291 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Threading; -using static System.IO.Compression.ZLibNative; - -namespace System.Net.WebSockets.Compression -{ - internal sealed class ZLibStreamPool - { - private static readonly ZLibStreamPool?[] s_pools - = new ZLibStreamPool[WebSocketValidate.MaxDeflateWindowBits - WebSocketValidate.MinDeflateWindowBits + 1]; - - /// - /// The default amount of time after which a cached item will be removed. - /// - private const int DefaultTimeoutMilliseconds = 60_000; - - private readonly int _windowBits; - private readonly List _inflaters = new(); - private readonly List _deflaters = new(); - private readonly Timer _cleaningTimer; - - /// - /// The amount of time after which a cached item will be removed. - /// - private readonly int _timeoutMilliseconds; - - /// - /// The number of cached inflaters and deflaters. - /// - private int _activeCount; - - private ZLibStreamPool(int windowBits, int timeoutMilliseconds) - { - // Use negative window bits to for raw deflate data - _windowBits = -windowBits; - _timeoutMilliseconds = timeoutMilliseconds; - - bool restoreFlow = false; - try - { - if (!ExecutionContext.IsFlowSuppressed()) - { - ExecutionContext.SuppressFlow(); - restoreFlow = true; - } - - // There is no need to use weak references here, because these pools are kept - // for the entire lifetime of the application. Also we reset the timer on each tick, - // which prevents the object being rooted forever. - _cleaningTimer = new Timer(x => ((ZLibStreamPool)x!).RemoveStaleItems(), - state: this, Timeout.Infinite, Timeout.Infinite); - } - finally - { - if (restoreFlow) - { - ExecutionContext.RestoreFlow(); - } - } - } - - public static ZLibStreamPool GetOrCreate(int windowBits) - { - Debug.Assert(windowBits >= WebSocketValidate.MinDeflateWindowBits - && windowBits <= WebSocketValidate.MaxDeflateWindowBits); - - int index = windowBits - WebSocketValidate.MinDeflateWindowBits; - ref ZLibStreamPool? pool = ref s_pools[index]; - - return Volatile.Read(ref pool) ?? EnsureInitialized(windowBits, ref pool); - - static ZLibStreamPool EnsureInitialized(int windowBits, ref ZLibStreamPool? target) - { - ZLibStreamPool newPool = new(windowBits, DefaultTimeoutMilliseconds); - if (Interlocked.CompareExchange(ref target, newPool, null) is not null) - { - newPool._cleaningTimer.Dispose(); - } - - Debug.Assert(target != null); - return target; - } - } - - public ZLibStreamHandle GetInflater() - { - if (TryGet(_inflaters, out ZLibStreamHandle? stream)) - { - return stream; - } - - return CreateInflater(); - } - - public void ReturnInflater(ZLibStreamHandle stream) - { - if (stream.InflateReset() != ErrorCode.Ok) - { - stream.Dispose(); - return; - } - - Return(stream, _inflaters); - } - - public ZLibStreamHandle GetDeflater() - { - if (TryGet(_deflaters, out ZLibStreamHandle? stream)) - { - return stream; - } - - return CreateDeflater(); - } - - public void ReturnDeflater(ZLibStreamHandle stream) - { - if (stream.DeflateReset() != ErrorCode.Ok) - { - stream.Dispose(); - return; - } - - Return(stream, _deflaters); - } - - private void Return(ZLibStreamHandle stream, List cache) - { - lock (cache) - { - cache.Add(new CacheItem(stream)); - - if (Interlocked.Increment(ref _activeCount) == 1) - { - _cleaningTimer.Change(_timeoutMilliseconds, Timeout.Infinite); - } - } - } - - private bool TryGet(List cache, [NotNullWhen(true)] out ZLibStreamHandle? stream) - { - lock (cache) - { - int count = cache.Count; - - if (count > 0) - { - CacheItem item = cache[count - 1]; - cache.RemoveAt(count - 1); - Interlocked.Decrement(ref _activeCount); - - stream = item.Stream; - return true; - } - } - - stream = null; - return false; - } - - private void RemoveStaleItems() - { - RemoveStaleItems(_inflaters); - RemoveStaleItems(_deflaters); - - // There is a race condition here, were _activeCount could be decremented - // by a rent operation, but it's not big deal to schedule a timer tick that - // would eventually do nothing. - if (_activeCount > 0) - { - _cleaningTimer.Change(_timeoutMilliseconds, Timeout.Infinite); - } - } - - private void RemoveStaleItems(List cache) - { - long currentTimestamp = Environment.TickCount64; - List? removedStreams = null; - - lock (cache) - { - for (int index = 0; index < cache.Count; ++index) - { - CacheItem item = cache[index]; - - if (currentTimestamp - item.Timestamp > _timeoutMilliseconds) - { - removedStreams ??= new List(); - removedStreams.Add(item.Stream); - Interlocked.Decrement(ref _activeCount); - } - else - { - // The freshest streams are in the back of the collection. - // If we've reached a stream that is not timed out, all - // other after it will not be as well. - break; - } - } - - if (removedStreams is null) - { - return; - } - - cache.RemoveRange(0, removedStreams.Count); - } - - foreach (ZLibStreamHandle stream in removedStreams) - { - stream.Dispose(); - } - } - - private ZLibStreamHandle CreateDeflater() - { - ZLibStreamHandle stream; - ErrorCode errorCode; - try - { - errorCode = CreateZLibStreamForDeflate(out stream, - level: CompressionLevel.DefaultCompression, - windowBits: _windowBits, - memLevel: Deflate_DefaultMemLevel, - strategy: CompressionStrategy.DefaultStrategy); - } - catch (Exception cause) - { - throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); - } - - if (errorCode != ErrorCode.Ok) - { - string message = errorCode == ErrorCode.MemError - ? SR.ZLibErrorNotEnoughMemory - : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); - throw new WebSocketException(message); - } - - return stream; - } - - private ZLibStreamHandle CreateInflater() - { - ZLibStreamHandle stream; - ErrorCode errorCode; - - try - { - errorCode = CreateZLibStreamForInflate(out stream, _windowBits); - } - catch (Exception exception) - { - throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); - } - - if (errorCode == ErrorCode.Ok) - { - return stream; - } - - stream.Dispose(); - - string message = errorCode == ErrorCode.MemError - ? SR.ZLibErrorNotEnoughMemory - : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); - throw new WebSocketException(message); - } - - private readonly struct CacheItem - { - public CacheItem(ZLibStreamHandle stream) - { - Stream = stream; - Timestamp = Environment.TickCount64; - } - - public ZLibStreamHandle Stream { get; } - - /// - /// The time when this item was returned to cache. - /// - public long Timestamp { get; } - } - } -} diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index 9ce68c7b42f998..4e0bc74ebdaeca 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -10,7 +10,6 @@ - - { - if (x % 2 == 0) - { - object inflater = pool.GetInflater(); - pool.ReturnInflater(inflater); - } - else - { - object deflater = pool.GetDeflater(); - pool.ReturnDeflater(deflater); - } - }); - - Assert.True(pool.ActiveCount >= 2); - Assert.True(pool.ActiveCount <= parallelOptions.MaxDegreeOfParallelism * 2); - await Task.Delay(200); - Assert.Equal(0, pool.ActiveCount); - } - - private sealed class Pool - { - private static Type? s_type; - private static ConstructorInfo? s_constructor; - private static FieldInfo? s_activeCount; - private static MethodInfo? s_rentInflater; - private static MethodInfo? s_returnInflater; - private static MethodInfo? s_rentDeflater; - private static MethodInfo? s_returnDeflater; - - private readonly object _instance; - - public Pool(int timeoutMilliseconds) - { - s_type ??= typeof(WebSocket).Assembly.GetType("System.Net.WebSockets.Compression.ZLibStreamPool", throwOnError: true); - s_constructor ??= s_type.GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic)[0]; - - _instance = s_constructor.Invoke(new object[] { /*windowBits*/9, timeoutMilliseconds }); - } - - public int ActiveCount => (int)(s_activeCount ??= s_type.GetField("_activeCount", BindingFlags.Instance | BindingFlags.NonPublic)).GetValue(_instance); - - public object GetInflater() => GetMethod(ref s_rentInflater).Invoke(_instance, null); - - public void ReturnInflater(object inflater) => GetMethod(ref s_returnInflater).Invoke(_instance, new[] { inflater }); - - public object GetDeflater() => GetMethod(ref s_rentDeflater).Invoke(_instance, null); - - public void ReturnDeflater(object deflater) => GetMethod(ref s_returnDeflater).Invoke(_instance, new[] { deflater }); - - private static MethodInfo GetMethod(ref MethodInfo? method, [CallerMemberName] string? name = null) - { - return method ??= s_type.GetMethod(name) - ?? throw new InvalidProgramException($"Method {name} was not found in {s_type}."); - } - } - } -} From ea049eaca5d7c9189e043bd628b1f0fd181b6fbb Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Wed, 28 Apr 2021 17:03:16 +0300 Subject: [PATCH 47/47] Added missing dispose when creating deflater. --- .../WebSockets/Compression/WebSocketDeflater.cs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index a17fe4f8aeb9a2..e7f18072842433 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -219,15 +219,17 @@ private ZLibStreamHandle CreateDeflater() throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); } - if (errorCode != ErrorCode.Ok) + if (errorCode == ErrorCode.Ok) { - string message = errorCode == ErrorCode.MemError - ? SR.ZLibErrorNotEnoughMemory - : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); - throw new WebSocketException(message); + return stream; } - return stream; + stream.Dispose(); + + string message = errorCode == ErrorCode.MemError + ? SR.ZLibErrorNotEnoughMemory + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); } } }