diff --git a/src/Common/src/System/Net/Http/NoWriteNoSeekStreamContent.cs b/src/Common/src/System/Net/Http/NoWriteNoSeekStreamContent.cs index 2d55520f2364..202782cc1031 100644 --- a/src/Common/src/System/Net/Http/NoWriteNoSeekStreamContent.cs +++ b/src/Common/src/System/Net/Http/NoWriteNoSeekStreamContent.cs @@ -14,10 +14,9 @@ namespace System.Net.Http internal sealed class NoWriteNoSeekStreamContent : HttpContent { private readonly Stream _content; - private readonly CancellationToken _cancellationToken; private bool _contentConsumed; - internal NoWriteNoSeekStreamContent(Stream content, CancellationToken cancellationToken) + internal NoWriteNoSeekStreamContent(Stream content) { Debug.Assert(content != null); Debug.Assert(content.CanRead); @@ -25,10 +24,16 @@ internal NoWriteNoSeekStreamContent(Stream content, CancellationToken cancellati Debug.Assert(!content.CanSeek); _content = content; - _cancellationToken = cancellationToken; } - protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) + protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) => + SerializeToStreamAsync(stream, context, CancellationToken.None); + + internal +#if HTTP_DLL + override +#endif + Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) { Debug.Assert(stream != null); @@ -39,7 +44,7 @@ protected override Task SerializeToStreamAsync(Stream stream, TransportContext c _contentConsumed = true; const int BufferSize = 8192; - Task copyTask = _content.CopyToAsync(stream, BufferSize, _cancellationToken); + Task copyTask = _content.CopyToAsync(stream, BufferSize, cancellationToken); if (copyTask.IsCompleted) { try { _content.Dispose(); } catch { } // same as StreamToStreamCopy behavior @@ -75,6 +80,10 @@ protected override void Dispose(bool disposing) base.Dispose(disposing); } - protected override Task CreateContentReadStreamAsync() => Task.FromResult(_content); + protected override Task CreateContentReadStreamAsync() => Task.FromResult(_content); + +#if HTTP_DLL + internal override Stream TryCreateContentReadStream() => _content; +#endif } } diff --git a/src/Common/src/System/Net/Logging/NetEventSource.Common.cs b/src/Common/src/System/Net/Logging/NetEventSource.Common.cs index 685a8fe4f1ea..0948ec744878 100644 --- a/src/Common/src/System/Net/Logging/NetEventSource.Common.cs +++ b/src/Common/src/System/Net/Logging/NetEventSource.Common.cs @@ -395,7 +395,9 @@ private static void DebugValidateArg(FormattableString arg) Debug.Assert(IsEnabled || arg == null, $"Should not be formatting FormattableString \"{arg}\" if tracing isn't enabled"); } - public static new bool IsEnabled => Log.IsEnabled(); + public static new bool IsEnabled => + Log.IsEnabled(); + //true; // uncomment for debugging only [NonEvent] public static string IdOf(object value) => value != null ? value.GetType().Name + "#" + GetHashCode(value) : NullInstance; diff --git a/src/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseParser.cs b/src/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseParser.cs index 16af2adeccdb..bc6ff3f10353 100644 --- a/src/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseParser.cs +++ b/src/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseParser.cs @@ -92,7 +92,7 @@ public static HttpResponseMessage CreateResponseMessage( } } - response.Content = new NoWriteNoSeekStreamContent(decompressedStream, state.CancellationToken); + response.Content = new NoWriteNoSeekStreamContent(decompressedStream); response.RequestMessage = request; // Parse raw response headers and place them into response message. diff --git a/src/System.Net.Http/src/System.Net.Http.csproj b/src/System.Net.Http/src/System.Net.Http.csproj index 21040941755b..ef88d164f39a 100644 --- a/src/System.Net.Http/src/System.Net.Http.csproj +++ b/src/System.Net.Http/src/System.Net.Http.csproj @@ -1,4 +1,4 @@ - + @@ -137,11 +137,11 @@ - + @@ -464,4 +464,4 @@ - \ No newline at end of file + diff --git a/src/System.Net.Http/src/System/Net/Http/CurlHandler/CurlHandler.CurlResponseMessage.cs b/src/System.Net.Http/src/System/Net/Http/CurlHandler/CurlHandler.CurlResponseMessage.cs index 1fafd9681d20..55d2228ccb43 100644 --- a/src/System.Net.Http/src/System/Net/Http/CurlHandler/CurlHandler.CurlResponseMessage.cs +++ b/src/System.Net.Http/src/System/Net/Http/CurlHandler/CurlHandler.CurlResponseMessage.cs @@ -22,7 +22,7 @@ internal CurlResponseMessage(EasyRequest easy) Debug.Assert(easy != null, "Expected non-null EasyRequest"); RequestMessage = easy._requestMessage; ResponseStream = new CurlResponseStream(easy); - Content = new NoWriteNoSeekStreamContent(ResponseStream, CancellationToken.None); + Content = new NoWriteNoSeekStreamContent(ResponseStream); // On Windows, we pass the equivalent of the easy._cancellationToken // in to StreamContent's ctor. This in turn passes that token through diff --git a/src/System.Net.Http/src/System/Net/Http/HttpClient.cs b/src/System.Net.Http/src/System/Net/Http/HttpClient.cs index 1bcac19a93b7..3c8eca38eeef 100644 --- a/src/System.Net.Http/src/System/Net/Http/HttpClient.cs +++ b/src/System.Net.Http/src/System/Net/Http/HttpClient.cs @@ -475,7 +475,7 @@ private async Task FinishSendAsyncBuffered( // Buffer the response content if we've been asked to and we have a Content to buffer. if (response.Content != null) { - await response.Content.LoadIntoBufferAsync(_maxResponseContentBufferSize).ConfigureAwait(false); + await response.Content.LoadIntoBufferAsync(_maxResponseContentBufferSize, cts.Token).ConfigureAwait(false); } if (NetEventSource.IsEnabled) NetEventSource.ClientSendCompleted(this, response, request); diff --git a/src/System.Net.Http/src/System/Net/Http/HttpContent.cs b/src/System.Net.Http/src/System/Net/Http/HttpContent.cs index 191c5f422bd4..f1c15e3cb774 100644 --- a/src/System.Net.Http/src/System/Net/Http/HttpContent.cs +++ b/src/System.Net.Http/src/System/Net/Http/HttpContent.cs @@ -299,7 +299,18 @@ internal Stream TryReadAsStream() protected abstract Task SerializeToStreamAsync(Stream stream, TransportContext context); - public Task CopyToAsync(Stream stream, TransportContext context) + // TODO #9071: Expose this publicly. Until it's public, only sealed or internal types should override it, and then change + // their SerializeToStreamAsync implementation to delegate to this one. They need to be sealed as otherwise an external + // type could derive from it and override SerializeToStreamAsync(stream, context) further, at which point when + // HttpClient calls SerializeToStreamAsync(stream, context, cancellationToken), their custom override will be skipped. + internal virtual Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) => + SerializeToStreamAsync(stream, context); + + public Task CopyToAsync(Stream stream, TransportContext context) => + CopyToAsync(stream, context, CancellationToken.None); + + // TODO #9071: Expose this publicly. + internal Task CopyToAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) { CheckDisposed(); if (stream == null) @@ -313,11 +324,11 @@ public Task CopyToAsync(Stream stream, TransportContext context) ArraySegment buffer; if (TryGetBuffer(out buffer)) { - task = stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count); + task = stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken); } else { - task = SerializeToStreamAsync(stream, context); + task = SerializeToStreamAsync(stream, context, cancellationToken); CheckTaskNotNull(task); } @@ -354,7 +365,10 @@ public Task LoadIntoBufferAsync() // No "CancellationToken" parameter needed since canceling the CTS will close the connection, resulting // in an exception being thrown while we're buffering. // If buffering is used without a connection, it is supposed to be fast, thus no cancellation required. - public Task LoadIntoBufferAsync(long maxBufferSize) + public Task LoadIntoBufferAsync(long maxBufferSize) => + LoadIntoBufferAsync(maxBufferSize, CancellationToken.None); + + internal Task LoadIntoBufferAsync(long maxBufferSize, CancellationToken cancellationToken) { CheckDisposed(); if (maxBufferSize > HttpContent.MaxBufferSize) @@ -382,7 +396,7 @@ public Task LoadIntoBufferAsync(long maxBufferSize) try { - Task task = SerializeToStreamAsync(tempBuffer, null); + Task task = SerializeToStreamAsync(tempBuffer, null, cancellationToken); CheckTaskNotNull(task); return LoadIntoBufferAsyncCore(task, tempBuffer); } diff --git a/src/System.Net.Http/src/System/Net/Http/NetEventSource.Http.cs b/src/System.Net.Http/src/System/Net/Http/NetEventSource.Http.cs index 1924dc631d87..8735970c8e69 100644 --- a/src/System.Net.Http/src/System/Net/Http/NetEventSource.Http.cs +++ b/src/System.Net.Http/src/System/Net/Http/NetEventSource.Http.cs @@ -63,6 +63,7 @@ public void HeadersInvalidValue(string name, string rawValue) => [Event(HandlerMessageId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] public void HandlerMessage(int handlerId, int workerId, int requestId, string memberName, string message) => WriteEvent(HandlerMessageId, handlerId, workerId, requestId, memberName, message); + //Console.WriteLine($"{handlerId}/{workerId}/{requestId}: ({memberName}): {message}"); // uncomment for debugging only [NonEvent] private unsafe void WriteEvent(int eventId, int arg1, int arg2, int arg3, string arg4, string arg5) diff --git a/src/System.Net.Http/src/System/Net/Http/ReadOnlyMemoryContent.cs b/src/System.Net.Http/src/System/Net/Http/ReadOnlyMemoryContent.cs index 0a9588d71ac5..11b98776e67f 100644 --- a/src/System.Net.Http/src/System/Net/Http/ReadOnlyMemoryContent.cs +++ b/src/System.Net.Http/src/System/Net/Http/ReadOnlyMemoryContent.cs @@ -4,6 +4,7 @@ using System.IO; using System.Runtime.InteropServices; +using System.Threading; using System.Threading.Tasks; namespace System.Net.Http @@ -26,6 +27,9 @@ public ReadOnlyMemoryContent(ReadOnlyMemory content) protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) => stream.WriteAsync(_content); + internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) => + stream.WriteAsync(_content, cancellationToken); + protected internal override bool TryComputeLength(out long length) { length = _content.Length; diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs index ab8865e81d34..32bd40b4d8a1 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs @@ -30,13 +30,13 @@ public ChunkedEncodingReadStream(HttpConnection connection) : base(connection) { } - private async Task TryGetNextChunkAsync(CancellationToken cancellationToken) + private async Task TryGetNextChunkAsync() { Debug.Assert(_chunkBytesRemaining == 0); // Read the start of the chunk line. _connection._allowedReadLineBytes = MaxChunkBytesAllowed; - ArraySegment line = await _connection.ReadNextLineAsync(cancellationToken).ConfigureAwait(false); + ArraySegment line = await _connection.ReadNextLineAsync().ConfigureAwait(false); // Parse the hex value. if (!Utf8Parser.TryParse(line.AsReadOnlySpan(), out ulong chunkSize, out int bytesConsumed, 'X')) @@ -73,7 +73,7 @@ private async Task TryGetNextChunkAsync(CancellationToken cancellationToke while (true) { _connection._allowedReadLineBytes = MaxTrailingHeaderLength; - if (LineIsEmpty(await _connection.ReadNextLineAsync(cancellationToken).ConfigureAwait(false))) + if (LineIsEmpty(await _connection.ReadNextLineAsync().ConfigureAwait(false))) { break; } @@ -84,59 +84,77 @@ private async Task TryGetNextChunkAsync(CancellationToken cancellationToke return false; } - private async Task ConsumeChunkBytesAsync(ulong bytesConsumed, CancellationToken cancellationToken) + private Task ConsumeChunkBytesAsync(ulong bytesConsumed) { Debug.Assert(bytesConsumed <= _chunkBytesRemaining); _chunkBytesRemaining -= bytesConsumed; - if (_chunkBytesRemaining == 0) + return _chunkBytesRemaining != 0 ? + Task.CompletedTask : + ReadNextLineAndThrowIfNotEmptyAsync(); + } + + private async Task ReadNextLineAndThrowIfNotEmptyAsync() + { + _connection._allowedReadLineBytes = 2; // \r\n + if (!LineIsEmpty(await _connection.ReadNextLineAsync().ConfigureAwait(false))) { - _connection._allowedReadLineBytes = 2; // \r\n - if (!LineIsEmpty(await _connection.ReadNextLineAsync(cancellationToken).ConfigureAwait(false))) - { - ThrowInvalidHttpResponse(); - } + ThrowInvalidHttpResponse(); } } public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { ValidateBufferArgs(buffer, offset, count); - return ReadAsync(new Memory(buffer, offset, count)).AsTask(); + return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); } - public override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + public override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (_connection == null || destination.Length == 0) { // Response body fully consumed or the caller didn't ask for any data return 0; } - if (_chunkBytesRemaining == 0) + CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + try { - if (!await TryGetNextChunkAsync(cancellationToken).ConfigureAwait(false)) + if (_chunkBytesRemaining == 0) { - // End of response body - return 0; + if (!await TryGetNextChunkAsync().ConfigureAwait(false)) + { + // End of response body + return 0; + } } - } - if (_chunkBytesRemaining < (ulong)destination.Length) - { - destination = destination.Slice(0, (int)_chunkBytesRemaining); - } + if (_chunkBytesRemaining < (ulong)destination.Length) + { + destination = destination.Slice(0, (int)_chunkBytesRemaining); + } - int bytesRead = await _connection.ReadAsync(destination, cancellationToken).ConfigureAwait(false); + int bytesRead = await _connection.ReadAsync(destination).ConfigureAwait(false); - if (bytesRead <= 0) - { - // Unexpected end of response stream - throw new IOException(SR.net_http_invalid_response); - } + if (bytesRead <= 0) + { + // Unexpected end of response stream + throw new IOException(SR.net_http_invalid_response); + } - await ConsumeChunkBytesAsync((ulong)bytesRead, cancellationToken).ConfigureAwait(false); + await ConsumeChunkBytesAsync((ulong)bytesRead).ConfigureAwait(false); - return bytesRead; + return bytesRead; + } + catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken)) + { + throw new OperationCanceledException(s_cancellationMessage, exc, cancellationToken); + } + finally + { + ctr.Dispose(); + } } public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) @@ -145,6 +163,12 @@ public override async Task CopyToAsync(Stream destination, int bufferSize, Cance { throw new ArgumentNullException(nameof(destination)); } + if (bufferSize <= 0) + { + throw new ArgumentOutOfRangeException(nameof(bufferSize)); + } + + cancellationToken.ThrowIfCancellationRequested(); if (_connection == null) { @@ -152,16 +176,28 @@ public override async Task CopyToAsync(Stream destination, int bufferSize, Cance return; } - if (_chunkBytesRemaining > 0) + CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + try { - await _connection.CopyToAsync(destination, _chunkBytesRemaining, cancellationToken).ConfigureAwait(false); - await ConsumeChunkBytesAsync(_chunkBytesRemaining, cancellationToken).ConfigureAwait(false); - } + if (_chunkBytesRemaining > 0) + { + await _connection.CopyToAsync(destination, _chunkBytesRemaining).ConfigureAwait(false); + await ConsumeChunkBytesAsync(_chunkBytesRemaining).ConfigureAwait(false); + } - while (await TryGetNextChunkAsync(cancellationToken).ConfigureAwait(false)) + while (await TryGetNextChunkAsync().ConfigureAwait(false)) + { + await _connection.CopyToAsync(destination, _chunkBytesRemaining).ConfigureAwait(false); + await ConsumeChunkBytesAsync(_chunkBytesRemaining).ConfigureAwait(false); + } + } + catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken)) + { + throw CreateOperationCanceledException(exc, cancellationToken); + } + finally { - await _connection.CopyToAsync(destination, _chunkBytesRemaining, cancellationToken).ConfigureAwait(false); - await ConsumeChunkBytesAsync(_chunkBytesRemaining, cancellationToken).ConfigureAwait(false); + ctr.Dispose(); } } } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs index 04bcdc194c37..82682acd6db0 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics; using System.Threading; using System.Threading.Tasks; @@ -13,8 +14,7 @@ private sealed class ChunkedEncodingWriteStream : HttpContentWriteStream { private static readonly byte[] s_finalChunkBytes = { (byte)'0', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' }; - public ChunkedEncodingWriteStream(HttpConnection connection, CancellationToken cancellationToken) : - base(connection, cancellationToken) + public ChunkedEncodingWriteStream(HttpConnection connection) : base(connection) { } @@ -24,13 +24,17 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati return WriteAsync(new Memory(buffer, offset, count), ignored); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + public override Task WriteAsync(ReadOnlyMemory source, CancellationToken ignored) { + // The token is ignored because it's coming from SendAsync and the only operations + // here are those that are already covered by the token having been registered with + // to close the connection. + if (source.Length == 0) { // Don't write if nothing was given, especially since we don't want to accidentally send a 0 chunk, // which would indicate end of body. Instead, just ensure no content is stuck in the buffer. - return _connection.FlushAsync(RequestCancellationToken); + return _connection.FlushAsync(); } if (_connection._currentRequest == null) @@ -54,17 +58,17 @@ private async Task WriteChunkAsync(ReadOnlyMemory source) int digit = (source.Length & mask) >> shift; if (digitWritten || digit != 0) { - await _connection.WriteByteAsync((byte)(digit < 10 ? '0' + digit : 'A' + digit - 10), RequestCancellationToken).ConfigureAwait(false); + await _connection.WriteByteAsync((byte)(digit < 10 ? '0' + digit : 'A' + digit - 10)).ConfigureAwait(false); digitWritten = true; } } // End chunk length - await _connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n', RequestCancellationToken).ConfigureAwait(false); + await _connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false); // Write chunk contents - await _connection.WriteAsync(source, RequestCancellationToken).ConfigureAwait(false); - await _connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n', RequestCancellationToken).ConfigureAwait(false); + await _connection.WriteAsync(source).ConfigureAwait(false); + await _connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false); // Flush the chunk. This is reasonable from the standpoint of having just written a standalone piece // of data, but is also necessary to support duplex communication, where a CopyToAsync is taking the @@ -72,18 +76,16 @@ private async Task WriteChunkAsync(ReadOnlyMemory source) // source was empty, and it might be kept open to enable subsequent communication. And it's necessary // in general for at least the first write, as we need to ensure if it's the entirety of the content // and if all of the headers and content fit in the write buffer that we've actually sent the request. - await _connection.FlushAsync(RequestCancellationToken).ConfigureAwait(false); + await _connection.FlushAsync().ConfigureAwait(false); } - public override Task FlushAsync(CancellationToken ignored) - { - return _connection.FlushAsync(RequestCancellationToken); - } - + public override Task FlushAsync(CancellationToken ignored) => // see comment on WriteAsync about "ignored" + _connection.FlushAsync(); + public override async Task FinishAsync() { // Send 0 byte chunk to indicate end, then final CrLf - await _connection.WriteBytesAsync(s_finalChunkBytes, RequestCancellationToken).ConfigureAwait(false); + await _connection.WriteBytesAsync(s_finalChunkBytes).ConfigureAwait(false); _connection = null; } } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs index ab0a28bb1cc3..2c000f63c420 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs @@ -22,17 +22,49 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); } - public override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + public override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (_connection == null || destination.Length == 0) { // Response body fully consumed or the caller didn't ask for any data return 0; } - int bytesRead = await _connection.ReadAsync(destination, cancellationToken).ConfigureAwait(false); + ValueTask readTask = _connection.ReadAsync(destination); + int bytesRead; + if (readTask.IsCompletedSuccessfully) + { + bytesRead = readTask.Result; + } + else + { + CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + try + { + bytesRead = await readTask.ConfigureAwait(false); + } + catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken)) + { + throw CreateOperationCanceledException(exc, cancellationToken); + } + finally + { + ctr.Dispose(); + } + } + if (bytesRead == 0) { + // If cancellation is requested and tears down the connection, it could cause the read + // to return 0, which would otherwise signal the end of the data, but that would lead + // the caller to think that it actually received all of the data, rather than it ending + // early due to cancellation. So we prioritize cancellation in this race condition, and + // if we read 0 bytes and then find that cancellation has requested, we assume cancellation + // was the cause and throw. + cancellationToken.ThrowIfCancellationRequested(); + // We cannot reuse this connection, so close it. _connection.Dispose(); _connection = null; @@ -48,15 +80,46 @@ public override async Task CopyToAsync(Stream destination, int bufferSize, Cance { throw new ArgumentNullException(nameof(destination)); } + if (bufferSize <= 0) + { + throw new ArgumentOutOfRangeException(nameof(bufferSize)); + } - if (_connection != null) // null if response body fully consumed + cancellationToken.ThrowIfCancellationRequested(); + + if (_connection == null) { - await _connection.CopyToAsync(destination, cancellationToken).ConfigureAwait(false); + // Response body fully consumed + return; + } - // We cannot reuse this connection, so close it. - _connection.Dispose(); - _connection = null; + Task copyTask = _connection.CopyToAsync(destination); + if (!copyTask.IsCompletedSuccessfully) + { + CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + try + { + await copyTask.ConfigureAwait(false); + } + catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken)) + { + throw CreateOperationCanceledException(exc, cancellationToken); + } + finally + { + ctr.Dispose(); + } } + + // If cancellation is requested and tears down the connection, it could cause the copy + // to end early but think it ended successfully. So we prioritize cancellation in this + // race condition, and if we find after the copy has completed that cancellation has + // been requested, we assume the copy completed due to cancellation and throw. + cancellationToken.ThrowIfCancellationRequested(); + + // We cannot reuse this connection, so close it. + _connection.Dispose(); + _connection = null; } } } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthReadStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthReadStream.cs index 65fd114950ef..e6daf285e287 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthReadStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthReadStream.cs @@ -15,8 +15,7 @@ private sealed class ContentLengthReadStream : HttpContentReadStream { private ulong _contentBytesRemaining; - public ContentLengthReadStream(HttpConnection connection, ulong contentLength) - : base(connection) + public ContentLengthReadStream(HttpConnection connection, ulong contentLength) : base(connection) { Debug.Assert(contentLength > 0, "Caller should have checked for 0."); _contentBytesRemaining = contentLength; @@ -28,8 +27,10 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); } - public override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + public override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (_connection == null || destination.Length == 0) { // Response body fully consumed or the caller didn't ask for any data @@ -43,11 +44,35 @@ public override async ValueTask ReadAsync(Memory destination, Cancell destination = destination.Slice(0, (int)_contentBytesRemaining); } - int bytesRead = await _connection.ReadAsync(destination, cancellationToken).ConfigureAwait(false); + ValueTask readTask = _connection.ReadAsync(destination); + int bytesRead; + if (readTask.IsCompletedSuccessfully) + { + bytesRead = readTask.Result; + } + else + { + CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + try + { + bytesRead = await readTask.ConfigureAwait(false); + } + catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken)) + { + throw CreateOperationCanceledException(exc, cancellationToken); + } + finally + { + ctr.Dispose(); + } + } if (bytesRead <= 0) { - // Unexpected end of response stream + // A cancellation request may have caused the EOF. + cancellationToken.ThrowIfCancellationRequested(); + + // Unexpected end of response stream. throw new IOException(SR.net_http_invalid_response); } @@ -70,6 +95,12 @@ public override async Task CopyToAsync(Stream destination, int bufferSize, Cance { throw new ArgumentNullException(nameof(destination)); } + if (bufferSize <= 0) + { + throw new ArgumentOutOfRangeException(nameof(bufferSize)); + } + + cancellationToken.ThrowIfCancellationRequested(); if (_connection == null) { @@ -77,7 +108,23 @@ public override async Task CopyToAsync(Stream destination, int bufferSize, Cance return; } - await _connection.CopyToAsync(destination, _contentBytesRemaining, cancellationToken).ConfigureAwait(false); + Task copyTask = _connection.CopyToAsync(destination, _contentBytesRemaining); + if (!copyTask.IsCompletedSuccessfully) + { + CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + try + { + await copyTask.ConfigureAwait(false); + } + catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken)) + { + throw CreateOperationCanceledException(exc, cancellationToken); + } + finally + { + ctr.Dispose(); + } + } _contentBytesRemaining = 0; _connection.ReturnConnectionToPool(); diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs index 0329f43ea9cf..507f883bb9c3 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs @@ -11,18 +11,17 @@ internal partial class HttpConnection : IDisposable { private sealed class ContentLengthWriteStream : HttpContentWriteStream { - public ContentLengthWriteStream(HttpConnection connection, CancellationToken cancellationToken) : - base(connection, cancellationToken) + public ContentLengthWriteStream(HttpConnection connection) : base(connection) { } - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ignored) + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ignored) // token ignored as it comes from SendAsync { ValidateBufferArgs(buffer, offset, count); return WriteAsync(new ReadOnlyMemory(buffer, offset, count), ignored); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + public override Task WriteAsync(ReadOnlyMemory source, CancellationToken ignored) // token ignored as it comes from SendAsync { if (_connection._currentRequest == null) { @@ -34,13 +33,11 @@ public override Task WriteAsync(ReadOnlyMemory source, CancellationToken c // Have the connection write the data, skipping the buffer. Importantly, this will // force a flush of anything already in the buffer, i.e. any remaining request headers // that are still buffered. - return _connection.WriteWithoutBufferingAsync(source, RequestCancellationToken); + return _connection.WriteWithoutBufferingAsync(source); } - public override Task FlushAsync(CancellationToken ignored) - { - return _connection.FlushAsync(RequestCancellationToken); - } + public override Task FlushAsync(CancellationToken ignored) => // token ignored as it comes from SendAsync + _connection.FlushAsync(); public override Task FinishAsync() { diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/DecompressionHandler.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/DecompressionHandler.cs index 261ae26af761..41843de14158 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/DecompressionHandler.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/DecompressionHandler.cs @@ -108,11 +108,14 @@ public DecompressedContent(HttpContent originalContent) protected abstract Stream GetDecompressedStream(Stream originalStream); - protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context) + protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) => + SerializeToStreamAsync(stream, context, CancellationToken.None); + + internal override async Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) { using (Stream decompressedStream = await CreateContentReadStreamAsync().ConfigureAwait(false)) { - await decompressedStream.CopyToAsync(stream).ConfigureAwait(false); + await decompressedStream.CopyToAsync(stream, cancellationToken).ConfigureAwait(false); } } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/EmptyReadStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/EmptyReadStream.cs index bcc4e604348b..6320af3c4637 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/EmptyReadStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/EmptyReadStream.cs @@ -23,10 +23,13 @@ public override void Close() { /* nop */ } public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { ValidateBufferArgs(buffer, offset, count); - return s_zeroTask; + return cancellationToken.IsCancellationRequested ? + Task.FromCanceled(cancellationToken) : + s_zeroTask; } - public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) => + public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : new ValueTask(0); } } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index dfbedcb66ccd..35dd90b72288 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -43,12 +43,14 @@ internal partial class HttpConnection : IDisposable private static readonly byte[] s_spaceHttp11NewlineAsciiBytes = Encoding.ASCII.GetBytes(" HTTP/1.1\r\n"); private static readonly byte[] s_hostKeyAndSeparator = Encoding.ASCII.GetBytes(HttpKnownHeaderNames.Host + ": "); private static readonly byte[] s_httpSchemeAndDelimiter = Encoding.ASCII.GetBytes(Uri.UriSchemeHttp + Uri.SchemeDelimiter); + private static readonly string s_cancellationMessage = new OperationCanceledException().Message; // use same message as the default ctor private readonly HttpConnectionPool _pool; private readonly Stream _stream; private readonly TransportContext _transportContext; private readonly bool _usingProxy; private readonly byte[] _idnHostAsciiBytes; + private readonly WeakReference _weakThisRef; private HttpRequestMessage _currentRequest; private Task _sendRequestContentTask; @@ -83,6 +85,8 @@ public HttpConnection( _writeBuffer = new byte[InitialWriteBufferSize]; _readBuffer = new byte[InitialReadBufferSize]; + _weakThisRef = new WeakReference(this); + if (NetEventSource.IsEnabled) { if (_stream is SslStream sslStream) @@ -152,64 +156,64 @@ public bool CanRetry public DateTimeOffset CreationTime { get; } = DateTimeOffset.UtcNow; - private async Task WriteHeadersAsync(HttpHeaders headers, string cookiesFromContainer, CancellationToken cancellationToken) + private async Task WriteHeadersAsync(HttpHeaders headers, string cookiesFromContainer) { foreach (KeyValuePair> header in headers) { - await WriteAsciiStringAsync(header.Key, cancellationToken).ConfigureAwait(false); - await WriteTwoBytesAsync((byte)':', (byte)' ', cancellationToken).ConfigureAwait(false); + await WriteAsciiStringAsync(header.Key).ConfigureAwait(false); + await WriteTwoBytesAsync((byte)':', (byte)' ').ConfigureAwait(false); var values = (string[])header.Value; // typed as IEnumerable, but always a string[] Debug.Assert(values.Length > 0, "No values for header??"); if (values.Length > 0) { - await WriteStringAsync(values[0], cancellationToken).ConfigureAwait(false); + await WriteStringAsync(values[0]).ConfigureAwait(false); if (cookiesFromContainer != null && header.Key == HttpKnownHeaderNames.Cookie) { - await WriteTwoBytesAsync((byte)';', (byte)' ', cancellationToken).ConfigureAwait(false); - await WriteStringAsync(cookiesFromContainer, cancellationToken).ConfigureAwait(false); + await WriteTwoBytesAsync((byte)';', (byte)' ').ConfigureAwait(false); + await WriteStringAsync(cookiesFromContainer).ConfigureAwait(false); cookiesFromContainer = null; } for (int i = 1; i < values.Length; i++) { - await WriteTwoBytesAsync((byte)',', (byte)' ', cancellationToken).ConfigureAwait(false); - await WriteStringAsync(values[i], cancellationToken).ConfigureAwait(false); + await WriteTwoBytesAsync((byte)',', (byte)' ').ConfigureAwait(false); + await WriteStringAsync(values[i]).ConfigureAwait(false); } } - await WriteTwoBytesAsync((byte)'\r', (byte)'\n', cancellationToken).ConfigureAwait(false); + await WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false); } if (cookiesFromContainer != null) { - await WriteAsciiStringAsync(HttpKnownHeaderNames.Cookie, cancellationToken).ConfigureAwait(false); - await WriteTwoBytesAsync((byte)':', (byte)' ', cancellationToken).ConfigureAwait(false); - await WriteAsciiStringAsync(cookiesFromContainer, cancellationToken).ConfigureAwait(false); - await WriteTwoBytesAsync((byte)'\r', (byte)'\n', cancellationToken).ConfigureAwait(false); + await WriteAsciiStringAsync(HttpKnownHeaderNames.Cookie).ConfigureAwait(false); + await WriteTwoBytesAsync((byte)':', (byte)' ').ConfigureAwait(false); + await WriteAsciiStringAsync(cookiesFromContainer).ConfigureAwait(false); + await WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false); } } - private async Task WriteHostHeaderAsync(Uri uri, CancellationToken cancellationToken) + private async Task WriteHostHeaderAsync(Uri uri) { - await WriteBytesAsync(s_hostKeyAndSeparator, cancellationToken).ConfigureAwait(false); + await WriteBytesAsync(s_hostKeyAndSeparator).ConfigureAwait(false); await (_idnHostAsciiBytes != null ? - WriteBytesAsync(_idnHostAsciiBytes, cancellationToken) : - WriteAsciiStringAsync(uri.IdnHost, cancellationToken)).ConfigureAwait(false); + WriteBytesAsync(_idnHostAsciiBytes) : + WriteAsciiStringAsync(uri.IdnHost)).ConfigureAwait(false); if (!uri.IsDefaultPort) { - await WriteByteAsync((byte)':', cancellationToken).ConfigureAwait(false); - await WriteFormattedInt32Async(uri.Port, cancellationToken).ConfigureAwait(false); + await WriteByteAsync((byte)':').ConfigureAwait(false); + await WriteFormattedInt32Async(uri.Port).ConfigureAwait(false); } - await WriteTwoBytesAsync((byte)'\r', (byte)'\n', cancellationToken).ConfigureAwait(false); + await WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false); } - private Task WriteFormattedInt32Async(int value, CancellationToken cancellationToken) + private Task WriteFormattedInt32Async(int value) { // Try to format into our output buffer directly. if (Utf8Formatter.TryFormat(value, new Span(_writeBuffer, _writeOffset, _writeBuffer.Length - _writeOffset), out int bytesWritten)) @@ -219,7 +223,7 @@ private Task WriteFormattedInt32Async(int value, CancellationToken cancellationT } // If we don't have enough room, do it the slow way. - return WriteAsciiStringAsync(value.ToString(CultureInfo.InvariantCulture), cancellationToken); + return WriteAsciiStringAsync(value.ToString(CultureInfo.InvariantCulture)); } public async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) @@ -233,43 +237,43 @@ public async Task SendAsync(HttpRequestMessage request, Can // Send the request. if (NetEventSource.IsEnabled) Trace($"Sending request: {request}"); + CancellationTokenRegistration cancellationRegistration = RegisterCancellation(cancellationToken); try { // Write request line - await WriteStringAsync(request.Method.Method, cancellationToken).ConfigureAwait(false); - await WriteByteAsync((byte)' ', cancellationToken).ConfigureAwait(false); + await WriteStringAsync(request.Method.Method).ConfigureAwait(false); + await WriteByteAsync((byte)' ').ConfigureAwait(false); if (_usingProxy) { // Proxied requests contain full URL Debug.Assert(request.RequestUri.Scheme == Uri.UriSchemeHttp); - await WriteBytesAsync(s_httpSchemeAndDelimiter, cancellationToken).ConfigureAwait(false); - await WriteAsciiStringAsync(request.RequestUri.IdnHost, cancellationToken).ConfigureAwait(false); + await WriteBytesAsync(s_httpSchemeAndDelimiter).ConfigureAwait(false); + await WriteAsciiStringAsync(request.RequestUri.IdnHost).ConfigureAwait(false); } - await WriteStringAsync(request.RequestUri.PathAndQuery, cancellationToken).ConfigureAwait(false); + await WriteStringAsync(request.RequestUri.PathAndQuery).ConfigureAwait(false); // Fall back to 1.1 for all versions other than 1.0 Debug.Assert(request.Version.Major >= 0 && request.Version.Minor >= 0); // guaranteed by Version class bool isHttp10 = request.Version.Minor == 0 && request.Version.Major == 1; - await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11NewlineAsciiBytes, - cancellationToken).ConfigureAwait(false); + await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11NewlineAsciiBytes).ConfigureAwait(false); // Determine cookies to send - string cookies = null; + string cookiesFromContainer = null; if (_pool.Pools.Settings._useCookies) { - cookies = _pool.Pools.Settings._cookieContainer.GetCookieHeader(request.RequestUri); - if (cookies == "") + cookiesFromContainer = _pool.Pools.Settings._cookieContainer.GetCookieHeader(request.RequestUri); + if (cookiesFromContainer == "") { - cookies = null; + cookiesFromContainer = null; } } // Write request headers - if (request.HasHeaders || cookies != null) + if (request.HasHeaders || cookiesFromContainer != null) { - await WriteHeadersAsync(request.Headers, cookies, cancellationToken).ConfigureAwait(false); + await WriteHeadersAsync(request.Headers, cookiesFromContainer).ConfigureAwait(false); } if (request.Content == null) @@ -278,30 +282,30 @@ await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11N // unless this is a method that never has a body. if (request.Method != HttpMethod.Get && request.Method != HttpMethod.Head) { - await WriteBytesAsync(s_contentLength0NewlineAsciiBytes, cancellationToken).ConfigureAwait(false); + await WriteBytesAsync(s_contentLength0NewlineAsciiBytes).ConfigureAwait(false); } } else { // Write content headers - await WriteHeadersAsync(request.Content.Headers, null, cancellationToken).ConfigureAwait(false); + await WriteHeadersAsync(request.Content.Headers, cookiesFromContainer: null).ConfigureAwait(false); } // Write special additional headers. If a host isn't in the headers list, then a Host header // wasn't sent, so as it's required by HTTP 1.1 spec, send one based on the Request Uri. if (!request.HasHeaders || request.Headers.Host == null) { - await WriteHostHeaderAsync(request.RequestUri, cancellationToken).ConfigureAwait(false); + await WriteHostHeaderAsync(request.RequestUri).ConfigureAwait(false); } // CRLF for end of headers. - await WriteTwoBytesAsync((byte)'\r', (byte)'\n', cancellationToken).ConfigureAwait(false); + await WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false); Debug.Assert(_sendRequestContentTask == null); if (request.Content == null) { // We have nothing more to send, so flush out any headers we haven't yet sent. - await FlushAsync(cancellationToken).ConfigureAwait(false); + await FlushAsync().ConfigureAwait(false); } else { @@ -310,13 +314,18 @@ await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11N // to ensure the headers and content are sent. bool transferEncodingChunked = request.HasHeaders && request.Headers.TransferEncodingChunked == true; HttpContentWriteStream stream = transferEncodingChunked ? (HttpContentWriteStream) - new ChunkedEncodingWriteStream(this, cancellationToken) : - new ContentLengthWriteStream(this, cancellationToken); + new ChunkedEncodingWriteStream(this) : + new ContentLengthWriteStream(this); if (!request.HasHeaders || request.Headers.ExpectContinue != true) { - // Send the request content asynchronously. - Task sendTask = _sendRequestContentTask = SendRequestContentAsync(request, stream); + // Send the request content asynchronously. Note that elsewhere in SendAsync we don't pass + // the cancellation token around, as we simply register with it for the duration of the + // method in order to dispose of this connection and wake up any operations. But SendRequestContentAsync + // is special in that it ends up dealing with an external entity, the request HttpContent provided + // by the caller to this handler, and we could end up blocking as part of getting that content, + // which won't be affected by disposing this connection. Thus, we do pass the token in here. + Task sendTask = _sendRequestContentTask = SendRequestContentAsync(request, stream, cancellationToken); if (sendTask.IsFaulted) { // Technically this isn't necessary: if the task failed, it will have stored the exception @@ -333,7 +342,7 @@ await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11N // We're sending an Expect: 100-continue header. We need to flush headers so that the server receives // all of them, and we need to do so before initiating the send, as once we do that, it effectively // owns the right to write, and we don't want to concurrently be accessing the write buffer. - await FlushAsync(cancellationToken).ConfigureAwait(false); + await FlushAsync().ConfigureAwait(false); // Create a TCS we'll use to block the request content from being sent, and create a timer that's used // as a fail-safe to unblock the request content if we don't hear back from the server in a timely manner. @@ -342,7 +351,8 @@ await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11N var expect100Timer = new Timer( s => ((TaskCompletionSource)s).TrySetResult(true), allowExpect100ToContinue, TimeSpan.FromMilliseconds(Expect100TimeoutMilliseconds), Timeout.InfiniteTimeSpan); - _sendRequestContentTask = SendRequestContentWithExpect100ContinueAsync(request, allowExpect100ToContinue.Task, stream, expect100Timer); + _sendRequestContentTask = SendRequestContentWithExpect100ContinueAsync( + request, allowExpect100ToContinue.Task, stream, expect100Timer, cancellationToken); } } @@ -380,9 +390,9 @@ await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11N _canRetry = false; // Parse the response status line. - var response = new HttpResponseMessage() { RequestMessage = request, Content = new HttpConnectionContent(CancellationToken.None) }; - ParseStatusLine(await ReadNextLineAsync(cancellationToken).ConfigureAwait(false), response); - + var response = new HttpResponseMessage() { RequestMessage = request, Content = new HttpConnectionResponseContent() }; + ParseStatusLine(await ReadNextLineAsync().ConfigureAwait(false), response); + // If we sent an Expect: 100-continue header, handle the response accordingly. if (allowExpect100ToContinue != null) { @@ -409,12 +419,12 @@ await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11N if (response.StatusCode == HttpStatusCode.Continue) { // We got our continue header. Read the subsequent empty line and parse the additional status line. - if (!LineIsEmpty(await ReadNextLineAsync(cancellationToken).ConfigureAwait(false))) + if (!LineIsEmpty(await ReadNextLineAsync().ConfigureAwait(false))) { ThrowInvalidHttpResponse(); } - ParseStatusLine(await ReadNextLineAsync(cancellationToken).ConfigureAwait(false), response); + ParseStatusLine(await ReadNextLineAsync().ConfigureAwait(false), response); } } } @@ -422,7 +432,7 @@ await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11N // Parse the response headers. while (true) { - ArraySegment line = await ReadNextLineAsync(cancellationToken).ConfigureAwait(false); + ArraySegment line = await ReadNextLineAsync().ConfigureAwait(false); if (LineIsEmpty(line)) { break; @@ -447,6 +457,13 @@ await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11N _sendRequestContentTask = null; } + // We're about to create the response stream, at which point responsibility for canceling + // the remainder of the response lies with the stream. Thus we dispose of our registration + // here (if an exception has occurred or does occur while creating/returning the stream, + // we'll still dispose of it in the catch below as part of Dispose'ing the connection). + cancellationRegistration.Dispose(); + cancellationToken.ThrowIfCancellationRequested(); // in case cancellation may have disposed of the stream + // Create the response stream. HttpContentStream responseStream; if (request.Method == HttpMethod.Head || (int)response.StatusCode == 204 || (int)response.StatusCode == 304) @@ -479,7 +496,7 @@ await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11N { responseStream = new ConnectionCloseReadStream(this); } - ((HttpConnectionContent)response.Content).SetStream(responseStream); + ((HttpConnectionResponseContent)response.Content).SetStream(responseStream); if (NetEventSource.IsEnabled) Trace($"Received response: {response}"); @@ -493,33 +510,82 @@ await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11N } catch (Exception error) { + // Clean up the cancellation registration in case we're still registered. + cancellationRegistration.Dispose(); + // Make sure to complete the allowExpect100ToContinue task if it exists. allowExpect100ToContinue?.TrySetResult(false); if (NetEventSource.IsEnabled) Trace($"Error sending request: {error}"); Dispose(); - if (_pendingException != null) + // At this point, we're going to throw an exception; we just need to + // determine which exception to throw. + + if (ShouldWrapInOperationCanceledException(error, cancellationToken)) + { + // Cancellation was requested, so assume that the failure is due to + // the cancellation request. This is a bit unorthodox, as usually we'd + // prioritize a non-OperationCanceledException over a cancellation + // request to avoid losing potentially pertinent information. But given + // the cancellation design where we tear down the underlying connection upon + // a cancellation request, which can then result in a myriad of different + // exceptions (argument exceptions, object disposed exceptions, socket exceptions, + // etc.), as a middle ground we treat it as cancellation, but still propagate the + // original information as the inner exception, for diagnostic purposes. + throw CreateOperationCanceledException(_pendingException ?? error, cancellationToken); + } + else if (_pendingException != null) { // If we incurred an exception in non-linear control flow such that // the exception didn't bubble up here (e.g. concurrent sending of // the request content), use that error instead. throw new HttpRequestException(SR.net_http_client_execution_error, _pendingException); } - - // Otherwise, propagate this exception, wrapping it if necessary to - // match exception type expectations. - if (error is InvalidOperationException || error is IOException) + else if (error is InvalidOperationException || error is IOException) { + // If it's an InvalidOperationException or an IOException, for consistency + // with other handlers we wrap the exception in an HttpRequestException. throw new HttpRequestException(SR.net_http_client_execution_error, error); } - throw; + else + { + // Otherwise, just allow the original exception to propagate. + throw; + } } } + private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken) + { + // Cancellation design: + // - We register with the SendAsync CancellationToken for the duration of the SendAsync operation. + // - We register with the Read/Write/CopyToAsync methods on the response stream for each such individual operation. + // - The registration disposes of the connection, tearing it down and causing any pending operations to wake up. + // - Because such a tear down can result in a variety of different exception types, we check for a cancellation + // request and prioritize that over other exceptions, wrapping the actual exception as an inner of an OCE. + // - A weak reference to this HttpConnection is stored in the cancellation token, to prevent the token from + // artificially keeping this connection alive. + return cancellationToken.Register(s => + { + var weakThisRef = (WeakReference)s; + if (weakThisRef.TryGetTarget(out HttpConnection strongThisRef)) + { + if (NetEventSource.IsEnabled) strongThisRef.Trace("Cancellation requested. Disposing of the connection."); + strongThisRef.Dispose(); + } + }, _weakThisRef); + } + + private static bool ShouldWrapInOperationCanceledException(Exception error, CancellationToken cancellationToken) => + !(error is OperationCanceledException) && cancellationToken.IsCancellationRequested; + + private static Exception CreateOperationCanceledException(Exception error, CancellationToken cancellationToken) => + new OperationCanceledException(s_cancellationMessage, error, cancellationToken); + private static bool LineIsEmpty(ArraySegment line) => line.Count == 0; - private async Task SendRequestContentAsync(HttpRequestMessage request, HttpContentWriteStream stream) + private async Task SendRequestContentAsync(HttpRequestMessage request, HttpContentWriteStream stream, CancellationToken cancellationToken) { // Now that we're sending content, prohibit retries on this connection. _canRetry = false; @@ -527,13 +593,13 @@ private async Task SendRequestContentAsync(HttpRequestMessage request, HttpConte try { // Copy all of the data to the server. - await request.Content.CopyToAsync(stream, _transportContext).ConfigureAwait(false); + await request.Content.CopyToAsync(stream, _transportContext, cancellationToken).ConfigureAwait(false); // Finish the content; with a chunked upload, this includes writing the terminating chunk. await stream.FinishAsync().ConfigureAwait(false); // Flush any content that might still be buffered. - await FlushAsync(stream.RequestCancellationToken).ConfigureAwait(false); + await FlushAsync().ConfigureAwait(false); } catch (Exception e) { @@ -545,7 +611,7 @@ private async Task SendRequestContentAsync(HttpRequestMessage request, HttpConte } private async Task SendRequestContentWithExpect100ContinueAsync( - HttpRequestMessage request, Task allowExpect100ToContinueTask, HttpContentWriteStream stream, Timer expect100Timer) + HttpRequestMessage request, Task allowExpect100ToContinueTask, HttpContentWriteStream stream, Timer expect100Timer, CancellationToken cancellationToken) { // Wait until we receive a trigger notification that it's ok to continue sending content. // This will come either when the timer fires or when we receive a response status line from the server. @@ -558,7 +624,7 @@ private async Task SendRequestContentWithExpect100ContinueAsync( if (sendRequestContent) { if (NetEventSource.IsEnabled) Trace($"Sending request content for Expect: 100-continue."); - await SendRequestContentAsync(request, stream).ConfigureAwait(false); + await SendRequestContentAsync(request, stream, cancellationToken).ConfigureAwait(false); } else { @@ -708,7 +774,7 @@ private void WriteToBuffer(ReadOnlyMemory source) _writeOffset += source.Length; } - private async Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + private async Task WriteAsync(ReadOnlyMemory source) { int remaining = _writeBuffer.Length - _writeOffset; @@ -724,14 +790,14 @@ private async Task WriteAsync(ReadOnlyMemory source, CancellationToken can // Fit what we can in the current write buffer and flush it. WriteToBuffer(source.Slice(0, remaining)); source = source.Slice(remaining); - await FlushAsync(cancellationToken).ConfigureAwait(false); + await FlushAsync().ConfigureAwait(false); } if (source.Length >= _writeBuffer.Length) { // Large write. No sense buffering this. Write directly to stream. // CONSIDER: May want to be a bit smarter here? Think about how large writes should work... - await WriteToStreamAsync(source, cancellationToken).ConfigureAwait(false); + await WriteToStreamAsync(source).ConfigureAwait(false); } else { @@ -740,13 +806,13 @@ private async Task WriteAsync(ReadOnlyMemory source, CancellationToken can } } - private Task WriteWithoutBufferingAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + private Task WriteWithoutBufferingAsync(ReadOnlyMemory source) { if (_writeOffset == 0) { // There's nothing in the write buffer we need to flush. // Just write the supplied data out to the stream. - return WriteToStreamAsync(source, cancellationToken); + return WriteToStreamAsync(source); } int remaining = _writeBuffer.Length - _writeOffset; @@ -757,40 +823,40 @@ private Task WriteWithoutBufferingAsync(ReadOnlyMemory source, Cancellatio // the content to the write buffer and then flush it, so that we // can do a single send rather than two. WriteToBuffer(source); - return FlushAsync(cancellationToken); + return FlushAsync(); } // There's data in the write buffer and the data we're writing doesn't fit after it. // Do two writes, one to flush the buffer and then another to write the supplied content. - return FlushThenWriteWithoutBufferingAsync(source, cancellationToken); + return FlushThenWriteWithoutBufferingAsync(source); } - private async Task FlushThenWriteWithoutBufferingAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + private async Task FlushThenWriteWithoutBufferingAsync(ReadOnlyMemory source) { - await FlushAsync(cancellationToken).ConfigureAwait(false); - await WriteToStreamAsync(source, cancellationToken).ConfigureAwait(false); + await FlushAsync().ConfigureAwait(false); + await WriteToStreamAsync(source).ConfigureAwait(false); } - private Task WriteByteAsync(byte b, CancellationToken cancellationToken) + private Task WriteByteAsync(byte b) { if (_writeOffset < _writeBuffer.Length) { _writeBuffer[_writeOffset++] = b; return Task.CompletedTask; } - return WriteByteSlowAsync(b, cancellationToken); + return WriteByteSlowAsync(b); } - private async Task WriteByteSlowAsync(byte b, CancellationToken cancellationToken) + private async Task WriteByteSlowAsync(byte b) { Debug.Assert(_writeOffset == _writeBuffer.Length); - await WriteToStreamAsync(_writeBuffer, cancellationToken).ConfigureAwait(false); + await WriteToStreamAsync(_writeBuffer).ConfigureAwait(false); _writeBuffer[0] = b; _writeOffset = 1; } - private Task WriteTwoBytesAsync(byte b1, byte b2, CancellationToken cancellationToken) + private Task WriteTwoBytesAsync(byte b1, byte b2) { if (_writeOffset <= _writeBuffer.Length - 2) { @@ -799,16 +865,16 @@ private Task WriteTwoBytesAsync(byte b1, byte b2, CancellationToken cancellation buffer[_writeOffset++] = b2; return Task.CompletedTask; } - return WriteTwoBytesSlowAsync(b1, b2, cancellationToken); + return WriteTwoBytesSlowAsync(b1, b2); } - private async Task WriteTwoBytesSlowAsync(byte b1, byte b2, CancellationToken cancellationToken) + private async Task WriteTwoBytesSlowAsync(byte b1, byte b2) { - await WriteByteAsync(b1, cancellationToken).ConfigureAwait(false); - await WriteByteAsync(b2, cancellationToken).ConfigureAwait(false); + await WriteByteAsync(b1).ConfigureAwait(false); + await WriteByteAsync(b2).ConfigureAwait(false); } - private Task WriteBytesAsync(byte[] bytes, CancellationToken cancellationToken) + private Task WriteBytesAsync(byte[] bytes) { if (_writeOffset <= _writeBuffer.Length - bytes.Length) { @@ -816,10 +882,10 @@ private Task WriteBytesAsync(byte[] bytes, CancellationToken cancellationToken) _writeOffset += bytes.Length; return Task.CompletedTask; } - return WriteBytesSlowAsync(bytes, cancellationToken); + return WriteBytesSlowAsync(bytes); } - private async Task WriteBytesSlowAsync(byte[] bytes, CancellationToken cancellationToken) + private async Task WriteBytesSlowAsync(byte[] bytes) { int offset = 0; while (true) @@ -838,13 +904,13 @@ private async Task WriteBytesSlowAsync(byte[] bytes, CancellationToken cancellat } else if (_writeOffset == _writeBuffer.Length) { - await WriteToStreamAsync(_writeBuffer, cancellationToken).ConfigureAwait(false); + await WriteToStreamAsync(_writeBuffer).ConfigureAwait(false); _writeOffset = 0; } } } - private Task WriteStringAsync(string s, CancellationToken cancellationToken) + private Task WriteStringAsync(string s) { // If there's enough space in the buffer to just copy all of the string's bytes, do so. // Unlike WriteAsciiStringAsync, validate each char along the way. @@ -866,10 +932,10 @@ private Task WriteStringAsync(string s, CancellationToken cancellationToken) // Otherwise, fall back to doing a normal slow string write; we could optimize away // the extra checks later, but the case where we cross a buffer boundary should be rare. - return WriteStringAsyncSlow(s, cancellationToken); + return WriteStringAsyncSlow(s); } - private Task WriteAsciiStringAsync(string s, CancellationToken cancellationToken) + private Task WriteAsciiStringAsync(string s) { // If there's enough space in the buffer to just copy all of the string's bytes, do so. int offset = _writeOffset; @@ -886,10 +952,10 @@ private Task WriteAsciiStringAsync(string s, CancellationToken cancellationToken // Otherwise, fall back to doing a normal slow string write; we could optimize away // the extra checks later, but the case where we cross a buffer boundary should be rare. - return WriteStringAsyncSlow(s, cancellationToken); + return WriteStringAsyncSlow(s); } - private async Task WriteStringAsyncSlow(string s, CancellationToken cancellationToken) + private async Task WriteStringAsyncSlow(string s) { for (int i = 0; i < s.Length; i++) { @@ -898,28 +964,28 @@ private async Task WriteStringAsyncSlow(string s, CancellationToken cancellation { throw new HttpRequestException(SR.net_http_request_invalid_char_encoding); } - await WriteByteAsync((byte)c, cancellationToken).ConfigureAwait(false); + await WriteByteAsync((byte)c).ConfigureAwait(false); } } - private Task FlushAsync(CancellationToken cancellationToken) + private Task FlushAsync() { if (_writeOffset > 0) { - Task t = WriteToStreamAsync(new ReadOnlyMemory(_writeBuffer, 0, _writeOffset), cancellationToken); + Task t = WriteToStreamAsync(new ReadOnlyMemory(_writeBuffer, 0, _writeOffset)); _writeOffset = 0; return t; } return Task.CompletedTask; } - private Task WriteToStreamAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + private Task WriteToStreamAsync(ReadOnlyMemory source) { if (NetEventSource.IsEnabled) Trace($"Writing {source.Length} bytes."); - return _stream.WriteAsync(source, cancellationToken); + return _stream.WriteAsync(source); } - private async ValueTask> ReadNextLineAsync(CancellationToken cancellationToken) + private async ValueTask> ReadNextLineAsync() { int previouslyScannedBytes = 0; while (true) @@ -954,12 +1020,12 @@ private async ValueTask> ReadNextLineAsync(CancellationToken { ThrowInvalidHttpResponse(); } - await FillAsync(cancellationToken).ConfigureAwait(false); + await FillAsync().ConfigureAwait(false); } } // Throws IOException on EOF. This is only called when we expect more data. - private async Task FillAsync(CancellationToken cancellationToken) + private async Task FillAsync() { Debug.Assert(_readAheadTask == null); @@ -994,7 +1060,7 @@ private async Task FillAsync(CancellationToken cancellationToken) _readLength = remaining; } - int bytesRead = await _stream.ReadAsync(new Memory(_readBuffer, _readLength, _readBuffer.Length - _readLength), cancellationToken).ConfigureAwait(false); + int bytesRead = await _stream.ReadAsync(new Memory(_readBuffer, _readLength, _readBuffer.Length - _readLength)).ConfigureAwait(false); if (NetEventSource.IsEnabled) Trace($"Received {bytesRead} bytes."); if (bytesRead == 0) @@ -1013,7 +1079,7 @@ private void ReadFromBuffer(Span buffer) _readOffset += buffer.Length; } - private async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken) + private async ValueTask ReadAsync(Memory destination) { // This is called when reading the response body @@ -1036,28 +1102,28 @@ private async ValueTask ReadAsync(Memory destination, CancellationTok // No data in read buffer. // Do an unbuffered read directly against the underlying stream. Debug.Assert(_readAheadTask == null, "Read ahead task should have been consumed as part of the headers."); - int count = await _stream.ReadAsync(destination, cancellationToken).ConfigureAwait(false); + int count = await _stream.ReadAsync(destination).ConfigureAwait(false); if (NetEventSource.IsEnabled) Trace($"Received {count} bytes."); return count; } - private async Task CopyFromBufferAsync(Stream destination, int count, CancellationToken cancellationToken) + private async Task CopyFromBufferAsync(Stream destination, int count) { Debug.Assert(count <= _readLength - _readOffset); if (NetEventSource.IsEnabled) Trace($"Copying {count} bytes to stream."); - await destination.WriteAsync(_readBuffer, _readOffset, count, cancellationToken).ConfigureAwait(false); + await destination.WriteAsync(_readBuffer, _readOffset, count).ConfigureAwait(false); _readOffset += count; } - private async Task CopyToAsync(Stream destination, CancellationToken cancellationToken) + private async Task CopyToAsync(Stream destination) { Debug.Assert(destination != null); int remaining = _readLength - _readOffset; if (remaining > 0) { - await CopyFromBufferAsync(destination, remaining, cancellationToken).ConfigureAwait(false); + await CopyFromBufferAsync(destination, remaining).ConfigureAwait(false); } while (true) @@ -1066,19 +1132,19 @@ private async Task CopyToAsync(Stream destination, CancellationToken cancellatio // Don't use FillAsync here as it will throw on EOF. Debug.Assert(_readAheadTask == null); - _readLength = await _stream.ReadAsync(_readBuffer, cancellationToken).ConfigureAwait(false); + _readLength = await _stream.ReadAsync(_readBuffer).ConfigureAwait(false); if (_readLength == 0) { // End of stream break; } - await CopyFromBufferAsync(destination, _readLength, cancellationToken).ConfigureAwait(false); + await CopyFromBufferAsync(destination, _readLength).ConfigureAwait(false); } } // Copy *exactly* [length] bytes into destination; throws on end of stream. - private async Task CopyToAsync(Stream destination, ulong length, CancellationToken cancellationToken) + private async Task CopyToAsync(Stream destination, ulong length) { Debug.Assert(destination != null); Debug.Assert(length > 0); @@ -1090,7 +1156,7 @@ private async Task CopyToAsync(Stream destination, ulong length, CancellationTok { remaining = (int)length; } - await CopyFromBufferAsync(destination, remaining, cancellationToken).ConfigureAwait(false); + await CopyFromBufferAsync(destination, remaining).ConfigureAwait(false); length -= (ulong)remaining; if (length == 0) @@ -1101,10 +1167,10 @@ private async Task CopyToAsync(Stream destination, ulong length, CancellationTok while (true) { - await FillAsync(cancellationToken).ConfigureAwait(false); + await FillAsync().ConfigureAwait(false); remaining = (ulong)_readLength < length ? _readLength : (int)length; - await CopyFromBufferAsync(destination, remaining, cancellationToken).ConfigureAwait(false); + await CopyFromBufferAsync(destination, remaining).ConfigureAwait(false); length -= (ulong)remaining; if (length == 0) @@ -1177,8 +1243,7 @@ private void ReturnConnectionToPoolCore() { try { - // Null out the associated request before the connection is potentially reused by another. - _currentRequest = null; + // Any remaining request content has completed successfully. Drop it. _sendRequestContentTask = null; // When putting a connection back into the pool, we initiate a pre-emptive diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionContent.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionResponseContent.cs similarity index 70% rename from src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionContent.cs rename to src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionResponseContent.cs index e77ff037f694..7cca2c9f7e23 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionContent.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionResponseContent.cs @@ -11,16 +11,10 @@ namespace System.Net.Http { internal partial class HttpConnection : IDisposable { - private sealed class HttpConnectionContent : HttpContent + private sealed class HttpConnectionResponseContent : HttpContent { - private readonly CancellationToken _cancellationToken; private HttpContentStream _stream; - public HttpConnectionContent(CancellationToken cancellationToken) - { - _cancellationToken = cancellationToken; - } - public void SetStream(HttpContentStream stream) { Debug.Assert(stream != null); @@ -41,30 +35,33 @@ private HttpContentStream ConsumeStream() return stream; } - protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context) + protected sealed override Task SerializeToStreamAsync(Stream stream, TransportContext context) => + SerializeToStreamAsync(stream, context, CancellationToken.None); + + internal sealed override async Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) { Debug.Assert(stream != null); using (HttpContentStream contentStream = ConsumeStream()) { const int BufferSize = 8192; - await contentStream.CopyToAsync(stream, BufferSize, _cancellationToken).ConfigureAwait(false); + await contentStream.CopyToAsync(stream, BufferSize, cancellationToken).ConfigureAwait(false); } } - protected internal override bool TryComputeLength(out long length) + protected internal sealed override bool TryComputeLength(out long length) { length = 0; return false; } - protected override Task CreateContentReadStreamAsync() => + protected sealed override Task CreateContentReadStreamAsync() => Task.FromResult(ConsumeStream()); - internal override Stream TryCreateContentReadStream() => + internal sealed override Stream TryCreateContentReadStream() => ConsumeStream(); - protected override void Dispose(bool disposing) + protected sealed override void Dispose(bool disposing) { if (disposing) { diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentDuplexStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentDuplexStream.cs index 022e87606751..73e04edec110 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentDuplexStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentDuplexStream.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Diagnostics; using System.IO; using System.Threading; @@ -14,21 +13,21 @@ public HttpContentDuplexStream(HttpConnection connection) : base(connection) { } - public override bool CanRead => true; - public override bool CanWrite => true; + public sealed override bool CanRead => true; + public sealed override bool CanWrite => true; - public override void Flush() => FlushAsync().GetAwaiter().GetResult(); + public sealed override void Flush() => FlushAsync().GetAwaiter().GetResult(); - public override int Read(byte[] buffer, int offset, int count) + public sealed override int Read(byte[] buffer, int offset, int count) { ValidateBufferArgs(buffer, offset, count); return ReadAsync(new Memory(buffer, offset, count), CancellationToken.None).GetAwaiter().GetResult(); } - public override void Write(byte[] buffer, int offset, int count) => + public sealed override void Write(byte[] buffer, int offset, int count) => WriteAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult(); - public override void CopyTo(Stream destination, int bufferSize) => + public sealed override void CopyTo(Stream destination, int bufferSize) => CopyToAsync(destination, bufferSize, CancellationToken.None).GetAwaiter().GetResult(); } } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentReadStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentReadStream.cs index ed5cce884a6d..fdad5898d738 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentReadStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentReadStream.cs @@ -13,20 +13,20 @@ public HttpContentReadStream(HttpConnection connection) : base(connection) { } - public override bool CanRead => true; - public override bool CanWrite => false; + public sealed override bool CanRead => true; + public sealed override bool CanWrite => false; - public override void Flush() { } + public sealed override void Flush() { } - public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public sealed override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); - public override int Read(byte[] buffer, int offset, int count) + public sealed override int Read(byte[] buffer, int offset, int count) { ValidateBufferArgs(buffer, offset, count); return ReadAsync(new Memory(buffer, offset, count), CancellationToken.None).GetAwaiter().GetResult(); } - public override void CopyTo(Stream destination, int bufferSize) => + public sealed override void CopyTo(Stream destination, int bufferSize) => CopyToAsync(destination, bufferSize, CancellationToken.None).GetAwaiter().GetResult(); } } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentStream.cs index b54f6c45dd3b..ad329ff019f3 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentStream.cs @@ -31,27 +31,27 @@ protected override void Dispose(bool disposing) base.Dispose(disposing); } - public override bool CanSeek => false; + public sealed override bool CanSeek => false; - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => + public sealed override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => TaskToApm.Begin(ReadAsync(buffer, offset, count, default(CancellationToken)), callback, state); - public override int EndRead(IAsyncResult asyncResult) => + public sealed override int EndRead(IAsyncResult asyncResult) => TaskToApm.End(asyncResult); - public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => + public sealed override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => TaskToApm.Begin(WriteAsync(buffer, offset, count, default(CancellationToken)), callback, state); - public override void EndWrite(IAsyncResult asyncResult) => + public sealed override void EndWrite(IAsyncResult asyncResult) => TaskToApm.End(asyncResult); - public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public sealed override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); - public override void SetLength(long value) => throw new NotSupportedException(); + public sealed override void SetLength(long value) => throw new NotSupportedException(); - public override long Length => throw new NotSupportedException(); + public sealed override long Length => throw new NotSupportedException(); - public override long Position + public sealed override long Position { get { throw new NotSupportedException(); } set { throw new NotSupportedException(); } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs index 54c68df478b5..b302e89c070a 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System.Diagnostics; -using System.IO; using System.Threading; using System.Threading.Tasks; @@ -11,28 +10,17 @@ namespace System.Net.Http { internal abstract class HttpContentWriteStream : HttpContentStream { - public HttpContentWriteStream(HttpConnection connection, CancellationToken cancellationToken) : base(connection) - { + public HttpContentWriteStream(HttpConnection connection) : base(connection) => Debug.Assert(connection != null); - RequestCancellationToken = cancellationToken; - } - /// Cancellation token associated with the send operation. - /// - /// Because of how this write stream is used, the CancellationToken passed into the individual - /// stream operations will be the default non-cancelable token and can be ignored. Instead, - /// this token is used. - /// - internal CancellationToken RequestCancellationToken { get; } + public sealed override bool CanRead => false; + public sealed override bool CanWrite => true; - public override bool CanRead => false; - public override bool CanWrite => true; + public sealed override void Flush() => FlushAsync().GetAwaiter().GetResult(); - public override void Flush() => FlushAsync().GetAwaiter().GetResult(); + public sealed override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); - public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); - - public override void Write(byte[] buffer, int offset, int count) => + public sealed override void Write(byte[] buffer, int offset, int count) => WriteAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult(); public abstract Task FinishAsync(); diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs index 7073afc8d8d3..e30f76db601a 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs @@ -22,17 +22,44 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); } - public override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + public override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (_connection == null || destination.Length == 0) { // Response body fully consumed or the caller didn't ask for any data return 0; } - int bytesRead = await _connection.ReadAsync(destination, cancellationToken).ConfigureAwait(false); + ValueTask readTask = _connection.ReadAsync(destination); + int bytesRead; + if (readTask.IsCompletedSuccessfully) + { + bytesRead = readTask.Result; + } + else + { + CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + try + { + bytesRead = await readTask.ConfigureAwait(false); + } + catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken)) + { + throw CreateOperationCanceledException(exc, cancellationToken); + } + finally + { + ctr.Dispose(); + } + } + if (bytesRead == 0) { + // A cancellation request may have caused the EOF. + cancellationToken.ThrowIfCancellationRequested(); + // We cannot reuse this connection, so close it. _connection.Dispose(); _connection = null; @@ -48,15 +75,40 @@ public override async Task CopyToAsync(Stream destination, int bufferSize, Cance { throw new ArgumentNullException(nameof(destination)); } + if (bufferSize <= 0) + { + throw new ArgumentOutOfRangeException(nameof(bufferSize)); + } - if (_connection != null) // null if response body fully consumed + cancellationToken.ThrowIfCancellationRequested(); + + if (_connection == null) { - await _connection.CopyToAsync(destination, cancellationToken).ConfigureAwait(false); + // Response body fully consumed + return; + } - // We cannot reuse this connection, so close it. - _connection.Dispose(); - _connection = null; + Task copyTask = _connection.CopyToAsync(destination); + if (!copyTask.IsCompletedSuccessfully) + { + CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + try + { + await copyTask.ConfigureAwait(false); + } + catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken)) + { + throw CreateOperationCanceledException(exc, cancellationToken); + } + finally + { + ctr.Dispose(); + } } + + // We cannot reuse this connection, so close it. + _connection.Dispose(); + _connection = null; } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) @@ -65,14 +117,63 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) => - _connection == null ? Task.FromException(new IOException(SR.net_http_io_write)) : - source.Length > 0 ? _connection.WriteWithoutBufferingAsync(source, cancellationToken) : - Task.CompletedTask; + public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + if (_connection == null) + { + return Task.FromException(new IOException(SR.net_http_io_write)); + } + + if (source.Length == 0) + { + return Task.CompletedTask; + } + + Task writeTask = _connection.WriteWithoutBufferingAsync(source); + return writeTask.IsCompleted ? + writeTask : + WaitWithConnectionCancellationAsync(writeTask, cancellationToken); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + if (_connection == null) + { + return Task.CompletedTask; + } + + Task flushTask = _connection.FlushAsync(); + return flushTask.IsCompleted ? + flushTask : + WaitWithConnectionCancellationAsync(flushTask, cancellationToken); + } - public override Task FlushAsync(CancellationToken cancellationToken) => - _connection != null ? _connection.FlushAsync(cancellationToken) : - Task.CompletedTask; + private async Task WaitWithConnectionCancellationAsync(Task task, CancellationToken cancellationToken) + { + CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + try + { + await task.ConfigureAwait(false); + } + catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken)) + { + throw CreateOperationCanceledException(exc, cancellationToken); + } + finally + { + ctr.Dispose(); + } + } } } } diff --git a/src/System.Net.Http/tests/FunctionalTests/CancellationTest.cs b/src/System.Net.Http/tests/FunctionalTests/CancellationTest.cs deleted file mode 100644 index f29922b17c54..000000000000 --- a/src/System.Net.Http/tests/FunctionalTests/CancellationTest.cs +++ /dev/null @@ -1,162 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Diagnostics; -using System.IO; -using System.Net.Test.Common; -using System.Threading; -using System.Threading.Tasks; - -using Xunit; -using Xunit.Abstractions; - -namespace System.Net.Http.Functional.Tests -{ - public class CancellationTest : HttpClientTestBase - { - private readonly ITestOutputHelper _output; - - public CancellationTest(ITestOutputHelper output) - { - _output = output; - } - - [OuterLoop] // includes seconds of delay - [Theory] - [InlineData(false, false)] - [InlineData(false, true)] - [InlineData(true, false)] - [InlineData(true, true)] - [ActiveIssue("dotnet/corefx #20010", TargetFrameworkMonikers.Uap)] - [ActiveIssue("dotnet/corefx #19038", TargetFrameworkMonikers.NetFramework)] - public async Task GetAsync_ResponseContentRead_CancelUsingTimeoutOrToken_TaskCanceledQuickly( - bool useTimeout, bool startResponseBody) - { - var cts = new CancellationTokenSource(); // ignored if useTimeout==true - TimeSpan timeout = useTimeout ? new TimeSpan(0, 0, 1) : Timeout.InfiniteTimeSpan; - CancellationToken cancellationToken = useTimeout ? CancellationToken.None : cts.Token; - - using (HttpClient client = CreateHttpClient()) - { - client.Timeout = timeout; - - var triggerResponseWrite = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var triggerRequestCancel = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - await LoopbackServer.CreateServerAsync(async (server, url) => - { - Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) => - { - while (!string.IsNullOrEmpty(await reader.ReadLineAsync())) ; - await writer.WriteAsync( - "HTTP/1.1 200 OK\r\n" + - $"Date: {DateTimeOffset.UtcNow:R}\r\n" + - "Content-Length: 16000\r\n" + - "\r\n" + - (startResponseBody ? "less than 16000 bytes" : "")); - - await Task.Delay(1000); - triggerRequestCancel.SetResult(true); // allow request to cancel - await triggerResponseWrite.Task; // pause until we're released - - return null; - }); - - var stopwatch = Stopwatch.StartNew(); - if (PlatformDetection.IsFullFramework) - { - // .NET Framework throws WebException instead of OperationCanceledException. - await Assert.ThrowsAnyAsync(async () => - { - Task getResponse = client.GetAsync(url, HttpCompletionOption.ResponseContentRead, cancellationToken); - await triggerRequestCancel.Task; - cts.Cancel(); - await getResponse; - }); - } - else - { - await Assert.ThrowsAnyAsync(async () => - { - Task getResponse = client.GetAsync(url, HttpCompletionOption.ResponseContentRead, cancellationToken); - await triggerRequestCancel.Task; - cts.Cancel(); - await getResponse; - }); - } - stopwatch.Stop(); - _output.WriteLine("GetAsync() completed at: {0}", stopwatch.Elapsed.ToString()); - - triggerResponseWrite.SetResult(true); - Assert.True(stopwatch.Elapsed < new TimeSpan(0, 0, 30), $"Elapsed time {stopwatch.Elapsed} should be less than 30 seconds, was {stopwatch.Elapsed.TotalSeconds}"); - }); - } - } - - [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, "dotnet/corefx #18864")] // Hangs on NETFX - [ActiveIssue(9075, TestPlatforms.AnyUnix)] // recombine this test into the subsequent one when issue is fixed - [OuterLoop] // includes seconds of delay - [Fact] - public Task ReadAsStreamAsync_ReadAsync_Cancel_BodyNeverStarted_TaskCanceledQuickly() - { - return ReadAsStreamAsync_ReadAsync_Cancel_TaskCanceledQuickly(false); - } - - [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, "dotnet/corefx #18864")] // Hangs on NETFX - [OuterLoop] // includes seconds of delay - [Theory] - [InlineData(true)] - public async Task ReadAsStreamAsync_ReadAsync_Cancel_TaskCanceledQuickly(bool startResponseBody) - { - using (HttpClient client = CreateHttpClient()) - { - await LoopbackServer.CreateServerAsync(async (server, url) => - { - var triggerResponseWrite = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) => - { - while (!string.IsNullOrEmpty(await reader.ReadLineAsync())) ; - await writer.WriteAsync( - "HTTP/1.1 200 OK\r\n" + - $"Date: {DateTimeOffset.UtcNow:R}\r\n" + - "Content-Length: 16000\r\n" + - "\r\n" + - (startResponseBody ? "20 bytes of the body" : "")); - - await triggerResponseWrite.Task; // pause until we're released - - return null; - }); - - using (HttpResponseMessage response = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead)) - using (Stream responseStream = await response.Content.ReadAsStreamAsync()) - { - // Read all expected content - byte[] buffer = new byte[20]; - if (startResponseBody) - { - int totalRead = 0; - int bytesRead; - while (totalRead < 20 && (bytesRead = await responseStream.ReadAsync(buffer, 0, buffer.Length)) > 0) - { - totalRead += bytesRead; - } - } - - // Now do a read that'll need to be canceled - var stopwatch = Stopwatch.StartNew(); - await Assert.ThrowsAnyAsync( - () => responseStream.ReadAsync(buffer, 0, buffer.Length, new CancellationTokenSource(1000).Token)); - stopwatch.Stop(); - - triggerResponseWrite.SetResult(true); - _output.WriteLine("ReadAsync() completed at: {0}", stopwatch.Elapsed.ToString()); - Assert.True(stopwatch.Elapsed < new TimeSpan(0, 0, 30), $"Elapsed time {stopwatch.Elapsed} should be less than 30 seconds, was {stopwatch.Elapsed.TotalSeconds}"); - } - }); - } - } - } -} diff --git a/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs new file mode 100644 index 000000000000..db3f4a8f7457 --- /dev/null +++ b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs @@ -0,0 +1,478 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.IO; +using System.Net.Test.Common; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.Http.Functional.Tests +{ + public class HttpClientHandler_Cancellation_Test : HttpClientTestBase + { + [Theory] + [MemberData(nameof(TwoBoolsAndCancellationMode))] + public async Task PostAsync_CancelDuringRequestContentSend_TaskCanceledQuickly(bool chunkedTransfer, bool connectionClose, CancellationMode mode) + { + if (IsWinHttpHandler || IsNetfxHandler) + { + // Issue #27063: hangs / doesn't cancel + return; + } + + using (HttpClient client = CreateHttpClient()) + { + client.Timeout = Timeout.InfiniteTimeSpan; + var cts = new CancellationTokenSource(); + + await LoopbackServer.CreateServerAsync(async (server, url) => + { + Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) => + { + // Since we won't receive all of the request, just read everything we do get + byte[] ignored = new byte[100]; + while (await stream.ReadAsync(ignored, 0, ignored.Length) > 0); + return null; + }); + + var preContentSent = new TaskCompletionSource(); + var sendPostContent = new TaskCompletionSource(); + + await ValidateClientCancellationAsync(async () => + { + var req = new HttpRequestMessage(HttpMethod.Post, url); + req.Content = new DelayedByteContent(2000, 3000, preContentSent, sendPostContent.Task); + req.Headers.TransferEncodingChunked = chunkedTransfer; + req.Headers.ConnectionClose = connectionClose; + + Task postResponse = client.SendAsync(req, HttpCompletionOption.ResponseHeadersRead, cts.Token); + await preContentSent.Task; + Cancel(mode, client, cts); + await postResponse; + }); + + try + { + sendPostContent.SetResult(true); + await serverTask; + } catch { } + }); + } + } + + [Theory] + [MemberData(nameof(TwoBoolsAndCancellationMode))] + public async Task GetAsync_CancelDuringResponseHeadersReceived_TaskCanceledQuickly(bool chunkedTransfer, bool connectionClose, CancellationMode mode) + { + using (HttpClient client = CreateHttpClient()) + { + client.Timeout = Timeout.InfiniteTimeSpan; + var cts = new CancellationTokenSource(); + + await LoopbackServer.CreateServerAsync(async (server, url) => + { + var partialResponseHeadersSent = new TaskCompletionSource(); + var clientFinished = new TaskCompletionSource(); + + Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) => + { + while (!string.IsNullOrEmpty(await reader.ReadLineAsync())); + + await writer.WriteAsync($"HTTP/1.1 200 OK\r\nDate: {DateTimeOffset.UtcNow:R}\r\n"); // missing final \r\n so headers don't complete + + partialResponseHeadersSent.TrySetResult(true); + await clientFinished.Task; + + return null; + }); + + await ValidateClientCancellationAsync(async () => + { + var req = new HttpRequestMessage(HttpMethod.Get, url); + req.Headers.ConnectionClose = connectionClose; + + Task getResponse = client.SendAsync(req, HttpCompletionOption.ResponseHeadersRead, cts.Token); + await partialResponseHeadersSent.Task; + Cancel(mode, client, cts); + await getResponse; + }); + + try + { + clientFinished.SetResult(true); + await serverTask; + } catch { } + }); + } + } + + [Theory] + [MemberData(nameof(TwoBoolsAndCancellationMode))] + public async Task GetAsync_CancelDuringResponseBodyReceived_Buffered_TaskCanceledQuickly(bool chunkedTransfer, bool connectionClose, CancellationMode mode) + { + using (HttpClient client = CreateHttpClient()) + { + client.Timeout = Timeout.InfiniteTimeSpan; + var cts = new CancellationTokenSource(); + + await LoopbackServer.CreateServerAsync(async (server, url) => + { + var responseHeadersSent = new TaskCompletionSource(); + var clientFinished = new TaskCompletionSource(); + + Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) => + { + while (!string.IsNullOrEmpty(await reader.ReadLineAsync())); + + await writer.WriteAsync( + $"HTTP/1.1 200 OK\r\n" + + $"Date: {DateTimeOffset.UtcNow:R}\r\n" + + (!chunkedTransfer ? "Content-Length: 20\r\n" : "") + + (connectionClose ? "Connection: close\r\n" : "") + + $"\r\n123"); // "123" is part of body and could either be chunked size or part of content-length bytes, both incomplete + + responseHeadersSent.TrySetResult(true); + await clientFinished.Task; + + return null; + }); + + await ValidateClientCancellationAsync(async () => + { + var req = new HttpRequestMessage(HttpMethod.Get, url); + req.Headers.ConnectionClose = connectionClose; + + Task getResponse = client.SendAsync(req, HttpCompletionOption.ResponseContentRead, cts.Token); + await responseHeadersSent.Task; + await Task.Delay(1); // make it more likely that client will have started processing response body + Cancel(mode, client, cts); + await getResponse; + }); + + try + { + clientFinished.SetResult(true); + await serverTask; + } catch { } + }); + } + } + + [Theory] + [MemberData(nameof(ThreeBools))] + public async Task GetAsync_CancelDuringResponseBodyReceived_Unbuffered_TaskCanceledQuickly(bool chunkedTransfer, bool connectionClose, bool readOrCopyToAsync) + { + if (IsNetfxHandler || IsCurlHandler) + { + // doesn't cancel + return; + } + + using (HttpClient client = CreateHttpClient()) + { + client.Timeout = Timeout.InfiniteTimeSpan; + var cts = new CancellationTokenSource(); + + await LoopbackServer.CreateServerAsync(async (server, url) => + { + var clientFinished = new TaskCompletionSource(); + + Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) => + { + while (!string.IsNullOrEmpty(await reader.ReadLineAsync())); + + await writer.WriteAsync( + $"HTTP/1.1 200 OK\r\n" + + $"Date: {DateTimeOffset.UtcNow:R}\r\n" + + (!chunkedTransfer ? "Content-Length: 20\r\n" : "") + + (connectionClose ? "Connection: close\r\n" : "") + + $"\r\n"); + + await clientFinished.Task; + + return null; + }); + + var req = new HttpRequestMessage(HttpMethod.Get, url); + req.Headers.ConnectionClose = connectionClose; + Task getResponse = client.SendAsync(req, HttpCompletionOption.ResponseHeadersRead, cts.Token); + await ValidateClientCancellationAsync(async () => + { + HttpResponseMessage resp = await getResponse; + Stream respStream = await resp.Content.ReadAsStreamAsync(); + Task readTask = readOrCopyToAsync ? + respStream.ReadAsync(new byte[1], 0, 1, cts.Token) : + respStream.CopyToAsync(Stream.Null, 10, cts.Token); + cts.Cancel(); + await readTask; + }); + + try + { + clientFinished.SetResult(true); + await serverTask; + } catch { } + }); + } + } + + [Theory] + [InlineData(CancellationMode.CancelPendingRequests, false)] + [InlineData(CancellationMode.DisposeHttpClient, true)] + [InlineData(CancellationMode.CancelPendingRequests, false)] + [InlineData(CancellationMode.DisposeHttpClient, true)] + public async Task GetAsync_CancelPendingRequests_DoesntCancelReadAsyncOnResponseStream(CancellationMode mode, bool copyToAsync) + { + if (IsNetfxHandler) + { + // throws ObjectDisposedException as part of Stream.CopyToAsync/ReadAsync + return; + } + if (IsCurlHandler) + { + // Issue #27065 + // throws OperationCanceledException from Stream.CopyToAsync/ReadAsync + return; + } + + using (HttpClient client = CreateHttpClient()) + { + client.Timeout = Timeout.InfiniteTimeSpan; + + await LoopbackServer.CreateServerAsync(async (server, url) => + { + var clientReadSomeBody = new TaskCompletionSource(); + var clientFinished = new TaskCompletionSource(); + + var responseContentSegment = new string('s', 3000); + int responseSegments = 4; + int contentLength = responseContentSegment.Length * responseSegments; + + Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) => + { + while (!string.IsNullOrEmpty(await reader.ReadLineAsync())); + + await writer.WriteAsync( + $"HTTP/1.1 200 OK\r\n" + + $"Date: {DateTimeOffset.UtcNow:R}\r\n" + + $"Content-Length: {contentLength}\r\n" + + $"\r\n"); + + for (int i = 0; i < responseSegments; i++) + { + await writer.WriteAsync(responseContentSegment); + if (i == 0) + { + await clientReadSomeBody.Task; + } + } + + await clientFinished.Task; + + return null; + }); + + + using (HttpResponseMessage resp = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead)) + using (Stream respStream = await resp.Content.ReadAsStreamAsync()) + { + var result = new MemoryStream(); + int b = respStream.ReadByte(); + Assert.NotEqual(-1, b); + result.WriteByte((byte)b); + + Cancel(mode, client, null); // should not cancel the operation, as using ResponseHeadersRead + clientReadSomeBody.SetResult(true); + + if (copyToAsync) + { + await respStream.CopyToAsync(result, 10, new CancellationTokenSource().Token); + } + else + { + byte[] buffer = new byte[10]; + int bytesRead; + while ((bytesRead = await respStream.ReadAsync(buffer, 0, buffer.Length)) > 0) + { + result.Write(buffer, 0, bytesRead); + } + } + + Assert.Equal(contentLength, result.Length); + } + + clientFinished.SetResult(true); + await serverTask; + }); + } + } + + [Fact] + public async Task MaxConnectionsPerServer_WaitingConnectionsAreCancelable() + { + if (IsWinHttpHandler) + { + // Issue #27064: + // Throws WinHttpException ("The server returned an invalid or unrecognized response") + // while parsing headers. + return; + } + if (IsNetfxHandler) + { + // Throws HttpRequestException wrapping a WebException for the canceled request + // instead of throwing an OperationCanceledException or a canceled WebException directly. + return; + } + + using (HttpClientHandler handler = CreateHttpClientHandler()) + using (HttpClient client = new HttpClient(handler)) + { + handler.MaxConnectionsPerServer = 1; + client.Timeout = Timeout.InfiniteTimeSpan; + + await LoopbackServer.CreateServerAsync(async (server, url) => + { + var serverAboutToBlock = new TaskCompletionSource(); + var blockServerResponse = new TaskCompletionSource(); + + Task serverTask1 = LoopbackServer.AcceptSocketAsync(server, async (socket1, stream1, reader1, writer1) => + { + while (!string.IsNullOrEmpty(await reader1.ReadLineAsync())); + await writer1.WriteAsync($"HTTP/1.1 200 OK\r\nDate: {DateTimeOffset.UtcNow:R}\r\n"); + serverAboutToBlock.SetResult(true); + await blockServerResponse.Task; + await writer1.WriteAsync("Content-Length: 5\r\n\r\nhello"); + return null; + }); + + Task get1 = client.GetAsync(url); + await serverAboutToBlock.Task; + + var cts = new CancellationTokenSource(); + Task get2 = ValidateClientCancellationAsync(() => client.GetAsync(url, cts.Token)); + Task get3 = ValidateClientCancellationAsync(() => client.GetAsync(url, cts.Token)); + + Task get4 = client.GetAsync(url); + + cts.Cancel(); + await get2; + await get3; + + blockServerResponse.SetResult(true); + await new[] { get1, serverTask1 }.WhenAllOrAnyFailed(); + + Task serverTask4 = LoopbackServer.AcceptSocketAsync(server, async (socket2, stream2, reader2, writer2) => + { + while (!string.IsNullOrEmpty(await reader2.ReadLineAsync())); + await writer2.WriteAsync($"HTTP/1.1 200 OK\r\nDate: {DateTimeOffset.UtcNow:R}\r\nContent-Length: 0\r\n\r\n"); + return null; + }); + + await new[] { get4, serverTask4 }.WhenAllOrAnyFailed(); + }); + } + } + + private async Task ValidateClientCancellationAsync(Func clientBodyAsync) + { + var stopwatch = Stopwatch.StartNew(); + Exception error = await Record.ExceptionAsync(clientBodyAsync); + stopwatch.Stop(); + + Assert.NotNull(error); + + if (IsNetfxHandler) + { + Assert.True( + error is WebException we && we.Status == WebExceptionStatus.RequestCanceled || + error is OperationCanceledException, + "Expected cancellation exception, got:" + Environment.NewLine + error); + } + else + { + Assert.True( + error is OperationCanceledException, + "Expected cancellation exception, got:" + Environment.NewLine + error); + } + + Assert.True(stopwatch.Elapsed < new TimeSpan(0, 0, 30), $"Elapsed time {stopwatch.Elapsed} should be less than 30 seconds, was {stopwatch.Elapsed.TotalSeconds}"); + } + + private static void Cancel(CancellationMode mode, HttpClient client, CancellationTokenSource cts) + { + if ((mode & CancellationMode.Token) != 0) + { + cts?.Cancel(); + } + + if ((mode & CancellationMode.CancelPendingRequests) != 0) + { + client?.CancelPendingRequests(); + } + + if ((mode & CancellationMode.DisposeHttpClient) != 0) + { + client?.Dispose(); + } + } + + [Flags] + public enum CancellationMode + { + Token = 0x1, + CancelPendingRequests = 0x2, + DisposeHttpClient = 0x4 + } + + private static readonly bool[] s_bools = new[] { true, false }; + + public static IEnumerable TwoBoolsAndCancellationMode() => + from first in s_bools + from second in s_bools + from mode in new[] { CancellationMode.Token, CancellationMode.CancelPendingRequests, CancellationMode.DisposeHttpClient, CancellationMode.Token | CancellationMode.CancelPendingRequests } + select new object[] { first, second, mode }; + + public static IEnumerable ThreeBools() => + from first in s_bools + from second in s_bools + from third in s_bools + select new object[] { first, second, third }; + + private sealed class DelayedByteContent : HttpContent + { + private readonly TaskCompletionSource _preContentSent; + private readonly Task _waitToSendPostContent; + + public DelayedByteContent(int preTriggerLength, int postTriggerLength, TaskCompletionSource preContentSent, Task waitToSendPostContent) + { + PreTriggerLength = preTriggerLength; + _preContentSent = preContentSent; + _waitToSendPostContent = waitToSendPostContent; + Content = new byte[preTriggerLength + postTriggerLength]; + new Random().NextBytes(Content); + } + + public byte[] Content { get; } + public int PreTriggerLength { get; } + + protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context) + { + await stream.WriteAsync(Content, 0, PreTriggerLength); + _preContentSent.TrySetResult(true); + await _waitToSendPostContent; + await stream.WriteAsync(Content, PreTriggerLength, Content.Length - PreTriggerLength); + } + + protected override bool TryComputeLength(out long length) + { + length = Content.Length; + return true; + } + } + } +} diff --git a/src/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index fa8b3a1634cc..eafb1b554798 100644 --- a/src/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -125,16 +125,10 @@ public sealed class SocketsHttpHandler_HttpCookieProtocolTests : HttpCookieProto protected override bool UseSocketsHttpHandler => true; } - // TODO #23141: Socket's don't support canceling individual operations, so ReadStream on NetworkStream - // isn't cancelable once the operation has started. We either need to wrap the operation with one that's - // "cancelable", meaning that the underlying operation will still be running even though we've returned "canceled", - // or we need to just recognize that cancellation in such situations can be left up to the caller to do the - // same thing if it's really important. - //public sealed class SocketsHttpHandler_CancellationTest : CancellationTest - //{ - // public SocketsHttpHandler_CancellationTest(ITestOutputHelper output) : base(output) { } - // protected override bool UseSocketsHttpHandler => true; - //} + public sealed class SocketsHttpHandler_HttpClientHandler_Cancellation_Test : HttpClientHandler_Cancellation_Test + { + protected override bool UseSocketsHttpHandler => true; + } public sealed class SocketsHttpHandler_HttpClientHandler_MaxResponseHeadersLength_Test : HttpClientHandler_MaxResponseHeadersLength_Test { diff --git a/src/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj b/src/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj index 78bd6656dcf3..578e3fa00da7 100644 --- a/src/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj +++ b/src/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj @@ -73,13 +73,13 @@ Common\System\Threading\Tasks\TaskTimeoutExtensions.cs - + @@ -150,4 +150,4 @@ - \ No newline at end of file +