Skip to content

Commit

Permalink
workloadrepo: Simplify the snapshot code. (#59236)
Browse files Browse the repository at this point in the history
ref #58247
  • Loading branch information
wddevries authored Feb 19, 2025
1 parent d94bcfb commit 0ac0618
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 80 deletions.
6 changes: 3 additions & 3 deletions pkg/executor/workloadrepo.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ import (
)

// TakeSnapshot is a hook from workload repo that may trigger manual snapshot.
var TakeSnapshot func() error
var TakeSnapshot func(context.Context) error

// WorkloadRepoCreateExec indicates WorkloadRepoCreate executor.
type WorkloadRepoCreateExec struct {
exec.BaseExecutor
}

// Next implements the Executor Next interface.
func (*WorkloadRepoCreateExec) Next(context.Context, *chunk.Chunk) error {
func (*WorkloadRepoCreateExec) Next(ctx context.Context, _ *chunk.Chunk) error {
if TakeSnapshot != nil {
return TakeSnapshot()
return TakeSnapshot(ctx)
}
return nil
}
11 changes: 5 additions & 6 deletions pkg/util/workloadrepo/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@ import (
)

const (
ownerKey = "/tidb/workloadrepo/owner"
promptKey = "workloadrepo"
snapIDKey = "/tidb/workloadrepo/snap_id"
snapCommandKey = "/tidb/workloadrepo/snap_command"

snapCommandTake = "take_snapshot"
ownerKey = "/tidb/workloadrepo/owner"
promptKey = "workloadrepo"
snapIDKey = "/tidb/workloadrepo/snap_id"

etcdOpTimeout = 5 * time.Second
snapshotRetries = 5
Expand All @@ -49,4 +46,6 @@ var (

errWrongValueForVar = dbterror.ClassUtil.NewStd(errno.ErrWrongValueForVar)
errUnsupportedEtcdRequired = dbterror.ClassUtil.NewStdErr(errno.ErrNotSupportedYet, mysql.Message("etcd client required for workload repository", nil))
errWorkloadNotStarted = dbterror.ClassUtil.NewStdErr(errno.ErrNotSupportedYet, mysql.Message("Workload repository is not enabled", nil))
errCouldNotStartSnapshot = dbterror.ClassUtil.NewStdErr(errno.ErrUnknown, mysql.Message("Snapshot initiation failed", nil))
)
94 changes: 32 additions & 62 deletions pkg/util/workloadrepo/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,6 @@ func (w *worker) getSnapID(ctx context.Context) (uint64, error) {
return strconv.ParseUint(snapIDStr, 10, 64)
}

func (w *worker) updateSnapID(ctx context.Context, oid, nid uint64) error {
return w.etcdCAS(ctx, snapIDKey,
strconv.FormatUint(oid, 10),
strconv.FormatUint(nid, 10))
}

func (w *worker) createSnapID(ctx context.Context, nid uint64) error {
return w.etcdCreate(ctx, snapIDKey, strconv.FormatUint(nid, 10))
}

func upsertHistSnapshot(ctx context.Context, sctx sessionctx.Context, snapID uint64) error {
// TODO: fill DB_VER, WR_VER
snapshotsInsert := sqlescape.MustEscapeSQL("INSERT INTO %n.%n (`BEGIN_TIME`, `SNAP_ID`) VALUES (now(), %%?) ON DUPLICATE KEY UPDATE `BEGIN_TIME` = now()",
Expand All @@ -107,7 +97,11 @@ func upsertHistSnapshot(ctx context.Context, sctx sessionctx.Context, snapID uin
return err
}

func updateHistSnapshot(ctx context.Context, sctx sessionctx.Context, snapID uint64, errs []error) error {
func (w *worker) updateHistSnapshot(ctx context.Context, snapID uint64, errs []error) error {
_sessctx := w.getSessionWithRetry()
defer w.sesspool.Put(_sessctx)
sctx := _sessctx.(sessionctx.Context)

var nerr any
if err := stderrors.Join(errs...); err != nil {
nerr = err.Error()
Expand Down Expand Up @@ -136,34 +130,19 @@ func (w *worker) snapshotTable(ctx context.Context, snapID uint64, rt *repositor
return nil
}

func (w *worker) takeSnapshot(ctx context.Context, sess sessionctx.Context, sendCommand bool) {
// coordination logic
if !w.owner.IsOwner() {
if sendCommand {
command, err := w.etcdGet(ctx, snapCommandKey)
if err != nil {
logutil.BgLogger().Info("workload repository cannot get current snap command value", zap.NamedError("err", err))
return
}

if command == "" {
err = w.etcdCreate(ctx, snapCommandKey, snapCommandTake)
} else {
err = w.etcdCAS(ctx, snapCommandKey, command, snapCommandTake)
}

if err != nil {
logutil.BgLogger().Info("workload repository cannot send snapshot command", zap.NamedError("err", err))
return
}
}
return
}
// takeSnapshot increments the value of snapIDKey, which triggers the tidb
// nodes to run the snapshot process. See the code in startSnapshot().
func (w *worker) takeSnapshot(ctx context.Context) (uint64, error) {
_sessctx := w.getSessionWithRetry()
defer w.sesspool.Put(_sessctx)
sess := _sessctx.(sessionctx.Context)

var snapID uint64
var err error
for range snapshotRetries {
snapID, err := w.getSnapID(ctx)
snapID, err = w.getSnapID(ctx)
if err != nil {
logutil.BgLogger().Info("workload repository cannot get current snapid", zap.NamedError("err", err))
err = fmt.Errorf("cannot get current snapid: %w", err)
continue
}

Expand All @@ -174,58 +153,45 @@ func (w *worker) takeSnapshot(ctx context.Context, sess sessionctx.Context, send
// due to another owner winning the etcd CAS loop.
// While undesirable, this scenario is acceptable since both owners would
// likely share similar datetime values and same cluster version.
if err := upsertHistSnapshot(ctx, sess, snapID+1); err != nil {
logutil.BgLogger().Info("workload repository could not insert into hist_snapshots", zap.NamedError("err", err))
if err = upsertHistSnapshot(ctx, sess, snapID+1); err != nil {
err = fmt.Errorf("could not insert into hist_snapshots: %w", err)
continue
}

if snapID == 0 {
err = w.createSnapID(ctx, snapID+1)
err = w.etcdCreate(ctx, snapIDKey, strconv.FormatUint(snapID+1, 10))
} else {
err = w.updateSnapID(ctx, snapID, snapID+1)
err = w.etcdCAS(ctx, snapIDKey, strconv.FormatUint(snapID, 10), strconv.FormatUint(snapID+1, 10))
}

if err != nil {
logutil.BgLogger().Info("workload repository cannot update current snapid", zap.Uint64("new_id", snapID), zap.NamedError("err", err))
err = fmt.Errorf("cannot update current snapid to %d: %w", snapID, err)
continue
}

logutil.BgLogger().Info("workload repository fired snapshot", zap.String("owner", w.instanceID), zap.Uint64("snapID", snapID+1))
break
}

// return the last error seen, if it ended on an error
return snapID, err
}

func (w *worker) startSnapshot(_ctx context.Context) func() {
return func() {
w.resetSnapshotInterval(w.snapshotInterval)

_sessctx := w.getSessionWithRetry()
defer w.sesspool.Put(_sessctx)
sess := _sessctx.(sessionctx.Context)

// this is for etcd watch
// other wise wch won't be collected after the exit of this function
ctx, cancel := context.WithCancel(_ctx)
defer cancel()
snapIDCh := w.etcdClient.Watch(ctx, snapIDKey)
snapCmdCh := w.etcdClient.Watch(ctx, snapCommandKey)

for {
select {
case <-ctx.Done():
return
case resp := <-snapCmdCh:
if len(resp.Events) < 1 {
continue
}

// same as snapID events
// we only catch the last event if possible
snapCommandStr := string(resp.Events[len(resp.Events)-1].Kv.Value)
if snapCommandStr == snapCommandTake {
w.takeSnapshot(ctx, sess, false)
}
case resp := <-snapIDCh:
// This case is triggered by both by w.snapshotInterval and the SQL command, which calls w.takeSnapshot() directly.
if len(resp.Events) < 1 {
// since there is no event, we don't know the latest snapid either
// really should not happen except creation
Expand Down Expand Up @@ -261,13 +227,17 @@ func (w *worker) startSnapshot(_ctx context.Context) func() {
}
wg.Wait()

if err := updateHistSnapshot(ctx, sess, snapID, errs); err != nil {
if err := w.updateHistSnapshot(ctx, snapID, errs); err != nil {
logutil.BgLogger().Info("workload repository snapshot failed: could not update hist_snapshots", zap.NamedError("err", err))
}
case <-w.snapshotChan:
w.takeSnapshot(ctx, sess, true)
case <-w.snapshotTicker.C:
w.takeSnapshot(ctx, sess, false)
if w.owner.IsOwner() {
if snapID, err := w.takeSnapshot(ctx); err != nil {
logutil.BgLogger().Info("workload repository snapshot failed", zap.NamedError("err", err))
} else {
logutil.BgLogger().Info("workload repository ran snapshot", zap.String("owner", w.instanceID), zap.Uint64("snapID", snapID))
}
}
}
}
}
Expand Down
21 changes: 14 additions & 7 deletions pkg/util/workloadrepo/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,26 @@ type worker struct {
samplingTicker *time.Ticker
snapshotInterval int32
snapshotTicker *time.Ticker
snapshotChan chan struct{}
retentionDays int32
}

var workerCtx = worker{}

func takeSnapshot() error {
if workerCtx.snapshotChan == nil {
return errors.New("Workload repository is not enabled yet")
func takeSnapshot(ctx context.Context) error {
workerCtx.Lock()
defer workerCtx.Unlock()

if !workerCtx.enabled {
return errWorkloadNotStarted.GenWithStackByArgs()
}
workerCtx.snapshotChan <- struct{}{}

snapID, err := workerCtx.takeSnapshot(ctx)
if err != nil {
logutil.BgLogger().Info("workload repository manual snapshot failed", zap.String("owner", workerCtx.instanceID), zap.NamedError("err", err))
return errCouldNotStartSnapshot.GenWithStackByArgs()
}

logutil.BgLogger().Info("workload repository ran manual snapshot", zap.String("owner", workerCtx.instanceID), zap.Uint64("snapID", snapID))
return nil
}

Expand Down Expand Up @@ -360,7 +369,6 @@ func (w *worker) start() error {
}

_ = stmtsummary.StmtSummaryByDigestMap.SetHistoryEnabled(false)
w.snapshotChan = make(chan struct{}, 1)
ctx, cancel := context.WithCancel(context.Background())
w.cancel = cancel
w.wg.RunWithRecover(w.startRepository(ctx), func(err any) {
Expand Down Expand Up @@ -388,7 +396,6 @@ func (w *worker) stop() {
}

w.cancel = nil
w.snapshotChan = nil
}

// setRepositoryDest will change the dest of workload snapshot.
Expand Down
4 changes: 2 additions & 2 deletions pkg/util/workloadrepo/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,13 @@ func TestRaceToCreateTablesWorker(t *testing.T) {
require.Len(t, res, 0)

// manually trigger snapshot by sending a tick to all workers
wrk1.snapshotChan <- struct{}{}
wrk1.takeSnapshot(ctx)
require.Eventually(t, func() bool {
res := tk.MustQuery("select snap_id, count(*) from workload_schema.hist_snapshots group by snap_id").Rows()
return len(res) == 1
}, time.Minute, time.Second)

wrk2.snapshotChan <- struct{}{}
wrk2.takeSnapshot(ctx)
require.Eventually(t, func() bool {
res := tk.MustQuery("select snap_id, count(*) from workload_schema.hist_snapshots group by snap_id").Rows()
return len(res) == 2
Expand Down

0 comments on commit 0ac0618

Please sign in to comment.