Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing database init/drop hooks to be a slice of hooks #7804

Merged
merged 3 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions go/cmd/dolt/commands/engine/sqlengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,14 @@ func NewSqlEngine(
pro = pro.WithRemoteDialer(mrEnv.RemoteDialProvider())

config.ClusterController.RegisterStoredProcedures(pro)
pro.InitDatabaseHook = cluster.NewInitDatabaseHook(config.ClusterController, bThreads, pro.InitDatabaseHook)
if config.ClusterController != nil {
pro.InitDatabaseHooks = append(pro.InitDatabaseHooks, cluster.NewInitDatabaseHook(config.ClusterController, bThreads))
pro.DropDatabaseHooks = append(pro.DropDatabaseHooks, config.ClusterController.DropDatabaseHook())
config.ClusterController.SetDropDatabase(pro.DropDatabase)
}

sqlEngine := &SqlEngine{}

pro.DropDatabaseHook = config.ClusterController.DropDatabaseHook()
config.ClusterController.SetDropDatabase(pro.DropDatabase)

// Create the engine
engine := gms.New(analyzer.NewBuilder(pro).WithParallelism(parallelism).Build(), &gms.Config{
IsReadOnly: config.IsReadOnly,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,17 @@ func (j jsonSerializer) serialize(ctx *sql.Context, typ sql.Type, descriptor val
return nil, err
}
if json != nil {
jsonDoc, ok := json.(gmstypes.JSONDocument)
if !ok {
return nil, fmt.Errorf("supported JSON type: %T", json)
var jsonDoc gmstypes.JSONDocument
if lazyJsonDoc, ok := json.(*gmstypes.LazyJSONDocument); ok {
i, err := lazyJsonDoc.ToInterface()
if err != nil {
return nil, err
}
jsonDoc = gmstypes.JSONDocument{Val: i}
} else if _, ok := json.(gmstypes.JSONDocument); ok {
jsonDoc = json.(gmstypes.JSONDocument)
} else {
return nil, fmt.Errorf("unsupported JSON type: %T", json)
}

jsonBuffer, err := encodeJsonDoc(jsonDoc)
Expand Down
9 changes: 3 additions & 6 deletions go/libraries/doltcore/sqle/cluster/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,13 @@ func (c *Controller) SetDropDatabase(dropDatabase func(*sql.Context, string) err
c.dropDatabase = dropDatabase
}

// Our DropDatabaseHook gets called when the database provider drops a
// DropDatabaseHook gets called when the database provider drops a
// database. This is how we learn that we need to replicate a drop database.
func (c *Controller) DropDatabaseHook() func(string) {
if c == nil {
return nil
}
func (c *Controller) DropDatabaseHook() func(*sql.Context, string) {
return c.dropDatabaseHook
}

func (c *Controller) dropDatabaseHook(dbname string) {
func (c *Controller) dropDatabaseHook(_ *sql.Context, dbname string) {
c.mu.Lock()
defer c.mu.Unlock()

Expand Down
11 changes: 1 addition & 10 deletions go/libraries/doltcore/sqle/cluster/initdbhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,8 @@ import (
"github.com/dolthub/dolt/go/store/types"
)

func NewInitDatabaseHook(controller *Controller, bt *sql.BackgroundThreads, orig sqle.InitDatabaseHook) sqle.InitDatabaseHook {
if controller == nil {
return orig
}
func NewInitDatabaseHook(controller *Controller, bt *sql.BackgroundThreads) sqle.InitDatabaseHook {
return func(ctx *sql.Context, pro *sqle.DoltDatabaseProvider, name string, denv *env.DoltEnv, db dsess.SqlDatabase) error {
var err error
err = orig(ctx, pro, name, denv, db)
if err != nil {
return err
}

dialprovider := controller.gRPCDialProvider(denv)
var remoteDBs []func(context.Context) (*doltdb.DoltDB, error)
var remoteUrls []string
Expand Down
26 changes: 14 additions & 12 deletions go/libraries/doltcore/sqle/database_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ type DoltDatabaseProvider struct {
functions map[string]sql.Function
tableFunctions map[string]sql.TableFunction
externalProcedures sql.ExternalStoredProcedureRegistry
InitDatabaseHook InitDatabaseHook
DropDatabaseHook DropDatabaseHook
InitDatabaseHooks []InitDatabaseHook
DropDatabaseHooks []DropDatabaseHook
mu *sync.RWMutex

droppedDatabaseManager *droppedDatabaseManager
Expand Down Expand Up @@ -146,7 +146,7 @@ func NewDoltDatabaseProviderWithDatabases(defaultBranch string, fs filesys.Files
fs: fs,
defaultBranch: defaultBranch,
dbFactoryUrl: dbFactoryUrl,
InitDatabaseHook: ConfigureReplicationDatabaseHook,
InitDatabaseHooks: []InitDatabaseHook{ConfigureReplicationDatabaseHook},
isStandby: new(bool),
droppedDatabaseManager: newDroppedDatabaseManager(fs),
}, nil
Expand Down Expand Up @@ -459,7 +459,7 @@ func (p *DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name str
}

type InitDatabaseHook func(ctx *sql.Context, pro *DoltDatabaseProvider, name string, env *env.DoltEnv, db dsess.SqlDatabase) error
type DropDatabaseHook func(name string)
type DropDatabaseHook func(ctx *sql.Context, name string)

// ConfigureReplicationDatabaseHook sets up replication for a newly created database as necessary
// TODO: consider the replication heads / all heads setting
Expand Down Expand Up @@ -630,10 +630,10 @@ func (p *DoltDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error
return err
}

if p.DropDatabaseHook != nil {
for _, dropHook := range p.DropDatabaseHooks {
// For symmetry with InitDatabaseHook and the names we see in
// MultiEnv initialization, we use `name` here, not `dbKey`.
p.DropDatabaseHook(name)
dropHook(ctx, name)
}

// We not only have to delete tracking metadata for this database, but also for any derivative
Expand Down Expand Up @@ -707,12 +707,14 @@ func (p *DoltDatabaseProvider) registerNewDatabase(ctx *sql.Context, name string
return err
}

// If we have an initialization hook, invoke it. By default, this will
// be ConfigureReplicationDatabaseHook, which will setup replication
// for the new database if a remote url template is set.
err = p.InitDatabaseHook(ctx, p, name, newEnv, db)
if err != nil {
return err
// If we have any initialization hooks, invoke them, until any error is returned.
// By default, this will be ConfigureReplicationDatabaseHook, which will set up
// replication for the new database if a remote url template is set.
for _, initHook := range p.InitDatabaseHooks {
err = initHook(ctx, p, name, newEnv, db)
if err != nil {
return err
}
}

formattedName := formatDbMapKeyName(db.Name())
Expand Down
14 changes: 8 additions & 6 deletions go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3287,12 +3287,14 @@ func TestCreateDatabaseErrorCleansUp(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, e)

dh.provider.(*sqle.DoltDatabaseProvider).InitDatabaseHook = func(_ *sql.Context, _ *sqle.DoltDatabaseProvider, name string, _ *env.DoltEnv, _ dsess.SqlDatabase) error {
if name == "cannot_create" {
return fmt.Errorf("there was an error initializing this database. abort!")
}
return nil
}
doltDatabaseProvider := dh.provider.(*sqle.DoltDatabaseProvider)
doltDatabaseProvider.InitDatabaseHooks = append(doltDatabaseProvider.InitDatabaseHooks,
func(_ *sql.Context, _ *sqle.DoltDatabaseProvider, name string, _ *env.DoltEnv, _ dsess.SqlDatabase) error {
if name == "cannot_create" {
return fmt.Errorf("there was an error initializing this database. abort!")
}
return nil
})

err = dh.provider.CreateDatabase(enginetest.NewContext(dh), "can_create")
require.NoError(t, err)
Expand Down
6 changes: 3 additions & 3 deletions go/libraries/doltcore/sqle/statspro/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
)

func (p *Provider) Configure(ctx context.Context, ctxFactory func(ctx context.Context) (*sql.Context, error), bThreads *sql.BackgroundThreads, dbs []dsess.SqlDatabase) error {
p.SetStarter(NewInitDatabaseHook(p, ctxFactory, bThreads, nil))
p.SetStarter(NewStatsInitDatabaseHook(p, ctxFactory, bThreads))

if _, disabled, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsMemoryOnly); disabled == int8(1) {
return nil
Expand All @@ -53,8 +53,8 @@ func (p *Provider) Configure(ctx context.Context, ctxFactory func(ctx context.Co
intervalSec = time.Second * time.Duration(interval64.(int64))
thresholdf64 = threshold.(float64)

p.pro.InitDatabaseHook = NewInitDatabaseHook(p, ctxFactory, bThreads, p.pro.InitDatabaseHook)
p.pro.DropDatabaseHook = NewDropDatabaseHook(p, ctxFactory, p.pro.DropDatabaseHook)
p.pro.InitDatabaseHooks = append(p.pro.InitDatabaseHooks, NewStatsInitDatabaseHook(p, ctxFactory, bThreads))
p.pro.DropDatabaseHooks = append(p.pro.DropDatabaseHooks, NewStatsDropDatabaseHook(p))
}

eg, ctx := loadCtx.NewErrgroup()
Expand Down
23 changes: 3 additions & 20 deletions go/libraries/doltcore/sqle/statspro/initdbhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
)

func NewInitDatabaseHook(
func NewStatsInitDatabaseHook(
statsProv *Provider,
ctxFactory func(ctx context.Context) (*sql.Context, error),
bThreads *sql.BackgroundThreads,
orig sqle.InitDatabaseHook,
) sqle.InitDatabaseHook {
return func(
ctx *sql.Context,
Expand All @@ -38,15 +37,6 @@ func NewInitDatabaseHook(
denv *env.DoltEnv,
db dsess.SqlDatabase,
) error {
// We assume there is nothing on disk to read. Probably safe and also
// would deadlock with dbProvider if we tried from reading root/session.
if orig != nil {
err := orig(ctx, pro, name, denv, db)
if err != nil {
return err
}
}

statsDb, err := statsProv.sf.Init(ctx, db, statsProv.pro, denv.FS, env.GetCurrentUserHomeDir)
if err != nil {
ctx.GetLogger().Debugf("statistics load error: %s", err.Error())
Expand All @@ -61,15 +51,8 @@ func NewInitDatabaseHook(
}
}

func NewDropDatabaseHook(statsProv *Provider, ctxFactory func(ctx context.Context) (*sql.Context, error), orig sqle.DropDatabaseHook) sqle.DropDatabaseHook {
return func(name string) {
if orig != nil {
orig(name)
}
ctx, err := ctxFactory(context.Background())
if err != nil {
return
}
func NewStatsDropDatabaseHook(statsProv *Provider) sqle.DropDatabaseHook {
return func(ctx *sql.Context, name string) {
statsProv.CancelRefreshThread(name)
statsProv.DropDbStats(ctx, name, false)

Expand Down
Loading