Skip to content

Commit

Permalink
Abstract the SSPI context generation
Browse files Browse the repository at this point in the history
This change introduces SSPIContextProvider that can generate payloads for SSPI. Specifically, this change plumbs the current SSPI context generation into this object, while later changes will continue to update the shape to be a more general purpose, public API.
  • Loading branch information
twsouthwick committed Jan 19, 2024
1 parent b12b15d commit e9992fa
Show file tree
Hide file tree
Showing 14 changed files with 518 additions and 617 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,18 @@
<Compile Include="..\..\src\System\Diagnostics\CodeAnalysis.cs">
<Link>Common\System\Diagnostics\CodeAnalysis.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\SSPI\ManagedSSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\ManagedSSPIContextProvider.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\SSPI\NegotiateSSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\NegotiateSSPIContextProvider.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\SSPI\SSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\SSPIContextProvider.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\TdsParser.cs">
<Link>Microsoft\Data\SqlClient\TdsParser.cs</Link>
</Compile>
</ItemGroup>
<ItemGroup Condition="'$(TargetGroup)' == 'netstandard' OR '$(TargetGroup)' == 'netcoreapp' OR '$(IsUAPAssembly)' == 'true'">
<Compile Include="Microsoft.Data.SqlClient.TypeForwards.cs" />
Expand Down Expand Up @@ -580,7 +592,7 @@
<Compile Include="Microsoft\Data\SqlClient\SqlConnectionFactory.AssemblyLoadContext.cs" />
</ItemGroup>
<ItemGroup>
<Compile Include="..\..\src\Resources\StringsHelper.cs">
<Compile Include="..\..\src\Resources\StringsHelper.cs">
<Link>Resources\StringsHelper.cs</Link>
</Compile>
<Compile Include="..\..\src\Resources\Strings.Designer.cs">
Expand Down Expand Up @@ -767,7 +779,8 @@
<ItemGroup Condition="'$(TargetsWindows)' == 'true' and '$(IsUAPAssembly)' != 'true'">
<Compile Include="$(CommonPath)\Interop\Windows\kernel32\Interop.LoadLibraryEx.cs">
<Link>Common\Interop\Windows\kernel32\Interop.LoadLibraryEx.cs</Link>
</Compile>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\SSPI\NativeSSPIContextProvider.cs" Link="Microsoft\Data\SqlClient\SSPI\NativeSSPIContextProvider.cs" />
<Compile Include="Interop\SNINativeMethodWrapper.Windows.cs" />
<Compile Include="Microsoft\Data\ProviderBase\DbConnectionPoolIdentity.Windows.cs" />
<Compile Include="Microsoft\Data\SqlClient\LocalDBAPI.Windows.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ namespace Microsoft.Data.SqlClient
{
internal sealed partial class TdsParser
{
private static volatile bool s_fSSPILoaded = false; // bool to indicate whether library has been loaded

internal void PostReadAsyncForMars()
{
if (TdsParserStateObjectFactory.UseManagedSNI)
Expand Down Expand Up @@ -43,37 +41,7 @@ internal void PostReadAsyncForMars()
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
ThrowExceptionAndWarning(_physicalStateObj);
}
}

private void LoadSSPILibrary()
{
if (TdsParserStateObjectFactory.UseManagedSNI)
return;
// Outer check so we don't acquire lock once it's loaded.
if (!s_fSSPILoaded)
{
lock (s_tdsParserLock)
{
// re-check inside lock
if (!s_fSSPILoaded)
{
// use local for ref param to defer setting s_maxSSPILength until we know the call succeeded.
uint maxLength = 0;

if (0 != SNINativeMethodWrapper.SNISecInitPackage(ref maxLength))
SSPIError(SQLMessage.SSPIInitializeError(), TdsEnums.INIT_SSPI_PACKAGE);

s_maxSSPILength = maxLength;
s_fSSPILoaded = true;
}
}
}

if (s_maxSSPILength > int.MaxValue)
{
throw SQL.InvalidSSPIPacketSize(); // SqlBu 332503
}
}
}

private void WaitForSSLHandShakeToComplete(ref uint error, ref int protocolVersion)
{
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ internal TdsParserStateObject(TdsParser parser, TdsParserStateObject physicalCon
AddError(parser.ProcessSNIError(this));
ThrowExceptionAndWarning();
}

// we post a callback that represents the call to dispose; once the
// object is disposed, the next callback will cause the GC Handle to
// be released.
Expand All @@ -64,6 +64,7 @@ internal TdsParserStateObject(TdsParser parser, TdsParserStateObject physicalCon
////////////////
internal abstract uint DisableSsl();

internal abstract SSPIContextProvider CreateSSPIContextProvider();

internal abstract uint EnableMars(ref uint info);

Expand All @@ -72,6 +73,8 @@ internal abstract uint Status
get;
}

internal abstract Guid? SessionId { get; }

internal abstract SessionHandle SessionHandle
{
get;
Expand Down Expand Up @@ -253,8 +256,6 @@ internal abstract void CreatePhysicalSNIHandle(

protected abstract void RemovePacketFromPendingList(PacketHandle pointer);

internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer);

internal int DecrementPendingCallbacks(bool release)
{
int remaining = Interlocked.Decrement(ref _pendingCallbacks);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ internal sealed class TdsParserStateObjectManaged : TdsParserStateObject
{
private SNIMarsConnection? _marsConnection;
private SNIHandle? _sessionHandle;
#if NET7_0_OR_GREATER
private NegotiateAuthentication? _negotiateAuth = null;
#else
private SspiClientContextStatus? _sspiClientContextStatus;
#endif

public TdsParserStateObjectManaged(TdsParser parser) : base(parser) { }

internal TdsParserStateObjectManaged(TdsParser parser, TdsParserStateObject physicalConnection, bool async) :
Expand Down Expand Up @@ -232,6 +228,8 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint

protected override PacketHandle EmptyReadPacket => PacketHandle.FromManagedPacket(null);

internal override Guid? SessionId => _sessionHandle?.ConnectionId;

internal override bool IsPacketEmpty(PacketHandle packet) => packet.ManagedPacket == null;

internal override void ReleasePacket(PacketHandle syncReadPacket)
Expand Down Expand Up @@ -389,30 +387,6 @@ internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
return TdsEnums.SNI_SUCCESS;
}

internal override uint GenerateSspiClientContext(byte[] receivedBuff,
uint receivedLength,
ref byte[] sendBuff,
ref uint sendLength,
byte[][] _sniSpnBuffer)
{
#if NET7_0_OR_GREATER
_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = Encoding.Unicode.GetString(_sniSpnBuffer[0]) });
sendBuff = _negotiateAuth.GetOutgoingBlob(receivedBuff, out NegotiateAuthenticationStatusCode statusCode)!;
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}, StatusCode={1}", _sessionHandle?.ConnectionId, statusCode);
if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded)
{
throw new InvalidOperationException(SQLMessage.SSPIGenerateError() + Environment.NewLine + statusCode);
}
#else
_sspiClientContextStatus ??= new SspiClientContextStatus();

SNIProxy.GenSspiClientContext(_sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer);
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}", _sessionHandle?.ConnectionId);
#endif
sendLength = (uint)(sendBuff != null ? sendBuff.Length : 0);
return 0;
}

internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
{
protocolVersion = GetSessionSNIHandleHandleOrThrow().ProtocolVersion;
Expand All @@ -432,5 +406,12 @@ private SNIHandle GetSessionSNIHandleHandleOrThrow()
[DoesNotReturn]
[MethodImpl(MethodImplOptions.NoInlining)] // this forces the exception throwing code not to be inlined for performance
private void ThrowClosedConnection() => throw ADP.ClosedConnectionError();

internal override SSPIContextProvider CreateSSPIContextProvider()
#if NET7_0_OR_GREATER
=> new NegotiateSSPIContextProvider();
#else
=> new ManagedSSPIContextProvider();
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ internal override void CreatePhysicalSNIHandle(
byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN);
Trace.Assert(srvSPN.Length <= SNINativeMethodWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
spnBuffer[0] = srvSPN;
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.",nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
}
else
{
Expand Down Expand Up @@ -272,6 +272,8 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint

protected override PacketHandle EmptyReadPacket => PacketHandle.FromNativePointer(default);

internal override Guid? SessionId => default;

internal override bool IsPacketEmpty(PacketHandle readPacket)
{
Debug.Assert(readPacket.Type == PacketHandle.NativePointerType || readPacket.Type == 0, "unexpected packet type when requiring NativePointer");
Expand Down Expand Up @@ -398,9 +400,6 @@ internal override uint EnableSsl(ref uint info, bool tlsFirst, string serverCert
internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
=> SNINativeMethodWrapper.SNISetInfo(Handle, SNINativeMethodWrapper.QTypes.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);

internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
=> SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer[0]);

internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
{
uint returnValue = SNINativeMethodWrapper.SNIWaitForSSLHandshakeToComplete(Handle, GetTimeoutRemaining(), out uint nativeProtocolVersion);
Expand Down Expand Up @@ -451,6 +450,8 @@ internal override void DisposePacketCache()
}
}

internal override SSPIContextProvider CreateSSPIContextProvider() => new NativeSSPIContextProvider();

internal sealed class WritePacketCache : IDisposable
{
private bool _disposed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,21 @@
<Compile Include="..\..\src\Microsoft\Data\ProviderBase\TimeoutTimer.cs">
<Link>Microsoft\Data\ProviderBase\TimeoutTimer.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\SSPI\ManagedSSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\ManagedSSPIContextProvider.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\SSPI\NativeSSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\NativeSSPIContextProvider.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\SSPI\NegotiateSSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\NegotiateSSPIContextProvider.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\SSPI\SSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\SSPIContextProvider.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\TdsParser.cs">
<Link>Microsoft\Data\SqlClient\TdsParser.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\Sql\SqlDataSourceEnumerator.cs">
<Link>Microsoft\Data\Sql\SqlDataSourceEnumerator.cs</Link>
</Compile>
Expand Down
Loading

0 comments on commit e9992fa

Please sign in to comment.