From 50010977f58bfc1ffa60d2e83abfc8e08caa6d27 Mon Sep 17 00:00:00 2001
From: Lawrence Cheung <31262254+lcheunglci@users.noreply.github.com>
Date: Fri, 31 Mar 2023 12:37:35 -0400
Subject: [PATCH] Fix | Adding disposable stack temp ref struct and use (#1980)
---
.../src/Microsoft.Data.SqlClient.csproj | 3 +
.../Microsoft/Data/SqlClient/SqlDataReader.cs | 136 +++++++++---------
.../netfx/src/Microsoft.Data.SqlClient.csproj | 3 +
.../Microsoft/Data/SqlClient/SqlDataReader.cs | 112 ++++++++-------
.../SqlClient/DisposableTemporaryOnStack.cs | 40 ++++++
5 files changed, 174 insertions(+), 120 deletions(-)
create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
index 5b72e28c18..a7552f847f 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
@@ -115,6 +115,9 @@
Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs
+
+ Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs
+
Microsoft\Data\SqlClient\EnclaveDelegate.cs
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs
index f24f374644..4668d69d27 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs
@@ -4397,6 +4397,7 @@ private void AssertReaderState(bool requireData, bool permitAsync, int? columnIn
public override Task NextResultAsync(CancellationToken cancellationToken)
{
using (TryEventScope.Create("SqlDataReader.NextResultAsync | API | Object Id {0}", ObjectID))
+ using (var registrationHolder = new DisposableTemporaryOnStack())
{
TaskCompletionSource source = new TaskCompletionSource();
@@ -4406,7 +4407,6 @@ public override Task NextResultAsync(CancellationToken cancellationToken)
return source.Task;
}
- IDisposable registration = null;
if (cancellationToken.CanBeCanceled)
{
if (cancellationToken.IsCancellationRequested)
@@ -4414,7 +4414,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken)
source.SetCanceled();
return source.Task;
}
- registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
+ registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
@@ -4432,7 +4432,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken)
return source.Task;
}
- return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registration));
+ return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registrationHolder.Take()));
}
}
@@ -4727,6 +4727,7 @@ out bytesRead
public override Task ReadAsync(CancellationToken cancellationToken)
{
using (TryEventScope.Create("SqlDataReader.ReadAsync | API | Object Id {0}", ObjectID))
+ using (var registrationHolder = new DisposableTemporaryOnStack())
{
if (IsClosed)
{
@@ -4734,10 +4735,9 @@ public override Task ReadAsync(CancellationToken cancellationToken)
}
// Register first to catch any already expired tokens to be able to trigger cancellation event.
- IDisposable registration = null;
if (cancellationToken.CanBeCanceled)
{
- registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
+ registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}
// If user's token is canceled, return a canceled task
@@ -4850,7 +4850,7 @@ public override Task ReadAsync(CancellationToken cancellationToken)
Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ReadAsyncCallContext was not properly disposed");
- context.Set(this, source, registration);
+ context.Set(this, source, registrationHolder.Take());
context._hasMoreData = more;
context._hasReadRowToken = rowTokenRead;
@@ -4988,49 +4988,51 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo
return Task.FromException(ex);
}
- // Setup and check for pending task
- TaskCompletionSource source = new TaskCompletionSource();
- Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
- if (original != null)
+ using (var registrationHolder = new DisposableTemporaryOnStack())
{
- source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
- return source.Task;
- }
+ // Setup and check for pending task
+ TaskCompletionSource source = new TaskCompletionSource();
+ Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
+ if (original != null)
+ {
+ source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
+ return source.Task;
+ }
- // Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
- if (_cancelAsyncOnCloseToken.IsCancellationRequested)
- {
- source.SetCanceled();
- _currentTask = null;
- return source.Task;
- }
+ // Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
+ if (_cancelAsyncOnCloseToken.IsCancellationRequested)
+ {
+ source.SetCanceled();
+ _currentTask = null;
+ return source.Task;
+ }
- // Setup cancellations
- IDisposable registration = null;
- if (cancellationToken.CanBeCanceled)
- {
- registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
- }
+ // Setup cancellations
+ if (cancellationToken.CanBeCanceled)
+ {
+ registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
+ }
- IsDBNullAsyncCallContext context = null;
- if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection)
- {
- context = Interlocked.Exchange(ref sqlInternalConnection.CachedDataReaderIsDBNullContext, null);
- }
- if (context is null)
- {
- context = new IsDBNullAsyncCallContext();
- }
+ IsDBNullAsyncCallContext context = null;
+ if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection)
+ {
+ context = Interlocked.Exchange(ref sqlInternalConnection.CachedDataReaderIsDBNullContext, null);
+ }
+ if (context is null)
+ {
+ context = new IsDBNullAsyncCallContext();
+ }
- Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ISDBNullAsync context not properly disposed");
+ Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ISDBNullAsync context not properly disposed");
- context.Set(this, source, registration);
- context._columnIndex = i;
+ context.Set(this, source, registrationHolder.Take());
+ context._columnIndex = i;
- // Setup async
- PrepareAsyncInvocation(useSnapshot: true);
+ // Setup async
+ PrepareAsyncInvocation(useSnapshot: true);
- return InvokeAsyncCall(context);
+ return InvokeAsyncCall(context);
+ }
}
}
@@ -5135,37 +5137,39 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat
return Task.FromException(ex);
}
- // Setup and check for pending task
- TaskCompletionSource source = new TaskCompletionSource();
- Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
- if (original != null)
+ using (var registrationHolder = new DisposableTemporaryOnStack())
{
- source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
- return source.Task;
- }
+ // Setup and check for pending task
+ TaskCompletionSource source = new TaskCompletionSource();
+ Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
+ if (original != null)
+ {
+ source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
+ return source.Task;
+ }
- // Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
- if (_cancelAsyncOnCloseToken.IsCancellationRequested)
- {
- source.SetCanceled();
- _currentTask = null;
- return source.Task;
- }
+ // Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
+ if (_cancelAsyncOnCloseToken.IsCancellationRequested)
+ {
+ source.SetCanceled();
+ _currentTask = null;
+ return source.Task;
+ }
- // Setup cancellations
- IDisposable registration = null;
- if (cancellationToken.CanBeCanceled)
- {
- registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
- }
+ // Setup cancellations
+ if (cancellationToken.CanBeCanceled)
+ {
+ registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
+ }
- // Setup async
- PrepareAsyncInvocation(useSnapshot: true);
+ // Setup async
+ PrepareAsyncInvocation(useSnapshot: true);
- GetFieldValueAsyncCallContext context = new GetFieldValueAsyncCallContext(this, source, registration);
- context._columnIndex = i;
+ GetFieldValueAsyncCallContext context = new GetFieldValueAsyncCallContext(this, source, registrationHolder.Take());
+ context._columnIndex = i;
- return InvokeAsyncCall(context);
+ return InvokeAsyncCall(context);
+ }
}
private static Task GetFieldValueAsyncExecute(Task task, object state)
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
index 12323365bf..617c8c3131 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
@@ -185,6 +185,9 @@
Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs
+
+ Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs
+
Microsoft\Data\SqlClient\EnclaveDelegate.cs
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs
index 090b0f7bfb..7c920fe4b6 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs
@@ -4983,6 +4983,7 @@ private void AssertReaderState(bool requireData, bool permitAsync, int? columnIn
public override Task NextResultAsync(CancellationToken cancellationToken)
{
using (TryEventScope.Create(" {0}", ObjectID))
+ using (var registrationHolder = new DisposableTemporaryOnStack())
{
TaskCompletionSource source = new TaskCompletionSource();
@@ -4993,7 +4994,6 @@ public override Task NextResultAsync(CancellationToken cancellationToken)
return source.Task;
}
- IDisposable registration = null;
if (cancellationToken.CanBeCanceled)
{
if (cancellationToken.IsCancellationRequested)
@@ -5001,7 +5001,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken)
source.SetCanceled();
return source.Task;
}
- registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
+ registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
@@ -5019,7 +5019,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken)
return source.Task;
}
- return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registration));
+ return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registrationHolder.Take()));
}
}
@@ -5320,6 +5320,7 @@ out bytesRead
public override Task ReadAsync(CancellationToken cancellationToken)
{
using (TryEventScope.Create(" {0}", ObjectID))
+ using (var registrationHolder = new DisposableTemporaryOnStack())
{
if (IsClosed)
{
@@ -5327,10 +5328,9 @@ public override Task ReadAsync(CancellationToken cancellationToken)
}
// Register first to catch any already expired tokens to be able to trigger cancellation event.
- IDisposable registration = null;
if (cancellationToken.CanBeCanceled)
{
- registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
+ registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}
// If user's token is canceled, return a canceled task
@@ -5436,7 +5436,7 @@ public override Task ReadAsync(CancellationToken cancellationToken)
Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ReadAsyncCallContext was not properly disposed");
- context.Set(this, source, registration);
+ context.Set(this, source, registrationHolder.Take());
context._hasMoreData = more;
context._hasReadRowToken = rowTokenRead;
@@ -5568,41 +5568,43 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo
return ADP.CreatedTaskWithException(ex);
}
- // Setup and check for pending task
- TaskCompletionSource source = new TaskCompletionSource();
- Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
- if (original != null)
+ using (var registrationHolder = new DisposableTemporaryOnStack())
{
- source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
- return source.Task;
- }
+ // Setup and check for pending task
+ TaskCompletionSource source = new TaskCompletionSource();
+ Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
+ if (original != null)
+ {
+ source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
+ return source.Task;
+ }
- // Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
- if (_cancelAsyncOnCloseToken.IsCancellationRequested)
- {
- source.SetCanceled();
- _currentTask = null;
- return source.Task;
- }
+ // Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
+ if (_cancelAsyncOnCloseToken.IsCancellationRequested)
+ {
+ source.SetCanceled();
+ _currentTask = null;
+ return source.Task;
+ }
- // Setup cancellations
- IDisposable registration = null;
- if (cancellationToken.CanBeCanceled)
- {
- registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
- }
+ // Setup cancellations
+ if (cancellationToken.CanBeCanceled)
+ {
+ registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
+ }
- IsDBNullAsyncCallContext context = Interlocked.Exchange(ref _cachedIsDBNullContext, null) ?? new IsDBNullAsyncCallContext();
+ IsDBNullAsyncCallContext context = Interlocked.Exchange(ref _cachedIsDBNullContext, null) ?? new IsDBNullAsyncCallContext();
- Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ISDBNullAsync context not properly disposed");
+ Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ISDBNullAsync context not properly disposed");
- context.Set(this, source, registration);
- context._columnIndex = i;
+ context.Set(this, source, registrationHolder.Take());
+ context._columnIndex = i;
- // Setup async
- PrepareAsyncInvocation(useSnapshot: true);
+ // Setup async
+ PrepareAsyncInvocation(useSnapshot: true);
- return InvokeAsyncCall(context);
+ return InvokeAsyncCall(context);
+ }
}
}
@@ -5704,31 +5706,33 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat
return ADP.CreatedTaskWithException(ex);
}
- // Setup and check for pending task
- TaskCompletionSource source = new TaskCompletionSource();
- Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
- if (original != null)
+ using (var registrationHolder = new DisposableTemporaryOnStack())
{
- source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
- return source.Task;
- }
+ // Setup and check for pending task
+ TaskCompletionSource source = new TaskCompletionSource();
+ Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
+ if (original != null)
+ {
+ source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
+ return source.Task;
+ }
- // Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
- if (_cancelAsyncOnCloseToken.IsCancellationRequested)
- {
- source.SetCanceled();
- _currentTask = null;
- return source.Task;
- }
+ // Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
+ if (_cancelAsyncOnCloseToken.IsCancellationRequested)
+ {
+ source.SetCanceled();
+ _currentTask = null;
+ return source.Task;
+ }
- // Setup cancellations
- IDisposable registration = null;
- if (cancellationToken.CanBeCanceled)
- {
- registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
- }
+ // Setup cancellations
+ if (cancellationToken.CanBeCanceled)
+ {
+ registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
+ }
- return InvokeAsyncCall(new GetFieldValueAsyncCallContext(this, source, registration, i));
+ return InvokeAsyncCall(new GetFieldValueAsyncCallContext(this, source, registrationHolder.Take(), i));
+ }
}
private static Task GetFieldValueAsyncExecute(Task task, object state)
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs
new file mode 100644
index 0000000000..b57b80d78f
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs
@@ -0,0 +1,40 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+
+using System;
+
+namespace Microsoft.Data.SqlClient
+{
+ internal ref struct DisposableTemporaryOnStack
+ where T : IDisposable
+ {
+ private T _value;
+ private bool _hasValue;
+
+ public void Set(T value)
+ {
+ _value = value;
+ _hasValue = true;
+ }
+
+ public T Take()
+ {
+ T value = _value;
+ _value = default;
+ _hasValue = false;
+ return value;
+ }
+
+ public void Dispose()
+ {
+ if (_hasValue)
+ {
+ _value.Dispose();
+ _value = default;
+ _hasValue = false;
+ }
+ }
+ }
+}