Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Commit

Permalink
Enable SocketsHttpHandler cancellation support (#27029)
Browse files Browse the repository at this point in the history
* Enable SocketsHttpHandler cancellation support

This change significantly improves the cancellation support in SocketsHttpHandler.  Previously we were passing the CancellationToken around to every method, eventually bottoming out in calls to the underlying Stream which then ends up passing them down to the underlying Socket.  But today Socket's support for cancellation is minimal, only doing up-front checks; if cancellation is requested during the socket operation rather than before, the request will be ignored.  Since HttpClient implements features like timeouts on top of cancellation support, it's important to do better than this.

The change implements cancellation by registering with the CancellationToken to dispose of the connection.  This will cause any reads/writes to wake up.  We then translate resulting exceptions into cancellation exceptions.  When in the main SendAsync method, we register once for the whole body of the operation until the point that we're returning the response message.  For individual operations on the response content stream, we register per operation; however, when feasible we try to avoid the registration costs by only registering if operations don't complete synchronously.  We also account for the case that on Unix, closing the connection may result in read operations waking up not with an exception but rather with EOF, which we also need to translate into cancellation when appropriate.

Along the way I cleaned up a few minor issues as well.

I also added a bunch of cancellation-related tests:
- Test cancellation occurring while sending request content
- Test cancellation occurring while receiving response headers
- Test cancellation occurring while receiving response body and using a buffered operation
- Test that all of the above are triggerable with CancellationTokenSource.Cancel, HttpClient.CancelPendingRequests, and HttpClient.Dispose
- Test cancellation occurring while receiving response body and using an unbuffered operation, either a ReadAsync or CopyToAsync on the response stream
- Test that a CancelPendingRequests doesn't affect unbuffered operations on the response stream

There are deficiencies here in the existing handlers, and tests have been selectively disabled accordingly (I also fixed a couple cases that naturally fell out of the changes I was making for SocketsHttpHandler).  SocketsHttpHandler passes now for all of them.

* Add test that Dispose doesn't cancel response stream
  • Loading branch information
stephentoub authored Feb 13, 2018
1 parent c705032 commit 53be85c
Show file tree
Hide file tree
Showing 27 changed files with 1,092 additions and 451 deletions.
21 changes: 15 additions & 6 deletions src/Common/src/System/Net/Http/NoWriteNoSeekStreamContent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,26 @@ namespace System.Net.Http
internal sealed class NoWriteNoSeekStreamContent : HttpContent
{
private readonly Stream _content;
private readonly CancellationToken _cancellationToken;
private bool _contentConsumed;

internal NoWriteNoSeekStreamContent(Stream content, CancellationToken cancellationToken)
internal NoWriteNoSeekStreamContent(Stream content)
{
Debug.Assert(content != null);
Debug.Assert(content.CanRead);
Debug.Assert(!content.CanWrite);
Debug.Assert(!content.CanSeek);

_content = content;
_cancellationToken = cancellationToken;
}

protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
SerializeToStreamAsync(stream, context, CancellationToken.None);

internal
#if HTTP_DLL
override
#endif
Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken)
{
Debug.Assert(stream != null);

Expand All @@ -39,7 +44,7 @@ protected override Task SerializeToStreamAsync(Stream stream, TransportContext c
_contentConsumed = true;

const int BufferSize = 8192;
Task copyTask = _content.CopyToAsync(stream, BufferSize, _cancellationToken);
Task copyTask = _content.CopyToAsync(stream, BufferSize, cancellationToken);
if (copyTask.IsCompleted)
{
try { _content.Dispose(); } catch { } // same as StreamToStreamCopy behavior
Expand Down Expand Up @@ -75,6 +80,10 @@ protected override void Dispose(bool disposing)
base.Dispose(disposing);
}

protected override Task<Stream> CreateContentReadStreamAsync() => Task.FromResult<Stream>(_content);
protected override Task<Stream> CreateContentReadStreamAsync() => Task.FromResult(_content);

#if HTTP_DLL
internal override Stream TryCreateContentReadStream() => _content;
#endif
}
}
4 changes: 3 additions & 1 deletion src/Common/src/System/Net/Logging/NetEventSource.Common.cs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,9 @@ private static void DebugValidateArg(FormattableString arg)
Debug.Assert(IsEnabled || arg == null, $"Should not be formatting FormattableString \"{arg}\" if tracing isn't enabled");
}

public static new bool IsEnabled => Log.IsEnabled();
public static new bool IsEnabled =>
Log.IsEnabled();
//true; // uncomment for debugging only

[NonEvent]
public static string IdOf(object value) => value != null ? value.GetType().Name + "#" + GetHashCode(value) : NullInstance;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public static HttpResponseMessage CreateResponseMessage(
}
}

response.Content = new NoWriteNoSeekStreamContent(decompressedStream, state.CancellationToken);
response.Content = new NoWriteNoSeekStreamContent(decompressedStream);
response.RequestMessage = request;

// Parse raw response headers and place them into response message.
Expand Down
6 changes: 3 additions & 3 deletions src/System.Net.Http/src/System.Net.Http.csproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="14.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" />
<PropertyGroup>
Expand Down Expand Up @@ -137,11 +137,11 @@
<Compile Include="System\Net\Http\SocketsHttpHandler\DecompressionHandler.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\EmptyReadStream.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnection.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionContent.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionHandler.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionKey.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionPool.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionPools.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionResponseContent.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionSettings.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\HttpContentDuplexStream.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\HttpContentReadStream.cs" />
Expand Down Expand Up @@ -464,4 +464,4 @@
<Reference Include="System.Security.Cryptography.Primitives" />
</ItemGroup>
<Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
</Project>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ internal CurlResponseMessage(EasyRequest easy)
Debug.Assert(easy != null, "Expected non-null EasyRequest");
RequestMessage = easy._requestMessage;
ResponseStream = new CurlResponseStream(easy);
Content = new NoWriteNoSeekStreamContent(ResponseStream, CancellationToken.None);
Content = new NoWriteNoSeekStreamContent(ResponseStream);

// On Windows, we pass the equivalent of the easy._cancellationToken
// in to StreamContent's ctor. This in turn passes that token through
Expand Down
2 changes: 1 addition & 1 deletion src/System.Net.Http/src/System/Net/Http/HttpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ private async Task<HttpResponseMessage> FinishSendAsyncBuffered(
// Buffer the response content if we've been asked to and we have a Content to buffer.
if (response.Content != null)
{
await response.Content.LoadIntoBufferAsync(_maxResponseContentBufferSize).ConfigureAwait(false);
await response.Content.LoadIntoBufferAsync(_maxResponseContentBufferSize, cts.Token).ConfigureAwait(false);
}

if (NetEventSource.IsEnabled) NetEventSource.ClientSendCompleted(this, response, request);
Expand Down
24 changes: 19 additions & 5 deletions src/System.Net.Http/src/System/Net/Http/HttpContent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,18 @@ internal Stream TryReadAsStream()

protected abstract Task SerializeToStreamAsync(Stream stream, TransportContext context);

public Task CopyToAsync(Stream stream, TransportContext context)
// TODO #9071: Expose this publicly. Until it's public, only sealed or internal types should override it, and then change
// their SerializeToStreamAsync implementation to delegate to this one. They need to be sealed as otherwise an external
// type could derive from it and override SerializeToStreamAsync(stream, context) further, at which point when
// HttpClient calls SerializeToStreamAsync(stream, context, cancellationToken), their custom override will be skipped.
internal virtual Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
SerializeToStreamAsync(stream, context);

public Task CopyToAsync(Stream stream, TransportContext context) =>
CopyToAsync(stream, context, CancellationToken.None);

// TODO #9071: Expose this publicly.
internal Task CopyToAsync(Stream stream, TransportContext context, CancellationToken cancellationToken)
{
CheckDisposed();
if (stream == null)
Expand All @@ -313,11 +324,11 @@ public Task CopyToAsync(Stream stream, TransportContext context)
ArraySegment<byte> buffer;
if (TryGetBuffer(out buffer))
{
task = stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count);
task = stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken);
}
else
{
task = SerializeToStreamAsync(stream, context);
task = SerializeToStreamAsync(stream, context, cancellationToken);
CheckTaskNotNull(task);
}

Expand Down Expand Up @@ -354,7 +365,10 @@ public Task LoadIntoBufferAsync()
// No "CancellationToken" parameter needed since canceling the CTS will close the connection, resulting
// in an exception being thrown while we're buffering.
// If buffering is used without a connection, it is supposed to be fast, thus no cancellation required.
public Task LoadIntoBufferAsync(long maxBufferSize)
public Task LoadIntoBufferAsync(long maxBufferSize) =>
LoadIntoBufferAsync(maxBufferSize, CancellationToken.None);

internal Task LoadIntoBufferAsync(long maxBufferSize, CancellationToken cancellationToken)
{
CheckDisposed();
if (maxBufferSize > HttpContent.MaxBufferSize)
Expand Down Expand Up @@ -382,7 +396,7 @@ public Task LoadIntoBufferAsync(long maxBufferSize)

try
{
Task task = SerializeToStreamAsync(tempBuffer, null);
Task task = SerializeToStreamAsync(tempBuffer, null, cancellationToken);
CheckTaskNotNull(task);
return LoadIntoBufferAsyncCore(task, tempBuffer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public void HeadersInvalidValue(string name, string rawValue) =>
[Event(HandlerMessageId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
public void HandlerMessage(int handlerId, int workerId, int requestId, string memberName, string message) =>
WriteEvent(HandlerMessageId, handlerId, workerId, requestId, memberName, message);
//Console.WriteLine($"{handlerId}/{workerId}/{requestId}: ({memberName}): {message}"); // uncomment for debugging only

[NonEvent]
private unsafe void WriteEvent(int eventId, int arg1, int arg2, int arg3, string arg4, string arg5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System.IO;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;

namespace System.Net.Http
Expand All @@ -26,6 +27,9 @@ public ReadOnlyMemoryContent(ReadOnlyMemory<byte> content)
protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
stream.WriteAsync(_content);

internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
stream.WriteAsync(_content, cancellationToken);

protected internal override bool TryComputeLength(out long length)
{
length = _content.Length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ public ChunkedEncodingReadStream(HttpConnection connection) : base(connection)
{
}

private async Task<bool> TryGetNextChunkAsync(CancellationToken cancellationToken)
private async Task<bool> TryGetNextChunkAsync()
{
Debug.Assert(_chunkBytesRemaining == 0);

// Read the start of the chunk line.
_connection._allowedReadLineBytes = MaxChunkBytesAllowed;
ArraySegment<byte> line = await _connection.ReadNextLineAsync(cancellationToken).ConfigureAwait(false);
ArraySegment<byte> line = await _connection.ReadNextLineAsync().ConfigureAwait(false);

// Parse the hex value.
if (!Utf8Parser.TryParse(line.AsReadOnlySpan(), out ulong chunkSize, out int bytesConsumed, 'X'))
Expand Down Expand Up @@ -73,7 +73,7 @@ private async Task<bool> TryGetNextChunkAsync(CancellationToken cancellationToke
while (true)
{
_connection._allowedReadLineBytes = MaxTrailingHeaderLength;
if (LineIsEmpty(await _connection.ReadNextLineAsync(cancellationToken).ConfigureAwait(false)))
if (LineIsEmpty(await _connection.ReadNextLineAsync().ConfigureAwait(false)))
{
break;
}
Expand All @@ -84,59 +84,77 @@ private async Task<bool> TryGetNextChunkAsync(CancellationToken cancellationToke
return false;
}

private async Task ConsumeChunkBytesAsync(ulong bytesConsumed, CancellationToken cancellationToken)
private Task ConsumeChunkBytesAsync(ulong bytesConsumed)
{
Debug.Assert(bytesConsumed <= _chunkBytesRemaining);
_chunkBytesRemaining -= bytesConsumed;
if (_chunkBytesRemaining == 0)
return _chunkBytesRemaining != 0 ?
Task.CompletedTask :
ReadNextLineAndThrowIfNotEmptyAsync();
}

private async Task ReadNextLineAndThrowIfNotEmptyAsync()
{
_connection._allowedReadLineBytes = 2; // \r\n
if (!LineIsEmpty(await _connection.ReadNextLineAsync().ConfigureAwait(false)))
{
_connection._allowedReadLineBytes = 2; // \r\n
if (!LineIsEmpty(await _connection.ReadNextLineAsync(cancellationToken).ConfigureAwait(false)))
{
ThrowInvalidHttpResponse();
}
ThrowInvalidHttpResponse();
}
}

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValidateBufferArgs(buffer, offset, count);
return ReadAsync(new Memory<byte>(buffer, offset, count)).AsTask();
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default)
public override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

if (_connection == null || destination.Length == 0)
{
// Response body fully consumed or the caller didn't ask for any data
return 0;
}

if (_chunkBytesRemaining == 0)
CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
try
{
if (!await TryGetNextChunkAsync(cancellationToken).ConfigureAwait(false))
if (_chunkBytesRemaining == 0)
{
// End of response body
return 0;
if (!await TryGetNextChunkAsync().ConfigureAwait(false))
{
// End of response body
return 0;
}
}
}

if (_chunkBytesRemaining < (ulong)destination.Length)
{
destination = destination.Slice(0, (int)_chunkBytesRemaining);
}
if (_chunkBytesRemaining < (ulong)destination.Length)
{
destination = destination.Slice(0, (int)_chunkBytesRemaining);
}

int bytesRead = await _connection.ReadAsync(destination, cancellationToken).ConfigureAwait(false);
int bytesRead = await _connection.ReadAsync(destination).ConfigureAwait(false);

if (bytesRead <= 0)
{
// Unexpected end of response stream
throw new IOException(SR.net_http_invalid_response);
}
if (bytesRead <= 0)
{
// Unexpected end of response stream
throw new IOException(SR.net_http_invalid_response);
}

await ConsumeChunkBytesAsync((ulong)bytesRead, cancellationToken).ConfigureAwait(false);
await ConsumeChunkBytesAsync((ulong)bytesRead).ConfigureAwait(false);

return bytesRead;
return bytesRead;
}
catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken))
{
throw new OperationCanceledException(s_cancellationMessage, exc, cancellationToken);
}
finally
{
ctr.Dispose();
}
}

public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
Expand All @@ -145,23 +163,41 @@ public override async Task CopyToAsync(Stream destination, int bufferSize, Cance
{
throw new ArgumentNullException(nameof(destination));
}
if (bufferSize <= 0)
{
throw new ArgumentOutOfRangeException(nameof(bufferSize));
}

cancellationToken.ThrowIfCancellationRequested();

if (_connection == null)
{
// Response body fully consumed
return;
}

if (_chunkBytesRemaining > 0)
CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
try
{
await _connection.CopyToAsync(destination, _chunkBytesRemaining, cancellationToken).ConfigureAwait(false);
await ConsumeChunkBytesAsync(_chunkBytesRemaining, cancellationToken).ConfigureAwait(false);
}
if (_chunkBytesRemaining > 0)
{
await _connection.CopyToAsync(destination, _chunkBytesRemaining).ConfigureAwait(false);
await ConsumeChunkBytesAsync(_chunkBytesRemaining).ConfigureAwait(false);
}

while (await TryGetNextChunkAsync(cancellationToken).ConfigureAwait(false))
while (await TryGetNextChunkAsync().ConfigureAwait(false))
{
await _connection.CopyToAsync(destination, _chunkBytesRemaining).ConfigureAwait(false);
await ConsumeChunkBytesAsync(_chunkBytesRemaining).ConfigureAwait(false);
}
}
catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken))
{
throw CreateOperationCanceledException(exc, cancellationToken);
}
finally
{
await _connection.CopyToAsync(destination, _chunkBytesRemaining, cancellationToken).ConfigureAwait(false);
await ConsumeChunkBytesAsync(_chunkBytesRemaining, cancellationToken).ConfigureAwait(false);
ctr.Dispose();
}
}
}
Expand Down
Loading

0 comments on commit 53be85c

Please sign in to comment.