From 50010977f58bfc1ffa60d2e83abfc8e08caa6d27 Mon Sep 17 00:00:00 2001 From: Lawrence Cheung <31262254+lcheunglci@users.noreply.github.com> Date: Fri, 31 Mar 2023 12:37:35 -0400 Subject: [PATCH] Fix | Adding disposable stack temp ref struct and use (#1980) --- .../src/Microsoft.Data.SqlClient.csproj | 3 + .../Microsoft/Data/SqlClient/SqlDataReader.cs | 136 +++++++++--------- .../netfx/src/Microsoft.Data.SqlClient.csproj | 3 + .../Microsoft/Data/SqlClient/SqlDataReader.cs | 112 ++++++++------- .../SqlClient/DisposableTemporaryOnStack.cs | 40 ++++++ 5 files changed, 174 insertions(+), 120 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 5b72e28c18..a7552f847f 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -115,6 +115,9 @@ Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs + + Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs + Microsoft\Data\SqlClient\EnclaveDelegate.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index f24f374644..4668d69d27 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4397,6 +4397,7 @@ private void AssertReaderState(bool requireData, bool permitAsync, int? columnIn public override Task NextResultAsync(CancellationToken cancellationToken) { using (TryEventScope.Create("SqlDataReader.NextResultAsync | API | Object Id {0}", ObjectID)) + using (var registrationHolder = new DisposableTemporaryOnStack()) { TaskCompletionSource source = new TaskCompletionSource(); @@ -4406,7 +4407,6 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - IDisposable registration = null; if (cancellationToken.CanBeCanceled) { if (cancellationToken.IsCancellationRequested) @@ -4414,7 +4414,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); } Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); @@ -4432,7 +4432,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registration)); + return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registrationHolder.Take())); } } @@ -4727,6 +4727,7 @@ out bytesRead public override Task ReadAsync(CancellationToken cancellationToken) { using (TryEventScope.Create("SqlDataReader.ReadAsync | API | Object Id {0}", ObjectID)) + using (var registrationHolder = new DisposableTemporaryOnStack()) { if (IsClosed) { @@ -4734,10 +4735,9 @@ public override Task ReadAsync(CancellationToken cancellationToken) } // Register first to catch any already expired tokens to be able to trigger cancellation event. - IDisposable registration = null; if (cancellationToken.CanBeCanceled) { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); } // If user's token is canceled, return a canceled task @@ -4850,7 +4850,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ReadAsyncCallContext was not properly disposed"); - context.Set(this, source, registration); + context.Set(this, source, registrationHolder.Take()); context._hasMoreData = more; context._hasReadRowToken = rowTokenRead; @@ -4988,49 +4988,51 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo return Task.FromException(ex); } - // Setup and check for pending task - TaskCompletionSource source = new TaskCompletionSource(); - Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); - if (original != null) + using (var registrationHolder = new DisposableTemporaryOnStack()) { - source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); - return source.Task; - } + // Setup and check for pending task + TaskCompletionSource source = new TaskCompletionSource(); + Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); + if (original != null) + { + source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); + return source.Task; + } - // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) - if (_cancelAsyncOnCloseToken.IsCancellationRequested) - { - source.SetCanceled(); - _currentTask = null; - return source.Task; - } + // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) + if (_cancelAsyncOnCloseToken.IsCancellationRequested) + { + source.SetCanceled(); + _currentTask = null; + return source.Task; + } - // Setup cancellations - IDisposable registration = null; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } + // Setup cancellations + if (cancellationToken.CanBeCanceled) + { + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); + } - IsDBNullAsyncCallContext context = null; - if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection) - { - context = Interlocked.Exchange(ref sqlInternalConnection.CachedDataReaderIsDBNullContext, null); - } - if (context is null) - { - context = new IsDBNullAsyncCallContext(); - } + IsDBNullAsyncCallContext context = null; + if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection) + { + context = Interlocked.Exchange(ref sqlInternalConnection.CachedDataReaderIsDBNullContext, null); + } + if (context is null) + { + context = new IsDBNullAsyncCallContext(); + } - Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ISDBNullAsync context not properly disposed"); + Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ISDBNullAsync context not properly disposed"); - context.Set(this, source, registration); - context._columnIndex = i; + context.Set(this, source, registrationHolder.Take()); + context._columnIndex = i; - // Setup async - PrepareAsyncInvocation(useSnapshot: true); + // Setup async + PrepareAsyncInvocation(useSnapshot: true); - return InvokeAsyncCall(context); + return InvokeAsyncCall(context); + } } } @@ -5135,37 +5137,39 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat return Task.FromException(ex); } - // Setup and check for pending task - TaskCompletionSource source = new TaskCompletionSource(); - Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); - if (original != null) + using (var registrationHolder = new DisposableTemporaryOnStack()) { - source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); - return source.Task; - } + // Setup and check for pending task + TaskCompletionSource source = new TaskCompletionSource(); + Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); + if (original != null) + { + source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); + return source.Task; + } - // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) - if (_cancelAsyncOnCloseToken.IsCancellationRequested) - { - source.SetCanceled(); - _currentTask = null; - return source.Task; - } + // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) + if (_cancelAsyncOnCloseToken.IsCancellationRequested) + { + source.SetCanceled(); + _currentTask = null; + return source.Task; + } - // Setup cancellations - IDisposable registration = null; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } + // Setup cancellations + if (cancellationToken.CanBeCanceled) + { + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); + } - // Setup async - PrepareAsyncInvocation(useSnapshot: true); + // Setup async + PrepareAsyncInvocation(useSnapshot: true); - GetFieldValueAsyncCallContext context = new GetFieldValueAsyncCallContext(this, source, registration); - context._columnIndex = i; + GetFieldValueAsyncCallContext context = new GetFieldValueAsyncCallContext(this, source, registrationHolder.Take()); + context._columnIndex = i; - return InvokeAsyncCall(context); + return InvokeAsyncCall(context); + } } private static Task GetFieldValueAsyncExecute(Task task, object state) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 12323365bf..617c8c3131 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -185,6 +185,9 @@ Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs + + Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs + Microsoft\Data\SqlClient\EnclaveDelegate.cs diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 090b0f7bfb..7c920fe4b6 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4983,6 +4983,7 @@ private void AssertReaderState(bool requireData, bool permitAsync, int? columnIn public override Task NextResultAsync(CancellationToken cancellationToken) { using (TryEventScope.Create(" {0}", ObjectID)) + using (var registrationHolder = new DisposableTemporaryOnStack()) { TaskCompletionSource source = new TaskCompletionSource(); @@ -4993,7 +4994,6 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - IDisposable registration = null; if (cancellationToken.CanBeCanceled) { if (cancellationToken.IsCancellationRequested) @@ -5001,7 +5001,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); } Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); @@ -5019,7 +5019,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registration)); + return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registrationHolder.Take())); } } @@ -5320,6 +5320,7 @@ out bytesRead public override Task ReadAsync(CancellationToken cancellationToken) { using (TryEventScope.Create(" {0}", ObjectID)) + using (var registrationHolder = new DisposableTemporaryOnStack()) { if (IsClosed) { @@ -5327,10 +5328,9 @@ public override Task ReadAsync(CancellationToken cancellationToken) } // Register first to catch any already expired tokens to be able to trigger cancellation event. - IDisposable registration = null; if (cancellationToken.CanBeCanceled) { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); } // If user's token is canceled, return a canceled task @@ -5436,7 +5436,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ReadAsyncCallContext was not properly disposed"); - context.Set(this, source, registration); + context.Set(this, source, registrationHolder.Take()); context._hasMoreData = more; context._hasReadRowToken = rowTokenRead; @@ -5568,41 +5568,43 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo return ADP.CreatedTaskWithException(ex); } - // Setup and check for pending task - TaskCompletionSource source = new TaskCompletionSource(); - Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); - if (original != null) + using (var registrationHolder = new DisposableTemporaryOnStack()) { - source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); - return source.Task; - } + // Setup and check for pending task + TaskCompletionSource source = new TaskCompletionSource(); + Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); + if (original != null) + { + source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); + return source.Task; + } - // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) - if (_cancelAsyncOnCloseToken.IsCancellationRequested) - { - source.SetCanceled(); - _currentTask = null; - return source.Task; - } + // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) + if (_cancelAsyncOnCloseToken.IsCancellationRequested) + { + source.SetCanceled(); + _currentTask = null; + return source.Task; + } - // Setup cancellations - IDisposable registration = null; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } + // Setup cancellations + if (cancellationToken.CanBeCanceled) + { + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); + } - IsDBNullAsyncCallContext context = Interlocked.Exchange(ref _cachedIsDBNullContext, null) ?? new IsDBNullAsyncCallContext(); + IsDBNullAsyncCallContext context = Interlocked.Exchange(ref _cachedIsDBNullContext, null) ?? new IsDBNullAsyncCallContext(); - Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ISDBNullAsync context not properly disposed"); + Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ISDBNullAsync context not properly disposed"); - context.Set(this, source, registration); - context._columnIndex = i; + context.Set(this, source, registrationHolder.Take()); + context._columnIndex = i; - // Setup async - PrepareAsyncInvocation(useSnapshot: true); + // Setup async + PrepareAsyncInvocation(useSnapshot: true); - return InvokeAsyncCall(context); + return InvokeAsyncCall(context); + } } } @@ -5704,31 +5706,33 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat return ADP.CreatedTaskWithException(ex); } - // Setup and check for pending task - TaskCompletionSource source = new TaskCompletionSource(); - Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); - if (original != null) + using (var registrationHolder = new DisposableTemporaryOnStack()) { - source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); - return source.Task; - } + // Setup and check for pending task + TaskCompletionSource source = new TaskCompletionSource(); + Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); + if (original != null) + { + source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); + return source.Task; + } - // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) - if (_cancelAsyncOnCloseToken.IsCancellationRequested) - { - source.SetCanceled(); - _currentTask = null; - return source.Task; - } + // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) + if (_cancelAsyncOnCloseToken.IsCancellationRequested) + { + source.SetCanceled(); + _currentTask = null; + return source.Task; + } - // Setup cancellations - IDisposable registration = null; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } + // Setup cancellations + if (cancellationToken.CanBeCanceled) + { + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); + } - return InvokeAsyncCall(new GetFieldValueAsyncCallContext(this, source, registration, i)); + return InvokeAsyncCall(new GetFieldValueAsyncCallContext(this, source, registrationHolder.Take(), i)); + } } private static Task GetFieldValueAsyncExecute(Task task, object state) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs new file mode 100644 index 0000000000..b57b80d78f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + + +using System; + +namespace Microsoft.Data.SqlClient +{ + internal ref struct DisposableTemporaryOnStack + where T : IDisposable + { + private T _value; + private bool _hasValue; + + public void Set(T value) + { + _value = value; + _hasValue = true; + } + + public T Take() + { + T value = _value; + _value = default; + _hasValue = false; + return value; + } + + public void Dispose() + { + if (_hasValue) + { + _value.Dispose(); + _value = default; + _hasValue = false; + } + } + } +}