diff --git a/pkg/util/workloadrepo/BUILD.bazel b/pkg/util/workloadrepo/BUILD.bazel index b885213d5662a..65fbac6f8bbce 100644 --- a/pkg/util/workloadrepo/BUILD.bazel +++ b/pkg/util/workloadrepo/BUILD.bazel @@ -50,7 +50,7 @@ go_test( srcs = ["worker_test.go"], embed = [":workloadrepo"], flaky = True, - shard_count = 13, + shard_count = 14, deps = [ "//pkg/domain", "//pkg/infoschema", diff --git a/pkg/util/workloadrepo/const.go b/pkg/util/workloadrepo/const.go index 6752ef473c17a..15868c0b63efa 100644 --- a/pkg/util/workloadrepo/const.go +++ b/pkg/util/workloadrepo/const.go @@ -15,6 +15,7 @@ package workloadrepo import ( + "errors" "time" "github.com/pingcap/tidb/pkg/errno" @@ -48,4 +49,6 @@ var ( errUnsupportedEtcdRequired = dbterror.ClassUtil.NewStdErr(errno.ErrNotSupportedYet, mysql.Message("etcd client required for workload repository", nil)) errWorkloadNotStarted = dbterror.ClassUtil.NewStdErr(errno.ErrNotSupportedYet, mysql.Message("Workload repository is not enabled", nil)) errCouldNotStartSnapshot = dbterror.ClassUtil.NewStdErr(errno.ErrUnknown, mysql.Message("Snapshot initiation failed", nil)) + + errKeyNotFound = errors.New("key not found") ) diff --git a/pkg/util/workloadrepo/snapshot.go b/pkg/util/workloadrepo/snapshot.go index 6e25555ae1b75..8a9673e0d79d4 100644 --- a/pkg/util/workloadrepo/snapshot.go +++ b/pkg/util/workloadrepo/snapshot.go @@ -77,14 +77,28 @@ func (w *worker) etcdCAS(ctx context.Context, key, oval, nval string) error { return nil } +func queryMaxSnapID(ctx context.Context, sctx sessionctx.Context) (uint64, error) { + query := sqlescape.MustEscapeSQL("SELECT MAX(`SNAP_ID`) FROM %n.%n", WorkloadSchema, histSnapshotsTable) + rs, err := runQuery(ctx, sctx, query) + if err != nil { + return 0, err + } + if len(rs) > 0 { + if rs[0].IsNull(0) { + return 0, nil + } + return rs[0].GetUint64(0), nil + } + return 0, errors.New("no rows returned when querying max snap id") +} + func (w *worker) getSnapID(ctx context.Context) (uint64, error) { snapIDStr, err := w.etcdGet(ctx, snapIDKey) if err != nil { return 0, err } if snapIDStr == "" { - // return zero when the key does not exist - return 0, nil + return 0, errKeyNotFound } return strconv.ParseUint(snapIDStr, 10, 64) } @@ -140,7 +154,12 @@ func (w *worker) takeSnapshot(ctx context.Context) (uint64, error) { var snapID uint64 var err error for range snapshotRetries { + isEmpty := false snapID, err = w.getSnapID(ctx) + if stderrors.Is(err, errKeyNotFound) { + snapID, err = queryMaxSnapID(ctx, sess) + isEmpty = true + } if err != nil { err = fmt.Errorf("cannot get current snapid: %w", err) continue @@ -158,7 +177,7 @@ func (w *worker) takeSnapshot(ctx context.Context) (uint64, error) { continue } - if snapID == 0 { + if isEmpty { err = w.etcdCreate(ctx, snapIDKey, strconv.FormatUint(snapID+1, 10)) } else { err = w.etcdCAS(ctx, snapIDKey, strconv.FormatUint(snapID, 10), strconv.FormatUint(snapID+1, 10)) diff --git a/pkg/util/workloadrepo/worker_test.go b/pkg/util/workloadrepo/worker_test.go index 50c44c0c89fc1..97272d8b5a33f 100644 --- a/pkg/util/workloadrepo/worker_test.go +++ b/pkg/util/workloadrepo/worker_test.go @@ -59,21 +59,7 @@ func setupWorkerForTest(ctx context.Context, etcdCli *clientv3.Client, dom *doma return wrk } -func setupDomainAndContext(t *testing.T) (context.Context, kv.Storage, *domain.Domain, string) { - ctx := context.Background() - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) - var cancel context.CancelFunc = nil - if ddl, ok := t.Deadline(); ok { - ctx, cancel = context.WithDeadline(ctx, ddl) - } - t.Cleanup(func() { - if cancel != nil { - cancel() - } - }) - - store, dom := testkit.CreateMockStoreAndDomain(t) - +func setupEtcd(t *testing.T) string { cfg := embed.NewConfig() cfg.Dir = t.TempDir() @@ -97,7 +83,25 @@ func setupDomainAndContext(t *testing.T) (context.Context, kv.Storage, *domain.D require.False(t, true, "server took too long to start") } - return ctx, store, dom, embedEtcd.Clients[0].Addr().String() + return embedEtcd.Clients[0].Addr().String() +} + +func setupDomainAndContext(t *testing.T) (context.Context, kv.Storage, *domain.Domain, string) { + ctx := context.Background() + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) + var cancel context.CancelFunc = nil + if ddl, ok := t.Deadline(); ok { + ctx, cancel = context.WithDeadline(ctx, ddl) + } + t.Cleanup(func() { + if cancel != nil { + cancel() + } + }) + + store, dom := testkit.CreateMockStoreAndDomain(t) + etcdAddr := setupEtcd(t) + return ctx, store, dom, etcdAddr } func setupWorker(ctx context.Context, t *testing.T, addr string, dom *domain.Domain, id string, testWorker bool) *worker { @@ -857,3 +861,44 @@ func TestCalcNextTick(t *testing.T) { require.True(t, calcNextTick(time.Date(2024, 12, 7, 2, 0, 0, 1, loc)) == time.Hour*24-time.Nanosecond) require.True(t, calcNextTick(time.Date(2024, 12, 7, 1, 59, 59, 999999999, loc)) == time.Nanosecond) } + +func TestRecoverSnapID(t *testing.T) { + ctx, store, dom, addr := setupDomainAndContext(t) + worker := setupWorker(ctx, t, addr, dom, "worker1", true) + require.NoError(t, worker.setRepositoryDest(ctx, "table")) + now := time.Now() + + require.Eventually(t, func() bool { + return worker.checkTablesExists(ctx, now) + }, time.Minute, time.Second) + tk := testkit.NewTestKit(t, store) + prevSnapID := uint64(0) + require.Eventually(t, func() bool { + res := tk.MustQuery("select max(snap_id) from workload_schema.hist_snapshots").Rows() + if len(res) == 0 || len(res[0]) == 0 { + return false + } + snapID, err := strconv.ParseUint(res[0][0].(string), 10, 64) + prevSnapID = snapID + return err == nil && snapID > 0 + }, time.Minute, time.Second) + worker.stop() + + etcd2 := setupEtcd(t) + worker2 := setupWorker(ctx, t, etcd2, dom, "worker2", true) + snapIDStr, err := worker2.etcdGet(ctx, snapIDKey) + require.Nil(t, err) + require.Equal(t, "", snapIDStr) + + _, err = worker2.getSnapID(ctx) + require.EqualError(t, errKeyNotFound, err.Error()) + newSnapID, err := queryMaxSnapID(ctx, worker2.getSessionWithRetry().(sessionctx.Context)) + require.Nil(t, err) + require.Equal(t, prevSnapID, newSnapID) + + require.NoError(t, worker2.setRepositoryDest(ctx, "table")) + require.Eventually(t, func() bool { + newSnapID, err = worker2.getSnapID(ctx) + return err == nil && newSnapID >= prevSnapID + }, time.Minute, time.Second) +}