Skip to content

Commit

Permalink
Fix | Adding disposable stack temp ref struct and use (#1980)
Browse files Browse the repository at this point in the history
  • Loading branch information
lcheunglci authored Mar 31, 2023
1 parent 9bd90e6 commit 5001097
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@
<Compile Include="..\..\src\Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs">
<Link>Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs">
<Link>Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\EnclaveDelegate.cs">
<Link>Microsoft\Data\SqlClient\EnclaveDelegate.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4397,6 +4397,7 @@ private void AssertReaderState(bool requireData, bool permitAsync, int? columnIn
public override Task<bool> NextResultAsync(CancellationToken cancellationToken)
{
using (TryEventScope.Create("SqlDataReader.NextResultAsync | API | Object Id {0}", ObjectID))
using (var registrationHolder = new DisposableTemporaryOnStack<CancellationTokenRegistration>())
{
TaskCompletionSource<bool> source = new TaskCompletionSource<bool>();

Expand All @@ -4406,15 +4407,14 @@ public override Task<bool> NextResultAsync(CancellationToken cancellationToken)
return source.Task;
}

IDisposable registration = null;
if (cancellationToken.CanBeCanceled)
{
if (cancellationToken.IsCancellationRequested)
{
source.SetCanceled();
return source.Task;
}
registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}

Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
Expand All @@ -4432,7 +4432,7 @@ public override Task<bool> NextResultAsync(CancellationToken cancellationToken)
return source.Task;
}

return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registration));
return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registrationHolder.Take()));
}
}

Expand Down Expand Up @@ -4727,17 +4727,17 @@ out bytesRead
public override Task<bool> ReadAsync(CancellationToken cancellationToken)
{
using (TryEventScope.Create("SqlDataReader.ReadAsync | API | Object Id {0}", ObjectID))
using (var registrationHolder = new DisposableTemporaryOnStack<CancellationTokenRegistration>())
{
if (IsClosed)
{
return Task.FromException<bool>(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed()));
}

// Register first to catch any already expired tokens to be able to trigger cancellation event.
IDisposable registration = null;
if (cancellationToken.CanBeCanceled)
{
registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}

// If user's token is canceled, return a canceled task
Expand Down Expand Up @@ -4850,7 +4850,7 @@ public override Task<bool> ReadAsync(CancellationToken cancellationToken)

Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ReadAsyncCallContext was not properly disposed");

context.Set(this, source, registration);
context.Set(this, source, registrationHolder.Take());
context._hasMoreData = more;
context._hasReadRowToken = rowTokenRead;

Expand Down Expand Up @@ -4988,49 +4988,51 @@ override public Task<bool> IsDBNullAsync(int i, CancellationToken cancellationTo
return Task.FromException<bool>(ex);
}

// Setup and check for pending task
TaskCompletionSource<bool> source = new TaskCompletionSource<bool>();
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
if (original != null)
using (var registrationHolder = new DisposableTemporaryOnStack<CancellationTokenRegistration>())
{
source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
return source.Task;
}
// Setup and check for pending task
TaskCompletionSource<bool> source = new TaskCompletionSource<bool>();
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
if (original != null)
{
source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
return source.Task;
}

// Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
if (_cancelAsyncOnCloseToken.IsCancellationRequested)
{
source.SetCanceled();
_currentTask = null;
return source.Task;
}
// Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
if (_cancelAsyncOnCloseToken.IsCancellationRequested)
{
source.SetCanceled();
_currentTask = null;
return source.Task;
}

// Setup cancellations
IDisposable registration = null;
if (cancellationToken.CanBeCanceled)
{
registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
}
// Setup cancellations
if (cancellationToken.CanBeCanceled)
{
registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}

IsDBNullAsyncCallContext context = null;
if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection)
{
context = Interlocked.Exchange(ref sqlInternalConnection.CachedDataReaderIsDBNullContext, null);
}
if (context is null)
{
context = new IsDBNullAsyncCallContext();
}
IsDBNullAsyncCallContext context = null;
if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection)
{
context = Interlocked.Exchange(ref sqlInternalConnection.CachedDataReaderIsDBNullContext, null);
}
if (context is null)
{
context = new IsDBNullAsyncCallContext();
}

Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ISDBNullAsync context not properly disposed");
Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ISDBNullAsync context not properly disposed");

context.Set(this, source, registration);
context._columnIndex = i;
context.Set(this, source, registrationHolder.Take());
context._columnIndex = i;

// Setup async
PrepareAsyncInvocation(useSnapshot: true);
// Setup async
PrepareAsyncInvocation(useSnapshot: true);

return InvokeAsyncCall(context);
return InvokeAsyncCall(context);
}
}
}

Expand Down Expand Up @@ -5135,37 +5137,39 @@ override public Task<T> GetFieldValueAsync<T>(int i, CancellationToken cancellat
return Task.FromException<T>(ex);
}

// Setup and check for pending task
TaskCompletionSource<T> source = new TaskCompletionSource<T>();
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
if (original != null)
using (var registrationHolder = new DisposableTemporaryOnStack<CancellationTokenRegistration>())
{
source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
return source.Task;
}
// Setup and check for pending task
TaskCompletionSource<T> source = new TaskCompletionSource<T>();
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
if (original != null)
{
source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
return source.Task;
}

// Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
if (_cancelAsyncOnCloseToken.IsCancellationRequested)
{
source.SetCanceled();
_currentTask = null;
return source.Task;
}
// Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
if (_cancelAsyncOnCloseToken.IsCancellationRequested)
{
source.SetCanceled();
_currentTask = null;
return source.Task;
}

// Setup cancellations
IDisposable registration = null;
if (cancellationToken.CanBeCanceled)
{
registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
}
// Setup cancellations
if (cancellationToken.CanBeCanceled)
{
registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}

// Setup async
PrepareAsyncInvocation(useSnapshot: true);
// Setup async
PrepareAsyncInvocation(useSnapshot: true);

GetFieldValueAsyncCallContext<T> context = new GetFieldValueAsyncCallContext<T>(this, source, registration);
context._columnIndex = i;
GetFieldValueAsyncCallContext<T> context = new GetFieldValueAsyncCallContext<T>(this, source, registrationHolder.Take());
context._columnIndex = i;

return InvokeAsyncCall(context);
return InvokeAsyncCall(context);
}
}

private static Task<T> GetFieldValueAsyncExecute<T>(Task task, object state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@
<Compile Include="..\..\src\Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs">
<Link>Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs">
<Link>Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\EnclaveDelegate.cs">
<Link>Microsoft\Data\SqlClient\EnclaveDelegate.cs</Link>
</Compile>
Expand Down
Loading

0 comments on commit 5001097

Please sign in to comment.