Skip to content

Commit

Permalink
Use DatabaseConnectionManager to manage connections
Browse files Browse the repository at this point in the history
  • Loading branch information
OoLunar committed Feb 10, 2024
1 parent 00f139e commit 7de745d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 21 deletions.
54 changes: 54 additions & 0 deletions src/Database/DatabaseConnectionManager.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.Extensions.Configuration;
using Npgsql;

namespace OoLunar.Tomoe.Database
{
public sealed class DatabaseConnectionManager : IAsyncDisposable, IDisposable
{
private readonly List<NpgsqlConnection> _connections = [];
private readonly string _connectionString;

public DatabaseConnectionManager(IConfiguration configuration)
{
ArgumentNullException.ThrowIfNull(configuration, nameof(configuration));
_connectionString = new NpgsqlConnectionStringBuilder()
{
Host = configuration.GetValue("database:host", "localhost"),
Port = configuration.GetValue("database:port", 5432),
Username = configuration.GetValue("database:username", "postgres"),
Password = configuration.GetValue("database:password", "postgres"),
Database = configuration.GetValue("database:database", "tomoe"),
CommandTimeout = configuration.GetValue("database:timeout", 5),
#if DEBUG
IncludeErrorDetail = true
#endif
}.ConnectionString;
}

public NpgsqlConnection GetConnection()
{
NpgsqlConnection connection = new(_connectionString);
_connections.Add(connection);
return connection;
}

public void Dispose()
{
foreach (NpgsqlConnection connection in _connections)
{
connection.Dispose();
}
}

public async ValueTask DisposeAsync()
{
foreach (NpgsqlConnection connection in _connections)
{
await connection.DisposeAsync();
}
}
}
}
29 changes: 9 additions & 20 deletions src/Database/DatabaseHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Npgsql;
Expand All @@ -14,28 +13,18 @@ namespace OoLunar.Tomoe.Database
public sealed class DatabaseHandler
{
private delegate ValueTask PrepareAsyncDelegate(NpgsqlConnection connection);
private readonly Dictionary<NpgsqlConnection, (SemaphoreSlim, PrepareAsyncDelegate)> _tableTypes;
private readonly Dictionary<NpgsqlConnection, (SemaphoreSlim, PrepareAsyncDelegate)> _tableTypes = [];
private readonly DatabaseConnectionManager _connectionManager;
private readonly ILogger<DatabaseHandler> _logger;
private readonly IConfiguration _configuration;

public DatabaseHandler(ILogger<DatabaseHandler> logger, IConfiguration configuration)
public DatabaseHandler(DatabaseConnectionManager connectionManager, ILogger<DatabaseHandler>? logger = null)
{
_connectionManager = connectionManager ?? throw new ArgumentNullException(nameof(connectionManager));
_logger = logger ?? NullLogger<DatabaseHandler>.Instance;
_configuration = configuration ?? throw new ArgumentNullException(nameof(configuration));
_tableTypes = [];
NpgsqlConnectionStringBuilder connectionStringBuilder = new()
{
Host = _configuration.GetValue("database:host", "localhost"),
Port = _configuration.GetValue("database:port", 5432),
Username = _configuration.GetValue("database:username", "postgres"),
Password = _configuration.GetValue("database:password", "postgres"),
Database = _configuration.GetValue("database:database", "tomoe"),
CommandTimeout = _configuration.GetValue("database:timeout", 5),
#if DEBUG
IncludeErrorDetail = true
#endif
};
}

public async ValueTask InitializeAsync(CancellationToken cancellationToken = default)
{
foreach (Type type in typeof(Program).Assembly.GetTypes())
{
if (type.GetCustomAttribute<DatabaseModelAttribute>() is null)
Expand All @@ -58,10 +47,10 @@ public DatabaseHandler(ILogger<DatabaseHandler> logger, IConfiguration configura
continue;
}

NpgsqlConnection connection = new(connectionStringBuilder.ToString());
NpgsqlConnection connection = _connectionManager.GetConnection();
_tableTypes.Add(connection, (semaphore, prepareAsyncDelegate));
connection.StateChange += StateChangedEventHandlerAsync;
connection.Open();
await connection.OpenAsync(cancellationToken);
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public static async Task Main(string[] args)
logger.AddSerilog(loggerConfiguration.CreateLogger());
});

serviceCollection.AddSingleton<DatabaseConnectionManager>();
serviceCollection.AddSingleton<DatabaseHandler>();

Assembly currentAssembly = typeof(Program).Assembly;
Expand Down Expand Up @@ -146,10 +147,12 @@ await extension.AddProcessorsAsync(
return eventManager;
});

// Start the database before connecting to Discord
IServiceProvider serviceProvider = serviceCollection.BuildServiceProvider();
await serviceProvider.GetRequiredService<DatabaseHandler>().InitializeAsync(); // Init the db connection

DiscordShardedClient shardedClient = await serviceProvider.GetRequiredService<Task<DiscordShardedClient>>();
DiscordEventManager eventManager = serviceProvider.GetRequiredService<DiscordEventManager>();
serviceProvider.GetRequiredService<DatabaseHandler>(); // Init the db connection
eventManager.RegisterEventHandlers(shardedClient);
foreach (CommandsExtension extension in shardedClient.GetCommandsExtensions().Values)
{
Expand Down

0 comments on commit 7de745d

Please sign in to comment.