Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test coverage for read-only routing #2897

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
<Compile Include="SqlClientMetaDataCollectionNamesTest.cs" />
<Compile Include="SqlDataAdapterTest.cs" />
<Compile Include="SqlConnectionBasicTests.cs" />
<Compile Include="SqlConnectionReadOnlyRoutingTests.cs" />
<Compile Include="SqlCommandTest.cs" />
<Compile Include="SqlConnectionTest.cs" />
<Compile Include="AADAuthenticationTests.cs" />
Expand All @@ -67,6 +68,7 @@
<Compile Include="SqlConnectionStringBuilderTest.cs" />
<Compile Include="SerializeSqlTypesTest.cs" />
<Compile Include="TestTdsServer.cs" />
<Compile Include="TestRoutingTdsServer.cs" />
<Compile Include="SqlHelperTest.cs" />
<Compile Include="..\..\src\Microsoft\Data\Common\MultipartIdentifier.cs" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Net;
using System.Threading.Tasks;
using Microsoft.SqlServer.TDS.Servers;
using Xunit;

namespace Microsoft.Data.SqlClient.Tests
{
public class SqlConnectionReadOnlyRoutingTests
{
[Fact]
public void NonRoutedConnection()
{
using TestTdsServer server = TestTdsServer.StartTestServer();
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(server.ConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly };
using SqlConnection connection = new SqlConnection(builder.ConnectionString);
connection.Open();
}

[Fact]
public async Task NonRoutedAsyncConnection()
{
using TestTdsServer server = TestTdsServer.StartTestServer();
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(server.ConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly };
using SqlConnection connection = new SqlConnection(builder.ConnectionString);
await connection.OpenAsync();
}

[Fact]
public void RoutedConnection()
=> RecursivelyRoutedConnection(1);

[Fact]
public async Task RoutedAsyncConnection()
=> await RecursivelyRoutedAsyncConnection(1);

[Theory]
[InlineData(2)]
[InlineData(9)]
[InlineData(11)] // The driver rejects more than 10 redirects (11 layers of redirecting servers)
public void RecursivelyRoutedConnection(int layers)
{
TestTdsServer innerServer = TestTdsServer.StartTestServer();
IPEndPoint lastEndpoint = innerServer.Endpoint;
Stack<GenericTDSServer> routingLayers = new(layers + 1);
string lastConnectionString = innerServer.ConnectionString;

try
{
routingLayers.Push(innerServer);
for (int i = 0; i < layers; i++)
{
TestRoutingTdsServer router = TestRoutingTdsServer.StartTestServer(lastEndpoint);

routingLayers.Push(router);
lastEndpoint = router.Endpoint;
lastConnectionString = router.ConnectionString;
}

SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(lastConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly };
using SqlConnection connection = new SqlConnection(builder.ConnectionString);
connection.Open();
}
finally
{
while (routingLayers.Count > 0)
{
GenericTDSServer layer = routingLayers.Pop();

if (layer is IDisposable disp)
{
disp.Dispose();
}
}
}
}

[Theory]
[InlineData(2)]
[InlineData(9)]
[InlineData(11)] // The driver rejects more than 10 redirects (11 layers of redirecting servers)
public async Task RecursivelyRoutedAsyncConnection(int layers)
{
TestTdsServer innerServer = TestTdsServer.StartTestServer();
IPEndPoint lastEndpoint = innerServer.Endpoint;
Stack<GenericTDSServer> routingLayers = new(layers + 1);
string lastConnectionString = innerServer.ConnectionString;

try
{
routingLayers.Push(innerServer);
for (int i = 0; i < layers; i++)
{
TestRoutingTdsServer router = TestRoutingTdsServer.StartTestServer(lastEndpoint);

routingLayers.Push(router);
lastEndpoint = router.Endpoint;
lastConnectionString = router.ConnectionString;
}

SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(lastConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly };
using SqlConnection connection = new SqlConnection(builder.ConnectionString);
await connection.OpenAsync();
}
finally
{
while (routingLayers.Count > 0)
{
GenericTDSServer layer = routingLayers.Pop();

if (layer is IDisposable disp)
{
disp.Dispose();
}
}
}
}

[Fact]
public void ConnectionRoutingLimit()
{
SqlException sqlEx = Assert.Throws<SqlException>(() => RecursivelyRoutedConnection(12)); // This will fail on the 11th redirect

Assert.Contains("Too many redirections have occurred.", sqlEx.Message, StringComparison.InvariantCultureIgnoreCase);
}

[Fact]
public async Task AsyncConnectionRoutingLimit()
{
SqlException sqlEx = await Assert.ThrowsAsync<SqlException>(() => RecursivelyRoutedAsyncConnection(12)); // This will fail on the 11th redirect

Assert.Contains("Too many redirections have occurred.", sqlEx.Message, StringComparison.InvariantCultureIgnoreCase);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Net;
using System.Runtime.CompilerServices;
using Microsoft.SqlServer.TDS.EndPoint;
using Microsoft.SqlServer.TDS.Servers;

namespace Microsoft.Data.SqlClient.Tests
{
internal class TestRoutingTdsServer : RoutingTDSServer, IDisposable
{
private const int DefaultConnectionTimeout = 5;

private TDSServerEndPoint _endpoint = null;

private SqlConnectionStringBuilder _connectionStringBuilder;

public TestRoutingTdsServer(RoutingTDSServerArguments args) : base(args) { }

public static TestRoutingTdsServer StartTestServer(IPEndPoint destinationEndpoint, bool enableFedAuth = false, bool enableLog = false, int connectionTimeout = DefaultConnectionTimeout, bool excludeEncryption = false, [CallerMemberName] string methodName = "")
{
RoutingTDSServerArguments args = new RoutingTDSServerArguments()
{
Log = enableLog ? Console.Out : null,
RoutingTCPHost = destinationEndpoint.Address.ToString() == IPAddress.Any.ToString() ? IPAddress.Loopback.ToString() : destinationEndpoint.Address.ToString(),
RoutingTCPPort = (ushort)destinationEndpoint.Port,
};

if (enableFedAuth)
{
args.FedAuthRequiredPreLoginOption = SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired;
}
if (excludeEncryption)
{
args.Encryption = SqlServer.TDS.PreLogin.TDSPreLoginTokenEncryptionType.None;
}

TestRoutingTdsServer server = new TestRoutingTdsServer(args);
server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) };
server._endpoint.EndpointName = methodName;
// The server EventLog should be enabled as it logs the exceptions.
server._endpoint.EventLog = Console.Out;
server._endpoint.Start();

int port = server._endpoint.ServerEndPoint.Port;
server._connectionStringBuilder = excludeEncryption
// Allow encryption to be set when encryption is to be excluded from pre-login response.
? new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Mandatory }
: new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Optional };
server.ConnectionString = server._connectionStringBuilder.ConnectionString;
server.Endpoint = server._endpoint.ServerEndPoint;
return server;
}

public void Dispose() => _endpoint?.Stop();

public string ConnectionString { get; private set; }

public IPEndPoint Endpoint { get; private set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool
? new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Mandatory }
: new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Optional };
server.ConnectionString = server._connectionStringBuilder.ConnectionString;
server.Endpoint = server._endpoint.ServerEndPoint;
return server;
}

Expand All @@ -65,5 +66,7 @@ public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool ena
public void Dispose() => _endpoint?.Stop();

public string ConnectionString { get; private set; }

public IPEndPoint Endpoint { get; private set; }
}
}
Loading