Skip to content

Commit

Permalink
Fix | Fix possible server connection leak if an exception occurs in p…
Browse files Browse the repository at this point in the history
…ooling layer (#890)
  • Loading branch information
cheenamalhotra authored Apr 14, 2021
1 parent 6a95ad4 commit c34e8a4
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,14 @@ private DbConnectionInternal CreateObject(DbConnection owningObject, DbConnectio

CheckPoolBlockingPeriod(e);

// Close associated Parser if connection already established.
if (newObj?.IsConnectionAlive() == true)
{
newObj.Dispose();
}

newObj = null; // set to null, so we do not return bad new object

// Failed to create instance
_resError = e;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,14 @@ private DbConnectionInternal CreateObject(DbConnection owningObject, DbConnectio
throw;
}

// Close associated Parser if connection already established.
if (newObj?.IsConnectionAlive() == true)
{
newObj.Dispose();
}

newObj = null; // set to null, so we do not return bad new object

// Failed to create instance
_resError = e;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Data;
using System.Data.Common;
using System.Reflection;
using System.Security;
using Microsoft.SqlServer.TDS.Servers;
using Xunit;

namespace Microsoft.Data.SqlClient.Tests
Expand Down Expand Up @@ -40,6 +42,37 @@ public void IntegratedAuthConnectionTest()
}
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotArmProcess))]
[PlatformSpecific(TestPlatforms.Windows)]
public void TransientFaultTest()
{
using (TransientFaultTDSServer server = TransientFaultTDSServer.StartTestServer(true, true, 40613))
{
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder()
{
DataSource = "localhost," + server.Port,
IntegratedSecurity = true
};

using (SqlConnection connection = new SqlConnection(builder.ConnectionString))
{
try
{
connection.Open();
Assert.Equal(ConnectionState.Open, connection.State);
}
catch (Exception e)
{
if (null != connection)
{
Assert.Equal(ConnectionState.Closed, connection.State);
}
Assert.False(true, e.Message);
}
}
}
}

[Fact]
public void SqlConnectionDbProviderFactoryTest()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@ namespace Microsoft.Data.SqlClient.Tests
{
internal class TestTdsServer : GenericTDSServer, IDisposable
{
private const int DefaultConnectionTimeout = 5;

private TDSServerEndPoint _endpoint = null;

private SqlConnectionStringBuilder connectionStringBuilder;
private SqlConnectionStringBuilder _connectionStringBuilder;

public TestTdsServer(TDSServerArguments args) : base(args) { }

public TestTdsServer(QueryEngine engine, TDSServerArguments args) : base(args)
{
this.Engine = engine;
Engine = engine;
}

public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool enableFedAuth = false, bool enableLog = false, [CallerMemberName] string methodName = "")
public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool enableFedAuth = false, bool enableLog = false, int connectionTimeout = DefaultConnectionTimeout, [CallerMemberName] string methodName = "")
{
TDSServerArguments args = new TDSServerArguments()
{
Expand All @@ -32,7 +34,7 @@ public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool

if (enableFedAuth)
{
args.FedAuthRequiredPreLoginOption = Microsoft.SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired;
args.FedAuthRequiredPreLoginOption = SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired;
}

TestTdsServer server = engine == null ? new TestTdsServer(args) : new TestTdsServer(engine, args);
Expand All @@ -43,14 +45,14 @@ public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool
server._endpoint.Start();

int port = server._endpoint.ServerEndPoint.Port;
server.connectionStringBuilder = new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = 5, Encrypt = false };
server.ConnectionString = server.connectionStringBuilder.ConnectionString;
server._connectionStringBuilder = new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = false };
server.ConnectionString = server._connectionStringBuilder.ConnectionString;
return server;
}

public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool enableLog = false, [CallerMemberName] string methodName = "")
public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool enableLog = false, int connectionTimeout = DefaultConnectionTimeout, [CallerMemberName] string methodName = "")
{
return StartServerWithQueryEngine(null, false, false, methodName);
return StartServerWithQueryEngine(null, enableFedAuth, enableLog, connectionTimeout, methodName);
}

public void Dispose() => _endpoint?.Stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
<Compile Include="RoutingTDSServerArguments.cs" />
<Compile Include="ServerNameFilterType.cs" />
<Compile Include="TDSServerArguments.cs" />
<Compile Include="TransientFaultTDSServer.cs" />
<Compile Include="TransientFaultTDSServerArguments.cs" />
<None Include="TdsServerCertificate.pfx">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Net;
using System.Runtime.CompilerServices;
using System.Threading;
using Microsoft.SqlServer.TDS.Done;
using Microsoft.SqlServer.TDS.EndPoint;
using Microsoft.SqlServer.TDS.Error;
using Microsoft.SqlServer.TDS.Login7;

namespace Microsoft.SqlServer.TDS.Servers
{
/// <summary>
/// TDS Server that authenticates clients according to the requested parameters
/// </summary>
public class TransientFaultTDSServer : GenericTDSServer, IDisposable
{
private static int RequestCounter = 0;

public int Port { get; set; }

/// <summary>
/// Constructor
/// </summary>
public TransientFaultTDSServer() => new TransientFaultTDSServer(new TransientFaultTDSServerArguments());

/// <summary>
/// Constructor
/// </summary>
/// <param name="arguments"></param>
public TransientFaultTDSServer(TransientFaultTDSServerArguments arguments) :
base(arguments)
{ }

/// <summary>
/// Constructor
/// </summary>
/// <param name="engine"></param>
/// <param name="args"></param>
public TransientFaultTDSServer(QueryEngine engine, TransientFaultTDSServerArguments args) : base(args)
{
Engine = engine;
}

private TDSServerEndPoint _endpoint = null;

private static string GetErrorMessage(uint errorNumber)
{
switch (errorNumber)
{
case 40613:
return "Database on server is not currently available. Please retry the connection later. " +
"If the problem persists, contact customer support, and provide them the session tracing ID.";
}
return "Unknown server error occurred";
}

/// <summary>
/// Handler for login request
/// </summary>
public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request)
{
// Inflate login7 request from the message
TDSLogin7Token loginRequest = request[0] as TDSLogin7Token;

// Check if arguments are of the transient fault TDS server
if (Arguments is TransientFaultTDSServerArguments)
{
// Cast to transient fault TDS server arguments
TransientFaultTDSServerArguments ServerArguments = Arguments as TransientFaultTDSServerArguments;

// Check if we're still going to raise transient error
if (ServerArguments.IsEnabledTransientError && RequestCounter < 1) // Fail first time, then connect
{
uint errorNumber = ServerArguments.Number;
string errorMessage = ServerArguments.Message;

// Log request to which we're about to send a failure
TDSUtilities.Log(Arguments.Log, "Request", loginRequest);

// Prepare ERROR token with the denial details
TDSErrorToken errorToken = new TDSErrorToken(errorNumber, 1, 20, errorMessage);

// Log response
TDSUtilities.Log(Arguments.Log, "Response", errorToken);

// Serialize the error token into the response packet
TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken);

// Create DONE token
TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error);

// Log response
TDSUtilities.Log(Arguments.Log, "Response", doneToken);

// Serialize DONE token into the response packet
responseMessage.Add(doneToken);

RequestCounter++;

// Put a single message into the collection and return it
return new TDSMessageCollection(responseMessage);
}
}

// Return login response from the base class
return base.OnLogin7Request(session, request);
}

public static TransientFaultTDSServer StartTestServer(bool isEnabledTransientFault, bool enableLog, uint errorNumber, [CallerMemberName] string methodName = "")
=> StartServerWithQueryEngine(null, isEnabledTransientFault, enableLog, errorNumber, methodName);

public static TransientFaultTDSServer StartServerWithQueryEngine(QueryEngine engine, bool isEnabledTransientFault, bool enableLog, uint errorNumber, [CallerMemberName] string methodName = "")
{
TransientFaultTDSServerArguments args = new TransientFaultTDSServerArguments()
{
Log = enableLog ? Console.Out : null,
IsEnabledTransientError = isEnabledTransientFault,
Number = errorNumber,
Message = GetErrorMessage(errorNumber)
};

TransientFaultTDSServer server = engine == null ? new TransientFaultTDSServer(args) : new TransientFaultTDSServer(engine, args);
server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) };
server._endpoint.EndpointName = methodName;

// The server EventLog should be enabled as it logs the exceptions.
server._endpoint.EventLog = Console.Out;
server._endpoint.Start();

server.Port = server._endpoint.ServerEndPoint.Port;
return server;
}

public void Dispose() => Dispose(true);

private void Dispose(bool isDisposing)
{
if (isDisposing)
{
_endpoint?.Stop();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.SqlServer.TDS.Servers
{
public class TransientFaultTDSServerArguments : TDSServerArguments
{
/// <summary>
/// Transient error number to be raised by server.
/// </summary>
public uint Number { get; set; }

/// <summary>
/// Transient error message to be raised by server.
/// </summary>
public string Message { get; set; }

/// <summary>
/// Flag to consider when raising Transient error.
/// </summary>
public bool IsEnabledTransientError { get; set; }

/// <summary>
/// Constructor to initialize
/// </summary>
public TransientFaultTDSServerArguments()
{
Number = 0;
Message = string.Empty;
IsEnabledTransientError = false;
}
}
}

0 comments on commit c34e8a4

Please sign in to comment.