Skip to content

Commit

Permalink
Merge pull request #7804 from dolthub/fulghum/hooks
Browse files Browse the repository at this point in the history
Changing database init/drop hooks to be a slice of hooks
  • Loading branch information
fulghum authored May 2, 2024
2 parents 5ee894f + 6056d23 commit 76da4e5
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 61 deletions.
9 changes: 5 additions & 4 deletions go/cmd/dolt/commands/engine/sqlengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,14 @@ func NewSqlEngine(
pro = pro.WithRemoteDialer(mrEnv.RemoteDialProvider())

config.ClusterController.RegisterStoredProcedures(pro)
pro.InitDatabaseHook = cluster.NewInitDatabaseHook(config.ClusterController, bThreads, pro.InitDatabaseHook)
if config.ClusterController != nil {
pro.InitDatabaseHooks = append(pro.InitDatabaseHooks, cluster.NewInitDatabaseHook(config.ClusterController, bThreads))
pro.DropDatabaseHooks = append(pro.DropDatabaseHooks, config.ClusterController.DropDatabaseHook())
config.ClusterController.SetDropDatabase(pro.DropDatabase)
}

sqlEngine := &SqlEngine{}

pro.DropDatabaseHook = config.ClusterController.DropDatabaseHook()
config.ClusterController.SetDropDatabase(pro.DropDatabase)

// Create the engine
engine := gms.New(analyzer.NewBuilder(pro).WithParallelism(parallelism).Build(), &gms.Config{
IsReadOnly: config.IsReadOnly,
Expand Down
9 changes: 3 additions & 6 deletions go/libraries/doltcore/sqle/cluster/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,13 @@ func (c *Controller) SetDropDatabase(dropDatabase func(*sql.Context, string) err
c.dropDatabase = dropDatabase
}

// Our DropDatabaseHook gets called when the database provider drops a
// DropDatabaseHook gets called when the database provider drops a
// database. This is how we learn that we need to replicate a drop database.
func (c *Controller) DropDatabaseHook() func(string) {
if c == nil {
return nil
}
func (c *Controller) DropDatabaseHook() func(*sql.Context, string) {
return c.dropDatabaseHook
}

func (c *Controller) dropDatabaseHook(dbname string) {
func (c *Controller) dropDatabaseHook(_ *sql.Context, dbname string) {
c.mu.Lock()
defer c.mu.Unlock()

Expand Down
11 changes: 1 addition & 10 deletions go/libraries/doltcore/sqle/cluster/initdbhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,8 @@ import (
"github.com/dolthub/dolt/go/store/types"
)

func NewInitDatabaseHook(controller *Controller, bt *sql.BackgroundThreads, orig sqle.InitDatabaseHook) sqle.InitDatabaseHook {
if controller == nil {
return orig
}
func NewInitDatabaseHook(controller *Controller, bt *sql.BackgroundThreads) sqle.InitDatabaseHook {
return func(ctx *sql.Context, pro *sqle.DoltDatabaseProvider, name string, denv *env.DoltEnv, db dsess.SqlDatabase) error {
var err error
err = orig(ctx, pro, name, denv, db)
if err != nil {
return err
}

dialprovider := controller.gRPCDialProvider(denv)
var remoteDBs []func(context.Context) (*doltdb.DoltDB, error)
var remoteUrls []string
Expand Down
26 changes: 14 additions & 12 deletions go/libraries/doltcore/sqle/database_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ type DoltDatabaseProvider struct {
functions map[string]sql.Function
tableFunctions map[string]sql.TableFunction
externalProcedures sql.ExternalStoredProcedureRegistry
InitDatabaseHook InitDatabaseHook
DropDatabaseHook DropDatabaseHook
InitDatabaseHooks []InitDatabaseHook
DropDatabaseHooks []DropDatabaseHook
mu *sync.RWMutex

droppedDatabaseManager *droppedDatabaseManager
Expand Down Expand Up @@ -146,7 +146,7 @@ func NewDoltDatabaseProviderWithDatabases(defaultBranch string, fs filesys.Files
fs: fs,
defaultBranch: defaultBranch,
dbFactoryUrl: dbFactoryUrl,
InitDatabaseHook: ConfigureReplicationDatabaseHook,
InitDatabaseHooks: []InitDatabaseHook{ConfigureReplicationDatabaseHook},
isStandby: new(bool),
droppedDatabaseManager: newDroppedDatabaseManager(fs),
}, nil
Expand Down Expand Up @@ -459,7 +459,7 @@ func (p *DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name str
}

type InitDatabaseHook func(ctx *sql.Context, pro *DoltDatabaseProvider, name string, env *env.DoltEnv, db dsess.SqlDatabase) error
type DropDatabaseHook func(name string)
type DropDatabaseHook func(ctx *sql.Context, name string)

// ConfigureReplicationDatabaseHook sets up replication for a newly created database as necessary
// TODO: consider the replication heads / all heads setting
Expand Down Expand Up @@ -630,10 +630,10 @@ func (p *DoltDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error
return err
}

if p.DropDatabaseHook != nil {
for _, dropHook := range p.DropDatabaseHooks {
// For symmetry with InitDatabaseHook and the names we see in
// MultiEnv initialization, we use `name` here, not `dbKey`.
p.DropDatabaseHook(name)
dropHook(ctx, name)
}

// We not only have to delete tracking metadata for this database, but also for any derivative
Expand Down Expand Up @@ -707,12 +707,14 @@ func (p *DoltDatabaseProvider) registerNewDatabase(ctx *sql.Context, name string
return err
}

// If we have an initialization hook, invoke it. By default, this will
// be ConfigureReplicationDatabaseHook, which will setup replication
// for the new database if a remote url template is set.
err = p.InitDatabaseHook(ctx, p, name, newEnv, db)
if err != nil {
return err
// If we have any initialization hooks, invoke them, until any error is returned.
// By default, this will be ConfigureReplicationDatabaseHook, which will set up
// replication for the new database if a remote url template is set.
for _, initHook := range p.InitDatabaseHooks {
err = initHook(ctx, p, name, newEnv, db)
if err != nil {
return err
}
}

formattedName := formatDbMapKeyName(db.Name())
Expand Down
14 changes: 8 additions & 6 deletions go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3287,12 +3287,14 @@ func TestCreateDatabaseErrorCleansUp(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, e)

dh.provider.(*sqle.DoltDatabaseProvider).InitDatabaseHook = func(_ *sql.Context, _ *sqle.DoltDatabaseProvider, name string, _ *env.DoltEnv, _ dsess.SqlDatabase) error {
if name == "cannot_create" {
return fmt.Errorf("there was an error initializing this database. abort!")
}
return nil
}
doltDatabaseProvider := dh.provider.(*sqle.DoltDatabaseProvider)
doltDatabaseProvider.InitDatabaseHooks = append(doltDatabaseProvider.InitDatabaseHooks,
func(_ *sql.Context, _ *sqle.DoltDatabaseProvider, name string, _ *env.DoltEnv, _ dsess.SqlDatabase) error {
if name == "cannot_create" {
return fmt.Errorf("there was an error initializing this database. abort!")
}
return nil
})

err = dh.provider.CreateDatabase(enginetest.NewContext(dh), "can_create")
require.NoError(t, err)
Expand Down
6 changes: 3 additions & 3 deletions go/libraries/doltcore/sqle/statspro/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
)

func (p *Provider) Configure(ctx context.Context, ctxFactory func(ctx context.Context) (*sql.Context, error), bThreads *sql.BackgroundThreads, dbs []dsess.SqlDatabase) error {
p.SetStarter(NewInitDatabaseHook(p, ctxFactory, bThreads, nil))
p.SetStarter(NewStatsInitDatabaseHook(p, ctxFactory, bThreads))

if _, disabled, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsMemoryOnly); disabled == int8(1) {
return nil
Expand All @@ -53,8 +53,8 @@ func (p *Provider) Configure(ctx context.Context, ctxFactory func(ctx context.Co
intervalSec = time.Second * time.Duration(interval64.(int64))
thresholdf64 = threshold.(float64)

p.pro.InitDatabaseHook = NewInitDatabaseHook(p, ctxFactory, bThreads, p.pro.InitDatabaseHook)
p.pro.DropDatabaseHook = NewDropDatabaseHook(p, ctxFactory, p.pro.DropDatabaseHook)
p.pro.InitDatabaseHooks = append(p.pro.InitDatabaseHooks, NewStatsInitDatabaseHook(p, ctxFactory, bThreads))
p.pro.DropDatabaseHooks = append(p.pro.DropDatabaseHooks, NewStatsDropDatabaseHook(p))
}

eg, ctx := loadCtx.NewErrgroup()
Expand Down
23 changes: 3 additions & 20 deletions go/libraries/doltcore/sqle/statspro/initdbhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
)

func NewInitDatabaseHook(
func NewStatsInitDatabaseHook(
statsProv *Provider,
ctxFactory func(ctx context.Context) (*sql.Context, error),
bThreads *sql.BackgroundThreads,
orig sqle.InitDatabaseHook,
) sqle.InitDatabaseHook {
return func(
ctx *sql.Context,
Expand All @@ -38,15 +37,6 @@ func NewInitDatabaseHook(
denv *env.DoltEnv,
db dsess.SqlDatabase,
) error {
// We assume there is nothing on disk to read. Probably safe and also
// would deadlock with dbProvider if we tried from reading root/session.
if orig != nil {
err := orig(ctx, pro, name, denv, db)
if err != nil {
return err
}
}

statsDb, err := statsProv.sf.Init(ctx, db, statsProv.pro, denv.FS, env.GetCurrentUserHomeDir)
if err != nil {
ctx.GetLogger().Debugf("statistics load error: %s", err.Error())
Expand All @@ -61,15 +51,8 @@ func NewInitDatabaseHook(
}
}

func NewDropDatabaseHook(statsProv *Provider, ctxFactory func(ctx context.Context) (*sql.Context, error), orig sqle.DropDatabaseHook) sqle.DropDatabaseHook {
return func(name string) {
if orig != nil {
orig(name)
}
ctx, err := ctxFactory(context.Background())
if err != nil {
return
}
func NewStatsDropDatabaseHook(statsProv *Provider) sqle.DropDatabaseHook {
return func(ctx *sql.Context, name string) {
statsProv.CancelRefreshThread(name)
statsProv.DropDbStats(ctx, name, false)

Expand Down

0 comments on commit 76da4e5

Please sign in to comment.