diff --git a/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj b/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj index c60ca96e13c2..d227f5494786 100644 --- a/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj +++ b/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj @@ -270,6 +270,10 @@ + + + + diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs new file mode 100644 index 000000000000..6e5cab47e390 --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs @@ -0,0 +1,117 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Data.SqlClient.SNI +{ + internal partial class SNIPacket + { + /// + /// Read data from a stream asynchronously + /// + /// Stream to read from + /// Completion callback + public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback) + { + // Treat local function as a static and pass all params otherwise as async will allocate + async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask valueTask) + { + bool error = false; + try + { + packet._length = await valueTask.ConfigureAwait(false); + if (packet._length == 0) + { + SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty); + error = true; + } + } + catch (Exception ex) + { + SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, ex); + error = true; + } + + if (error) + { + packet.Release(); + } + + cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS); + } + + ValueTask vt = stream.ReadAsync(new Memory(_data, 0, _capacity), CancellationToken.None); + + if (vt.IsCompletedSuccessfully) + { + _length = vt.Result; + // Zero length to go via async local function as is error condition + if (_length > 0) + { + callback(this, TdsEnums.SNI_SUCCESS); + + // Completed + return; + } + } + + // Not complete or error call the async local function to complete + _ = ReadFromStreamAsync(this, callback, vt); + } + + /// + /// Write data to a stream asynchronously + /// + /// Stream to write to + public void WriteToStreamAsync(Stream stream, SNIAsyncCallback callback, SNIProviders provider, bool disposeAfterWriteAsync = false) + { + // Treat local function as a static and pass all params otherwise as async will allocate + async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool disposeAfter, ValueTask valueTask) + { + uint status = TdsEnums.SNI_SUCCESS; + try + { + await valueTask.ConfigureAwait(false); + } + catch (Exception e) + { + SNILoadHandle.SingletonInstance.LastError = new SNIError(providers, SNICommon.InternalExceptionError, e); + status = TdsEnums.SNI_ERROR; + } + + cb(packet, status); + + if (disposeAfter) + { + packet.Dispose(); + } + } + + ValueTask vt = stream.WriteAsync(new Memory(_data, 0, _length), CancellationToken.None); + + if (vt.IsCompletedSuccessfully) + { + // Read the result to register as complete for the ValueTask + vt.GetAwaiter().GetResult(); + + callback(this, TdsEnums.SNI_SUCCESS); + + if (disposeAfterWriteAsync) + { + Dispose(); + } + + // Completed + return; + } + + // Not complete or error call the async local function to complete + _ = WriteToStreamAsync(this, callback, provider, disposeAfterWriteAsync, vt); + } + } +} diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs new file mode 100644 index 000000000000..bfa48ac17b61 --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs @@ -0,0 +1,117 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Data.SqlClient.SNI +{ + internal partial class SNIPacket + { + /// + /// Read data from a stream asynchronously + /// + /// Stream to read from + /// Completion callback + public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback) + { + // Treat local function as a static and pass all params otherwise as async will allocate + async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task task) + { + bool error = false; + try + { + packet._length = await task.ConfigureAwait(false); + if (packet._length == 0) + { + SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty); + error = true; + } + } + catch (Exception ex) + { + SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, ex); + error = true; + } + + if (error) + { + packet.Release(); + } + + cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS); + } + + Task t = stream.ReadAsync(_data, 0, _capacity, CancellationToken.None); + + if ((t.Status & TaskStatus.RanToCompletion) != 0) + { + _length = t.Result; + // Zero length to go via async local function as is error condition + if (_length > 0) + { + callback(this, TdsEnums.SNI_SUCCESS); + + // Completed + return; + } + } + + // Not complete or error call the async local function to complete + _ = ReadFromStreamAsync(this, callback, t); + } + + /// + /// Write data to a stream asynchronously + /// + /// Stream to write to + public void WriteToStreamAsync(Stream stream, SNIAsyncCallback callback, SNIProviders provider, bool disposeAfterWriteAsync = false) + { + // Treat local function as a static and pass all params otherwise as async will allocate + async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool disposeAfter, Task task) + { + uint status = TdsEnums.SNI_SUCCESS; + try + { + await task.ConfigureAwait(false); + } + catch (Exception e) + { + SNILoadHandle.SingletonInstance.LastError = new SNIError(providers, SNICommon.InternalExceptionError, e); + status = TdsEnums.SNI_ERROR; + } + + cb(packet, status); + + if (disposeAfter) + { + packet.Dispose(); + } + } + + Task t = stream.WriteAsync(_data, 0, _length, CancellationToken.None); + + if ((t.Status & TaskStatus.RanToCompletion) != 0) + { + // Read the result to register as complete for the Task + t.GetAwaiter().GetResult(); + + callback(this, TdsEnums.SNI_SUCCESS); + + if (disposeAfterWriteAsync) + { + Dispose(); + } + + // Completed + return; + } + + // Not complete or error call the async local function to complete + _ = WriteToStreamAsync(this, callback, provider, disposeAfterWriteAsync, t); + } + } +} diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs index f7ba249b066f..931d064b4cef 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs @@ -12,7 +12,7 @@ namespace System.Data.SqlClient.SNI /// /// SNI Packet /// - internal class SNIPacket : IDisposable, IEquatable + internal partial class SNIPacket : IDisposable, IEquatable { private byte[] _data; private int _length; @@ -240,46 +240,6 @@ public void Reset() _completionCallback = null; } - /// - /// Read data from a stream asynchronously - /// - /// Stream to read from - /// Completion callback - public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback) - { - bool error = false; - - stream.ReadAsync(_data, 0, _capacity, CancellationToken.None).ContinueWith(t => - { - Exception e = t.Exception?.InnerException; - if (e != null) - { - SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, e); - error = true; - } - else - { - _length = t.Result; - - if (_length == 0) - { - SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty); - error = true; - } - } - - if (error) - { - Release(); - } - - callback(this, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS); - }, - CancellationToken.None, - TaskContinuationOptions.DenyChildAttach, - TaskScheduler.Default); - } - /// /// Read data from a stream synchronously /// @@ -298,30 +258,6 @@ public void WriteToStream(Stream stream) stream.Write(_data, 0, _length); } - /// - /// Write data to a stream asynchronously - /// - /// Stream to write to - public async void WriteToStreamAsync(Stream stream, SNIAsyncCallback callback, SNIProviders provider, bool disposeAfterWriteAsync = false) - { - uint status = TdsEnums.SNI_SUCCESS; - try - { - await stream.WriteAsync(_data, 0, _length, CancellationToken.None).ConfigureAwait(false); - } - catch (Exception e) - { - SNILoadHandle.SingletonInstance.LastError = new SNIError(provider, SNICommon.InternalExceptionError, e); - status = TdsEnums.SNI_ERROR; - } - callback(this, status); - - if (disposeAfterWriteAsync) - { - Dispose(); - } - } - /// /// Get hash code ///