Skip to content

Commit

Permalink
feat(besql): add dbcontext initialization hook #9720 (#9721)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysmoradi authored Jan 23, 2025
1 parent 8a08d49 commit f200dab
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 61 deletions.
42 changes: 13 additions & 29 deletions src/Besql/Bit.Besql/BesqlPooledDbContextFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@

namespace Bit.Besql;

public class BesqlPooledDbContextFactory<TDbContext> : PooledDbContextFactory<TDbContext>
public class BesqlPooledDbContextFactory<TDbContext> : PooledDbContextFactoryBase<TDbContext>
where TDbContext : DbContext
{
private readonly string _fileName;
private readonly IBesqlStorage _storage;
private readonly string _connectionString;
private readonly TaskCompletionSource _initTcs = new();

public BesqlPooledDbContextFactory(
IBesqlStorage storage,
DbContextOptions<TDbContext> options)
: base(options)
DbContextOptions<TDbContext> options,
Func<IServiceProvider, TDbContext, Task> dbContextInitializer)
: base(options, dbContextInitializer)
{
_connectionString = options.Extensions
.OfType<RelationalOptionsExtension>()
Expand All @@ -28,33 +28,17 @@ public BesqlPooledDbContextFactory(
}["Data Source"].ToString()!.Trim('/');

_storage = storage;
_ = InitAsync();
}

public override async Task<TDbContext> CreateDbContextAsync(CancellationToken cancellationToken = default)
protected override async Task InitializeDbContext()
{
await _initTcs.Task.ConfigureAwait(false);

var ctx = await base.CreateDbContextAsync(cancellationToken).ConfigureAwait(false);

return ctx;
}

private async Task InitAsync()
{
try
{
await _storage.Init(_fileName).ConfigureAwait(false);
await using var connection = new SqliteConnection(_connectionString);
await connection.OpenAsync().ConfigureAwait(false);
await using var command = connection.CreateCommand();
command.CommandText = "PRAGMA synchronous = FULL;";
await command.ExecuteNonQueryAsync().ConfigureAwait(false);
_initTcs.SetResult();
}
catch (Exception exp)
{
_initTcs.SetException(exp);
}
await _storage.Init(_fileName).ConfigureAwait(false);
await using var connection = new SqliteConnection(_connectionString);
await connection.OpenAsync().ConfigureAwait(false);
await using var command = connection.CreateCommand();
command.CommandText = "PRAGMA synchronous = FULL;";
await command.ExecuteNonQueryAsync().ConfigureAwait(false);

await base.InitializeDbContext().ConfigureAwait(false);
}
}
26 changes: 19 additions & 7 deletions src/Besql/Bit.Besql/IServiceCollectionBesqlExtentions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,24 @@ namespace Microsoft.Extensions.DependencyInjection;

public static class IServiceCollectionBesqlExtentions
{
public static IServiceCollection AddBesqlDbContextFactory<TContext>(this IServiceCollection services, Action<IServiceProvider, DbContextOptionsBuilder> optionsAction)
where TContext : DbContext
public static IServiceCollection AddBesqlDbContextFactory<TDbContext>(this IServiceCollection services,
Action<IServiceProvider, DbContextOptionsBuilder>? optionsAction = null,
Func<IServiceProvider, TDbContext, Task>? dbContextInitializer = null)
where TDbContext : DbContext
{
optionsAction ??= (_, _) => { };
dbContextInitializer ??= async (_, _) => { };

services.AddSingleton(dbContextInitializer);

if (OperatingSystem.IsBrowser())
{
services.AddSingleton<BesqlDbContextInterceptor>();
services.TryAddSingleton<IBesqlStorage, BrowserCacheBesqlStorage>();
// To make optimized db context work in blazor wasm: https://github.com/dotnet/efcore/issues/31751
// https://learn.microsoft.com/en-us/ef/core/performance/advanced-performance-topics?tabs=with-di%2Cexpression-api-with-constant#compiled-models
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue31751", true);
services.AddDbContextFactory<TContext, BesqlPooledDbContextFactory<TContext>>((serviceProvider, options) =>
services.AddDbContextFactory<TDbContext, BesqlPooledDbContextFactory<TDbContext>>((serviceProvider, options) =>
{
options.AddInterceptors(serviceProvider.GetRequiredService<BesqlDbContextInterceptor>());
#if NET9_0_OR_GREATER
Expand All @@ -31,15 +38,20 @@ public static IServiceCollection AddBesqlDbContextFactory<TContext>(this IServic
else
{
services.TryAddSingleton<IBesqlStorage, NoopBesqlStorage>();
services.AddPooledDbContextFactory<TContext>(optionsAction);
services.AddDbContextFactory<TDbContext, PooledDbContextFactoryBase<TDbContext>>(optionsAction);
}

return services;
}

public static IServiceCollection AddBesqlDbContextFactory<TContext>(this IServiceCollection services, Action<DbContextOptionsBuilder>? optionsAction)
where TContext : DbContext
public static IServiceCollection AddBesqlDbContextFactory<TDbContext>(this IServiceCollection services,
Action<DbContextOptionsBuilder>? optionsAction = null,
Func<TDbContext, Task>? dbContextInitializer = null)
where TDbContext : DbContext
{
return services.AddBesqlDbContextFactory<TContext>((serviceProvider, options) => optionsAction?.Invoke(options));
optionsAction ??= _ => { };
dbContextInitializer ??= async _ => { };

return services.AddBesqlDbContextFactory<TDbContext>((serviceProvider, options) => optionsAction.Invoke(options), (serviceProvider, dbContext) => dbContextInitializer.Invoke(dbContext));
}
}
50 changes: 50 additions & 0 deletions src/Besql/Bit.Besql/PooledDbContextFactoryBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;

namespace Bit.Besql;

public class PooledDbContextFactoryBase<TDbContext>(DbContextOptions<TDbContext> options,
Func<IServiceProvider, TDbContext, Task> dbContextInitializer) : PooledDbContextFactory<TDbContext>(options)
where TDbContext : DbContext
{
private TaskCompletionSource? dbContextInitializerTcs;

public override async Task<TDbContext> CreateDbContextAsync(CancellationToken cancellationToken = default)
{
if (dbContextInitializerTcs is null)
{
await StartRunningDbContextInitializer();
}

await dbContextInitializerTcs!.Task.ConfigureAwait(false);

return await base.CreateDbContextAsync(cancellationToken).ConfigureAwait(false);
}

private async Task StartRunningDbContextInitializer()
{
if (dbContextInitializerTcs is not null)
return;

dbContextInitializerTcs = new();

try
{
await InitializeDbContext().ConfigureAwait(false);
dbContextInitializerTcs.SetResult();
}
catch (Exception ex)
{
dbContextInitializerTcs.SetException(ex);
}
}

protected virtual async Task InitializeDbContext()
{
if (dbContextInitializer is not null)
{
await using var dbContext = await base.CreateDbContextAsync().ConfigureAwait(false);
await dbContextInitializer(dbContext.GetService<IServiceProvider>(), dbContext).ConfigureAwait(false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public static IServiceCollection AddAppServices(this IServiceCollection services
optionsBuilder
.UseModel(OfflineDbContextModel.Instance) // use generated compiled model in order to make db context optimized
.UseSqlite($"Data Source=Offline-Client.db");
});
}, dbContextInitializer: async (sp, dbContext) => await dbContext.Database.MigrateAsync());

return services;
}
Expand Down
12 changes: 0 additions & 12 deletions src/Besql/Demo/Bit.Besql.Demo.Client/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,4 @@

var app = builder.Build();

// To Create database and apply migrations
await using (var scope = app.Services.CreateAsyncScope())
{
// Create db context
await using var dbContext = await scope.ServiceProvider
.GetRequiredService<IDbContextFactory<OfflineDbContext>>()
.CreateDbContextAsync();

// migrate database
await dbContext.Database.MigrateAsync();
}

await app.RunAsync();
12 changes: 0 additions & 12 deletions src/Besql/Demo/Bit.Besql.Demo/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,4 @@
.AddInteractiveWebAssemblyRenderMode()
.AddAdditionalAssemblies(typeof(Counter).Assembly);

// To Create database and apply migrations
await using (var scope = app.Services.CreateAsyncScope())
{
// Create db context
await using var dbContext = await scope.ServiceProvider
.GetRequiredService<IDbContextFactory<OfflineDbContext>>()
.CreateDbContextAsync();

// migrate database
await dbContext.Database.MigrateAsync();
}

app.Run();

0 comments on commit f200dab

Please sign in to comment.