diff --git a/go/libraries/doltcore/sqle/statsnoms/database.go b/go/libraries/doltcore/sqle/statsnoms/database.go index 4565062f700..42bfc083648 100644 --- a/go/libraries/doltcore/sqle/statsnoms/database.go +++ b/go/libraries/doltcore/sqle/statsnoms/database.go @@ -151,6 +151,9 @@ func (n *NomsStatsDatabase) LoadBranchStats(ctx *sql.Context, branch string) err } func (n *NomsStatsDatabase) getBranchStats(branch string) dbStats { + n.mu.Lock() + defer n.mu.Unlock() + for i, b := range n.branches { if strings.EqualFold(b, branch) { return n.stats[i] @@ -174,7 +177,7 @@ func (n *NomsStatsDatabase) ListStatQuals(branch string) []sql.StatQualifier { return ret } -func (n *NomsStatsDatabase) SetStat(ctx context.Context, branch string, qual sql.StatQualifier, stats *statspro.DoltStats) error { +func (n *NomsStatsDatabase) setStat(ctx context.Context, branch string, qual sql.StatQualifier, stats *statspro.DoltStats) error { var statsMap *prolly.MutableMap for i, b := range n.branches { if strings.EqualFold(branch, b) { @@ -195,6 +198,12 @@ func (n *NomsStatsDatabase) SetStat(ctx context.Context, branch string, qual sql return n.replaceStats(ctx, statsMap, stats) } +func (n *NomsStatsDatabase) SetStat(ctx context.Context, branch string, qual sql.StatQualifier, stats *statspro.DoltStats) error { + n.mu.Lock() + defer n.mu.Unlock() + + return n.setStat(ctx, branch, qual, stats) +} func (n *NomsStatsDatabase) trackBranch(ctx context.Context, branch string) error { n.branches = append(n.branches, branch) @@ -220,6 +229,9 @@ func (n *NomsStatsDatabase) initMutable(ctx context.Context, i int) error { } func (n *NomsStatsDatabase) DeleteStats(branch string, quals ...sql.StatQualifier) { + n.mu.Lock() + defer n.mu.Unlock() + for i, b := range n.branches { if strings.EqualFold(b, branch) { for _, qual := range quals { @@ -230,6 +242,9 @@ func (n *NomsStatsDatabase) DeleteStats(branch string, quals ...sql.StatQualifie } func (n *NomsStatsDatabase) DeleteBranchStats(ctx context.Context, branch string, flush bool) error { + n.mu.Lock() + defer n.mu.Unlock() + for i, b := range n.branches { if strings.EqualFold(b, branch) { n.branches = append(n.branches[:i], n.branches[i+1:]...) @@ -245,6 +260,9 @@ func (n *NomsStatsDatabase) DeleteBranchStats(ctx context.Context, branch string } func (n *NomsStatsDatabase) ReplaceChunks(ctx context.Context, branch string, qual sql.StatQualifier, targetHashes []hash.Hash, dropChunks, newChunks []sql.HistogramBucket) error { + n.mu.Lock() + defer n.mu.Unlock() + var dbStat dbStats for i, b := range n.branches { if strings.EqualFold(b, branch) { @@ -274,10 +292,13 @@ func (n *NomsStatsDatabase) ReplaceChunks(ctx context.Context, branch string, qu dbStat[qual].UpdateActive() // let |n.SetStats| update memory and disk - return n.SetStat(ctx, branch, qual, dbStat[qual]) + return n.setStat(ctx, branch, qual, dbStat[qual]) } func (n *NomsStatsDatabase) Flush(ctx context.Context, branch string) error { + n.mu.Lock() + defer n.mu.Unlock() + for i, b := range n.branches { if strings.EqualFold(b, branch) { if n.dirty[i] != nil { diff --git a/go/libraries/doltcore/sqle/statspro/analyze.go b/go/libraries/doltcore/sqle/statspro/analyze.go index 1c8f2969010..d34e20e71d4 100644 --- a/go/libraries/doltcore/sqle/statspro/analyze.go +++ b/go/libraries/doltcore/sqle/statspro/analyze.go @@ -84,6 +84,11 @@ func (p *Provider) BootstrapDatabaseStats(ctx *sql.Context, db string) error { } func (p *Provider) RefreshTableStatsWithBranch(ctx *sql.Context, table sql.Table, db string, branch string) error { + if !p.TryLockForUpdate(table.Name(), db, branch) { + return fmt.Errorf("already updating statistics") + } + defer p.UnlockTable(table.Name(), db, branch) + dSess := dsess.DSessFromSess(ctx.Session) sqlDb, err := dSess.Provider().Database(ctx, p.branchQualifiedDatabase(db, branch)) @@ -92,8 +97,6 @@ func (p *Provider) RefreshTableStatsWithBranch(ctx *sql.Context, table sql.Table } // lock only after accessing DatabaseProvider - p.mu.Lock() - defer p.mu.Unlock() tableName := strings.ToLower(table.Name()) dbName := strings.ToLower(db) diff --git a/go/libraries/doltcore/sqle/statspro/auto_refresh.go b/go/libraries/doltcore/sqle/statspro/auto_refresh.go index 775a945a871..e87e2d46772 100644 --- a/go/libraries/doltcore/sqle/statspro/auto_refresh.go +++ b/go/libraries/doltcore/sqle/statspro/auto_refresh.go @@ -107,8 +107,10 @@ func (p *Provider) InitAutoRefreshWithParams(ctxFactory func(ctx context.Context } func (p *Provider) checkRefresh(ctx *sql.Context, sqlDb sql.Database, dbName, branch string, updateThresh float64) error { - p.mu.Lock() - defer p.mu.Unlock() + if !p.TryLockForUpdate("", dbName, branch) { + return nil + } + defer p.UnlockTable("", dbName, branch) // Iterate all dbs, tables, indexes. Each db will collect // []indexMeta above refresh threshold. We read and process those @@ -131,6 +133,10 @@ func (p *Provider) checkRefresh(ctx *sql.Context, sqlDb sql.Database, dbName, br } for _, table := range tables { + if !p.TryLockForUpdate(table, dbName, branch) { + continue + } + defer p.UnlockTable(table, dbName, branch) sqlTable, dTab, err := GetLatestTable(ctx, table, sqlDb) if err != nil { return err diff --git a/go/libraries/doltcore/sqle/statspro/stats_provider.go b/go/libraries/doltcore/sqle/statspro/stats_provider.go index 09ea404bc66..4cf3e201c44 100644 --- a/go/libraries/doltcore/sqle/statspro/stats_provider.go +++ b/go/libraries/doltcore/sqle/statspro/stats_provider.go @@ -49,12 +49,13 @@ type updateOrdinal struct { func NewProvider(pro *sqle.DoltDatabaseProvider, sf StatsFactory) *Provider { return &Provider{ - pro: pro, - sf: sf, - mu: &sync.Mutex{}, - statDbs: make(map[string]Database), - cancelers: make(map[string]context.CancelFunc), - status: make(map[string]string), + pro: pro, + sf: sf, + mu: &sync.Mutex{}, + statDbs: make(map[string]Database), + cancelers: make(map[string]context.CancelFunc), + status: make(map[string]string), + lockedTables: make(map[string]bool), } } @@ -62,13 +63,14 @@ func NewProvider(pro *sqle.DoltDatabaseProvider, sf StatsFactory) *Provider { // Each database has its own statistics table that all tables/indexes in a db // share. type Provider struct { - mu *sync.Mutex - pro *sqle.DoltDatabaseProvider - sf StatsFactory - statDbs map[string]Database - cancelers map[string]context.CancelFunc - starter sqle.InitDatabaseHook - status map[string]string + mu *sync.Mutex + pro *sqle.DoltDatabaseProvider + sf StatsFactory + statDbs map[string]Database + cancelers map[string]context.CancelFunc + starter sqle.InitDatabaseHook + status map[string]string + lockedTables map[string]bool } // each database has one statistics table that is a collection of the @@ -92,10 +94,27 @@ func newDbStats(dbName string) *dbToStats { var _ sql.StatsProvider = (*Provider)(nil) -func (p *Provider) StartRefreshThread(ctx *sql.Context, pro dsess.DoltDatabaseProvider, name string, env *env.DoltEnv, db dsess.SqlDatabase) error { - err := p.starter(ctx, pro.(*sqle.DoltDatabaseProvider), name, env, db) +func (p *Provider) TryLockForUpdate(table string, db string, branch string) bool { p.mu.Lock() defer p.mu.Unlock() + lockId := fmt.Sprintf("%s.%s.%s", db, branch, table) + if ok := p.lockedTables[lockId]; ok { + return false + } + p.lockedTables[lockId] = true + return true +} + +func (p *Provider) UnlockTable(table string, db string, branch string) { + p.mu.Lock() + defer p.mu.Unlock() + lockId := fmt.Sprintf("%s.%s.%s", db, branch, table) + p.lockedTables[lockId] = false + return +} + +func (p *Provider) StartRefreshThread(ctx *sql.Context, pro dsess.DoltDatabaseProvider, name string, env *env.DoltEnv, db dsess.SqlDatabase) error { + err := p.starter(ctx, pro.(*sqle.DoltDatabaseProvider), name, env, db) if err != nil { p.UpdateStatus(name, fmt.Sprintf("error restarting thread %s: %s", name, err.Error())) @@ -111,11 +130,12 @@ func (p *Provider) SetStarter(hook sqle.InitDatabaseHook) { func (p *Provider) CancelRefreshThread(dbName string) { p.mu.Lock() - defer p.mu.Unlock() if cancel, ok := p.cancelers[dbName]; ok { cancel() - p.UpdateStatus(dbName, fmt.Sprintf("cancelled thread: %s", dbName)) } + p.mu.Unlock() + p.UpdateStatus(dbName, fmt.Sprintf("cancelled thread: %s", dbName)) + } func (p *Provider) ThreadStatus(dbName string) string { @@ -140,9 +160,6 @@ func (p *Provider) GetTableStats(ctx *sql.Context, db string, table sql.Table) ( } func (p *Provider) GetTableDoltStats(ctx *sql.Context, branch, db, table string) ([]sql.Statistic, error) { - p.mu.Lock() - defer p.mu.Unlock() - statDb, ok := p.getStatDb(db) if !ok || statDb == nil { return nil, nil @@ -173,14 +190,13 @@ func (p *Provider) setStatDb(name string, db Database) { } func (p *Provider) getStatDb(name string) (Database, bool) { + p.mu.Lock() + defer p.mu.Unlock() statDb, ok := p.statDbs[strings.ToLower(name)] return statDb, ok } func (p *Provider) SetStats(ctx *sql.Context, s sql.Statistic) error { - p.mu.Lock() - defer p.mu.Unlock() - statDb, ok := p.getStatDb(s.Qualifier().Db()) if !ok { return nil @@ -218,9 +234,6 @@ func (p *Provider) getQualStats(ctx *sql.Context, qual sql.StatQualifier) (*Dolt } func (p *Provider) GetStats(ctx *sql.Context, qual sql.StatQualifier, _ []string) (sql.Statistic, bool) { - p.mu.Lock() - defer p.mu.Unlock() - stat, ok := p.getQualStats(ctx, qual) if !ok { return nil, false @@ -229,9 +242,6 @@ func (p *Provider) GetStats(ctx *sql.Context, qual sql.StatQualifier, _ []string } func (p *Provider) DropDbStats(ctx *sql.Context, db string, flush bool) error { - p.mu.Lock() - defer p.mu.Unlock() - statDb, ok := p.getStatDb(db) if !ok { return nil @@ -243,6 +253,9 @@ func (p *Provider) DropDbStats(ctx *sql.Context, db string, flush bool) error { return err } + p.mu.Lock() + defer p.mu.Unlock() + // remove provider access if err := statDb.DeleteBranchStats(ctx, branch, flush); err != nil { return nil @@ -254,9 +267,6 @@ func (p *Provider) DropDbStats(ctx *sql.Context, db string, flush bool) error { } func (p *Provider) DropStats(ctx *sql.Context, qual sql.StatQualifier, _ []string) error { - p.mu.Lock() - defer p.mu.Unlock() - statDb, ok := p.getStatDb(qual.Db()) if !ok { return nil @@ -277,13 +287,13 @@ func (p *Provider) DropStats(ctx *sql.Context, qual sql.StatQualifier, _ []strin } func (p *Provider) UpdateStatus(db string, msg string) { + p.mu.Lock() + defer p.mu.Unlock() + p.status[db] = msg } func (p *Provider) RowCount(ctx *sql.Context, db string, table sql.Table) (uint64, error) { - p.mu.Lock() - defer p.mu.Unlock() - statDb, ok := p.getStatDb(db) if !ok { return 0, sql.ErrDatabaseNotFound.New(db) @@ -305,9 +315,6 @@ func (p *Provider) RowCount(ctx *sql.Context, db string, table sql.Table) (uint6 } func (p *Provider) DataLength(ctx *sql.Context, db string, table sql.Table) (uint64, error) { - p.mu.Lock() - defer p.mu.Unlock() - statDb, ok := p.getStatDb(db) if !ok { return 0, sql.ErrDatabaseNotFound.New(db)