Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use an IBufferWriter<byte> to write the outgoing SSPI blob #2452

Merged
merged 25 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1ec663f
Use an IBufferWriter<byte> to write the outgoing SSPI blob
twsouthwick Apr 8, 2024
7c3335e
Merge branch 'main' into sspi-writer
twsouthwick Jun 27, 2024
4c3c3f2
Merge remote-tracking branch 'upstream/main' into sspi-writer
twsouthwick Jul 10, 2024
d3bd2b9
fix
twsouthwick Jul 10, 2024
266cf7f
switch to span
twsouthwick May 7, 2024
7d99053
add return
twsouthwick Aug 19, 2024
7e4d15f
Merge remote-tracking branch 'upstream/main' into sspi-writer
twsouthwick Aug 19, 2024
09db047
revert
twsouthwick Aug 19, 2024
a49abc3
use return
twsouthwick Aug 20, 2024
a1703d8
inline
twsouthwick Aug 20, 2024
1781d92
Merge remote-tracking branch 'upstream/main' into sspi-writer
twsouthwick Nov 14, 2024
c41dec3
merge main
twsouthwick Nov 14, 2024
4def406
remove unneeded if/def
twsouthwick Nov 15, 2024
669cc6a
Merge remote-tracking branch 'origin/main' into sspi-writer
twsouthwick Jan 31, 2025
1ba12e6
react to ISniNativeMethods
twsouthwick Jan 31, 2025
c4d3dbe
add to strings.resx
twsouthwick Jan 31, 2025
b00a17c
revert other changes to string designer
twsouthwick Jan 31, 2025
43ec92d
move write methods that are the same to shared file
twsouthwick Jan 31, 2025
7b7ee36
make sure to use correct length
twsouthwick Feb 3, 2025
4f95969
Merge branch 'main' into sspi-writer
twsouthwick Feb 11, 2025
e1986ac
Merge remote-tracking branch 'origin/main' into sspi-writer
twsouthwick Feb 12, 2025
130cdef
put method back
twsouthwick Feb 12, 2025
f40013e
Add comment for pool
twsouthwick Feb 13, 2025
31e3ec9
Add note about file origin and editing restrictions
twsouthwick Feb 13, 2025
e17288c
Merge branch 'main' into sspi-writer
twsouthwick Feb 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8500,8 +8500,7 @@ private void WriteLoginData(SqlLogin rec,
int length,
int featureExOffset,
string clientInterfaceName,
byte[] outSSPIBuff,
uint outSSPILength)
ReadOnlySpan<byte> outSSPI)
{
try
{
Expand Down Expand Up @@ -8673,8 +8672,8 @@ private void WriteLoginData(SqlLogin rec,
WriteShort(offset, _physicalStateObj); // ibSSPI offset
if (rec.useSSPI)
{
WriteShort((int)outSSPILength, _physicalStateObj);
offset += (int)outSSPILength;
WriteShort(outSSPI.Length, _physicalStateObj);
offset += outSSPI.Length;
}
else
{
Expand Down Expand Up @@ -8729,7 +8728,7 @@ private void WriteLoginData(SqlLogin rec,

// send over SSPI data if we are using SSPI
if (rec.useSSPI)
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
_physicalStateObj.WriteByteSpan(outSSPI);

WriteString(rec.attachDBFilename, _physicalStateObj);
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -793,143 +793,6 @@ internal void WriteByte(byte b)
_outBuff[_outBytesUsed++] = b;
}

internal Task WriteByteSpan(ReadOnlySpan<byte> span, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is removed in the netcore/netfx specific files, and the netcore implementation is now in the shared TdsParserStateObject class

{
return WriteBytes(span, span.Length, 0, canAccumulate, completion);
}

internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
{
return WriteBytes(ReadOnlySpan<byte>.Empty, len, offsetBuffer, canAccumulate, completion, b);
}

//
// Takes a span or a byte array and writes it to the buffer
// If you pass in a span and a null array then the span wil be used.
// If you pass in a non-null array then the array will be used and the span is ignored.
// if the span cannot be written into the current packet then the remaining contents of the span are copied to a
// new heap allocated array that will used to callback into the method to continue the write operation.
private Task WriteBytes(ReadOnlySpan<byte> b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null, byte[] array = null)
{
if (array != null)
{
b = new ReadOnlySpan<byte>(array, offsetBuffer, len);
}
try
{
bool async = _parser._asyncWrite; // NOTE: We are capturing this now for the assert after the Task is returned, since WritePacket will turn off async if there is an exception
Debug.Assert(async || _asyncWriteCount == 0);
// Do we have to send out in packet size chunks, or can we rely on netlib layer to break it up?
// would prefer to do something like:
//
// if (len > what we have room for || len > out buf)
// flush buffer
// UnsafeNativeMethods.Write(b)
//

int offset = offsetBuffer;

Debug.Assert(b.Length >= len, "Invalid length sent to WriteBytes()!");

// loop through and write the entire array
do
{
if ((_outBytesUsed + len) > _outBuff.Length)
{
// If the remainder of the data won't fit into the buffer, then we have to put
// whatever we can into the buffer, and flush that so we can then put more into
// the buffer on the next loop of the while.

int remainder = _outBuff.Length - _outBytesUsed;

// write the remainder
Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, remainder);
ReadOnlySpan<byte> copyFrom = b.Slice(0, remainder);

Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length{copyFrom.Length:D} should be the same");

copyFrom.CopyTo(copyTo);

offset += remainder;
_outBytesUsed += remainder;
len -= remainder;
b = b.Slice(remainder, len);

Task packetTask = WritePacket(TdsEnums.SOFTFLUSH, canAccumulate);

if (packetTask != null)
{
Task task = null;
Debug.Assert(async, "Returned task in sync mode");
if (completion == null)
{
completion = new TaskCompletionSource<object>();
task = completion.Task; // we only care about return from topmost call, so do not access Task property in other cases
}

if (array == null)
{
byte[] tempArray = new byte[len];
Span<byte> copyTempTo = tempArray.AsSpan();

Debug.Assert(copyTempTo.Length == b.Length, $"copyTempTo.Length:{copyTempTo.Length} and copyTempFrom.Length:{b.Length:D} should be the same");

b.CopyTo(copyTempTo);
array = tempArray;
offset = 0;
}

WriteBytesSetupContinuation(array, len, completion, offset, packetTask);
return task;
}
}
else
{
//((stateObj._outBytesUsed + len) <= stateObj._outBuff.Length )
// Else the remainder of the string will fit into the buffer, so copy it into the
// buffer and then break out of the loop.

Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, len);
ReadOnlySpan<byte> copyFrom = b.Slice(0, len);

Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length:{copyFrom.Length:D} should be the same");

copyFrom.CopyTo(copyTo);

// handle out buffer bytes used counter
_outBytesUsed += len;
break;
}
} while (len > 0);

if (completion != null)
{
completion.SetResult(null);
}
return null;
}
catch (Exception e)
{
if (completion != null)
{
completion.SetException(e);
return null;
}
else
{
throw;
}
}
}

// This is in its own method to avoid always allocating the lambda in WriteBytes
private void WriteBytesSetupContinuation(byte[] array, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
{
AsyncHelper.ContinueTask(packetTask, completion,
onSuccess: () => WriteBytes(ReadOnlySpan<byte>.Empty, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion, array)
);
}

// Dumps contents of buffer to SNI for network write.
internal Task WritePacket(byte flushMode, bool canAccumulate = false)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,10 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\Sql\SqlDataSourceEnumerator.cs">
<Link>Microsoft\Data\Sql\SqlDataSourceEnumerator.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
<Link>Microsoft\Data\SqlClient\AAsyncCallContext.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs">
Expand All @@ -361,6 +364,9 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AlwaysEncryptedKeyConverter.cs">
<Link>Microsoft\Data\SqlClient\AlwaysEncryptedKeyConverter.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ArrayBufferWriter.cs">
<Link>Microsoft\Data\SqlClient\ArrayBufferWriter.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AzureAttestationBasedEnclaveProvider.cs">
<Link>Microsoft\Data\SqlClient\AzureAttestationBasedEnclaveProvider.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8972,8 +8972,7 @@ private void WriteLoginData(SqlLogin rec,
int length,
int featureExOffset,
string clientInterfaceName,
byte[] outSSPIBuff,
uint outSSPILength)
ReadOnlySpan<byte> outSSPI)
{
try
{
Expand Down Expand Up @@ -9145,8 +9144,8 @@ private void WriteLoginData(SqlLogin rec,
WriteShort(offset, _physicalStateObj); // ibSSPI offset
if (rec.useSSPI)
{
WriteShort((int)outSSPILength, _physicalStateObj);
offset += (int)outSSPILength;
WriteShort(outSSPI.Length, _physicalStateObj);
offset += outSSPI.Length;
}
else
{
Expand Down Expand Up @@ -9205,7 +9204,7 @@ private void WriteLoginData(SqlLogin rec,

// send over SSPI data if we are using SSPI
if (rec.useSSPI)
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
_physicalStateObj.WriteByteSpan(outSSPI);

WriteString(rec.attachDBFilename, _physicalStateObj);
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -868,102 +868,6 @@ internal void WriteByte(byte b)
_outBuff[_outBytesUsed++] = b;
}

internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
{
try
{
bool async = _parser._asyncWrite; // NOTE: We are capturing this now for the assert after the Task is returned, since WritePacket will turn off async if there is an exception
Debug.Assert(async || _asyncWriteCount == 0);
// Do we have to send out in packet size chunks, or can we rely on netlib layer to break it up?
// would prefer to do something like:
//
// if (len > what we have room for || len > out buf)
// flush buffer
// UnsafeNativeMethods.Write(b)
//

int offset = offsetBuffer;

Debug.Assert(b.Length >= len, "Invalid length sent to WriteByteArray()!");

// loop through and write the entire array
do
{
if ((_outBytesUsed + len) > _outBuff.Length)
{
// If the remainder of the data won't fit into the buffer, then we have to put
// whatever we can into the buffer, and flush that so we can then put more into
// the buffer on the next loop of the while.

int remainder = _outBuff.Length - _outBytesUsed;

// write the remainder
Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, remainder);

// handle counters
offset += remainder;
_outBytesUsed += remainder;
len -= remainder;

Task packetTask = WritePacket(TdsEnums.SOFTFLUSH, canAccumulate);

if (packetTask != null)
{
Task task = null;
Debug.Assert(async, "Returned task in sync mode");
if (completion == null)
{
completion = new TaskCompletionSource<object>();
task = completion.Task; // we only care about return from topmost call, so do not access Task property in other cases
}
WriteByteArraySetupContinuation(b, len, completion, offset, packetTask);
return task;
}

}
else
{
//((stateObj._outBytesUsed + len) <= stateObj._outBuff.Length )
// Else the remainder of the string will fit into the buffer, so copy it into the
// buffer and then break out of the loop.

Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, len);

// handle out buffer bytes used counter
_outBytesUsed += len;
break;
}
} while (len > 0);

if (completion != null)
{
completion.SetResult(null);
}
return null;
}
catch (Exception e)
{
if (completion != null)
{
completion.SetException(e);
return null;
}
else
{
throw;
}
}
}

// This is in its own method to avoid always allocating the lambda in WriteByteArray
private void WriteByteArraySetupContinuation(byte[] b, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
{
AsyncHelper.ContinueTask(packetTask, completion,
() => WriteByteArray(b, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion),
connectionToDoom: _parser.Connection
);
}

// Dumps contents of buffer to SNI for network write.
internal Task WritePacket(byte flushMode, bool canAccumulate = false)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ unsafe uint SniSecGenClientContextWrapper(
SNIHandle pConn,
byte* pIn,
uint cbIn,
byte[] pOut,
byte* pOut,
ref uint pcbOut,
out bool pfDone,
byte* szServerInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public unsafe uint SniSecGenClientContextWrapper(
SNIHandle pConn,
byte* pIn,
uint cbIn,
byte[] pOut,
byte* pOut,
ref uint pcbOut,
out bool pfDone,
byte* szServerInfo,
Expand Down Expand Up @@ -265,7 +265,7 @@ private static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In, Out] byte* pOut,
[In] ref uint pcbOut,
[MarshalAs(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public unsafe uint SniSecGenClientContextWrapper(
SNIHandle pConn,
byte* pIn,
uint cbIn,
byte[] pOut,
byte* pOut,
ref uint pcbOut,
out bool pfDone,
byte* szServerInfo,
Expand Down Expand Up @@ -265,7 +265,7 @@ private static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In, Out] byte* pOut,
[In] ref uint pcbOut,
[MarshalAs(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public unsafe uint SniSecGenClientContextWrapper(
SNIHandle pConn,
byte* pIn,
uint cbIn,
byte[] pOut,
byte* pOut,
ref uint pcbOut,
out bool pfDone,
byte* szServerInfo,
Expand Down
Loading
Loading