diff --git a/br/cmd/br/stream.go b/br/cmd/br/stream.go index c59ae6d859af0..f452e38917ea5 100644 --- a/br/cmd/br/stream.go +++ b/br/cmd/br/stream.go @@ -16,6 +16,7 @@ package main import ( "github.com/pingcap/errors" + advancercfg "github.com/pingcap/tidb/br/pkg/streamhelper/config" "github.com/pingcap/tidb/br/pkg/task" "github.com/pingcap/tidb/br/pkg/trace" "github.com/pingcap/tidb/br/pkg/utils" @@ -49,6 +50,7 @@ func NewStreamCommand() *cobra.Command { newStreamStatusCommand(), newStreamTruncateCommand(), newStreamCheckCommand(), + newStreamAdvancerCommand(), ) command.SetHelpFunc(func(command *cobra.Command, strings []string) { task.HiddenFlagsForStream(command.Root().PersistentFlags()) @@ -157,6 +159,21 @@ func newStreamCheckCommand() *cobra.Command { return command } +func newStreamAdvancerCommand() *cobra.Command { + command := &cobra.Command{ + Use: "advancer", + Short: "Start a central worker for advancing the checkpoint. (only for debuging, this subcommand should be integrated to TiDB)", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + return streamCommand(cmd, task.StreamCtl) + }, + Hidden: true, + } + task.DefineStreamCommonFlags(command.Flags()) + advancercfg.DefineFlagsForCheckpointAdvancerConfig(command.Flags()) + return command +} + func streamCommand(command *cobra.Command, cmdName string) error { var cfg task.StreamConfig var err error @@ -192,6 +209,13 @@ func streamCommand(command *cobra.Command, cmdName string) error { if err = cfg.ParseStreamPauseFromFlags(command.Flags()); err != nil { return errors.Trace(err) } + case task.StreamCtl: + if err = cfg.ParseStreamCommonFromFlags(command.Flags()); err != nil { + return errors.Trace(err) + } + if err = cfg.AdvancerCfg.GetFromFlags(command.Flags()); err != nil { + return errors.Trace(err) + } default: if err = cfg.ParseStreamCommonFromFlags(command.Flags()); err != nil { return errors.Trace(err) diff --git a/br/pkg/conn/conn.go b/br/pkg/conn/conn.go index 75eef2c1555ab..f90743e1bd3d5 100755 --- a/br/pkg/conn/conn.go +++ b/br/pkg/conn/conn.go @@ -9,16 +9,14 @@ import ( "fmt" "net/http" "net/url" - "os" "strings" - "sync" - "time" "github.com/docker/go-units" "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" backuppb "github.com/pingcap/kvproto/pkg/brpb" + logbackup "github.com/pingcap/kvproto/pkg/logbackuppb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" berrors "github.com/pingcap/tidb/br/pkg/errors" @@ -35,9 +33,7 @@ import ( pd "github.com/tikv/pd/client" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/backoff" "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" ) @@ -49,83 +45,17 @@ const ( // DefaultMergeRegionKeyCount is the default region key count, 960000. DefaultMergeRegionKeyCount uint64 = 960000 - - dialTimeout = 30 * time.Second - - resetRetryTimes = 3 ) -// Pool is a lazy pool of gRPC channels. -// When `Get` called, it lazily allocates new connection if connection not full. -// If it's full, then it will return allocated channels round-robin. -type Pool struct { - mu sync.Mutex - - conns []*grpc.ClientConn - next int - cap int - newConn func(ctx context.Context) (*grpc.ClientConn, error) -} - -func (p *Pool) takeConns() (conns []*grpc.ClientConn) { - p.mu.Lock() - defer p.mu.Unlock() - p.conns, conns = nil, p.conns - p.next = 0 - return conns -} - -// Close closes the conn pool. -func (p *Pool) Close() { - for _, c := range p.takeConns() { - if err := c.Close(); err != nil { - log.Warn("failed to close clientConn", zap.String("target", c.Target()), zap.Error(err)) - } - } -} - -// Get tries to get an existing connection from the pool, or make a new one if the pool not full. -func (p *Pool) Get(ctx context.Context) (*grpc.ClientConn, error) { - p.mu.Lock() - defer p.mu.Unlock() - if len(p.conns) < p.cap { - c, err := p.newConn(ctx) - if err != nil { - return nil, err - } - p.conns = append(p.conns, c) - return c, nil - } - - conn := p.conns[p.next] - p.next = (p.next + 1) % p.cap - return conn, nil -} - -// NewConnPool creates a new Pool by the specified conn factory function and capacity. -func NewConnPool(capacity int, newConn func(ctx context.Context) (*grpc.ClientConn, error)) *Pool { - return &Pool{ - cap: capacity, - conns: make([]*grpc.ClientConn, 0, capacity), - newConn: newConn, - - mu: sync.Mutex{}, - } -} - // Mgr manages connections to a TiDB cluster. type Mgr struct { *pdutil.PdController - tlsConf *tls.Config - dom *domain.Domain - storage kv.Storage // Used to access SQL related interfaces. - tikvStore tikv.Storage // Used to access TiKV specific interfaces. - grpcClis struct { - mu sync.Mutex - clis map[uint64]*grpc.ClientConn - } - keepalive keepalive.ClientParameters + dom *domain.Domain + storage kv.Storage // Used to access SQL related interfaces. + tikvStore tikv.Storage // Used to access TiKV specific interfaces. ownsStorage bool + + *utils.StoreManager } // StoreBehavior is the action to do in GetAllTiKVStores when a non-TiKV @@ -298,122 +228,31 @@ func NewMgr( storage: storage, tikvStore: tikvStorage, dom: dom, - tlsConf: tlsConf, ownsStorage: g.OwnsStorage(), - grpcClis: struct { - mu sync.Mutex - clis map[uint64]*grpc.ClientConn - }{clis: make(map[uint64]*grpc.ClientConn)}, - keepalive: keepalive, + StoreManager: utils.NewStoreManager(controller.GetPDClient(), keepalive, tlsConf), } return mgr, nil } -func (mgr *Mgr) getGrpcConnLocked(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { - failpoint.Inject("hint-get-backup-client", func(v failpoint.Value) { - log.Info("failpoint hint-get-backup-client injected, "+ - "process will notify the shell.", zap.Uint64("store", storeID)) - if sigFile, ok := v.(string); ok { - file, err := os.Create(sigFile) - if err != nil { - log.Warn("failed to create file for notifying, skipping notify", zap.Error(err)) - } - if file != nil { - file.Close() - } - } - time.Sleep(3 * time.Second) - }) - store, err := mgr.GetPDClient().GetStore(ctx, storeID) - if err != nil { - return nil, errors.Trace(err) - } - opt := grpc.WithInsecure() - if mgr.tlsConf != nil { - opt = grpc.WithTransportCredentials(credentials.NewTLS(mgr.tlsConf)) - } - ctx, cancel := context.WithTimeout(ctx, dialTimeout) - bfConf := backoff.DefaultConfig - bfConf.MaxDelay = time.Second * 3 - addr := store.GetPeerAddress() - if addr == "" { - addr = store.GetAddress() - } - conn, err := grpc.DialContext( - ctx, - addr, - opt, - grpc.WithBlock(), - grpc.WithConnectParams(grpc.ConnectParams{Backoff: bfConf}), - grpc.WithKeepaliveParams(mgr.keepalive), - ) - cancel() - if err != nil { - return nil, berrors.ErrFailedToConnect.Wrap(err).GenWithStack("failed to make connection to store %d", storeID) - } - return conn, nil -} - // GetBackupClient get or create a backup client. func (mgr *Mgr) GetBackupClient(ctx context.Context, storeID uint64) (backuppb.BackupClient, error) { - if ctx.Err() != nil { - return nil, errors.Trace(ctx.Err()) - } - - mgr.grpcClis.mu.Lock() - defer mgr.grpcClis.mu.Unlock() - - if conn, ok := mgr.grpcClis.clis[storeID]; ok { - // Find a cached backup client. - return backuppb.NewBackupClient(conn), nil - } - - conn, err := mgr.getGrpcConnLocked(ctx, storeID) - if err != nil { - return nil, errors.Trace(err) + var cli backuppb.BackupClient + if err := mgr.WithConn(ctx, storeID, func(cc *grpc.ClientConn) { + cli = backuppb.NewBackupClient(cc) + }); err != nil { + return nil, err } - // Cache the conn. - mgr.grpcClis.clis[storeID] = conn - return backuppb.NewBackupClient(conn), nil + return cli, nil } -// ResetBackupClient reset the connection for backup client. -func (mgr *Mgr) ResetBackupClient(ctx context.Context, storeID uint64) (backuppb.BackupClient, error) { - if ctx.Err() != nil { - return nil, errors.Trace(ctx.Err()) - } - - mgr.grpcClis.mu.Lock() - defer mgr.grpcClis.mu.Unlock() - - if conn, ok := mgr.grpcClis.clis[storeID]; ok { - // Find a cached backup client. - log.Info("Reset backup client", zap.Uint64("storeID", storeID)) - err := conn.Close() - if err != nil { - log.Warn("close backup connection failed, ignore it", zap.Uint64("storeID", storeID)) - } - delete(mgr.grpcClis.clis, storeID) - } - var ( - conn *grpc.ClientConn - err error - ) - for retry := 0; retry < resetRetryTimes; retry++ { - conn, err = mgr.getGrpcConnLocked(ctx, storeID) - if err != nil { - log.Warn("failed to reset grpc connection, retry it", - zap.Int("retry time", retry), logutil.ShortError(err)) - time.Sleep(time.Duration(retry+3) * time.Second) - continue - } - mgr.grpcClis.clis[storeID] = conn - break - } - if err != nil { - return nil, errors.Trace(err) +func (mgr *Mgr) GetLogBackupClient(ctx context.Context, storeID uint64) (logbackup.LogBackupClient, error) { + var cli logbackup.LogBackupClient + if err := mgr.WithConn(ctx, storeID, func(cc *grpc.ClientConn) { + cli = logbackup.NewLogBackupClient(cc) + }); err != nil { + return nil, err } - return backuppb.NewBackupClient(conn), nil + return cli, nil } // GetStorage returns a kv storage. @@ -423,7 +262,7 @@ func (mgr *Mgr) GetStorage() kv.Storage { // GetTLSConfig returns the tls config. func (mgr *Mgr) GetTLSConfig() *tls.Config { - return mgr.tlsConf + return mgr.StoreManager.TLSConfig() } // GetLockResolver gets the LockResolver. @@ -436,17 +275,10 @@ func (mgr *Mgr) GetDomain() *domain.Domain { return mgr.dom } -// Close closes all client in Mgr. func (mgr *Mgr) Close() { - mgr.grpcClis.mu.Lock() - for _, cli := range mgr.grpcClis.clis { - err := cli.Close() - if err != nil { - log.Error("fail to close Mgr", zap.Error(err)) - } + if mgr.StoreManager != nil { + mgr.StoreManager.Close() } - mgr.grpcClis.mu.Unlock() - // Gracefully shutdown domain so it does not affect other TiDB DDL. // Must close domain before closing storage, otherwise it gets stuck forever. if mgr.ownsStorage { diff --git a/br/pkg/logutil/logging.go b/br/pkg/logutil/logging.go index 71b882b7af9db..354b900e5605a 100644 --- a/br/pkg/logutil/logging.go +++ b/br/pkg/logutil/logging.go @@ -14,6 +14,7 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" "github.com/pingcap/tidb/br/pkg/redact" + "github.com/pingcap/tidb/kv" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) @@ -269,3 +270,29 @@ func Redact(field zap.Field) zap.Field { } return field } + +// StringifyRanges wrappes the key range into a stringer. +type StringifyKeys []kv.KeyRange + +func (kr StringifyKeys) String() string { + sb := new(strings.Builder) + sb.WriteString("{") + for i, rng := range kr { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString("[") + sb.WriteString(redact.Key(rng.StartKey)) + sb.WriteString(", ") + var endKey string + if len(rng.EndKey) == 0 { + endKey = "inf" + } else { + endKey = redact.Key(rng.EndKey) + } + sb.WriteString(redact.String(endKey)) + sb.WriteString(")") + } + sb.WriteString("}") + return sb.String() +} diff --git a/br/pkg/restore/client.go b/br/pkg/restore/client.go index 411c8fa474cb9..e5f382233f94a 100644 --- a/br/pkg/restore/client.go +++ b/br/pkg/restore/client.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/backup" "github.com/pingcap/tidb/br/pkg/checksum" "github.com/pingcap/tidb/br/pkg/conn" berrors "github.com/pingcap/tidb/br/pkg/errors" @@ -145,6 +146,11 @@ type Client struct { currentTS uint64 storage storage.ExternalStorage + + // the query to insert rows into table `gc_delete_range`, lack of ts. + deleteRangeQuery []string + deleteRangeQueryCh chan string + deleteRangeQueryWaitGroup sync.WaitGroup } // NewRestoreClient returns a new RestoreClient. @@ -155,11 +161,13 @@ func NewRestoreClient( isRawKv bool, ) *Client { return &Client{ - pdClient: pdClient, - toolClient: NewSplitClient(pdClient, tlsConf, isRawKv), - tlsConf: tlsConf, - keepaliveConf: keepaliveConf, - switchCh: make(chan struct{}), + pdClient: pdClient, + toolClient: NewSplitClient(pdClient, tlsConf, isRawKv), + tlsConf: tlsConf, + keepaliveConf: keepaliveConf, + switchCh: make(chan struct{}), + deleteRangeQuery: make([]string, 0), + deleteRangeQueryCh: make(chan string, 10), } } @@ -297,6 +305,8 @@ func (rc *Client) SetRestoreRangeTS(startTs, restoreTS, shiftStartTS uint64) { rc.startTS = startTs rc.restoreTS = restoreTS rc.shiftStartTS = shiftStartTS + log.Info("set restore range ts", zap.Uint64("shift-start-ts", shiftStartTS), + zap.Uint64("start-ts", startTs), zap.Uint64("restored-ts", restoreTS)) } func (rc *Client) SetCurrentTS(ts uint64) { @@ -1605,6 +1615,15 @@ func (rc *Client) ReadStreamDataFiles( log.Debug("backup stream collect data file", zap.String("file", d.Path)) } } + + // sort files firstly. + slices.SortFunc(mFiles, func(i, j *backuppb.DataFileInfo) bool { + if i.ResolvedTs > 0 && j.ResolvedTs > 0 { + return i.ResolvedTs < j.ResolvedTs + } else { + return i.MaxTs < j.MaxTs + } + }) return dFiles, mFiles, nil } @@ -1671,6 +1690,7 @@ func (rc *Client) RestoreKVFiles( ctx context.Context, rules map[int64]*RewriteRules, files []*backuppb.DataFileInfo, + updateStats func(kvCount uint64, size uint64), onProgress func(), ) error { var err error @@ -1712,6 +1732,7 @@ func (rc *Client) RestoreKVFiles( fileStart := time.Now() defer func() { onProgress() + updateStats(uint64(file.NumberOfEntries), file.Length) summary.CollectInt("File", 1) log.Info("import files done", zap.String("name", file.Path), zap.Duration("take", time.Since(fileStart))) }() @@ -1829,13 +1850,14 @@ func (rc *Client) RestoreMetaKVFiles( ctx context.Context, files []*backuppb.DataFileInfo, schemasReplace *stream.SchemasReplace, + updateStats func(kvCount uint64, size uint64), progressInc func(), ) error { filesInWriteCF := make([]*backuppb.DataFileInfo, 0, len(files)) // The k-v events in default CF should be restored firstly. The reason is that: - // The error of transactions of meta will happen, - // if restore default CF events successfully, but failed to restore write CF events. + // The error of transactions of meta could happen if restore write CF events successfully, + // but failed to restore default CF events. for _, f := range files { if f.Cf == stream.WriteCF { filesInWriteCF = append(filesInWriteCF, f) @@ -1849,19 +1871,21 @@ func (rc *Client) RestoreMetaKVFiles( continue } - err := rc.RestoreMetaKVFile(ctx, f, schemasReplace) + kvCount, size, err := rc.RestoreMetaKVFile(ctx, f, schemasReplace) if err != nil { return errors.Trace(err) } + updateStats(kvCount, size) progressInc() } // Restore files in write CF. for _, f := range filesInWriteCF { - err := rc.RestoreMetaKVFile(ctx, f, schemasReplace) + kvCount, size, err := rc.RestoreMetaKVFile(ctx, f, schemasReplace) if err != nil { return errors.Trace(err) } + updateStats(kvCount, size) progressInc() } @@ -1877,7 +1901,11 @@ func (rc *Client) RestoreMetaKVFile( ctx context.Context, file *backuppb.DataFileInfo, sr *stream.SchemasReplace, -) error { +) (uint64, uint64, error) { + var ( + kvCount uint64 + size uint64 + ) log.Info("restore meta kv events", zap.String("file", file.Path), zap.String("cf", file.Cf), zap.Int64("kv-count", file.NumberOfEntries), zap.Uint64("min-ts", file.MinTs), zap.Uint64("max-ts", file.MaxTs)) @@ -1885,10 +1913,10 @@ func (rc *Client) RestoreMetaKVFile( rc.rawKVClient.SetColumnFamily(file.GetCf()) buff, err := rc.storage.ReadFile(ctx, file.Path) if err != nil { - return errors.Trace(err) + return 0, 0, errors.Trace(err) } if checksum := sha256.Sum256(buff); !bytes.Equal(checksum[:], file.GetSha256()) { - return errors.Annotatef(berrors.ErrInvalidMetaFile, + return 0, 0, errors.Annotatef(berrors.ErrInvalidMetaFile, "checksum mismatch expect %x, got %x", file.GetSha256(), checksum[:]) } @@ -1896,13 +1924,13 @@ func (rc *Client) RestoreMetaKVFile( for iter.Valid() { iter.Next() if iter.GetError() != nil { - return errors.Trace(iter.GetError()) + return 0, 0, errors.Trace(iter.GetError()) } txnEntry := kv.Entry{Key: iter.Key(), Value: iter.Value()} ts, err := GetKeyTS(txnEntry.Key) if err != nil { - return errors.Trace(err) + return 0, 0, errors.Trace(err) } // The commitTs in write CF need be limited on [startTs, restoreTs]. @@ -1924,11 +1952,11 @@ func (rc *Client) RestoreMetaKVFile( } log.Debug("txn entry", zap.Uint64("key-ts", ts), zap.Int("txnKey-len", len(txnEntry.Key)), zap.Int("txnValue-len", len(txnEntry.Value)), zap.ByteString("txnKey", txnEntry.Key)) - newEntry, err := sr.RewriteKvEntry(&txnEntry, file.Cf) + newEntry, err := sr.RewriteKvEntry(&txnEntry, file.Cf, rc.InsertDeleteRangeForTable, rc.InsertDeleteRangeForIndex) if err != nil { log.Error("rewrite txn entry failed", zap.Int("klen", len(txnEntry.Key)), logutil.Key("txn-key", txnEntry.Key)) - return errors.Trace(err) + return 0, 0, errors.Trace(err) } else if newEntry == nil { continue } @@ -1936,11 +1964,14 @@ func (rc *Client) RestoreMetaKVFile( zap.Int("newValue-len", len(txnEntry.Value)), zap.ByteString("newkey", newEntry.Key)) if err := rc.rawKVClient.Put(ctx, newEntry.Key, newEntry.Value, ts); err != nil { - return errors.Trace(err) + return 0, 0, errors.Trace(err) } + + kvCount += 1 + size += uint64(len(newEntry.Key) + len(newEntry.Value)) } - return rc.rawKVClient.PutRest(ctx) + return kvCount, size, rc.rawKVClient.PutRest(ctx) } func transferBoolToValue(enable bool) string { @@ -2026,6 +2057,111 @@ func (rc *Client) UpdateSchemaVersion(ctx context.Context) error { return nil } +const ( + insertDeleteRangeSQLPrefix = `INSERT IGNORE INTO mysql.gc_delete_range VALUES ` + insertDeleteRangeSQLValue = "(%d, %d, '%s', '%s', %%[1]d)" + + batchInsertDeleteRangeSize = 256 +) + +// InsertDeleteRangeForTable generates query to insert table delete job into table `gc_delete_range`. +func (rc *Client) InsertDeleteRangeForTable(jobID int64, tableIDs []int64) { + var elementID int64 = 1 + var tableID int64 + for i := 0; i < len(tableIDs); i += batchInsertDeleteRangeSize { + batchEnd := len(tableIDs) + if batchEnd > i+batchInsertDeleteRangeSize { + batchEnd = i + batchInsertDeleteRangeSize + } + + var buf strings.Builder + buf.WriteString(insertDeleteRangeSQLPrefix) + for j := i; j < batchEnd; j++ { + tableID = tableIDs[j] + startKey := tablecodec.EncodeTablePrefix(tableID) + endKey := tablecodec.EncodeTablePrefix(tableID + 1) + startKeyEncoded := hex.EncodeToString(startKey) + endKeyEncoded := hex.EncodeToString(endKey) + buf.WriteString(fmt.Sprintf(insertDeleteRangeSQLValue, jobID, elementID, startKeyEncoded, endKeyEncoded)) + if j != batchEnd-1 { + buf.WriteString(",") + } + elementID += 1 + } + rc.deleteRangeQueryCh <- buf.String() + } +} + +// InsertDeleteRangeForIndex generates query to insert index delete job into table `gc_delete_range`. +func (rc *Client) InsertDeleteRangeForIndex(jobID int64, elementID *int64, tableID int64, indexIDs []int64) { + var indexID int64 + for i := 0; i < len(indexIDs); i += batchInsertDeleteRangeSize { + batchEnd := len(indexIDs) + if batchEnd > i+batchInsertDeleteRangeSize { + batchEnd = i + batchInsertDeleteRangeSize + } + + var buf strings.Builder + buf.WriteString(insertDeleteRangeSQLPrefix) + for j := i; j < batchEnd; j++ { + indexID = indexIDs[j] + startKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID) + endKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID+1) + startKeyEncoded := hex.EncodeToString(startKey) + endKeyEncoded := hex.EncodeToString(endKey) + buf.WriteString(fmt.Sprintf(insertDeleteRangeSQLValue, jobID, *elementID, startKeyEncoded, endKeyEncoded)) + if j != batchEnd-1 { + buf.WriteString(",") + } + *elementID += 1 + } + rc.deleteRangeQueryCh <- buf.String() + } +} + +// use channel to save the delete-range query to make it thread-safety. +func (rc *Client) RunGCRowsLoader(ctx context.Context) { + rc.deleteRangeQueryWaitGroup.Add(1) + + go func() { + defer rc.deleteRangeQueryWaitGroup.Done() + for { + select { + case <-ctx.Done(): + return + case query, ok := <-rc.deleteRangeQueryCh: + if !ok { + return + } + rc.deleteRangeQuery = append(rc.deleteRangeQuery, query) + } + } + }() +} + +// InsertGCRows insert the querys into table `gc_delete_range` +func (rc *Client) InsertGCRows(ctx context.Context) error { + close(rc.deleteRangeQueryCh) + rc.deleteRangeQueryWaitGroup.Wait() + ts, err := rc.GetTS(ctx) + if err != nil { + return errors.Trace(err) + } + for _, query := range rc.deleteRangeQuery { + if err := rc.db.se.ExecuteInternal(ctx, fmt.Sprintf(query, ts)); err != nil { + return errors.Trace(err) + } + } + return nil +} + +// only for unit test +func (rc *Client) GetGCRows() []string { + close(rc.deleteRangeQueryCh) + rc.deleteRangeQueryWaitGroup.Wait() + return rc.deleteRangeQuery +} + func (rc *Client) SaveSchemas( ctx context.Context, sr *stream.SchemasReplace, @@ -2039,7 +2175,7 @@ func (rc *Client) SaveSchemas( m.StartVersion = logStartTS }) - schemas := sr.TidyOldSchemas() + schemas := TidyOldSchemas(sr) schemasConcurrency := uint(mathutil.Min(64, schemas.Len())) err := schemas.BackupSchemas(ctx, metaWriter, nil, nil, rc.restoreTS, schemasConcurrency, 0, true, nil) if err != nil { @@ -2056,3 +2192,31 @@ func (rc *Client) SaveSchemas( func MockClient(dbs map[string]*utils.Database) *Client { return &Client{databases: dbs} } + +// TidyOldSchemas produces schemas information. +func TidyOldSchemas(sr *stream.SchemasReplace) *backup.Schemas { + var schemaIsEmpty bool + schemas := backup.NewBackupSchemas() + + for _, dr := range sr.DbMap { + if dr.OldDBInfo == nil { + continue + } + + schemaIsEmpty = true + for _, tr := range dr.TableMap { + if tr.OldTableInfo == nil { + continue + } + schemas.AddSchema(dr.OldDBInfo, tr.OldTableInfo) + schemaIsEmpty = false + } + + // backup this empty schema if it has nothing table. + if schemaIsEmpty { + schemas.AddSchema(dr.OldDBInfo, nil) + } + } + return schemas + +} diff --git a/br/pkg/restore/client_test.go b/br/pkg/restore/client_test.go index 08e57e83a7095..c7374351b2d4f 100644 --- a/br/pkg/restore/client_test.go +++ b/br/pkg/restore/client_test.go @@ -351,3 +351,50 @@ func TestSetSpeedLimit(t *testing.T) { require.Equal(t, mockStores[i].Id, recordStores.stores[i]) } } + +func TestDeleteRangeQuery(t *testing.T) { + ctx := context.Background() + m := mc + mockStores := []*metapb.Store{ + { + Id: 1, + Labels: []*metapb.StoreLabel{ + { + Key: "engine", + Value: "tiflash", + }, + }, + }, + { + Id: 2, + Labels: []*metapb.StoreLabel{ + { + Key: "engine", + Value: "tiflash", + }, + }, + }, + } + + g := gluetidb.New() + client := restore.NewRestoreClient(fakePDClient{ + stores: mockStores, + }, nil, defaultKeepaliveCfg, false) + err := client.Init(g, m.Storage) + require.NoError(t, err) + + client.RunGCRowsLoader(ctx) + + client.InsertDeleteRangeForTable(2, []int64{3}) + client.InsertDeleteRangeForTable(4, []int64{5, 6}) + + elementID := int64(1) + client.InsertDeleteRangeForIndex(7, &elementID, 8, []int64{1}) + client.InsertDeleteRangeForIndex(9, &elementID, 10, []int64{1, 2}) + + querys := client.GetGCRows() + require.Equal(t, querys[0], "INSERT IGNORE INTO mysql.gc_delete_range VALUES (2, 1, '748000000000000003', '748000000000000004', %[1]d)") + require.Equal(t, querys[1], "INSERT IGNORE INTO mysql.gc_delete_range VALUES (4, 1, '748000000000000005', '748000000000000006', %[1]d),(4, 2, '748000000000000006', '748000000000000007', %[1]d)") + require.Equal(t, querys[2], "INSERT IGNORE INTO mysql.gc_delete_range VALUES (7, 1, '7480000000000000085f698000000000000001', '7480000000000000085f698000000000000002', %[1]d)") + require.Equal(t, querys[3], "INSERT IGNORE INTO mysql.gc_delete_range VALUES (9, 2, '74800000000000000a5f698000000000000001', '74800000000000000a5f698000000000000002', %[1]d),(9, 3, '74800000000000000a5f698000000000000002', '74800000000000000a5f698000000000000003', %[1]d)") +} diff --git a/br/pkg/restore/import_retry.go b/br/pkg/restore/import_retry.go index 20d613d9bbd31..17c706c9e4444 100644 --- a/br/pkg/restore/import_retry.go +++ b/br/pkg/restore/import_retry.go @@ -234,6 +234,7 @@ func (r *RPCResult) StrategyForRetryGoError() RetryStrategy { if r.Err == nil { return StrategyGiveUp } + // we should unwrap the error or we cannot get the write gRPC status. if gRPCErr, ok := status.FromError(errors.Cause(r.Err)); ok { switch gRPCErr.Code() { diff --git a/br/pkg/restore/stream_metas.go b/br/pkg/restore/stream_metas.go index 48a2b21e65ec4..57ebe35eb81d7 100644 --- a/br/pkg/restore/stream_metas.go +++ b/br/pkg/restore/stream_metas.go @@ -9,9 +9,12 @@ import ( "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/stream" + "github.com/pingcap/tidb/util/mathutil" + "go.uber.org/zap" ) type StreamMetadataSet struct { @@ -53,6 +56,21 @@ func (ms *StreamMetadataSet) iterateDataFiles(f func(d *backuppb.DataFileInfo) ( } } +// CalculateShiftTS calculates the shift-ts. +func (ms *StreamMetadataSet) CalculateShiftTS(startTS uint64) uint64 { + metadatas := make([]*backuppb.Metadata, 0, len(ms.metadata)) + for _, m := range ms.metadata { + metadatas = append(metadatas, m) + } + + minBeginTS, exist := CalculateShiftTS(metadatas, startTS, mathutil.MaxUint) + if !exist { + minBeginTS = startTS + } + log.Warn("calculate shift-ts", zap.Uint64("start-ts", startTS), zap.Uint64("shift-ts", minBeginTS)) + return minBeginTS +} + // IterateFilesFullyBefore runs the function over all files contain data before the timestamp only. // 0 before // |------------------------------------------| @@ -214,3 +232,35 @@ func SetTSToFile( content := strconv.FormatUint(safepoint, 10) return truncateAndWrite(ctx, s, filename, []byte(content)) } + +// CalculateShiftTS gets the minimal begin-ts about transaction according to the kv-event in write-cf. +func CalculateShiftTS( + metas []*backuppb.Metadata, + startTS uint64, + restoreTS uint64, +) (uint64, bool) { + var ( + minBeginTS uint64 + isExist bool + ) + for _, m := range metas { + if len(m.Files) == 0 || m.MinTs > restoreTS || m.MaxTs < startTS { + continue + } + + for _, d := range m.Files { + if d.Cf == stream.DefaultCF || d.MinBeginTsInDefaultCf == 0 { + continue + } + if d.MinTs > restoreTS || d.MaxTs < startTS { + continue + } + if d.MinBeginTsInDefaultCf < minBeginTS || !isExist { + isExist = true + minBeginTS = d.MinBeginTsInDefaultCf + } + } + } + + return minBeginTS, isExist +} diff --git a/br/pkg/restore/stream_metas_test.go b/br/pkg/restore/stream_metas_test.go index 96f46cf5c6747..e7e1607f4a0cf 100644 --- a/br/pkg/restore/stream_metas_test.go +++ b/br/pkg/restore/stream_metas_test.go @@ -15,6 +15,7 @@ import ( "github.com/pingcap/tidb/br/pkg/restore" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/stream" + "github.com/pingcap/tidb/util/mathutil" "github.com/stretchr/testify/require" "go.uber.org/zap" ) @@ -129,3 +130,72 @@ func TestTruncateSafepoint(t *testing.T) { require.Equal(t, ts, n, "failed at %d round: truncate safepoint mismatch", i) } } + +func fakeMetaDatas(cf string) []*backuppb.Metadata { + ms := []*backuppb.Metadata{ + { + StoreId: 1, + MinTs: 1500, + MaxTs: 2000, + Files: []*backuppb.DataFileInfo{ + { + MinTs: 1500, + MaxTs: 2000, + Cf: cf, + MinBeginTsInDefaultCf: 800, + }, + }, + }, + { + StoreId: 2, + MinTs: 3000, + MaxTs: 4000, + Files: []*backuppb.DataFileInfo{ + { + MinTs: 3000, + MaxTs: 4000, + Cf: cf, + MinBeginTsInDefaultCf: 2000, + }, + }, + }, + { + StoreId: 3, + MinTs: 5100, + MaxTs: 6100, + Files: []*backuppb.DataFileInfo{ + { + MinTs: 5100, + MaxTs: 6100, + Cf: cf, + MinBeginTsInDefaultCf: 1800, + }, + }, + }, + } + return ms +} + +func TestCalculateShiftTS(t *testing.T) { + var ( + startTs uint64 = 2900 + restoreTS uint64 = 4500 + ) + + ms := fakeMetaDatas(stream.WriteCF) + shiftTS, exist := restore.CalculateShiftTS(ms, startTs, restoreTS) + require.Equal(t, shiftTS, uint64(2000)) + require.Equal(t, exist, true) + + shiftTS, exist = restore.CalculateShiftTS(ms, startTs, mathutil.MaxUint) + require.Equal(t, shiftTS, uint64(1800)) + require.Equal(t, exist, true) + + shiftTS, exist = restore.CalculateShiftTS(ms, 1999, 3001) + require.Equal(t, shiftTS, uint64(800)) + require.Equal(t, exist, true) + + ms = fakeMetaDatas(stream.DefaultCF) + _, exist = restore.CalculateShiftTS(ms, startTs, restoreTS) + require.Equal(t, exist, false) +} diff --git a/br/pkg/stream/rewrite_meta_rawkv.go b/br/pkg/stream/rewrite_meta_rawkv.go index dd90625d9de0d..84f1b3d200048 100644 --- a/br/pkg/stream/rewrite_meta_rawkv.go +++ b/br/pkg/stream/rewrite_meta_rawkv.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/backup" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/parser/model" @@ -98,33 +97,6 @@ func NewSchemasReplace( } } -// TidyOldSchemas produces schemas information. -func (sr *SchemasReplace) TidyOldSchemas() *backup.Schemas { - var schemaIsEmpty bool - schemas := backup.NewBackupSchemas() - - for _, dr := range sr.DbMap { - if dr.OldDBInfo == nil { - continue - } - - schemaIsEmpty = true - for _, tr := range dr.TableMap { - if tr.OldTableInfo == nil { - continue - } - schemas.AddSchema(dr.OldDBInfo, tr.OldTableInfo) - schemaIsEmpty = false - } - - // backup this empty schema if it has nothing table. - if schemaIsEmpty { - schemas.AddSchema(dr.OldDBInfo, nil) - } - } - return schemas -} - func (sr *SchemasReplace) rewriteKeyForDB(key []byte, cf string) ([]byte, bool, error) { rawMetaKey, err := ParseTxnMetaKeyFrom(key) if err != nil { @@ -432,9 +404,26 @@ func (sr *SchemasReplace) rewriteValue( } // RewriteKvEntry uses to rewrite tableID/dbID in entry.key and entry.value -func (sr *SchemasReplace) RewriteKvEntry(e *kv.Entry, cf string) (*kv.Entry, error) { +func (sr *SchemasReplace) RewriteKvEntry(e *kv.Entry, cf string, insertDeleteRangeForTable InsertDeleteRangeForTable, insertDeleteRangeForIndex InsertDeleteRangeForIndex) (*kv.Entry, error) { // skip mDDLJob + if !strings.HasPrefix(string(e.Key), "mDB") { + if cf == DefaultCF && strings.HasPrefix(string(e.Key), "mDDLJobH") { // mDDLJobHistory + job := &model.Job{} + if err := job.Decode(e.Value); err != nil { + log.Debug("failed to decode the job", zap.String("error", err.Error()), zap.String("job", string(e.Value))) + // The value in write-cf is like "p\XXXX\XXX" need not restore. skip it + // The value in default-cf that can Decode() need restore. + return nil, nil + } + if jobNeedGC(job) { + return nil, sr.deleteRange(job, insertDeleteRangeForTable, insertDeleteRangeForIndex) + } + + if job.Type == model.ActionExchangeTablePartition { + return nil, errors.Errorf("restore of ddl `exchange-table-partition` is not supported") + } + } return nil, nil } @@ -461,3 +450,321 @@ func (sr *SchemasReplace) RewriteKvEntry(e *kv.Entry, cf string) (*kv.Entry, err return nil, nil } } + +type InsertDeleteRangeForTable func(int64, []int64) +type InsertDeleteRangeForIndex func(int64, *int64, int64, []int64) + +func jobNeedGC(job *model.Job) bool { + if !job.IsCancelled() { + switch job.Type { + case model.ActionAddIndex, model.ActionAddPrimaryKey: + return job.State == model.JobStateRollbackDone + case model.ActionDropSchema, model.ActionDropTable, model.ActionTruncateTable, model.ActionDropIndex, model.ActionDropPrimaryKey, + model.ActionDropTablePartition, model.ActionTruncateTablePartition, model.ActionDropColumn, model.ActionDropColumns, model.ActionModifyColumn, model.ActionDropIndexes: + return job.State == model.JobStateSynced + } + } + return false +} + +func (sr *SchemasReplace) deleteRange(job *model.Job, insertDeleteRangeForTable InsertDeleteRangeForTable, insertDeleteRangeForIndex InsertDeleteRangeForIndex) error { + dbReplace, exist := sr.DbMap[job.SchemaID] + if !exist { + // skip this mddljob, the same below + log.Debug("try to drop a non-existent range, missing oldDBID", zap.Int64("oldDBID", job.SchemaID)) + return nil + } + + // allocate a new fake job id to avoid row conflicts in table `gc_delete_range` + newJobID, err := sr.genGenGlobalID(context.Background()) + if err != nil { + return errors.Trace(err) + } + + switch job.Type { + case model.ActionDropSchema: + var tableIDs []int64 + if err := job.DecodeArgs(&tableIDs); err != nil { + return errors.Trace(err) + } + // Note: tableIDs contains partition ids, cannot directly use dbReplace.TableMap + /* TODO: use global ID replace map + * + * for i := 0; i < len(tableIDs); i++ { + * tableReplace, exist := dbReplace.TableMap[tableIDs[i]] + * if !exist { + * return errors.Errorf("DropSchema: try to drop a non-existent table, missing oldTableID") + * } + * tableIDs[i] = tableReplace.NewTableID + * } + */ + + argsSet := make(map[int64]struct{}, len(tableIDs)) + for _, tableID := range tableIDs { + argsSet[tableID] = struct{}{} + } + + newTableIDs := make([]int64, 0, len(tableIDs)) + for tableID, tableReplace := range dbReplace.TableMap { + if _, exist := argsSet[tableID]; !exist { + log.Debug("DropSchema: record a table, but it doesn't exist in job args", zap.Int64("oldTableID", tableID)) + continue + } + newTableIDs = append(newTableIDs, tableReplace.NewTableID) + for partitionID, newPartitionID := range tableReplace.PartitionMap { + if _, exist := argsSet[partitionID]; !exist { + log.Debug("DropSchema: record a partition, but it doesn't exist in job args", zap.Int64("oldPartitionID", partitionID)) + continue + } + newTableIDs = append(newTableIDs, newPartitionID) + } + } + + if len(newTableIDs) != len(tableIDs) { + log.Debug("DropSchema: try to drop a non-existent table/partition, whose oldID doesn't exist in tableReplace") + // only drop newTableIDs' ranges + } + + if len(newTableIDs) > 0 { + insertDeleteRangeForTable(newJobID, newTableIDs) + } + + return nil + // Truncate will generates new id for table or partition, so ts can be large enough + case model.ActionDropTable, model.ActionTruncateTable: + tableReplace, exist := dbReplace.TableMap[job.TableID] + if !exist { + log.Debug("DropTable/TruncateTable: try to drop a non-existent table, missing oldTableID", zap.Int64("oldTableID", job.TableID)) + return nil + } + + // The startKey here is for compatibility with previous versions, old version did not endKey so don't have to deal with. + var startKey kv.Key // unused + var physicalTableIDs []int64 + var ruleIDs []string // unused + if err := job.DecodeArgs(&startKey, &physicalTableIDs, &ruleIDs); err != nil { + return errors.Trace(err) + } + if len(physicalTableIDs) > 0 { + // delete partition id instead of table id + for i := 0; i < len(physicalTableIDs); i++ { + newPid, exist := tableReplace.PartitionMap[physicalTableIDs[i]] + if !exist { + log.Debug("DropTable/TruncateTable: try to drop a non-existent table, missing oldPartitionID", zap.Int64("oldPartitionID", physicalTableIDs[i])) + continue + } + physicalTableIDs[i] = newPid + } + if len(physicalTableIDs) > 0 { + insertDeleteRangeForTable(newJobID, physicalTableIDs) + } + return nil + } + + insertDeleteRangeForTable(newJobID, []int64{tableReplace.NewTableID}) + return nil + case model.ActionDropTablePartition, model.ActionTruncateTablePartition: + tableReplace, exist := dbReplace.TableMap[job.TableID] + if !exist { + log.Debug("DropTablePartition/TruncateTablePartition: try to drop a non-existent table, missing oldTableID", zap.Int64("oldTableID", job.TableID)) + return nil + } + var physicalTableIDs []int64 + if err := job.DecodeArgs(&physicalTableIDs); err != nil { + return errors.Trace(err) + } + + for i := 0; i < len(physicalTableIDs); i++ { + newPid, exist := tableReplace.PartitionMap[physicalTableIDs[i]] + if !exist { + log.Debug("DropTablePartition/TruncateTablePartition: try to drop a non-existent table, missing oldPartitionID", zap.Int64("oldPartitionID", physicalTableIDs[i])) + continue + } + physicalTableIDs[i] = newPid + } + if len(physicalTableIDs) > 0 { + insertDeleteRangeForTable(newJobID, physicalTableIDs) + } + return nil + // ActionAddIndex, ActionAddPrimaryKey needs do it, because it needs to be rolled back when it's canceled. + case model.ActionAddIndex, model.ActionAddPrimaryKey: + // iff job.State = model.JobStateRollbackDone + tableReplace, exist := dbReplace.TableMap[job.TableID] + if !exist { + log.Debug("AddIndex/AddPrimaryKey roll-back: try to drop a non-existent table, missing oldTableID", zap.Int64("oldTableID", job.TableID)) + return nil + } + var indexID int64 + var partitionIDs []int64 + if err := job.DecodeArgs(&indexID, &partitionIDs); err != nil { + return errors.Trace(err) + } + + var elementID int64 = 1 + indexIDs := []int64{indexID} + + if len(partitionIDs) > 0 { + for _, oldPid := range partitionIDs { + newPid, exist := tableReplace.PartitionMap[oldPid] + if !exist { + log.Debug("AddIndex/AddPrimaryKey roll-back: try to drop a non-existent table, missing oldPartitionID", zap.Int64("oldPartitionID", oldPid)) + continue + } + + insertDeleteRangeForIndex(newJobID, &elementID, newPid, indexIDs) + } + } else { + insertDeleteRangeForIndex(newJobID, &elementID, tableReplace.NewTableID, indexIDs) + } + return nil + case model.ActionDropIndex, model.ActionDropPrimaryKey: + tableReplace, exist := dbReplace.TableMap[job.TableID] + if !exist { + log.Debug("DropIndex/DropPrimaryKey: try to drop a non-existent table, missing oldTableID", zap.Int64("oldTableID", job.TableID)) + return nil + } + + var indexName interface{} + var indexID int64 + var partitionIDs []int64 + if err := job.DecodeArgs(&indexName, &indexID, &partitionIDs); err != nil { + return errors.Trace(err) + } + + var elementID int64 = 1 + indexIDs := []int64{indexID} + + if len(partitionIDs) > 0 { + for _, oldPid := range partitionIDs { + newPid, exist := tableReplace.PartitionMap[oldPid] + if !exist { + log.Debug("DropIndex/DropPrimaryKey: try to drop a non-existent table, missing oldPartitionID", zap.Int64("oldPartitionID", oldPid)) + continue + } + // len(indexIDs) = 1 + insertDeleteRangeForIndex(newJobID, &elementID, newPid, indexIDs) + } + } else { + insertDeleteRangeForIndex(newJobID, &elementID, tableReplace.NewTableID, indexIDs) + } + return nil + case model.ActionDropIndexes: + var indexIDs []int64 + var partitionIDs []int64 + if err := job.DecodeArgs(&[]model.CIStr{}, &[]bool{}, &indexIDs, &partitionIDs); err != nil { + return errors.Trace(err) + } + // Remove data in TiKV. + if len(indexIDs) == 0 { + return nil + } + + tableReplace, exist := dbReplace.TableMap[job.TableID] + if !exist { + log.Debug("DropIndexes: try to drop a non-existent table, missing oldTableID", zap.Int64("oldTableID", job.TableID)) + return nil + } + + var elementID int64 = 1 + if len(partitionIDs) > 0 { + for _, oldPid := range partitionIDs { + newPid, exist := tableReplace.PartitionMap[oldPid] + if !exist { + log.Debug("DropIndexes: try to drop a non-existent table, missing oldPartitionID", zap.Int64("oldPartitionID", oldPid)) + continue + } + insertDeleteRangeForIndex(newJobID, &elementID, newPid, indexIDs) + } + } else { + insertDeleteRangeForIndex(newJobID, &elementID, tableReplace.NewTableID, indexIDs) + } + return nil + case model.ActionDropColumn: + var colName model.CIStr + var indexIDs []int64 + var partitionIDs []int64 + if err := job.DecodeArgs(&colName, &indexIDs, &partitionIDs); err != nil { + return errors.Trace(err) + } + if len(indexIDs) > 0 { + tableReplace, exist := dbReplace.TableMap[job.TableID] + if !exist { + log.Debug("DropColumn: try to drop a non-existent table, missing oldTableID", zap.Int64("oldTableID", job.TableID)) + return nil + } + + var elementID int64 = 1 + if len(partitionIDs) > 0 { + for _, oldPid := range partitionIDs { + newPid, exist := tableReplace.PartitionMap[oldPid] + if !exist { + log.Debug("DropColumn: try to drop a non-existent table, missing oldPartitionID", zap.Int64("oldPartitionID", oldPid)) + continue + } + insertDeleteRangeForIndex(newJobID, &elementID, newPid, indexIDs) + } + } else { + insertDeleteRangeForIndex(newJobID, &elementID, tableReplace.NewTableID, indexIDs) + } + } + return nil + case model.ActionDropColumns: + var colNames []model.CIStr + var ifExists []bool + var indexIDs []int64 + var partitionIDs []int64 + if err := job.DecodeArgs(&colNames, &ifExists, &indexIDs, &partitionIDs); err != nil { + return errors.Trace(err) + } + if len(indexIDs) > 0 { + tableReplace, exist := dbReplace.TableMap[job.TableID] + if !exist { + log.Debug("DropColumns: try to drop a non-existent table, missing oldTableID", zap.Int64("oldTableID", job.TableID)) + return nil + } + + var elementID int64 = 1 + if len(partitionIDs) > 0 { + for _, oldPid := range partitionIDs { + newPid, exist := tableReplace.PartitionMap[oldPid] + if !exist { + log.Debug("DropColumns: try to drop a non-existent table, missing oldPartitionID", zap.Int64("oldPartitionID", oldPid)) + continue + } + insertDeleteRangeForIndex(newJobID, &elementID, newPid, indexIDs) + } + } else { + insertDeleteRangeForIndex(newJobID, &elementID, tableReplace.NewTableID, indexIDs) + } + } + case model.ActionModifyColumn: + var indexIDs []int64 + var partitionIDs []int64 + if err := job.DecodeArgs(&indexIDs, &partitionIDs); err != nil { + return errors.Trace(err) + } + if len(indexIDs) == 0 { + return nil + } + tableReplace, exist := dbReplace.TableMap[job.TableID] + if !exist { + log.Debug("DropColumn: try to drop a non-existent table, missing oldTableID", zap.Int64("oldTableID", job.TableID)) + return nil + } + + var elementID int64 = 1 + if len(partitionIDs) > 0 { + for _, oldPid := range partitionIDs { + newPid, exist := tableReplace.PartitionMap[oldPid] + if !exist { + log.Debug("DropColumn: try to drop a non-existent table, missing oldPartitionID", zap.Int64("oldPartitionID", oldPid)) + continue + } + insertDeleteRangeForIndex(newJobID, &elementID, newPid, indexIDs) + } + } else { + insertDeleteRangeForIndex(newJobID, &elementID, tableReplace.NewTableID, indexIDs) + } + } + return nil +} diff --git a/br/pkg/stream/rewrite_meta_rawkv_test.go b/br/pkg/stream/rewrite_meta_rawkv_test.go index 06ac8ef4f042a..1f6ad9fde607a 100644 --- a/br/pkg/stream/rewrite_meta_rawkv_test.go +++ b/br/pkg/stream/rewrite_meta_rawkv_test.go @@ -206,3 +206,303 @@ func TestRewriteValueForPartitionTable(t *testing.T) { ) require.Equal(t, tableInfo.Partition.Definitions[1].ID, newID2) } + +// db:70->80 - +// | - t0:71->81 - +// | | - p0:72->82 +// | | - p1:73->83 +// | | - p2:74->84 +// | - t1:75->85 + +const ( + mDDLJobDBOldID int64 = 70 + iota + mDDLJobTable0OldID + mDDLJobPartition0OldID + mDDLJobPartition1OldID + mDDLJobPartition2OldID + mDDLJobTable1OldID +) + +const ( + mDDLJobDBNewID int64 = 80 + iota + mDDLJobTable0NewID + mDDLJobPartition0NewID + mDDLJobPartition1NewID + mDDLJobPartition2NewID + mDDLJobTable1NewID +) + +var ( + mDDLJobALLNewTableIDSet = map[int64]struct{}{ + mDDLJobTable0NewID: {}, + mDDLJobPartition0NewID: {}, + mDDLJobPartition1NewID: {}, + mDDLJobPartition2NewID: {}, + mDDLJobTable1NewID: {}, + } + mDDLJobALLNewPartitionIDSet = map[int64]struct{}{ + mDDLJobPartition0NewID: {}, + mDDLJobPartition1NewID: {}, + mDDLJobPartition2NewID: {}, + } + mDDLJobALLIndexesIDSet = map[int64]struct{}{ + 2: {}, + 3: {}, + } +) + +var ( + dropSchemaJob = &model.Job{Type: model.ActionDropSchema, SchemaID: mDDLJobDBOldID, RawArgs: json.RawMessage(`[[71,72,73,74,75]]`)} + dropTable0Job = &model.Job{Type: model.ActionDropTable, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable0OldID, RawArgs: json.RawMessage(`["",[72,73,74],[""]]`)} + dropTable1Job = &model.Job{Type: model.ActionDropTable, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable1OldID, RawArgs: json.RawMessage(`["",[],[""]]`)} + dropTable0Partition1Job = &model.Job{Type: model.ActionDropTablePartition, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable0OldID, RawArgs: json.RawMessage(`[[73]]`)} + rollBackTable0IndexJob = &model.Job{Type: model.ActionAddIndex, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable0OldID, RawArgs: json.RawMessage(`[2,[72,73,74]]`)} + rollBackTable1IndexJob = &model.Job{Type: model.ActionAddIndex, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable1OldID, RawArgs: json.RawMessage(`[2,[]]`)} + dropTable0IndexJob = &model.Job{Type: model.ActionDropIndex, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable0OldID, RawArgs: json.RawMessage(`["",2,[72,73,74]]`)} + dropTable1IndexJob = &model.Job{Type: model.ActionDropIndex, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable1OldID, RawArgs: json.RawMessage(`["",2,[]]`)} + dropTable0IndexesJob = &model.Job{Type: model.ActionDropIndexes, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable0OldID, RawArgs: json.RawMessage(`[[],[],[2,3],[72,73,74]]`)} + dropTable1IndexesJob = &model.Job{Type: model.ActionDropIndexes, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable1OldID, RawArgs: json.RawMessage(`[[],[],[2,3],[]]`)} + dropTable0ColumnJob = &model.Job{Type: model.ActionDropColumn, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable0OldID, RawArgs: json.RawMessage(`["",[2,3],[72,73,74]]`)} + dropTable1ColumnJob = &model.Job{Type: model.ActionDropColumn, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable1OldID, RawArgs: json.RawMessage(`["",[2,3],[]]`)} + dropTable0ColumnsJob = &model.Job{Type: model.ActionDropColumns, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable0OldID, RawArgs: json.RawMessage(`[[],[],[2,3],[72,73,74]]`)} + dropTable1ColumnsJob = &model.Job{Type: model.ActionDropColumns, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable1OldID, RawArgs: json.RawMessage(`[[],[],[2,3],[]]`)} + modifyTable0ColumnJob = &model.Job{Type: model.ActionModifyColumn, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable0OldID, RawArgs: json.RawMessage(`[[2,3],[72,73,74]]`)} + modifyTable1ColumnJob = &model.Job{Type: model.ActionModifyColumn, SchemaID: mDDLJobDBOldID, TableID: mDDLJobTable1OldID, RawArgs: json.RawMessage(`[[2,3],[]]`)} +) + +type TableDeletQueryArgs struct { + tableIDs []int64 +} + +type IndexDeleteQueryArgs struct { + tableID int64 + indexIDs []int64 +} + +type mockInsertDeleteRange struct { + tableCh chan TableDeletQueryArgs + indexCh chan IndexDeleteQueryArgs +} + +func newMockInsertDeleteRange() *mockInsertDeleteRange { + // Since there is only single thread, we need to set the channel buf large enough. + return &mockInsertDeleteRange{ + tableCh: make(chan TableDeletQueryArgs, 10), + indexCh: make(chan IndexDeleteQueryArgs, 10), + } +} + +func (midr *mockInsertDeleteRange) mockInsertDeleteRangeForTable(jobID int64, tableIDs []int64) { + midr.tableCh <- TableDeletQueryArgs{ + tableIDs: tableIDs, + } +} + +func (midr *mockInsertDeleteRange) mockInsertDeleteRangeForIndex(jobID int64, elementID *int64, tableID int64, indexIDs []int64) { + midr.indexCh <- IndexDeleteQueryArgs{ + tableID: tableID, + indexIDs: indexIDs, + } +} + +func TestDeleteRangeForMDDLJob(t *testing.T) { + schemaReplace := MockEmptySchemasReplace() + partitionMap := map[int64]int64{ + mDDLJobPartition0OldID: mDDLJobPartition0NewID, + mDDLJobPartition1OldID: mDDLJobPartition1NewID, + mDDLJobPartition2OldID: mDDLJobPartition2NewID, + } + tableReplace0 := &TableReplace{ + NewTableID: mDDLJobTable0NewID, + PartitionMap: partitionMap, + } + tableReplace1 := &TableReplace{ + NewTableID: mDDLJobTable1NewID, + } + tableMap := map[int64]*TableReplace{ + mDDLJobTable0OldID: tableReplace0, + mDDLJobTable1OldID: tableReplace1, + } + dbReplace := &DBReplace{ + NewDBID: mDDLJobDBNewID, + TableMap: tableMap, + } + schemaReplace.DbMap[mDDLJobDBOldID] = dbReplace + + midr := newMockInsertDeleteRange() + + var targs TableDeletQueryArgs + var iargs IndexDeleteQueryArgs + var err error + // drop schema + err = schemaReplace.deleteRange(dropSchemaJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + targs = <-midr.tableCh + require.Equal(t, len(targs.tableIDs), len(mDDLJobALLNewTableIDSet)) + for _, tableID := range targs.tableIDs { + _, exist := mDDLJobALLNewTableIDSet[tableID] + require.True(t, exist) + } + + // drop table0 + err = schemaReplace.deleteRange(dropTable0Job, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + targs = <-midr.tableCh + require.Equal(t, len(targs.tableIDs), len(mDDLJobALLNewPartitionIDSet)) + for _, tableID := range targs.tableIDs { + _, exist := mDDLJobALLNewPartitionIDSet[tableID] + require.True(t, exist) + } + + // drop table1 + err = schemaReplace.deleteRange(dropTable1Job, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + targs = <-midr.tableCh + require.Equal(t, len(targs.tableIDs), 1) + require.Equal(t, targs.tableIDs[0], mDDLJobTable1NewID) + + // drop table partition1 + err = schemaReplace.deleteRange(dropTable0Partition1Job, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + targs = <-midr.tableCh + require.Equal(t, len(targs.tableIDs), 1) + require.Equal(t, targs.tableIDs[0], mDDLJobPartition1NewID) + + // roll back add index for table0 + err = schemaReplace.deleteRange(rollBackTable0IndexJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + for i := 0; i < len(mDDLJobALLNewPartitionIDSet); i++ { + iargs = <-midr.indexCh + _, exist := mDDLJobALLNewPartitionIDSet[iargs.tableID] + require.True(t, exist) + require.Equal(t, len(iargs.indexIDs), 1) + require.Equal(t, iargs.indexIDs[0], int64(2)) + } + + // roll back add index for table1 + err = schemaReplace.deleteRange(rollBackTable1IndexJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + iargs = <-midr.indexCh + require.Equal(t, iargs.tableID, mDDLJobTable1NewID) + require.Equal(t, len(iargs.indexIDs), 1) + require.Equal(t, iargs.indexIDs[0], int64(2)) + + // drop index for table0 + err = schemaReplace.deleteRange(dropTable0IndexJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + for i := 0; i < len(mDDLJobALLNewPartitionIDSet); i++ { + iargs = <-midr.indexCh + _, exist := mDDLJobALLNewPartitionIDSet[iargs.tableID] + require.True(t, exist) + require.Equal(t, len(iargs.indexIDs), 1) + require.Equal(t, iargs.indexIDs[0], int64(2)) + } + + // drop index for table1 + err = schemaReplace.deleteRange(dropTable1IndexJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + iargs = <-midr.indexCh + require.Equal(t, iargs.tableID, mDDLJobTable1NewID) + require.Equal(t, len(iargs.indexIDs), 1) + require.Equal(t, iargs.indexIDs[0], int64(2)) + + // drop indexes for table0 + err = schemaReplace.deleteRange(dropTable0IndexesJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + for i := 0; i < len(mDDLJobALLNewPartitionIDSet); i++ { + iargs = <-midr.indexCh + _, exist := mDDLJobALLNewPartitionIDSet[iargs.tableID] + require.True(t, exist) + require.Equal(t, len(iargs.indexIDs), len(mDDLJobALLIndexesIDSet)) + for _, indexID := range iargs.indexIDs { + _, exist := mDDLJobALLIndexesIDSet[indexID] + require.True(t, exist) + } + } + + // drop indexes for table1 + err = schemaReplace.deleteRange(dropTable1IndexesJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + iargs = <-midr.indexCh + require.Equal(t, iargs.tableID, mDDLJobTable1NewID) + require.Equal(t, len(iargs.indexIDs), len(mDDLJobALLIndexesIDSet)) + for _, indexID := range iargs.indexIDs { + _, exist := mDDLJobALLIndexesIDSet[indexID] + require.True(t, exist) + } + + // drop column for table0 + err = schemaReplace.deleteRange(dropTable0ColumnJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + for i := 0; i < len(mDDLJobALLNewPartitionIDSet); i++ { + iargs = <-midr.indexCh + _, exist := mDDLJobALLNewPartitionIDSet[iargs.tableID] + require.True(t, exist) + require.Equal(t, len(iargs.indexIDs), len(mDDLJobALLIndexesIDSet)) + for _, indexID := range iargs.indexIDs { + _, exist := mDDLJobALLIndexesIDSet[indexID] + require.True(t, exist) + } + } + + // drop column for table1 + err = schemaReplace.deleteRange(dropTable1ColumnJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + iargs = <-midr.indexCh + require.Equal(t, iargs.tableID, mDDLJobTable1NewID) + require.Equal(t, len(iargs.indexIDs), len(mDDLJobALLIndexesIDSet)) + for _, indexID := range iargs.indexIDs { + _, exist := mDDLJobALLIndexesIDSet[indexID] + require.True(t, exist) + } + + // drop columns for table0 + err = schemaReplace.deleteRange(dropTable0ColumnsJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + for i := 0; i < len(mDDLJobALLNewPartitionIDSet); i++ { + iargs = <-midr.indexCh + _, exist := mDDLJobALLNewPartitionIDSet[iargs.tableID] + require.True(t, exist) + require.Equal(t, len(iargs.indexIDs), len(mDDLJobALLIndexesIDSet)) + for _, indexID := range iargs.indexIDs { + _, exist := mDDLJobALLIndexesIDSet[indexID] + require.True(t, exist) + } + } + + // drop columns for table1 + err = schemaReplace.deleteRange(dropTable1ColumnsJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + iargs = <-midr.indexCh + require.Equal(t, iargs.tableID, mDDLJobTable1NewID) + require.Equal(t, len(iargs.indexIDs), len(mDDLJobALLIndexesIDSet)) + for _, indexID := range iargs.indexIDs { + _, exist := mDDLJobALLIndexesIDSet[indexID] + require.True(t, exist) + } + + // drop columns for table0 + err = schemaReplace.deleteRange(modifyTable0ColumnJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + for i := 0; i < len(mDDLJobALLNewPartitionIDSet); i++ { + iargs = <-midr.indexCh + _, exist := mDDLJobALLNewPartitionIDSet[iargs.tableID] + require.True(t, exist) + require.Equal(t, len(iargs.indexIDs), len(mDDLJobALLIndexesIDSet)) + for _, indexID := range iargs.indexIDs { + _, exist := mDDLJobALLIndexesIDSet[indexID] + require.True(t, exist) + } + } + + // drop columns for table1 + err = schemaReplace.deleteRange(modifyTable1ColumnJob, midr.mockInsertDeleteRangeForTable, midr.mockInsertDeleteRangeForIndex) + require.NoError(t, err) + iargs = <-midr.indexCh + require.Equal(t, iargs.tableID, mDDLJobTable1NewID) + require.Equal(t, len(iargs.indexIDs), len(mDDLJobALLIndexesIDSet)) + for _, indexID := range iargs.indexIDs { + _, exist := mDDLJobALLIndexesIDSet[indexID] + require.True(t, exist) + } +} diff --git a/br/pkg/stream/stream_mgr.go b/br/pkg/stream/stream_mgr.go index 23061fcaa5ddf..0eccada182f06 100644 --- a/br/pkg/stream/stream_mgr.go +++ b/br/pkg/stream/stream_mgr.go @@ -16,6 +16,7 @@ package stream import ( "context" + "strings" "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" @@ -172,7 +173,11 @@ func FastUnmarshalMetaData( m := &backuppb.Metadata{} err = m.Unmarshal(b) if err != nil { - return err + if !strings.HasSuffix(path, ".meta") { + return nil + } else { + return err + } } return fn(readPath, m) }) diff --git a/br/pkg/stream/stream_misc_test.go b/br/pkg/stream/stream_misc_test.go index ac31254ffd641..3a057ed2a16df 100644 --- a/br/pkg/stream/stream_misc_test.go +++ b/br/pkg/stream/stream_misc_test.go @@ -7,6 +7,7 @@ import ( backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/tidb/br/pkg/stream" + "github.com/pingcap/tidb/br/pkg/streamhelper" "github.com/stretchr/testify/require" ) @@ -15,7 +16,7 @@ func TestGetCheckpointOfTask(t *testing.T) { Info: backuppb.StreamBackupTaskInfo{ StartTs: 8, }, - Checkpoints: []stream.Checkpoint{ + Checkpoints: []streamhelper.Checkpoint{ { ID: 1, TS: 10, diff --git a/br/pkg/stream/stream_status.go b/br/pkg/stream/stream_status.go index 70d9b1708f938..e08f3f6c34513 100644 --- a/br/pkg/stream/stream_status.go +++ b/br/pkg/stream/stream_status.go @@ -4,6 +4,7 @@ package stream import ( "context" + "crypto/tls" "encoding/json" "fmt" "io" @@ -17,12 +18,13 @@ import ( backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/conn" "github.com/pingcap/tidb/br/pkg/glue" "github.com/pingcap/tidb/br/pkg/httputil" "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/storage" + . "github.com/pingcap/tidb/br/pkg/streamhelper" "github.com/tikv/client-go/v2/oracle" + pd "github.com/tikv/pd/client" "go.uber.org/zap" "golang.org/x/sync/errgroup" ) @@ -104,6 +106,9 @@ func (t TaskStatus) GetMinStoreCheckpoint() Checkpoint { initialized = true checkpoint = cp } + if cp.Type() == CheckpointTypeGlobal { + return cp + } } return checkpoint } @@ -131,7 +136,6 @@ func (p *printByTable) AddTask(task TaskStatus) { info := fmt.Sprintf("%s; gap=%s", pTime, gapColor.Sprint(gap)) return info } - table.Add("checkpoint[global]", formatTS(task.GetMinStoreCheckpoint().TS)) p.addCheckpoints(&task, table, formatTS) for store, e := range task.LastErrors { table.Add(fmt.Sprintf("error[store=%d]", store), e.ErrorCode) @@ -142,11 +146,21 @@ func (p *printByTable) AddTask(task TaskStatus) { } func (p *printByTable) addCheckpoints(task *TaskStatus, table *glue.Table, formatTS func(uint64) string) { - for _, cp := range task.Checkpoints { - switch cp.Type() { - case CheckpointTypeStore: - table.Add(fmt.Sprintf("checkpoint[store=%d]", cp.ID), formatTS(cp.TS)) + cp := task.GetMinStoreCheckpoint() + items := make([][2]string, 0, len(task.Checkpoints)) + if cp.Type() != CheckpointTypeGlobal { + for _, cp := range task.Checkpoints { + switch cp.Type() { + case CheckpointTypeStore: + items = append(items, [2]string{fmt.Sprintf("checkpoint[store=%d]", cp.ID), formatTS(cp.TS)}) + } } + } else { + items = append(items, [2]string{"checkpoint[central-global]", formatTS(cp.TS)}) + } + + for _, item := range items { + table.Add(item[0], item[1]) } } @@ -241,10 +255,15 @@ func (p *printByJSON) PrintTasks() { var logCountSumRe = regexp.MustCompile(`tikv_stream_handle_kv_batch_sum ([0-9]+)`) +type PDInfoProvider interface { + GetPDClient() pd.Client + GetTLSConfig() *tls.Config +} + // MaybeQPS get a number like the QPS of last seconds for each store via the prometheus interface. // TODO: this is a temporary solution(aha, like in a Hackthon), // we MUST find a better way for providing this information. -func MaybeQPS(ctx context.Context, mgr *conn.Mgr) (float64, error) { +func MaybeQPS(ctx context.Context, mgr PDInfoProvider) (float64, error) { c := mgr.GetPDClient() prefix := "http://" if mgr.GetTLSConfig() != nil { @@ -316,12 +335,12 @@ func MaybeQPS(ctx context.Context, mgr *conn.Mgr) (float64, error) { // StatusController is the controller type (or context type) for the command `stream status`. type StatusController struct { meta *MetaDataClient - mgr *conn.Mgr + mgr PDInfoProvider view TaskPrinter } // NewStatusContorller make a status controller via some resource accessors. -func NewStatusController(meta *MetaDataClient, mgr *conn.Mgr, view TaskPrinter) *StatusController { +func NewStatusController(meta *MetaDataClient, mgr PDInfoProvider, view TaskPrinter) *StatusController { return &StatusController{ meta: meta, mgr: mgr, diff --git a/br/pkg/streamhelper/advancer.go b/br/pkg/streamhelper/advancer.go new file mode 100644 index 0000000000000..e20516285e962 --- /dev/null +++ b/br/pkg/streamhelper/advancer.go @@ -0,0 +1,514 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper + +import ( + "bytes" + "context" + "math" + "reflect" + "sort" + "strings" + "sync" + "time" + + "github.com/pingcap/errors" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/streamhelper/config" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/metrics" + "github.com/tikv/client-go/v2/oracle" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +// CheckpointAdvancer is the central node for advancing the checkpoint of log backup. +// It's a part of "checkpoint v3". +// Generally, it scan the regions in the task range, collect checkpoints from tikvs. +// ┌──────┐ +// ┌────►│ TiKV │ +// │ └──────┘ +// │ +// │ +// ┌──────────┐GetLastFlushTSOfRegion│ ┌──────┐ +// │ Advancer ├──────────────────────┼────►│ TiKV │ +// └────┬─────┘ │ └──────┘ +// │ │ +// │ │ +// │ │ ┌──────┐ +// │ └────►│ TiKV │ +// │ └──────┘ +// │ +// │ UploadCheckpointV3 ┌──────────────────┐ +// └─────────────────────►│ PD │ +// └──────────────────┘ +type CheckpointAdvancer struct { + env Env + + // The concurrency accessed task: + // both by the task listener and ticking. + task *backuppb.StreamBackupTaskInfo + taskMu sync.Mutex + + // the read-only config. + // once tick begin, this should not be changed for now. + cfg config.Config + + // the cache of region checkpoints. + // so we can advance only ranges with huge gap. + cache CheckpointsCache + + // the internal state of advancer. + state advancerState + // the cached last checkpoint. + // if no progress, this cache can help us don't to send useless requests. + lastCheckpoint uint64 +} + +// advancerState is the sealed type for the state of advancer. +// the advancer has two stage: full scan and update small tree. +type advancerState interface { + // Note: + // Go doesn't support sealed classes or ADTs currently. + // (it can only be used at generic constraints...) + // Leave it empty for now. + + // ~*fullScan | ~*updateSmallTree +} + +// fullScan is the initial state of advancer. +// in this stage, we would "fill" the cache: +// insert ranges that union of them become the full range of task. +type fullScan struct { + fullScanTick int +} + +// updateSmallTree is the "incremental stage" of advancer. +// we have build a "filled" cache, and we can pop a subrange of it, +// try to advance the checkpoint of those ranges. +type updateSmallTree struct { + consistencyCheckTick int +} + +// NewCheckpointAdvancer creates a checkpoint advancer with the env. +func NewCheckpointAdvancer(env Env) *CheckpointAdvancer { + return &CheckpointAdvancer{ + env: env, + cfg: config.Default(), + cache: NewCheckpoints(), + state: &fullScan{}, + } +} + +// disableCache removes the cache. +// note this won't lock the checkpoint advancer at `fullScan` state forever, +// you may need to change the config `AdvancingByCache`. +func (c *CheckpointAdvancer) disableCache() { + c.cache = NoOPCheckpointCache{} + c.state = fullScan{} +} + +// enable the cache. +// also check `AdvancingByCache` in the config. +func (c *CheckpointAdvancer) enableCache() { + c.cache = NewCheckpoints() + c.state = fullScan{} +} + +// UpdateConfig updates the config for the advancer. +// Note this should be called before starting the loop, because there isn't locks, +// TODO: support updating config when advancer starts working. +// (Maybe by applying changes at begin of ticking, and add locks.) +func (c *CheckpointAdvancer) UpdateConfig(newConf config.Config) { + needRefreshCache := newConf.AdvancingByCache != c.cfg.AdvancingByCache + c.cfg = newConf + if needRefreshCache { + if c.cfg.AdvancingByCache { + c.enableCache() + } else { + c.disableCache() + } + } +} + +// UpdateConfigWith updates the config by modifying the current config. +func (c *CheckpointAdvancer) UpdateConfigWith(f func(*config.Config)) { + cfg := c.cfg + f(&cfg) + c.UpdateConfig(cfg) +} + +// Config returns the current config. +func (c *CheckpointAdvancer) Config() config.Config { + return c.cfg +} + +// GetCheckpointInRange scans the regions in the range, +// collect them to the collector. +func (c *CheckpointAdvancer) GetCheckpointInRange(ctx context.Context, start, end []byte, collector *clusterCollector) error { + log.Debug("scanning range", logutil.Key("start", start), logutil.Key("end", end)) + iter := IterateRegion(c.env, start, end) + for !iter.Done() { + rs, err := iter.Next(ctx) + if err != nil { + return err + } + log.Debug("scan region", zap.Int("len", len(rs))) + for _, r := range rs { + err := collector.collectRegion(r) + if err != nil { + log.Warn("meet error during getting checkpoint", logutil.ShortError(err)) + return err + } + } + } + return nil +} + +func (c *CheckpointAdvancer) recordTimeCost(message string, fields ...zap.Field) func() { + now := time.Now() + label := strings.ReplaceAll(message, " ", "-") + return func() { + cost := time.Since(now) + fields = append(fields, zap.Stringer("take", cost)) + metrics.AdvancerTickDuration.WithLabelValues(label).Observe(cost.Seconds()) + log.Debug(message, fields...) + } +} + +// tryAdvance tries to advance the checkpoint ts of a set of ranges which shares the same checkpoint. +func (c *CheckpointAdvancer) tryAdvance(ctx context.Context, rst RangesSharesTS) (err error) { + defer c.recordTimeCost("try advance", zap.Uint64("checkpoint", rst.TS), zap.Int("len", len(rst.Ranges)))() + defer func() { + if err != nil { + c.cache.InsertRanges(rst) + } + }() + defer utils.PanicToErr(&err) + + ranges := CollapseRanges(len(rst.Ranges), func(i int) kv.KeyRange { return rst.Ranges[i] }) + workers := utils.NewWorkerPool(4, "sub ranges") + eg, cx := errgroup.WithContext(ctx) + collector := NewClusterCollector(ctx, c.env) + collector.setOnSuccessHook(c.cache.InsertRange) + for _, r := range ranges { + r := r + workers.ApplyOnErrorGroup(eg, func() (e error) { + defer c.recordTimeCost("get regions in range", zap.Uint64("checkpoint", rst.TS))() + defer utils.PanicToErr(&e) + return c.GetCheckpointInRange(cx, r.StartKey, r.EndKey, collector) + }) + } + err = eg.Wait() + if err != nil { + return err + } + + result, err := collector.Finish(ctx) + if err != nil { + return err + } + fr := result.FailureSubRanges + if len(fr) != 0 { + log.Debug("failure regions collected", zap.Int("size", len(fr))) + c.cache.InsertRanges(RangesSharesTS{ + TS: rst.TS, + Ranges: fr, + }) + } + return nil +} + +// CalculateGlobalCheckpointLight tries to advance the global checkpoint by the cache. +func (c *CheckpointAdvancer) CalculateGlobalCheckpointLight(ctx context.Context) (uint64, error) { + log.Info("advancer with cache: current tree", zap.Stringer("ct", c.cache)) + rsts := c.cache.PopRangesWithGapGT(config.DefaultTryAdvanceThreshold) + if len(rsts) == 0 { + return 0, nil + } + workers := utils.NewWorkerPool(uint(config.DefaultMaxConcurrencyAdvance), "regions") + eg, cx := errgroup.WithContext(ctx) + for _, rst := range rsts { + rst := rst + workers.ApplyOnErrorGroup(eg, func() (err error) { + return c.tryAdvance(cx, *rst) + }) + } + err := eg.Wait() + if err != nil { + return 0, err + } + log.Info("advancer with cache: new tree", zap.Stringer("cache", c.cache)) + ts := c.cache.CheckpointTS() + return ts, nil +} + +// CalculateGlobalCheckpoint calculates the global checkpoint, which won't use the cache. +func (c *CheckpointAdvancer) CalculateGlobalCheckpoint(ctx context.Context) (uint64, error) { + var ( + cp = uint64(math.MaxInt64) + // TODO: Use The task range here. + thisRun []kv.KeyRange = []kv.KeyRange{{}} + nextRun []kv.KeyRange + ) + defer c.recordTimeCost("record all") + cx, cancel := context.WithTimeout(ctx, c.cfg.MaxBackoffTime) + defer cancel() + for { + coll := NewClusterCollector(ctx, c.env) + coll.setOnSuccessHook(c.cache.InsertRange) + for _, u := range thisRun { + err := c.GetCheckpointInRange(cx, u.StartKey, u.EndKey, coll) + if err != nil { + return 0, err + } + } + result, err := coll.Finish(ctx) + if err != nil { + return 0, err + } + log.Debug("full: a run finished", zap.Any("checkpoint", result)) + + nextRun = append(nextRun, result.FailureSubRanges...) + if cp > result.Checkpoint { + cp = result.Checkpoint + } + if len(nextRun) == 0 { + return cp, nil + } + thisRun = nextRun + nextRun = nil + log.Debug("backoffing with subranges", zap.Int("subranges", len(thisRun))) + time.Sleep(c.cfg.BackoffTime) + } +} + +// CollapseRanges collapse ranges overlapping or adjacent. +// Example: +// CollapseRanges({[1, 4], [2, 8], [3, 9]}) == {[1, 9]} +// CollapseRanges({[1, 3], [4, 7], [2, 3]}) == {[1, 3], [4, 7]} +func CollapseRanges(length int, getRange func(int) kv.KeyRange) []kv.KeyRange { + frs := make([]kv.KeyRange, 0, length) + for i := 0; i < length; i++ { + frs = append(frs, getRange(i)) + } + + sort.Slice(frs, func(i, j int) bool { + return bytes.Compare(frs[i].StartKey, frs[j].StartKey) < 0 + }) + + result := make([]kv.KeyRange, 0, len(frs)) + i := 0 + for i < len(frs) { + item := frs[i] + for { + i++ + if i >= len(frs) || (len(item.EndKey) != 0 && bytes.Compare(frs[i].StartKey, item.EndKey) > 0) { + break + } + if len(item.EndKey) != 0 && bytes.Compare(item.EndKey, frs[i].EndKey) < 0 || len(frs[i].EndKey) == 0 { + item.EndKey = frs[i].EndKey + } + } + result = append(result, item) + } + return result +} + +func (c *CheckpointAdvancer) consumeAllTask(ctx context.Context, ch <-chan TaskEvent) error { + for { + select { + case e, ok := <-ch: + if !ok { + return nil + } + log.Info("meet task event", zap.Stringer("event", &e)) + if err := c.onTaskEvent(e); err != nil { + if errors.Cause(e.Err) != context.Canceled { + log.Error("listen task meet error, would reopen.", logutil.ShortError(err)) + return err + } + return nil + } + default: + return nil + } + } +} + +// beginListenTaskChange bootstraps the initial task set, +// and returns a channel respecting the change of tasks. +func (c *CheckpointAdvancer) beginListenTaskChange(ctx context.Context) (<-chan TaskEvent, error) { + ch := make(chan TaskEvent, 1024) + if err := c.env.Begin(ctx, ch); err != nil { + return nil, err + } + err := c.consumeAllTask(ctx, ch) + if err != nil { + return nil, err + } + return ch, nil +} + +// StartTaskListener starts the task listener for the advancer. +// When no task detected, advancer would do nothing, please call this before begin the tick loop. +func (c *CheckpointAdvancer) StartTaskListener(ctx context.Context) { + cx, cancel := context.WithCancel(ctx) + var ch <-chan TaskEvent + for { + if cx.Err() != nil { + // make linter happy. + cancel() + return + } + var err error + ch, err = c.beginListenTaskChange(cx) + if err == nil { + break + } + log.Warn("failed to begin listening, retrying...", logutil.ShortError(err)) + time.Sleep(c.cfg.BackoffTime) + } + + go func() { + defer cancel() + for { + select { + case <-ctx.Done(): + return + case e, ok := <-ch: + if !ok { + return + } + log.Info("meet task event", zap.Stringer("event", &e)) + if err := c.onTaskEvent(e); err != nil { + if errors.Cause(e.Err) != context.Canceled { + log.Error("listen task meet error, would reopen.", logutil.ShortError(err)) + time.AfterFunc(c.cfg.BackoffTime, func() { c.StartTaskListener(ctx) }) + } + return + } + } + } + }() +} + +func (c *CheckpointAdvancer) onTaskEvent(e TaskEvent) error { + c.taskMu.Lock() + defer c.taskMu.Unlock() + switch e.Type { + case EventAdd: + c.task = e.Info + case EventDel: + c.task = nil + c.state = &fullScan{} + c.cache.Clear() + case EventErr: + return e.Err + } + return nil +} + +// advanceCheckpointBy advances the checkpoint by a checkpoint getter function. +func (c *CheckpointAdvancer) advanceCheckpointBy(ctx context.Context, getCheckpoint func(context.Context) (uint64, error)) error { + start := time.Now() + cp, err := getCheckpoint(ctx) + if err != nil { + return err + } + if cp < c.lastCheckpoint { + log.Warn("failed to update global checkpoint: stale", zap.Uint64("old", c.lastCheckpoint), zap.Uint64("new", cp)) + } + if cp <= c.lastCheckpoint { + return nil + } + + log.Info("uploading checkpoint for task", + zap.Stringer("checkpoint", oracle.GetTimeFromTS(cp)), + zap.Uint64("checkpoint", cp), + zap.String("task", c.task.Name), + zap.Stringer("take", time.Since(start))) + if err := c.env.UploadV3GlobalCheckpointForTask(ctx, c.task.Name, cp); err != nil { + return errors.Annotate(err, "failed to upload global checkpoint") + } + c.lastCheckpoint = cp + metrics.LastCheckpoint.WithLabelValues(c.task.GetName()).Set(float64(c.lastCheckpoint)) + return nil +} + +// OnTick advances the inner logic clock for the advancer. +// It's synchronous: this would only return after the events triggered by the clock has all been done. +// It's generally panic-free, you may not need to trying recover a panic here. +func (c *CheckpointAdvancer) OnTick(ctx context.Context) (err error) { + defer c.recordTimeCost("tick")() + defer func() { + e := recover() + if e != nil { + log.Error("panic during handing tick", zap.Stack("stack"), logutil.ShortError(err)) + err = errors.Annotatef(berrors.ErrUnknown, "panic during handling tick: %s", e) + } + }() + err = c.tick(ctx) + return +} + +func (c *CheckpointAdvancer) onConsistencyCheckTick(s *updateSmallTree) error { + if s.consistencyCheckTick > 0 { + s.consistencyCheckTick-- + return nil + } + defer c.recordTimeCost("consistency check")() + err := c.cache.ConsistencyCheck() + if err != nil { + log.Error("consistency check failed! log backup may lose data! rolling back to full scan for saving.", logutil.ShortError(err)) + c.state = &fullScan{} + return err + } else { + log.Debug("consistency check passed.") + } + s.consistencyCheckTick = config.DefaultConsistencyCheckTick + return nil +} + +func (c *CheckpointAdvancer) tick(ctx context.Context) error { + c.taskMu.Lock() + defer c.taskMu.Unlock() + + switch s := c.state.(type) { + case *fullScan: + if s.fullScanTick > 0 { + s.fullScanTick-- + break + } + if c.task == nil { + log.Debug("No tasks yet, skipping advancing.") + return nil + } + defer func() { + s.fullScanTick = c.cfg.FullScanTick + }() + err := c.advanceCheckpointBy(ctx, c.CalculateGlobalCheckpoint) + if err != nil { + return err + } + + if c.cfg.AdvancingByCache { + c.state = &updateSmallTree{} + } + case *updateSmallTree: + if err := c.onConsistencyCheckTick(s); err != nil { + return err + } + err := c.advanceCheckpointBy(ctx, c.CalculateGlobalCheckpointLight) + if err != nil { + return err + } + default: + log.Error("Unknown state type, skipping tick", zap.Stringer("type", reflect.TypeOf(c.state))) + } + return nil +} diff --git a/br/pkg/streamhelper/advancer_daemon.go b/br/pkg/streamhelper/advancer_daemon.go new file mode 100644 index 0000000000000..909bdd85df3c6 --- /dev/null +++ b/br/pkg/streamhelper/advancer_daemon.go @@ -0,0 +1,81 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper + +import ( + "context" + "time" + + "github.com/google/uuid" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/owner" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +const ( + ownerPrompt = "log-backup" + ownerPath = "/tidb/br-stream/owner" +) + +// AdvancerDaemon is a "high-availability" version of advancer. +// It involved the manager for electing a owner and doing things. +// You can embed it into your code by simply call: +// +// ad := NewAdvancerDaemon(adv, mgr) +// loop, err := ad.Begin(ctx) +// if err != nil { +// return err +// } +// loop() +type AdvancerDaemon struct { + adv *CheckpointAdvancer + manager owner.Manager +} + +func NewAdvancerDaemon(adv *CheckpointAdvancer, manager owner.Manager) *AdvancerDaemon { + return &AdvancerDaemon{ + adv: adv, + manager: manager, + } +} + +func OwnerManagerForLogBackup(ctx context.Context, etcdCli *clientv3.Client) owner.Manager { + id := uuid.New() + return owner.NewOwnerManager(ctx, etcdCli, ownerPrompt, id.String(), ownerPath) +} + +// Begin starts the daemon. +// It would do some bootstrap task, and return a closure that would begin the main loop. +func (ad *AdvancerDaemon) Begin(ctx context.Context) (func(), error) { + log.Info("begin advancer daemon", zap.String("id", ad.manager.ID())) + if err := ad.manager.CampaignOwner(); err != nil { + return nil, err + } + + ad.adv.StartTaskListener(ctx) + tick := time.NewTicker(ad.adv.cfg.TickDuration) + loop := func() { + log.Info("begin advancer daemon loop", zap.String("id", ad.manager.ID())) + for { + select { + case <-ctx.Done(): + log.Info("advancer loop exits", zap.String("id", ad.manager.ID())) + return + case <-tick.C: + log.Debug("deamon tick start", zap.Bool("is-owner", ad.manager.IsOwner())) + if ad.manager.IsOwner() { + metrics.AdvancerOwner.Set(1.0) + if err := ad.adv.OnTick(ctx); err != nil { + log.Warn("failed on tick", logutil.ShortError(err)) + } + } else { + metrics.AdvancerOwner.Set(0.0) + } + } + } + } + return loop, nil +} diff --git a/br/pkg/streamhelper/advancer_env.go b/br/pkg/streamhelper/advancer_env.go new file mode 100644 index 0000000000000..21c61ff129ce2 --- /dev/null +++ b/br/pkg/streamhelper/advancer_env.go @@ -0,0 +1,107 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper + +import ( + "context" + "time" + + logbackup "github.com/pingcap/kvproto/pkg/logbackuppb" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/config" + pd "github.com/tikv/pd/client" + clientv3 "go.etcd.io/etcd/client/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" +) + +// Env is the interface required by the advancer. +type Env interface { + // The region scanner provides the region information. + RegionScanner + // LogBackupService connects to the TiKV, so we can collect the region checkpoints. + LogBackupService + // StreamMeta connects to the metadata service (normally PD). + StreamMeta +} + +// PDRegionScanner is a simple wrapper over PD +// to adapt the requirement of `RegionScan`. +type PDRegionScanner struct { + pd.Client +} + +// RegionScan gets a list of regions, starts from the region that contains key. +// Limit limits the maximum number of regions returned. +func (c PDRegionScanner) RegionScan(ctx context.Context, key []byte, endKey []byte, limit int) ([]RegionWithLeader, error) { + rs, err := c.Client.ScanRegions(ctx, key, endKey, limit) + if err != nil { + return nil, err + } + rls := make([]RegionWithLeader, 0, len(rs)) + for _, r := range rs { + rls = append(rls, RegionWithLeader{ + Region: r.Meta, + Leader: r.Leader, + }) + } + return rls, nil +} + +// clusterEnv is the environment for running in the real cluster. +type clusterEnv struct { + clis *utils.StoreManager + *TaskEventClient + PDRegionScanner +} + +// GetLogBackupClient gets the log backup client. +func (t clusterEnv) GetLogBackupClient(ctx context.Context, storeID uint64) (logbackup.LogBackupClient, error) { + var cli logbackup.LogBackupClient + err := t.clis.WithConn(ctx, storeID, func(cc *grpc.ClientConn) { + cli = logbackup.NewLogBackupClient(cc) + }) + if err != nil { + return nil, err + } + return cli, nil +} + +// CliEnv creates the Env for CLI usage. +func CliEnv(cli *utils.StoreManager, etcdCli *clientv3.Client) Env { + return clusterEnv{ + clis: cli, + TaskEventClient: &TaskEventClient{MetaDataClient: *NewMetaDataClient(etcdCli)}, + PDRegionScanner: PDRegionScanner{cli.PDClient()}, + } +} + +// TiDBEnv creates the Env by TiDB config. +func TiDBEnv(pdCli pd.Client, etcdCli *clientv3.Client, conf *config.Config) (Env, error) { + tconf, err := conf.GetTiKVConfig().Security.ToTLSConfig() + if err != nil { + return nil, err + } + return clusterEnv{ + clis: utils.NewStoreManager(pdCli, keepalive.ClientParameters{ + Time: time.Duration(conf.TiKVClient.GrpcKeepAliveTime) * time.Second, + Timeout: time.Duration(conf.TiKVClient.GrpcKeepAliveTimeout) * time.Second, + }, tconf), + TaskEventClient: &TaskEventClient{MetaDataClient: *NewMetaDataClient(etcdCli)}, + PDRegionScanner: PDRegionScanner{Client: pdCli}, + }, nil +} + +type LogBackupService interface { + // GetLogBackupClient gets the log backup client. + GetLogBackupClient(ctx context.Context, storeID uint64) (logbackup.LogBackupClient, error) +} + +// StreamMeta connects to the metadata service (normally PD). +// It provides the global checkpoint information. +type StreamMeta interface { + // Begin begins listen the task event change. + Begin(ctx context.Context, ch chan<- TaskEvent) error + // UploadV3GlobalCheckpointForTask uploads the global checkpoint to the meta store. + UploadV3GlobalCheckpointForTask(ctx context.Context, taskName string, checkpoint uint64) error +} diff --git a/br/pkg/streamhelper/advancer_test.go b/br/pkg/streamhelper/advancer_test.go new file mode 100644 index 0000000000000..f32b099069726 --- /dev/null +++ b/br/pkg/streamhelper/advancer_test.go @@ -0,0 +1,185 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper_test + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/streamhelper" + "github.com/pingcap/tidb/br/pkg/streamhelper/config" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zapcore" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestBasic(t *testing.T) { + c := createFakeCluster(t, 4, false) + defer func() { + fmt.Println(c) + }() + c.splitAndScatter("01", "02", "022", "023", "033", "04", "043") + ctx := context.Background() + minCheckpoint := c.advanceCheckpoints() + env := &testEnv{fakeCluster: c, testCtx: t} + adv := streamhelper.NewCheckpointAdvancer(env) + coll := streamhelper.NewClusterCollector(ctx, env) + err := adv.GetCheckpointInRange(ctx, []byte{}, []byte{}, coll) + require.NoError(t, err) + r, err := coll.Finish(ctx) + require.NoError(t, err) + require.Len(t, r.FailureSubRanges, 0) + require.Equal(t, r.Checkpoint, minCheckpoint, "%d %d", r.Checkpoint, minCheckpoint) +} + +func TestTick(t *testing.T) { + c := createFakeCluster(t, 4, false) + defer func() { + fmt.Println(c) + }() + c.splitAndScatter("01", "02", "022", "023", "033", "04", "043") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + env := &testEnv{fakeCluster: c, testCtx: t} + adv := streamhelper.NewCheckpointAdvancer(env) + adv.StartTaskListener(ctx) + adv.UpdateConfigWith(func(cac *config.Config) { + cac.FullScanTick = 0 + }) + require.NoError(t, adv.OnTick(ctx)) + for i := 0; i < 5; i++ { + cp := c.advanceCheckpoints() + require.NoError(t, adv.OnTick(ctx)) + require.Equal(t, env.getCheckpoint(), cp) + } +} + +func TestWithFailure(t *testing.T) { + log.SetLevel(zapcore.DebugLevel) + c := createFakeCluster(t, 4, true) + defer func() { + fmt.Println(c) + }() + c.splitAndScatter("01", "02", "022", "023", "033", "04", "043") + c.flushAll() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + env := &testEnv{fakeCluster: c, testCtx: t} + adv := streamhelper.NewCheckpointAdvancer(env) + adv.StartTaskListener(ctx) + adv.UpdateConfigWith(func(cac *config.Config) { + cac.FullScanTick = 0 + }) + require.NoError(t, adv.OnTick(ctx)) + + cp := c.advanceCheckpoints() + for _, v := range c.stores { + v.flush() + break + } + require.NoError(t, adv.OnTick(ctx)) + require.Less(t, env.getCheckpoint(), cp, "%d %d", env.getCheckpoint(), cp) + + for _, v := range c.stores { + v.flush() + } + + require.NoError(t, adv.OnTick(ctx)) + require.Equal(t, env.getCheckpoint(), cp) +} + +func shouldFinishInTime(t *testing.T, d time.Duration, name string, f func()) { + ch := make(chan struct{}) + go func() { + f() + close(ch) + }() + select { + case <-time.After(d): + t.Fatalf("%s should finish in %s, but not", name, d) + case <-ch: + } +} + +func TestCollectorFailure(t *testing.T) { + log.SetLevel(zapcore.DebugLevel) + c := createFakeCluster(t, 4, true) + c.onGetClient = func(u uint64) error { + return status.Error(codes.DataLoss, + "Exiled requests from the client, please slow down and listen a story: "+ + "the server has been dropped, we are longing for new nodes, however the goddess(k8s) never allocates new resource. "+ + "May you take the sword named `vim`, refactoring the definition of the nature, in the yaml file hidden at somewhere of the cluster, "+ + "to save all of us and gain the response you desiring?") + } + ctx := context.Background() + splitKeys := make([]string, 0, 10000) + for i := 0; i < 10000; i++ { + splitKeys = append(splitKeys, fmt.Sprintf("%04d", i)) + } + c.splitAndScatter(splitKeys...) + + env := &testEnv{fakeCluster: c, testCtx: t} + adv := streamhelper.NewCheckpointAdvancer(env) + coll := streamhelper.NewClusterCollector(ctx, env) + + shouldFinishInTime(t, 30*time.Second, "scan with always fail", func() { + // At this time, the sending may or may not fail because the sending and batching is doing asynchronously. + _ = adv.GetCheckpointInRange(ctx, []byte{}, []byte{}, coll) + // ...but this must fail, not getting stuck. + _, err := coll.Finish(ctx) + require.Error(t, err) + }) +} + +func oneStoreFailure() func(uint64) error { + victim := uint64(0) + mu := new(sync.Mutex) + return func(u uint64) error { + mu.Lock() + defer mu.Unlock() + if victim == 0 { + victim = u + } + if victim == u { + return status.Error(codes.NotFound, + "The place once lit by the warm lamplight has been swallowed up by the debris now.") + } + return nil + } +} + +func TestOneStoreFailure(t *testing.T) { + log.SetLevel(zapcore.DebugLevel) + c := createFakeCluster(t, 4, true) + ctx := context.Background() + splitKeys := make([]string, 0, 1000) + for i := 0; i < 1000; i++ { + splitKeys = append(splitKeys, fmt.Sprintf("%04d", i)) + } + c.splitAndScatter(splitKeys...) + c.flushAll() + + env := &testEnv{fakeCluster: c, testCtx: t} + adv := streamhelper.NewCheckpointAdvancer(env) + adv.StartTaskListener(ctx) + require.NoError(t, adv.OnTick(ctx)) + c.onGetClient = oneStoreFailure() + + for i := 0; i < 100; i++ { + c.advanceCheckpoints() + c.flushAll() + require.ErrorContains(t, adv.OnTick(ctx), "the warm lamplight") + } + + c.onGetClient = nil + cp := c.advanceCheckpoints() + c.flushAll() + require.NoError(t, adv.OnTick(ctx)) + require.Equal(t, cp, env.checkpoint) +} diff --git a/br/pkg/streamhelper/basic_lib_for_test.go b/br/pkg/streamhelper/basic_lib_for_test.go new file mode 100644 index 0000000000000..14d777f1d24e7 --- /dev/null +++ b/br/pkg/streamhelper/basic_lib_for_test.go @@ -0,0 +1,432 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper_test + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "math" + "math/rand" + "sort" + "strings" + "sync" + "testing" + + backup "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/errorpb" + logbackup "github.com/pingcap/kvproto/pkg/logbackuppb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/br/pkg/streamhelper" + "github.com/pingcap/tidb/kv" + "google.golang.org/grpc" +) + +type flushSimulator struct { + flushedEpoch uint64 + enabled bool +} + +func (c flushSimulator) makeError(requestedEpoch uint64) *errorpb.Error { + if !c.enabled { + return nil + } + if c.flushedEpoch == 0 { + e := errorpb.Error{ + Message: "not flushed", + } + return &e + } + if c.flushedEpoch != requestedEpoch { + e := errorpb.Error{ + Message: "flushed epoch not match", + } + return &e + } + return nil +} + +func (c flushSimulator) fork() flushSimulator { + return flushSimulator{ + enabled: c.enabled, + } +} + +type region struct { + rng kv.KeyRange + leader uint64 + epoch uint64 + id uint64 + checkpoint uint64 + + fsim flushSimulator +} + +type fakeStore struct { + id uint64 + regions map[uint64]*region +} + +type fakeCluster struct { + mu sync.Mutex + idAlloced uint64 + stores map[uint64]*fakeStore + regions []*region + testCtx *testing.T + + onGetClient func(uint64) error +} + +func overlaps(a, b kv.KeyRange) bool { + if len(b.EndKey) == 0 { + return len(a.EndKey) == 0 || bytes.Compare(a.EndKey, b.StartKey) > 0 + } + if len(a.EndKey) == 0 { + return len(b.EndKey) == 0 || bytes.Compare(b.EndKey, a.StartKey) > 0 + } + return bytes.Compare(a.StartKey, b.EndKey) < 0 && bytes.Compare(b.StartKey, a.EndKey) < 0 +} + +func (f *region) splitAt(newID uint64, k string) *region { + newRegion := ®ion{ + rng: kv.KeyRange{StartKey: []byte(k), EndKey: f.rng.EndKey}, + leader: f.leader, + epoch: f.epoch + 1, + id: newID, + checkpoint: f.checkpoint, + fsim: f.fsim.fork(), + } + f.rng.EndKey = []byte(k) + f.epoch += 1 + f.fsim = f.fsim.fork() + return newRegion +} + +func (f *region) flush() { + f.fsim.flushedEpoch = f.epoch +} + +func (f *fakeStore) GetLastFlushTSOfRegion(ctx context.Context, in *logbackup.GetLastFlushTSOfRegionRequest, opts ...grpc.CallOption) (*logbackup.GetLastFlushTSOfRegionResponse, error) { + resp := &logbackup.GetLastFlushTSOfRegionResponse{ + Checkpoints: []*logbackup.RegionCheckpoint{}, + } + for _, r := range in.Regions { + region, ok := f.regions[r.Id] + if !ok || region.leader != f.id { + resp.Checkpoints = append(resp.Checkpoints, &logbackup.RegionCheckpoint{ + Err: &errorpb.Error{ + Message: "not found", + }, + Region: &logbackup.RegionIdentity{ + Id: region.id, + EpochVersion: region.epoch, + }, + }) + continue + } + if err := region.fsim.makeError(r.EpochVersion); err != nil { + resp.Checkpoints = append(resp.Checkpoints, &logbackup.RegionCheckpoint{ + Err: err, + Region: &logbackup.RegionIdentity{ + Id: region.id, + EpochVersion: region.epoch, + }, + }) + continue + } + if region.epoch != r.EpochVersion { + resp.Checkpoints = append(resp.Checkpoints, &logbackup.RegionCheckpoint{ + Err: &errorpb.Error{ + Message: "epoch not match", + }, + Region: &logbackup.RegionIdentity{ + Id: region.id, + EpochVersion: region.epoch, + }, + }) + continue + } + resp.Checkpoints = append(resp.Checkpoints, &logbackup.RegionCheckpoint{ + Checkpoint: region.checkpoint, + Region: &logbackup.RegionIdentity{ + Id: region.id, + EpochVersion: region.epoch, + }, + }) + } + return resp, nil +} + +// RegionScan gets a list of regions, starts from the region that contains key. +// Limit limits the maximum number of regions returned. +func (f *fakeCluster) RegionScan(ctx context.Context, key []byte, endKey []byte, limit int) ([]streamhelper.RegionWithLeader, error) { + f.mu.Lock() + defer f.mu.Unlock() + sort.Slice(f.regions, func(i, j int) bool { + return bytes.Compare(f.regions[i].rng.StartKey, f.regions[j].rng.StartKey) < 0 + }) + + result := make([]streamhelper.RegionWithLeader, 0, limit) + for _, region := range f.regions { + if overlaps(kv.KeyRange{StartKey: key, EndKey: endKey}, region.rng) && len(result) < limit { + regionInfo := streamhelper.RegionWithLeader{ + Region: &metapb.Region{ + Id: region.id, + StartKey: region.rng.StartKey, + EndKey: region.rng.EndKey, + RegionEpoch: &metapb.RegionEpoch{ + Version: region.epoch, + }, + }, + Leader: &metapb.Peer{ + StoreId: region.leader, + }, + } + result = append(result, regionInfo) + } else if bytes.Compare(region.rng.StartKey, key) > 0 { + break + } + } + return result, nil +} + +func (f *fakeCluster) GetLogBackupClient(ctx context.Context, storeID uint64) (logbackup.LogBackupClient, error) { + if f.onGetClient != nil { + err := f.onGetClient(storeID) + if err != nil { + return nil, err + } + } + cli, ok := f.stores[storeID] + if !ok { + f.testCtx.Fatalf("the store %d doesn't exist", storeID) + } + return cli, nil +} + +func (f *fakeCluster) findRegionById(rid uint64) *region { + for _, r := range f.regions { + if r.id == rid { + return r + } + } + return nil +} + +func (f *fakeCluster) findRegionByKey(key []byte) *region { + for _, r := range f.regions { + if bytes.Compare(key, r.rng.StartKey) >= 0 && (len(r.rng.EndKey) == 0 || bytes.Compare(key, r.rng.EndKey) < 0) { + return r + } + } + panic(fmt.Sprintf("inconsistent key space; key = %X", key)) +} + +func (f *fakeCluster) transferRegionTo(rid uint64, newPeers []uint64) { + r := f.findRegionById(rid) +storeLoop: + for _, store := range f.stores { + for _, pid := range newPeers { + if pid == store.id { + store.regions[rid] = r + continue storeLoop + } + } + delete(store.regions, rid) + } +} + +func (f *fakeCluster) splitAt(key string) { + k := []byte(key) + r := f.findRegionByKey(k) + newRegion := r.splitAt(f.idAlloc(), key) + for _, store := range f.stores { + _, ok := store.regions[r.id] + if ok { + store.regions[newRegion.id] = newRegion + } + } + f.regions = append(f.regions, newRegion) +} + +func (f *fakeCluster) idAlloc() uint64 { + f.idAlloced++ + return f.idAlloced +} + +func (f *fakeCluster) chooseStores(n int) []uint64 { + s := make([]uint64, 0, len(f.stores)) + for id := range f.stores { + s = append(s, id) + } + rand.Shuffle(len(s), func(i, j int) { + s[i], s[j] = s[j], s[i] + }) + return s[:n] +} + +func (f *fakeCluster) findPeers(rid uint64) (result []uint64) { + for _, store := range f.stores { + if _, ok := store.regions[rid]; ok { + result = append(result, store.id) + } + } + return +} + +func (f *fakeCluster) shuffleLeader(rid uint64) { + r := f.findRegionById(rid) + peers := f.findPeers(rid) + rand.Shuffle(len(peers), func(i, j int) { + peers[i], peers[j] = peers[j], peers[i] + }) + + newLeader := peers[0] + r.leader = newLeader +} + +func (f *fakeCluster) splitAndScatter(keys ...string) { + f.mu.Lock() + defer f.mu.Unlock() + for _, key := range keys { + f.splitAt(key) + } + for _, r := range f.regions { + f.transferRegionTo(r.id, f.chooseStores(3)) + f.shuffleLeader(r.id) + } +} + +// a stub once in the future we want to make different stores hold different region instances. +func (f *fakeCluster) updateRegion(rid uint64, mut func(*region)) { + r := f.findRegionById(rid) + mut(r) +} + +func (f *fakeCluster) advanceCheckpoints() uint64 { + minCheckpoint := uint64(math.MaxUint64) + for _, r := range f.regions { + f.updateRegion(r.id, func(r *region) { + r.checkpoint += rand.Uint64() % 256 + if r.checkpoint < minCheckpoint { + minCheckpoint = r.checkpoint + } + r.fsim.flushedEpoch = 0 + }) + } + return minCheckpoint +} + +func createFakeCluster(t *testing.T, n int, simEnabled bool) *fakeCluster { + c := &fakeCluster{ + stores: map[uint64]*fakeStore{}, + regions: []*region{}, + testCtx: t, + } + stores := make([]*fakeStore, 0, n) + for i := 0; i < n; i++ { + s := new(fakeStore) + s.id = c.idAlloc() + s.regions = map[uint64]*region{} + stores = append(stores, s) + } + initialRegion := ®ion{ + rng: kv.KeyRange{}, + leader: stores[0].id, + epoch: 0, + id: c.idAlloc(), + checkpoint: 0, + fsim: flushSimulator{ + enabled: simEnabled, + }, + } + for i := 0; i < 3; i++ { + if i < len(stores) { + stores[i].regions[initialRegion.id] = initialRegion + } + } + for _, s := range stores { + c.stores[s.id] = s + } + c.regions = append(c.regions, initialRegion) + return c +} + +func (r *region) String() string { + return fmt.Sprintf("%d(%d):[%s,%s);%dL%d", r.id, r.epoch, hex.EncodeToString(r.rng.StartKey), hex.EncodeToString(r.rng.EndKey), r.checkpoint, r.leader) +} + +func (s *fakeStore) String() string { + buf := new(strings.Builder) + fmt.Fprintf(buf, "%d: ", s.id) + for _, r := range s.regions { + fmt.Fprintf(buf, "%s ", r) + } + return buf.String() +} + +func (f *fakeCluster) flushAll() { + for _, r := range f.regions { + r.flush() + } +} + +func (s *fakeStore) flush() { + for _, r := range s.regions { + if r.leader == s.id { + r.flush() + } + } +} + +func (f *fakeCluster) String() string { + buf := new(strings.Builder) + fmt.Fprint(buf, ">>> fake cluster <<<\nregions: ") + for _, region := range f.regions { + fmt.Fprint(buf, region, " ") + } + fmt.Fprintln(buf) + for _, store := range f.stores { + fmt.Fprintln(buf, store) + } + return buf.String() +} + +type testEnv struct { + *fakeCluster + checkpoint uint64 + testCtx *testing.T + + mu sync.Mutex +} + +func (t *testEnv) Begin(ctx context.Context, ch chan<- streamhelper.TaskEvent) error { + tsk := streamhelper.TaskEvent{ + Type: streamhelper.EventAdd, + Name: "whole", + Info: &backup.StreamBackupTaskInfo{ + Name: "whole", + }, + } + ch <- tsk + return nil +} + +func (t *testEnv) UploadV3GlobalCheckpointForTask(ctx context.Context, _ string, checkpoint uint64) error { + t.mu.Lock() + defer t.mu.Unlock() + + if checkpoint < t.checkpoint { + t.testCtx.Fatalf("checkpoint rolling back (from %d to %d)", t.checkpoint, checkpoint) + } + t.checkpoint = checkpoint + return nil +} + +func (t *testEnv) getCheckpoint() uint64 { + t.mu.Lock() + defer t.mu.Unlock() + + return t.checkpoint +} diff --git a/br/pkg/stream/client.go b/br/pkg/streamhelper/client.go similarity index 90% rename from br/pkg/stream/client.go rename to br/pkg/streamhelper/client.go index cbeaf8a4b5437..95c5cb07e2da5 100644 --- a/br/pkg/stream/client.go +++ b/br/pkg/streamhelper/client.go @@ -1,5 +1,5 @@ // Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. -package stream +package streamhelper import ( "bytes" @@ -28,6 +28,8 @@ type Checkpoint struct { ID uint64 `json:"id,omitempty"` Version uint64 `json:"epoch_version,omitempty"` TS uint64 `json:"ts"` + + IsGlobal bool `json:"-"` } type CheckpointType int @@ -36,12 +38,15 @@ const ( CheckpointTypeStore CheckpointType = iota CheckpointTypeRegion CheckpointTypeTask + CheckpointTypeGlobal CheckpointTypeInvalid ) // Type returns the type(provider) of the checkpoint. func (cp Checkpoint) Type() CheckpointType { switch { + case cp.IsGlobal: + return CheckpointTypeGlobal case cp.ID == 0 && cp.Version == 0: return CheckpointTypeTask case cp.ID != 0 && cp.Version == 0: @@ -72,7 +77,7 @@ func ParseCheckpoint(task string, key, value []byte) (Checkpoint, error) { segs := bytes.Split(key, []byte("/")) var checkpoint Checkpoint switch string(segs[0]) { - case "store": + case checkpointTypeStore: if len(segs) != 2 { return checkpoint, errors.Annotatef(berrors.ErrPiTRMalformedMetadata, "the store checkpoint seg mismatch; segs = %v", segs) @@ -82,7 +87,9 @@ func ParseCheckpoint(task string, key, value []byte) (Checkpoint, error) { return checkpoint, err } checkpoint.ID = id - case "region": + case checkpointTypeGlobal: + checkpoint.IsGlobal = true + case checkpointTypeRegion: if len(segs) != 3 { return checkpoint, errors.Annotatef(berrors.ErrPiTRMalformedMetadata, "the region checkpoint seg mismatch; segs = %v", segs) @@ -187,6 +194,17 @@ func (c *MetaDataClient) CleanLastErrorOfTask(ctx context.Context, taskName stri return nil } +func (c *MetaDataClient) UploadV3GlobalCheckpointForTask(ctx context.Context, taskName string, checkpoint uint64) error { + key := GlobalCheckpointOf(taskName) + value := string(encodeUint64(checkpoint)) + _, err := c.KV.Put(ctx, key, value) + + if err != nil { + return err + } + return nil +} + // GetTask get the basic task handle from the metadata storage. func (c *MetaDataClient) GetTask(ctx context.Context, taskName string) (*Task, error) { resp, err := c.Get(ctx, TaskOf(taskName)) @@ -235,25 +253,35 @@ func (c *MetaDataClient) GetTaskWithPauseStatus(ctx context.Context, taskName st return &Task{cli: c, Info: taskInfo}, paused, nil } -// GetAllTasks get all of tasks from metadata storage. -func (c *MetaDataClient) GetAllTasks(ctx context.Context) ([]Task, error) { - scanner := scanEtcdPrefix(c.Client, PrefixOfTask()) - kvs, err := scanner.AllPages(ctx, 1) +func (c *MetaDataClient) TaskByInfo(t backuppb.StreamBackupTaskInfo) *Task { + return &Task{cli: c, Info: t} +} + +func (c *MetaDataClient) GetAllTasksWithRevision(ctx context.Context) ([]Task, int64, error) { + resp, err := c.KV.Get(ctx, PrefixOfTask(), clientv3.WithPrefix()) if err != nil { - return nil, errors.Trace(err) - } else if len(kvs) == 0 { - return nil, nil + return nil, 0, errors.Trace(err) + } + kvs := resp.Kvs + if len(kvs) == 0 { + return nil, resp.Header.GetRevision(), nil } tasks := make([]Task, len(kvs)) for idx, kv := range kvs { err = proto.Unmarshal(kv.Value, &tasks[idx].Info) if err != nil { - return nil, errors.Trace(err) + return nil, 0, errors.Trace(err) } tasks[idx].cli = c } - return tasks, nil + return tasks, resp.Header.GetRevision(), nil +} + +// GetAllTasks get all of tasks from metadata storage. +func (c *MetaDataClient) GetAllTasks(ctx context.Context) ([]Task, error) { + tasks, _, err := c.GetAllTasksWithRevision(ctx) + return tasks, err } // GetTaskCount get the count of tasks from metadata storage. @@ -375,6 +403,14 @@ func (t *Task) Step(ctx context.Context, store uint64, ts uint64) error { return nil } +func (t *Task) UploadGlobalCheckpoint(ctx context.Context, ts uint64) error { + _, err := t.cli.KV.Put(ctx, GlobalCheckpointOf(t.Info.Name), string(encodeUint64(ts))) + if err != nil { + return err + } + return nil +} + func (t *Task) LastError(ctx context.Context) (map[uint64]backuppb.StreamBackupError, error) { storeToError := map[uint64]backuppb.StreamBackupError{} prefix := LastErrorPrefixOf(t.Info.Name) diff --git a/br/pkg/streamhelper/collector.go b/br/pkg/streamhelper/collector.go new file mode 100644 index 0000000000000..1df39d0633d68 --- /dev/null +++ b/br/pkg/streamhelper/collector.go @@ -0,0 +1,315 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper + +import ( + "context" + "fmt" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/pingcap/errors" + logbackup "github.com/pingcap/kvproto/pkg/logbackuppb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/metrics" + "go.uber.org/zap" +) + +const ( + defaultBatchSize = 1024 +) + +type onSuccessHook = func(uint64, kv.KeyRange) + +// storeCollector collects the region checkpoints from some store. +// it receives requests from the input channel, batching the requests, and send them to the store. +// because the server supports batching, the range of request regions can be discrete. +// note this is a temporary struct, its lifetime is shorter that the tick of advancer. +type storeCollector struct { + storeID uint64 + batchSize int + + service LogBackupService + + input chan RegionWithLeader + // the oneshot error reporter. + err *atomic.Value + // whether the recv and send loop has exited. + doneMessenger chan struct{} + onSuccess onSuccessHook + + // concurrency safety: + // those fields should only be write on the goroutine running `recvLoop`. + // Once it exits, we can read those fields. + currentRequest logbackup.GetLastFlushTSOfRegionRequest + checkpoint uint64 + inconsistent []kv.KeyRange + regionMap map[uint64]kv.KeyRange +} + +func newStoreCollector(storeID uint64, srv LogBackupService) *storeCollector { + return &storeCollector{ + storeID: storeID, + batchSize: defaultBatchSize, + service: srv, + input: make(chan RegionWithLeader, defaultBatchSize), + err: new(atomic.Value), + doneMessenger: make(chan struct{}), + regionMap: make(map[uint64]kv.KeyRange), + } +} + +func (c *storeCollector) reportErr(err error) { + if oldErr := c.Err(); oldErr != nil { + log.Warn("reporting error twice, ignoring", logutil.AShortError("old", err), logutil.AShortError("new", oldErr)) + return + } + c.err.Store(err) +} + +func (c *storeCollector) Err() error { + err, ok := c.err.Load().(error) + if !ok { + return nil + } + return err +} + +func (c *storeCollector) setOnSuccessHook(hook onSuccessHook) { + c.onSuccess = hook +} + +func (c *storeCollector) begin(ctx context.Context) { + err := c.recvLoop(ctx) + if err != nil { + log.Warn("collector loop meet error", logutil.ShortError(err)) + c.reportErr(err) + } + close(c.doneMessenger) +} + +func (c *storeCollector) recvLoop(ctx context.Context) (err error) { + defer utils.PanicToErr(&err) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case r, ok := <-c.input: + if !ok { + return c.sendPendingRequests(ctx) + } + + if r.Leader.StoreId != c.storeID { + log.Warn("trying to request to store which isn't the leader of region.", + zap.Uint64("region", r.Region.Id), + zap.Uint64("target-store", c.storeID), + zap.Uint64("leader", r.Leader.StoreId), + ) + } + c.appendRegionMap(r) + c.currentRequest.Regions = append(c.currentRequest.Regions, &logbackup.RegionIdentity{ + Id: r.Region.GetId(), + EpochVersion: r.Region.GetRegionEpoch().GetVersion(), + }) + if len(c.currentRequest.Regions) >= c.batchSize { + err := c.sendPendingRequests(ctx) + if err != nil { + return err + } + } + } + } +} + +func (c *storeCollector) appendRegionMap(r RegionWithLeader) { + c.regionMap[r.Region.GetId()] = kv.KeyRange{StartKey: r.Region.StartKey, EndKey: r.Region.EndKey} +} + +type StoreCheckpoints struct { + HasCheckpoint bool + Checkpoint uint64 + FailureSubRanges []kv.KeyRange +} + +func (s *StoreCheckpoints) merge(other StoreCheckpoints) { + if other.HasCheckpoint && (other.Checkpoint < s.Checkpoint || !s.HasCheckpoint) { + s.Checkpoint = other.Checkpoint + s.HasCheckpoint = true + } + s.FailureSubRanges = append(s.FailureSubRanges, other.FailureSubRanges...) +} + +func (s *StoreCheckpoints) String() string { + sb := new(strings.Builder) + sb.WriteString("StoreCheckpoints:") + if s.HasCheckpoint { + sb.WriteString(strconv.Itoa(int(s.Checkpoint))) + } else { + sb.WriteString("none") + } + fmt.Fprintf(sb, ":(remaining %d ranges)", len(s.FailureSubRanges)) + return sb.String() +} + +func (c *storeCollector) spawn(ctx context.Context) func(context.Context) (StoreCheckpoints, error) { + go c.begin(ctx) + return func(cx context.Context) (StoreCheckpoints, error) { + close(c.input) + select { + case <-cx.Done(): + return StoreCheckpoints{}, cx.Err() + case <-c.doneMessenger: + } + if err := c.Err(); err != nil { + return StoreCheckpoints{}, err + } + sc := StoreCheckpoints{ + HasCheckpoint: c.checkpoint != 0, + Checkpoint: c.checkpoint, + FailureSubRanges: c.inconsistent, + } + return sc, nil + } +} + +func (c *storeCollector) sendPendingRequests(ctx context.Context) error { + log.Debug("sending batch", zap.Int("size", len(c.currentRequest.Regions)), zap.Uint64("store", c.storeID)) + cli, err := c.service.GetLogBackupClient(ctx, c.storeID) + if err != nil { + return err + } + cps, err := cli.GetLastFlushTSOfRegion(ctx, &c.currentRequest) + if err != nil { + return err + } + metrics.GetCheckpointBatchSize.WithLabelValues("checkpoint").Observe(float64(len(c.currentRequest.GetRegions()))) + c.currentRequest = logbackup.GetLastFlushTSOfRegionRequest{} + for _, checkpoint := range cps.Checkpoints { + if checkpoint.Err != nil { + log.Debug("failed to get region checkpoint", zap.Stringer("err", checkpoint.Err)) + c.inconsistent = append(c.inconsistent, c.regionMap[checkpoint.Region.Id]) + } else { + if c.onSuccess != nil { + c.onSuccess(checkpoint.Checkpoint, c.regionMap[checkpoint.Region.Id]) + } + // assuming the checkpoint would never be zero, use it as the placeholder. (1970 is so far away...) + if checkpoint.Checkpoint < c.checkpoint || c.checkpoint == 0 { + c.checkpoint = checkpoint.Checkpoint + } + } + } + return nil +} + +type runningStoreCollector struct { + collector *storeCollector + wait func(context.Context) (StoreCheckpoints, error) +} + +// clusterCollector is the controller for collecting region checkpoints for the cluster. +// It creates multi store collectors. +// ┌──────────────────────┐ Requesting ┌────────────┐ +// ┌─►│ StoreCollector[id=1] ├─────────────►│ TiKV[id=1] │ +// │ └──────────────────────┘ └────────────┘ +// │ +// │Owns +// ┌──────────────────┐ │ ┌──────────────────────┐ Requesting ┌────────────┐ +// │ ClusterCollector ├─────┼─►│ StoreCollector[id=4] ├─────────────►│ TiKV[id=4] │ +// └──────────────────┘ │ └──────────────────────┘ └────────────┘ +// │ +// │ +// │ ┌──────────────────────┐ Requesting ┌────────────┐ +// └─►│ StoreCollector[id=5] ├─────────────►│ TiKV[id=5] │ +// └──────────────────────┘ └────────────┘ +type clusterCollector struct { + mu sync.Mutex + collectors map[uint64]runningStoreCollector + noLeaders []kv.KeyRange + onSuccess onSuccessHook + + // The context for spawning sub collectors. + // Because the collectors are running lazily, + // keep the initial context for all subsequent goroutines, + // so we can make sure we can cancel all subtasks. + masterCtx context.Context + cancel context.CancelFunc + srv LogBackupService +} + +// NewClusterCollector creates a new cluster collector. +// collectors are the structure transform region information to checkpoint information, +// by requesting the checkpoint of regions in the store. +func NewClusterCollector(ctx context.Context, srv LogBackupService) *clusterCollector { + cx, cancel := context.WithCancel(ctx) + return &clusterCollector{ + collectors: map[uint64]runningStoreCollector{}, + masterCtx: cx, + cancel: cancel, + srv: srv, + } +} + +// setOnSuccessHook sets the hook when getting checkpoint of some region. +func (c *clusterCollector) setOnSuccessHook(hook onSuccessHook) { + c.onSuccess = hook +} + +// collectRegion adds a region to the collector. +func (c *clusterCollector) collectRegion(r RegionWithLeader) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.masterCtx.Err() != nil { + return nil + } + + if r.Leader.GetStoreId() == 0 { + log.Warn("there are regions without leader", zap.Uint64("region", r.Region.GetId())) + c.noLeaders = append(c.noLeaders, kv.KeyRange{StartKey: r.Region.StartKey, EndKey: r.Region.EndKey}) + return nil + } + leader := r.Leader.StoreId + _, ok := c.collectors[leader] + if !ok { + coll := newStoreCollector(leader, c.srv) + if c.onSuccess != nil { + coll.setOnSuccessHook(c.onSuccess) + } + c.collectors[leader] = runningStoreCollector{ + collector: coll, + wait: coll.spawn(c.masterCtx), + } + } + + sc := c.collectors[leader].collector + select { + case sc.input <- r: + return nil + case <-sc.doneMessenger: + err := sc.Err() + if err != nil { + c.cancel() + } + return err + } +} + +// Finish finishes collecting the region checkpoints, wait and returning the final result. +// Note this takes the ownership of this collector, you may create a new collector for next use. +func (c *clusterCollector) Finish(ctx context.Context) (StoreCheckpoints, error) { + defer c.cancel() + result := StoreCheckpoints{FailureSubRanges: c.noLeaders} + for id, coll := range c.collectors { + r, err := coll.wait(ctx) + if err != nil { + return StoreCheckpoints{}, errors.Annotatef(err, "store %d", id) + } + result.merge(r) + log.Debug("get checkpoint", zap.Stringer("checkpoint", &r), zap.Stringer("merged", &result)) + } + return result, nil +} diff --git a/br/pkg/streamhelper/config/advancer_conf.go b/br/pkg/streamhelper/config/advancer_conf.go new file mode 100644 index 0000000000000..21fac65ae0323 --- /dev/null +++ b/br/pkg/streamhelper/config/advancer_conf.go @@ -0,0 +1,82 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package config + +import ( + "time" + + "github.com/spf13/pflag" +) + +const ( + flagBackoffTime = "backoff-time" + flagMaxBackoffTime = "max-backoff-time" + flagTickInterval = "tick-interval" + flagFullScanDiffTick = "full-scan-tick" + flagAdvancingByCache = "advancing-by-cache" + + DefaultConsistencyCheckTick = 5 + DefaultTryAdvanceThreshold = 3 * time.Minute +) + +var ( + DefaultMaxConcurrencyAdvance = 8 +) + +type Config struct { + // The gap between two retries. + BackoffTime time.Duration `toml:"backoff-time" json:"backoff-time"` + // When after this time we cannot collect the safe resolved ts, give up. + MaxBackoffTime time.Duration `toml:"max-backoff-time" json:"max-backoff-time"` + // The gap between calculating checkpoints. + TickDuration time.Duration `toml:"tick-interval" json:"tick-interval"` + // The backoff time of full scan. + FullScanTick int `toml:"full-scan-tick" json:"full-scan-tick"` + + // Whether enable the optimization -- use a cached heap to advancing the global checkpoint. + // This may reduce the gap of checkpoint but may cost more CPU. + AdvancingByCache bool `toml:"advancing-by-cache" json:"advancing-by-cache"` +} + +func DefineFlagsForCheckpointAdvancerConfig(f *pflag.FlagSet) { + f.Duration(flagBackoffTime, 5*time.Second, "The gap between two retries.") + f.Duration(flagMaxBackoffTime, 20*time.Minute, "After how long we should advance the checkpoint.") + f.Duration(flagTickInterval, 12*time.Second, "From how log we trigger the tick (advancing the checkpoint).") + f.Bool(flagAdvancingByCache, true, "Whether enable the optimization -- use a cached heap to advancing the global checkpoint.") + f.Int(flagFullScanDiffTick, 4, "The backoff of full scan.") +} + +func Default() Config { + return Config{ + BackoffTime: 5 * time.Second, + MaxBackoffTime: 20 * time.Minute, + TickDuration: 12 * time.Second, + FullScanTick: 4, + AdvancingByCache: true, + } +} + +func (conf *Config) GetFromFlags(f *pflag.FlagSet) error { + var err error + conf.BackoffTime, err = f.GetDuration(flagBackoffTime) + if err != nil { + return err + } + conf.MaxBackoffTime, err = f.GetDuration(flagMaxBackoffTime) + if err != nil { + return err + } + conf.TickDuration, err = f.GetDuration(flagTickInterval) + if err != nil { + return err + } + conf.FullScanTick, err = f.GetInt(flagFullScanDiffTick) + if err != nil { + return err + } + conf.AdvancingByCache, err = f.GetBool(flagAdvancingByCache) + if err != nil { + return err + } + return nil +} diff --git a/br/pkg/stream/integration_test.go b/br/pkg/streamhelper/integration_test.go similarity index 68% rename from br/pkg/stream/integration_test.go rename to br/pkg/streamhelper/integration_test.go index 92a465172afec..09f50f46e0011 100644 --- a/br/pkg/stream/integration_test.go +++ b/br/pkg/streamhelper/integration_test.go @@ -1,7 +1,7 @@ // Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. // This package tests the login in MetaClient with a embed etcd. -package stream_test +package streamhelper_test import ( "context" @@ -15,7 +15,7 @@ import ( berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/stream" + "github.com/pingcap/tidb/br/pkg/streamhelper" "github.com/pingcap/tidb/tablecodec" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/kv" @@ -63,11 +63,11 @@ func runEtcd(t *testing.T) (*embed.Etcd, *clientv3.Client) { return etcd, cli } -func simpleRanges(tableCount int) stream.Ranges { - ranges := stream.Ranges{} +func simpleRanges(tableCount int) streamhelper.Ranges { + ranges := streamhelper.Ranges{} for i := 0; i < tableCount; i++ { base := int64(i*2 + 1) - ranges = append(ranges, stream.Range{ + ranges = append(ranges, streamhelper.Range{ StartKey: tablecodec.EncodeTablePrefix(base), EndKey: tablecodec.EncodeTablePrefix(base + 1), }) @@ -75,9 +75,9 @@ func simpleRanges(tableCount int) stream.Ranges { return ranges } -func simpleTask(name string, tableCount int) stream.TaskInfo { +func simpleTask(name string, tableCount int) streamhelper.TaskInfo { backend, _ := storage.ParseBackend("noop://", nil) - task, err := stream.NewTask(name). + task, err := streamhelper.NewTask(name). FromTS(1). UntilTS(1000). WithRanges(simpleRanges(tableCount)...). @@ -110,7 +110,7 @@ func keyNotExists(t *testing.T, key []byte, etcd *embed.Etcd) { require.Len(t, r.KVs, 0) } -func rangeMatches(t *testing.T, ranges stream.Ranges, etcd *embed.Etcd) { +func rangeMatches(t *testing.T, ranges streamhelper.Ranges, etcd *embed.Etcd) { r, err := etcd.Server.KV().Range(context.TODO(), ranges[0].StartKey, ranges[len(ranges)-1].EndKey, mvcc.RangeOptions{}) require.NoError(t, err) if len(r.KVs) != len(ranges) { @@ -133,33 +133,34 @@ func rangeIsEmpty(t *testing.T, prefix []byte, etcd *embed.Etcd) { func TestIntegration(t *testing.T) { etcd, cli := runEtcd(t) defer etcd.Server.Stop() - metaCli := stream.MetaDataClient{Client: cli} + metaCli := streamhelper.MetaDataClient{Client: cli} t.Run("TestBasic", func(t *testing.T) { testBasic(t, metaCli, etcd) }) t.Run("TestForwardProgress", func(t *testing.T) { testForwardProgress(t, metaCli, etcd) }) + t.Run("TestStreamListening", func(t *testing.T) { testStreamListening(t, streamhelper.TaskEventClient{MetaDataClient: metaCli}) }) } func TestChecking(t *testing.T) { noop, _ := storage.ParseBackend("noop://", nil) // The name must not contains slash. - _, err := stream.NewTask("/root"). + _, err := streamhelper.NewTask("/root"). WithRange([]byte("1"), []byte("2")). WithTableFilter("*.*"). ToStorage(noop). Check() require.ErrorIs(t, errors.Cause(err), berrors.ErrPiTRInvalidTaskInfo) // Must specify the external storage. - _, err = stream.NewTask("root"). + _, err = streamhelper.NewTask("root"). WithRange([]byte("1"), []byte("2")). WithTableFilter("*.*"). Check() require.ErrorIs(t, errors.Cause(err), berrors.ErrPiTRInvalidTaskInfo) // Must specift the table filter and range? - _, err = stream.NewTask("root"). + _, err = streamhelper.NewTask("root"). ToStorage(noop). Check() require.ErrorIs(t, errors.Cause(err), berrors.ErrPiTRInvalidTaskInfo) // Happy path. - _, err = stream.NewTask("root"). + _, err = streamhelper.NewTask("root"). WithRange([]byte("1"), []byte("2")). WithTableFilter("*.*"). ToStorage(noop). @@ -167,43 +168,43 @@ func TestChecking(t *testing.T) { require.NoError(t, err) } -func testBasic(t *testing.T, metaCli stream.MetaDataClient, etcd *embed.Etcd) { +func testBasic(t *testing.T, metaCli streamhelper.MetaDataClient, etcd *embed.Etcd) { ctx := context.Background() taskName := "two_tables" task := simpleTask(taskName, 2) taskData, err := task.PBInfo.Marshal() require.NoError(t, err) require.NoError(t, metaCli.PutTask(ctx, task)) - keyIs(t, []byte(stream.TaskOf(taskName)), taskData, etcd) - keyNotExists(t, []byte(stream.Pause(taskName)), etcd) - rangeMatches(t, []stream.Range{ - {StartKey: []byte(stream.RangeKeyOf(taskName, tablecodec.EncodeTablePrefix(1))), EndKey: tablecodec.EncodeTablePrefix(2)}, - {StartKey: []byte(stream.RangeKeyOf(taskName, tablecodec.EncodeTablePrefix(3))), EndKey: tablecodec.EncodeTablePrefix(4)}, + keyIs(t, []byte(streamhelper.TaskOf(taskName)), taskData, etcd) + keyNotExists(t, []byte(streamhelper.Pause(taskName)), etcd) + rangeMatches(t, []streamhelper.Range{ + {StartKey: []byte(streamhelper.RangeKeyOf(taskName, tablecodec.EncodeTablePrefix(1))), EndKey: tablecodec.EncodeTablePrefix(2)}, + {StartKey: []byte(streamhelper.RangeKeyOf(taskName, tablecodec.EncodeTablePrefix(3))), EndKey: tablecodec.EncodeTablePrefix(4)}, }, etcd) remoteTask, err := metaCli.GetTask(ctx, taskName) require.NoError(t, err) require.NoError(t, remoteTask.Pause(ctx)) - keyExists(t, []byte(stream.Pause(taskName)), etcd) + keyExists(t, []byte(streamhelper.Pause(taskName)), etcd) require.NoError(t, metaCli.PauseTask(ctx, taskName)) - keyExists(t, []byte(stream.Pause(taskName)), etcd) + keyExists(t, []byte(streamhelper.Pause(taskName)), etcd) paused, err := remoteTask.IsPaused(ctx) require.NoError(t, err) require.True(t, paused) require.NoError(t, metaCli.ResumeTask(ctx, taskName)) - keyNotExists(t, []byte(stream.Pause(taskName)), etcd) + keyNotExists(t, []byte(streamhelper.Pause(taskName)), etcd) require.NoError(t, metaCli.ResumeTask(ctx, taskName)) - keyNotExists(t, []byte(stream.Pause(taskName)), etcd) + keyNotExists(t, []byte(streamhelper.Pause(taskName)), etcd) paused, err = remoteTask.IsPaused(ctx) require.NoError(t, err) require.False(t, paused) require.NoError(t, metaCli.DeleteTask(ctx, taskName)) - keyNotExists(t, []byte(stream.TaskOf(taskName)), etcd) - rangeIsEmpty(t, []byte(stream.RangesOf(taskName)), etcd) + keyNotExists(t, []byte(streamhelper.TaskOf(taskName)), etcd) + rangeIsEmpty(t, []byte(streamhelper.RangesOf(taskName)), etcd) } -func testForwardProgress(t *testing.T, metaCli stream.MetaDataClient, etcd *embed.Etcd) { +func testForwardProgress(t *testing.T, metaCli streamhelper.MetaDataClient, etcd *embed.Etcd) { ctx := context.Background() taskName := "many_tables" taskInfo := simpleTask(taskName, 65) @@ -227,3 +228,34 @@ func testForwardProgress(t *testing.T, metaCli stream.MetaDataClient, etcd *embe require.NoError(t, err) require.Equal(t, store2Checkpoint, uint64(40)) } + +func testStreamListening(t *testing.T, metaCli streamhelper.TaskEventClient) { + ctx, cancel := context.WithCancel(context.Background()) + taskName := "simple" + taskInfo := simpleTask(taskName, 4) + + require.NoError(t, metaCli.PutTask(ctx, taskInfo)) + ch := make(chan streamhelper.TaskEvent, 1024) + require.NoError(t, metaCli.Begin(ctx, ch)) + require.NoError(t, metaCli.DeleteTask(ctx, taskName)) + + taskName2 := "simple2" + taskInfo2 := simpleTask(taskName2, 4) + require.NoError(t, metaCli.PutTask(ctx, taskInfo2)) + require.NoError(t, metaCli.DeleteTask(ctx, taskName2)) + first := <-ch + require.Equal(t, first.Type, streamhelper.EventAdd) + require.Equal(t, first.Name, taskName) + second := <-ch + require.Equal(t, second.Type, streamhelper.EventDel) + require.Equal(t, second.Name, taskName) + third := <-ch + require.Equal(t, third.Type, streamhelper.EventAdd) + require.Equal(t, third.Name, taskName2) + forth := <-ch + require.Equal(t, forth.Type, streamhelper.EventDel) + require.Equal(t, forth.Name, taskName2) + cancel() + _, ok := <-ch + require.False(t, ok) +} diff --git a/br/pkg/stream/models.go b/br/pkg/streamhelper/models.go similarity index 92% rename from br/pkg/stream/models.go rename to br/pkg/streamhelper/models.go index 7aee22de0c239..265669799a581 100644 --- a/br/pkg/stream/models.go +++ b/br/pkg/streamhelper/models.go @@ -1,5 +1,5 @@ // Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. -package stream +package streamhelper import ( "bytes" @@ -21,10 +21,13 @@ const ( streamKeyPrefix = "/tidb/br-stream" taskInfoPath = "/info" // nolint:deadcode,varcheck - taskCheckpointPath = "/checkpoint" - taskRangesPath = "/ranges" - taskPausePath = "/pause" - taskLastErrorPath = "/last-error" + taskCheckpointPath = "/checkpoint" + taskRangesPath = "/ranges" + taskPausePath = "/pause" + taskLastErrorPath = "/last-error" + checkpointTypeGlobal = "central_global" + checkpointTypeRegion = "region" + checkpointTypeStore = "store" ) var ( @@ -78,6 +81,11 @@ func CheckPointsOf(task string) string { return buf.String() } +// GlobalCheckpointOf returns the path to the "global" checkpoint of some task. +func GlobalCheckpointOf(task string) string { + return path.Join(streamKeyPrefix, taskCheckpointPath, task, checkpointTypeGlobal) +} + // CheckpointOf returns the checkpoint prefix of some store. // Normally it would be /checkpoint//. func CheckPointOf(task string, store uint64) string { diff --git a/br/pkg/stream/prefix_scanner.go b/br/pkg/streamhelper/prefix_scanner.go similarity index 99% rename from br/pkg/stream/prefix_scanner.go rename to br/pkg/streamhelper/prefix_scanner.go index 4700b26c5acd2..c06b3b9a26867 100644 --- a/br/pkg/stream/prefix_scanner.go +++ b/br/pkg/streamhelper/prefix_scanner.go @@ -1,5 +1,5 @@ // Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. -package stream +package streamhelper import ( "context" diff --git a/br/pkg/streamhelper/regioniter.go b/br/pkg/streamhelper/regioniter.go new file mode 100644 index 0000000000000..b2bfa0820316c --- /dev/null +++ b/br/pkg/streamhelper/regioniter.go @@ -0,0 +1,122 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper + +import ( + "bytes" + "context" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/metapb" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/redact" + "github.com/pingcap/tidb/br/pkg/utils" +) + +const ( + defaultPageSize = 2048 +) + +type RegionWithLeader struct { + Region *metapb.Region + Leader *metapb.Peer +} + +type RegionScanner interface { + // RegionScan gets a list of regions, starts from the region that contains key. + // Limit limits the maximum number of regions returned. + RegionScan(ctx context.Context, key, endKey []byte, limit int) ([]RegionWithLeader, error) +} + +type RegionIter struct { + cli RegionScanner + startKey, endKey []byte + currentStartKey []byte + // When the endKey become "", we cannot check whether the scan is done by + // comparing currentStartKey and endKey (because "" has different meaning in start key and end key). + // So set this to `ture` when endKey == "" and the scan is done. + infScanFinished bool + + // The max slice size returned by `Next`. + // This can be changed before calling `Next` each time, + // however no thread safety provided. + PageSize int +} + +// IterateRegion creates an iterater over the region range. +func IterateRegion(cli RegionScanner, startKey, endKey []byte) *RegionIter { + return &RegionIter{ + cli: cli, + startKey: startKey, + endKey: endKey, + currentStartKey: startKey, + PageSize: defaultPageSize, + } +} + +func CheckRegionConsistency(startKey, endKey []byte, regions []RegionWithLeader) error { + // current pd can't guarantee the consistency of returned regions + if len(regions) == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, "scan region return empty result, startKey: %s, endKey: %s", + redact.Key(startKey), redact.Key(endKey)) + } + + if bytes.Compare(regions[0].Region.StartKey, startKey) > 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, "first region's startKey > startKey, startKey: %s, regionStartKey: %s", + redact.Key(startKey), redact.Key(regions[0].Region.StartKey)) + } else if len(regions[len(regions)-1].Region.EndKey) != 0 && bytes.Compare(regions[len(regions)-1].Region.EndKey, endKey) < 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, "last region's endKey < endKey, endKey: %s, regionEndKey: %s", + redact.Key(endKey), redact.Key(regions[len(regions)-1].Region.EndKey)) + } + + cur := regions[0] + for _, r := range regions[1:] { + if !bytes.Equal(cur.Region.EndKey, r.Region.StartKey) { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, "region endKey not equal to next region startKey, endKey: %s, startKey: %s", + redact.Key(cur.Region.EndKey), redact.Key(r.Region.StartKey)) + } + cur = r + } + + return nil +} + +// Next get the next page of regions. +func (r *RegionIter) Next(ctx context.Context) ([]RegionWithLeader, error) { + var rs []RegionWithLeader + state := utils.InitialRetryState(30, 500*time.Millisecond, 500*time.Millisecond) + err := utils.WithRetry(ctx, func() error { + regions, err := r.cli.RegionScan(ctx, r.currentStartKey, r.endKey, r.PageSize) + if err != nil { + return err + } + if len(regions) > 0 { + endKey := regions[len(regions)-1].Region.GetEndKey() + if err := CheckRegionConsistency(r.currentStartKey, endKey, regions); err != nil { + return err + } + rs = regions + return nil + } + return CheckRegionConsistency(r.currentStartKey, r.endKey, regions) + }, &state) + if err != nil { + return nil, err + } + endKey := rs[len(rs)-1].Region.EndKey + // We have meet the last region. + if len(endKey) == 0 { + r.infScanFinished = true + } + r.currentStartKey = endKey + return rs, nil +} + +// Done checks whether the iteration is done. +func (r *RegionIter) Done() bool { + if len(r.endKey) == 0 { + return r.infScanFinished + } + return bytes.Compare(r.currentStartKey, r.endKey) >= 0 +} diff --git a/br/pkg/streamhelper/stream_listener.go b/br/pkg/streamhelper/stream_listener.go new file mode 100644 index 0000000000000..e48064613efdb --- /dev/null +++ b/br/pkg/streamhelper/stream_listener.go @@ -0,0 +1,170 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper + +import ( + "bytes" + "context" + "fmt" + "strings" + + "github.com/golang/protobuf/proto" + "github.com/pingcap/errors" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + berrors "github.com/pingcap/tidb/br/pkg/errors" + clientv3 "go.etcd.io/etcd/client/v3" +) + +type EventType int + +const ( + EventAdd EventType = iota + EventDel + EventErr +) + +func (t EventType) String() string { + switch t { + case EventAdd: + return "Add" + case EventDel: + return "Del" + case EventErr: + return "Err" + } + return "Unknown" +} + +type TaskEvent struct { + Type EventType + Name string + Info *backuppb.StreamBackupTaskInfo + Err error +} + +func (t *TaskEvent) String() string { + if t.Err != nil { + return fmt.Sprintf("%s(%s, err = %s)", t.Type, t.Name, t.Err) + } + return fmt.Sprintf("%s(%s)", t.Type, t.Name) +} + +type TaskEventClient struct { + MetaDataClient +} + +func errorEvent(err error) TaskEvent { + return TaskEvent{ + Type: EventErr, + Err: err, + } +} + +func toTaskEvent(event *clientv3.Event) (TaskEvent, error) { + if !bytes.HasPrefix(event.Kv.Key, []byte(PrefixOfTask())) { + return TaskEvent{}, errors.Annotatef(berrors.ErrInvalidArgument, "the path isn't a task path (%s)", string(event.Kv.Key)) + } + + te := TaskEvent{} + te.Name = strings.TrimPrefix(string(event.Kv.Key), PrefixOfTask()) + if event.Type == clientv3.EventTypeDelete { + te.Type = EventDel + } else if event.Type == clientv3.EventTypePut { + te.Type = EventAdd + } else { + return TaskEvent{}, errors.Annotatef(berrors.ErrInvalidArgument, "event type is wrong (%s)", event.Type) + } + te.Info = new(backuppb.StreamBackupTaskInfo) + if err := proto.Unmarshal(event.Kv.Value, te.Info); err != nil { + return TaskEvent{}, err + } + return te, nil +} + +func eventFromWatch(resp clientv3.WatchResponse) ([]TaskEvent, error) { + result := make([]TaskEvent, 0, len(resp.Events)) + for _, event := range resp.Events { + te, err := toTaskEvent(event) + if err != nil { + te.Type = EventErr + te.Err = err + } + result = append(result, te) + } + return result, nil +} + +func (t TaskEventClient) startListen(ctx context.Context, rev int64, ch chan<- TaskEvent) { + c := t.Client.Watcher.Watch(ctx, PrefixOfTask(), clientv3.WithPrefix(), clientv3.WithRev(rev)) + handleResponse := func(resp clientv3.WatchResponse) bool { + events, err := eventFromWatch(resp) + if err != nil { + ch <- errorEvent(err) + return false + } + for _, event := range events { + ch <- event + } + return true + } + + go func() { + defer close(ch) + for { + select { + case resp, ok := <-c: + if !ok { + return + } + if !handleResponse(resp) { + return + } + case <-ctx.Done(): + // drain the remain event from channel. + for { + select { + case resp, ok := <-c: + if !ok { + return + } + if !handleResponse(resp) { + return + } + default: + return + } + } + } + } + }() +} + +func (t TaskEventClient) getFullTasksAsEvent(ctx context.Context) ([]TaskEvent, int64, error) { + tasks, rev, err := t.GetAllTasksWithRevision(ctx) + if err != nil { + return nil, 0, err + } + events := make([]TaskEvent, 0, len(tasks)) + for _, task := range tasks { + te := TaskEvent{ + Type: EventAdd, + Name: task.Info.Name, + Info: &task.Info, + } + events = append(events, te) + } + return events, rev, nil +} + +func (t TaskEventClient) Begin(ctx context.Context, ch chan<- TaskEvent) error { + initialTasks, rev, err := t.getFullTasksAsEvent(ctx) + if err != nil { + return err + } + // Note: maybe `go` here so we won't block? + for _, task := range initialTasks { + ch <- task + } + t.startListen(ctx, rev+1, ch) + return nil +} diff --git a/br/pkg/streamhelper/tsheap.go b/br/pkg/streamhelper/tsheap.go new file mode 100644 index 0000000000000..64669a151467a --- /dev/null +++ b/br/pkg/streamhelper/tsheap.go @@ -0,0 +1,216 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper + +import ( + "fmt" + "strings" + "sync" + "time" + + "github.com/google/btree" + "github.com/pingcap/errors" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/kv" + "github.com/tikv/client-go/v2/oracle" +) + +// CheckpointsCache is the heap-like cache for checkpoints. +// +// "Checkpoint" is the "Resolved TS" of some range. +// A resolved ts is a "watermark" for the system, which: +// - implies there won't be any transactions (in some range) commit with `commit_ts` smaller than this TS. +// - is monotonic increasing. +// A "checkpoint" is a "safe" Resolved TS, which: +// - is a TS *less than* the real resolved ts of now. +// - is based on range (it only promises there won't be new committed txns in the range). +// - the checkpoint of union of ranges is the minimal checkpoint of all ranges. +// As an example: +// +----------------------------------+ +// ^-----------^ (Checkpoint = 42) +// ^---------------^ (Checkpoint = 76) +// ^-----------------------^ (Checkpoint = min(42, 76) = 42) +// +// For calculating the global checkpoint, we can make a heap-like structure: +// Checkpoint Ranges +// 42 -> {[0, 8], [16, 100]} +// 1002 -> {[8, 16]} +// 1082 -> {[100, inf]} +// For now, the checkpoint of range [8, 16] and [100, inf] won't affect the global checkpoint +// directly, so we can try to advance only the ranges of {[0, 8], [16, 100]} (which's checkpoint is steal). +// Once them get advance, the global checkpoint would be advanced then, +// and we don't need to update all ranges (because some new ranges don't need to be advanced so quickly.) +type CheckpointsCache interface { + fmt.Stringer + // InsertRange inserts a range with specified TS to the cache. + InsertRange(ts uint64, rng kv.KeyRange) + // InsertRanges inserts a set of ranges that sharing checkpoint to the cache. + InsertRanges(rst RangesSharesTS) + // CheckpointTS returns the now global (union of all ranges) checkpoint of the cache. + CheckpointTS() uint64 + // PopRangesWithGapGT pops the ranges which's checkpoint is + PopRangesWithGapGT(d time.Duration) []*RangesSharesTS + // Check whether the ranges in the cache is integrate. + ConsistencyCheck() error + // Clear the cache. + Clear() +} + +// NoOPCheckpointCache is used when cache disabled. +type NoOPCheckpointCache struct{} + +func (NoOPCheckpointCache) InsertRange(ts uint64, rng kv.KeyRange) {} + +func (NoOPCheckpointCache) InsertRanges(rst RangesSharesTS) {} + +func (NoOPCheckpointCache) Clear() {} + +func (NoOPCheckpointCache) String() string { + return "NoOPCheckpointCache" +} + +func (NoOPCheckpointCache) CheckpointTS() uint64 { + panic("invalid state: NoOPCheckpointCache should never be used in advancing!") +} + +func (NoOPCheckpointCache) PopRangesWithGapGT(d time.Duration) []*RangesSharesTS { + panic("invalid state: NoOPCheckpointCache should never be used in advancing!") +} + +func (NoOPCheckpointCache) ConsistencyCheck() error { + return errors.Annotatef(berrors.ErrUnsupportedOperation, "invalid state: NoOPCheckpointCache should never be used in advancing!") +} + +// RangesSharesTS is a set of ranges shares the same timestamp. +type RangesSharesTS struct { + TS uint64 + Ranges []kv.KeyRange +} + +func (rst *RangesSharesTS) String() string { + // Make a more friendly string. + return fmt.Sprintf("@%sR%d", oracle.GetTimeFromTS(rst.TS).Format("0405"), len(rst.Ranges)) +} + +func (rst *RangesSharesTS) Less(other btree.Item) bool { + return rst.TS < other.(*RangesSharesTS).TS +} + +// Checkpoints is a heap that collects all checkpoints of +// regions, it supports query the latest checkpoint fast. +// This structure is thread safe. +type Checkpoints struct { + tree *btree.BTree + + mu sync.Mutex +} + +func NewCheckpoints() *Checkpoints { + return &Checkpoints{ + tree: btree.New(32), + } +} + +// String formats the slowest 5 ranges sharing TS to string. +func (h *Checkpoints) String() string { + h.mu.Lock() + defer h.mu.Unlock() + + b := new(strings.Builder) + count := 0 + total := h.tree.Len() + h.tree.Ascend(func(i btree.Item) bool { + rst := i.(*RangesSharesTS) + b.WriteString(rst.String()) + b.WriteString(";") + count++ + return count < 5 + }) + if total-count > 0 { + fmt.Fprintf(b, "O%d", total-count) + } + return b.String() +} + +// InsertRanges insert a RangesSharesTS directly to the tree. +func (h *Checkpoints) InsertRanges(r RangesSharesTS) { + h.mu.Lock() + defer h.mu.Unlock() + if items := h.tree.Get(&r); items != nil { + i := items.(*RangesSharesTS) + i.Ranges = append(i.Ranges, r.Ranges...) + } else { + h.tree.ReplaceOrInsert(&r) + } +} + +// InsertRange inserts the region and its TS into the region tree. +func (h *Checkpoints) InsertRange(ts uint64, rng kv.KeyRange) { + h.mu.Lock() + defer h.mu.Unlock() + r := h.tree.Get(&RangesSharesTS{TS: ts}) + if r == nil { + r = &RangesSharesTS{TS: ts} + h.tree.ReplaceOrInsert(r) + } + rr := r.(*RangesSharesTS) + rr.Ranges = append(rr.Ranges, rng) +} + +// Clear removes all records in the checkpoint cache. +func (h *Checkpoints) Clear() { + h.mu.Lock() + defer h.mu.Unlock() + h.tree.Clear(false) +} + +// PopRangesWithGapGT pops ranges with gap greater than the specified duration. +// NOTE: maybe make something like `DrainIterator` for better composing? +func (h *Checkpoints) PopRangesWithGapGT(d time.Duration) []*RangesSharesTS { + h.mu.Lock() + defer h.mu.Unlock() + result := []*RangesSharesTS{} + for { + item, ok := h.tree.Min().(*RangesSharesTS) + if !ok { + return result + } + if time.Since(oracle.GetTimeFromTS(item.TS)) >= d { + result = append(result, item) + h.tree.DeleteMin() + } else { + return result + } + } +} + +// CheckpointTS returns the cached checkpoint TS by the current state of the cache. +func (h *Checkpoints) CheckpointTS() uint64 { + h.mu.Lock() + defer h.mu.Unlock() + item, ok := h.tree.Min().(*RangesSharesTS) + if !ok { + return 0 + } + return item.TS +} + +// ConsistencyCheck checks whether the tree contains the full range of key space. +// TODO: add argument to it and check a sub range. +func (h *Checkpoints) ConsistencyCheck() error { + h.mu.Lock() + ranges := make([]kv.KeyRange, 0, 1024) + h.tree.Ascend(func(i btree.Item) bool { + ranges = append(ranges, i.(*RangesSharesTS).Ranges...) + return true + }) + h.mu.Unlock() + + r := CollapseRanges(len(ranges), func(i int) kv.KeyRange { return ranges[i] }) + if len(r) != 1 || len(r[0].StartKey) != 0 || len(r[0].EndKey) != 0 { + return errors.Annotatef(berrors.ErrPiTRMalformedMetadata, + "the region tree cannot cover the key space, collapsed: %s", logutil.StringifyKeys(r)) + } + return nil +} diff --git a/br/pkg/streamhelper/tsheap_test.go b/br/pkg/streamhelper/tsheap_test.go new file mode 100644 index 0000000000000..843dbf3f42f09 --- /dev/null +++ b/br/pkg/streamhelper/tsheap_test.go @@ -0,0 +1,161 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. +package streamhelper_test + +import ( + "math" + "testing" + + "github.com/pingcap/tidb/br/pkg/streamhelper" + "github.com/pingcap/tidb/kv" + "github.com/stretchr/testify/require" +) + +func TestInsert(t *testing.T) { + cases := []func(func(ts uint64, a, b string)){ + func(insert func(ts uint64, a, b string)) { + insert(1, "", "01") + insert(1, "01", "02") + insert(2, "02", "022") + insert(4, "022", "") + }, + func(insert func(ts uint64, a, b string)) { + insert(1, "", "01") + insert(2, "", "01") + insert(2, "011", "02") + insert(1, "", "") + insert(65, "03", "04") + }, + } + + for _, c := range cases { + cps := streamhelper.NewCheckpoints() + expected := map[uint64]*streamhelper.RangesSharesTS{} + checkpoint := uint64(math.MaxUint64) + insert := func(ts uint64, a, b string) { + cps.InsertRange(ts, kv.KeyRange{ + StartKey: []byte(a), + EndKey: []byte(b), + }) + i, ok := expected[ts] + if !ok { + expected[ts] = &streamhelper.RangesSharesTS{TS: ts, Ranges: []kv.KeyRange{{StartKey: []byte(a), EndKey: []byte(b)}}} + } else { + i.Ranges = append(i.Ranges, kv.KeyRange{StartKey: []byte(a), EndKey: []byte(b)}) + } + if ts < checkpoint { + checkpoint = ts + } + } + c(insert) + require.Equal(t, checkpoint, cps.CheckpointTS()) + rngs := cps.PopRangesWithGapGT(0) + for _, rng := range rngs { + other := expected[rng.TS] + require.Equal(t, other, rng) + } + } +} + +func TestMergeRanges(t *testing.T) { + r := func(a, b string) kv.KeyRange { + return kv.KeyRange{StartKey: []byte(a), EndKey: []byte(b)} + } + type Case struct { + expected []kv.KeyRange + parameter []kv.KeyRange + } + cases := []Case{ + { + parameter: []kv.KeyRange{r("01", "01111"), r("0111", "0112")}, + expected: []kv.KeyRange{r("01", "0112")}, + }, + { + parameter: []kv.KeyRange{r("01", "03"), r("02", "04")}, + expected: []kv.KeyRange{r("01", "04")}, + }, + { + parameter: []kv.KeyRange{r("04", "08"), r("09", "10")}, + expected: []kv.KeyRange{r("04", "08"), r("09", "10")}, + }, + { + parameter: []kv.KeyRange{r("01", "03"), r("02", "04"), r("05", "07"), r("08", "09")}, + expected: []kv.KeyRange{r("01", "04"), r("05", "07"), r("08", "09")}, + }, + { + parameter: []kv.KeyRange{r("01", "02"), r("012", "")}, + expected: []kv.KeyRange{r("01", "")}, + }, + { + parameter: []kv.KeyRange{r("", "01"), r("02", "03"), r("021", "")}, + expected: []kv.KeyRange{r("", "01"), r("02", "")}, + }, + { + parameter: []kv.KeyRange{r("", "01"), r("001", "")}, + expected: []kv.KeyRange{r("", "")}, + }, + { + parameter: []kv.KeyRange{r("", "01"), r("", ""), r("", "02")}, + expected: []kv.KeyRange{r("", "")}, + }, + { + parameter: []kv.KeyRange{r("", "01"), r("01", ""), r("", "02"), r("", "03"), r("01", "02")}, + expected: []kv.KeyRange{r("", "")}, + }, + } + + for i, c := range cases { + result := streamhelper.CollapseRanges(len(c.parameter), func(i int) kv.KeyRange { + return c.parameter[i] + }) + require.Equal(t, c.expected, result, "case = %d", i) + } + +} + +func TestInsertRanges(t *testing.T) { + r := func(a, b string) kv.KeyRange { + return kv.KeyRange{StartKey: []byte(a), EndKey: []byte(b)} + } + rs := func(ts uint64, ranges ...kv.KeyRange) streamhelper.RangesSharesTS { + return streamhelper.RangesSharesTS{TS: ts, Ranges: ranges} + } + + type Case struct { + Expected []streamhelper.RangesSharesTS + Parameters []streamhelper.RangesSharesTS + } + + cases := []Case{ + { + Parameters: []streamhelper.RangesSharesTS{ + rs(1, r("0", "1"), r("1", "2")), + rs(1, r("2", "3"), r("3", "4")), + }, + Expected: []streamhelper.RangesSharesTS{ + rs(1, r("0", "1"), r("1", "2"), r("2", "3"), r("3", "4")), + }, + }, + { + Parameters: []streamhelper.RangesSharesTS{ + rs(1, r("0", "1")), + rs(2, r("2", "3")), + rs(1, r("4", "5"), r("6", "7")), + }, + Expected: []streamhelper.RangesSharesTS{ + rs(1, r("0", "1"), r("4", "5"), r("6", "7")), + rs(2, r("2", "3")), + }, + }, + } + + for _, c := range cases { + theTree := streamhelper.NewCheckpoints() + for _, p := range c.Parameters { + theTree.InsertRanges(p) + } + ranges := theTree.PopRangesWithGapGT(0) + for i, rs := range ranges { + require.ElementsMatch(t, c.Expected[i].Ranges, rs.Ranges, "case = %#v", c) + } + } +} diff --git a/br/pkg/task/stream.go b/br/pkg/task/stream.go index 0c25a01c56023..160aa5ad6712a 100644 --- a/br/pkg/task/stream.go +++ b/br/pkg/task/stream.go @@ -39,6 +39,8 @@ import ( "github.com/pingcap/tidb/br/pkg/restore" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/stream" + "github.com/pingcap/tidb/br/pkg/streamhelper" + advancercfg "github.com/pingcap/tidb/br/pkg/streamhelper/config" "github.com/pingcap/tidb/br/pkg/summary" "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/kv" @@ -70,6 +72,7 @@ var ( StreamStatus = "log status" StreamTruncate = "log truncate" StreamMetadata = "log metadata" + StreamCtl = "log ctl" skipSummaryCommandList = map[string]struct{}{ StreamStatus: {}, @@ -90,6 +93,7 @@ var StreamCommandMap = map[string]func(c context.Context, g glue.Glue, cmdName s StreamStatus: RunStreamStatus, StreamTruncate: RunStreamTruncate, StreamMetadata: RunStreamMetadata, + StreamCtl: RunStreamAdvancer, } // StreamConfig specifies the configure about backup stream @@ -111,6 +115,9 @@ type StreamConfig struct { // Spec for the command `status`. JSONOutput bool `json:"json-output" toml:"json-output"` + + // Spec for the command `advancer`. + AdvancerCfg advancercfg.Config `json:"advancer-config" toml:"advancer-config"` } func (cfg *StreamConfig) makeStorage(ctx context.Context) (storage.ExternalStorage, error) { @@ -521,7 +528,7 @@ func RunStreamStart( return errors.Trace(err) } - cli := stream.NewMetaDataClient(streamMgr.mgr.GetDomain().GetEtcdClient()) + cli := streamhelper.NewMetaDataClient(streamMgr.mgr.GetDomain().GetEtcdClient()) // It supports single stream log task currently. if count, err := cli.GetTaskCount(ctx); err != nil { return errors.Trace(err) @@ -548,7 +555,7 @@ func RunStreamStart( return errors.Annotate(berrors.ErrInvalidArgument, "nothing need to observe") } - ti := stream.TaskInfo{ + ti := streamhelper.TaskInfo{ PBInfo: backuppb.StreamBackupTaskInfo{ Storage: streamMgr.bc.GetStorageBackend(), StartTs: cfg.StartTS, @@ -623,7 +630,7 @@ func RunStreamStop( } defer streamMgr.close() - cli := stream.NewMetaDataClient(streamMgr.mgr.GetDomain().GetEtcdClient()) + cli := streamhelper.NewMetaDataClient(streamMgr.mgr.GetDomain().GetEtcdClient()) // to add backoff ti, err := cli.GetTask(ctx, cfg.TaskName) if err != nil { @@ -673,7 +680,7 @@ func RunStreamPause( } defer streamMgr.close() - cli := stream.NewMetaDataClient(streamMgr.mgr.GetDomain().GetEtcdClient()) + cli := streamhelper.NewMetaDataClient(streamMgr.mgr.GetDomain().GetEtcdClient()) // to add backoff ti, isPaused, err := cli.GetTaskWithPauseStatus(ctx, cfg.TaskName) if err != nil { @@ -691,7 +698,7 @@ func RunStreamPause( utils.BRServiceSafePoint{ ID: buildPauseSafePointName(ti.Info.Name), TTL: cfg.SafePointTTL, - BackupTS: globalCheckPointTS, + BackupTS: globalCheckPointTS - 1, }, ); err != nil { return errors.Trace(err) @@ -731,7 +738,7 @@ func RunStreamResume( } defer streamMgr.close() - cli := stream.NewMetaDataClient(streamMgr.mgr.GetDomain().GetEtcdClient()) + cli := streamhelper.NewMetaDataClient(streamMgr.mgr.GetDomain().GetEtcdClient()) // to add backoff ti, isPaused, err := cli.GetTaskWithPauseStatus(ctx, cfg.TaskName) if err != nil { @@ -776,6 +783,31 @@ func RunStreamResume( return nil } +func RunStreamAdvancer(c context.Context, g glue.Glue, cmdName string, cfg *StreamConfig) error { + ctx, cancel := context.WithCancel(c) + defer cancel() + mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), + cfg.CheckRequirements, false) + if err != nil { + return err + } + + etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) + if err != nil { + return err + } + env := streamhelper.CliEnv(mgr.StoreManager, etcdCLI) + advancer := streamhelper.NewCheckpointAdvancer(env) + advancer.UpdateConfig(cfg.AdvancerCfg) + daemon := streamhelper.NewAdvancerDaemon(advancer, streamhelper.OwnerManagerForLogBackup(ctx, etcdCLI)) + loop, err := daemon.Begin(ctx) + if err != nil { + return err + } + loop() + return nil +} + func checkConfigForStatus(cfg *StreamConfig) error { if len(cfg.PD) == 0 { return errors.Annotatef(berrors.ErrInvalidArgument, @@ -793,7 +825,7 @@ func makeStatusController(ctx context.Context, cfg *StreamConfig, g glue.Glue) ( if err != nil { return nil, err } - cli := stream.NewMetaDataClient(etcdCLI) + cli := streamhelper.NewMetaDataClient(etcdCLI) var printer stream.TaskPrinter if !cfg.JSONOutput { printer = stream.PrintTaskByTable(console) @@ -888,10 +920,17 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre } readMetaDone() - fileCount := 0 - shiftUntilTS := ShiftTS(cfg.Until) + var ( + fileCount uint64 = 0 + kvCount int64 = 0 + totalSize uint64 = 0 + shiftUntilTS = metas.CalculateShiftTS(cfg.Until) + ) + metas.IterateFilesFullyBefore(shiftUntilTS, func(d *backuppb.DataFileInfo) (shouldBreak bool) { fileCount++ + totalSize += d.Length + kvCount += d.NumberOfEntries return }) console.Printf("We are going to remove %s files, until %s.\n", @@ -904,6 +943,7 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre removed := metas.RemoveDataBefore(shiftUntilTS) + // remove metadata removeMetaDone := console.StartTask("Removing metadata... ") if !cfg.DryRun { if err := metas.DoWriteBack(ctx, storage); err != nil { @@ -911,7 +951,10 @@ func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *Stre } } removeMetaDone() - clearDataFileDone := console.StartTask("Clearing data files... ") + + // remove log + clearDataFileDone := console.StartTask( + fmt.Sprintf("Clearing data files done. kv-count = %v, total-size = %v", kvCount, totalSize)) worker := utils.NewWorkerPool(128, "delete files") wg := new(sync.WaitGroup) for _, f := range removed { @@ -938,16 +981,6 @@ func RunStreamRestore( cmdName string, cfg *RestoreConfig, ) (err error) { - startTime := time.Now() - defer func() { - dur := time.Since(startTime) - if err != nil { - summary.Log(cmdName+" failed summary", zap.Error(err)) - } else { - summary.Log(cmdName+" success summary", zap.Duration("total-take", dur), - zap.Uint64("restore-from", cfg.StartTS), zap.Uint64("restore-to", cfg.RestoreTS)) - } - }() ctx, cancelFn := context.WithCancel(c) defer cancelFn() @@ -1012,7 +1045,23 @@ func restoreStream( g glue.Glue, cfg *RestoreConfig, logMinTS, logMaxTS uint64, -) error { +) (err error) { + var ( + totalKVCount uint64 + totalSize uint64 + mu sync.Mutex + startTime = time.Now() + ) + defer func() { + if err != nil { + summary.Log("restore log failed summary", zap.Error(err)) + } else { + summary.Log("restore log success summary", zap.Duration("total-take", time.Since(startTime)), + zap.Uint64("restore-from", cfg.StartTS), zap.Uint64("restore-to", cfg.RestoreTS), + zap.Uint64("total-kv-count", totalKVCount), zap.Uint64("total-size", totalSize)) + } + }() + ctx, cancelFn := context.WithCancel(c) defer cancelFn() @@ -1042,7 +1091,6 @@ func restoreStream( if err != nil { return errors.Trace(err) } - client.SetRestoreRangeTS(cfg.StartTS, cfg.RestoreTS, ShiftTS(cfg.StartTS)) client.SetCurrentTS(currentTS) restoreSchedulers, err := restorePreWork(ctx, client, mgr, false) @@ -1063,6 +1111,12 @@ func restoreStream( return nil } + shiftStartTS, exist := restore.CalculateShiftTS(metas, cfg.StartTS, cfg.RestoreTS) + if !exist { + shiftStartTS = cfg.StartTS + } + client.SetRestoreRangeTS(cfg.StartTS, cfg.RestoreTS, shiftStartTS) + // read data file by given ts. dmlFiles, ddlFiles, err := client.ReadStreamDataFiles(ctx, metas) if err != nil { @@ -1081,9 +1135,16 @@ func restoreStream( return errors.Trace(err) } + updateStats := func(kvCount uint64, size uint64) { + mu.Lock() + defer mu.Unlock() + totalKVCount += kvCount + totalSize += size + } pm := g.StartProgress(ctx, "Restore Meta Files", int64(len(ddlFiles)), !cfg.LogProgress) if err = withProgress(pm, func(p glue.Progress) error { - return client.RestoreMetaKVFiles(ctx, ddlFiles, schemasReplace, p.Inc) + client.RunGCRowsLoader(ctx) + return client.RestoreMetaKVFiles(ctx, ddlFiles, schemasReplace, updateStats, p.Inc) }); err != nil { return errors.Annotate(err, "failed to restore meta files") } @@ -1097,7 +1158,7 @@ func restoreStream( pd := g.StartProgress(ctx, "Restore KV Files", int64(len(dmlFiles)), !cfg.LogProgress) err = withProgress(pd, func(p glue.Progress) error { - return client.RestoreKVFiles(ctx, rewriteRules, dmlFiles, p.Inc) + return client.RestoreKVFiles(ctx, rewriteRules, dmlFiles, updateStats, p.Inc) }) if err != nil { return errors.Annotate(err, "failed to restore kv files") @@ -1110,6 +1171,11 @@ func restoreStream( if err = client.SaveSchemas(ctx, schemasReplace, logMinTS, cfg.RestoreTS); err != nil { return errors.Trace(err) } + + if err = client.InsertGCRows(ctx); err != nil { + return errors.Annotate(err, "failed to insert rows into gc_delete_range") + } + return nil } diff --git a/br/pkg/utils/store_manager.go b/br/pkg/utils/store_manager.go new file mode 100644 index 0000000000000..db7381842e6b1 --- /dev/null +++ b/br/pkg/utils/store_manager.go @@ -0,0 +1,244 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package utils + +import ( + "context" + "crypto/tls" + "os" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" +) + +const ( + dialTimeout = 30 * time.Second + resetRetryTimes = 3 +) + +// Pool is a lazy pool of gRPC channels. +// When `Get` called, it lazily allocates new connection if connection not full. +// If it's full, then it will return allocated channels round-robin. +type Pool struct { + mu sync.Mutex + + conns []*grpc.ClientConn + next int + cap int + newConn func(ctx context.Context) (*grpc.ClientConn, error) +} + +func (p *Pool) takeConns() (conns []*grpc.ClientConn) { + p.mu.Lock() + defer p.mu.Unlock() + p.conns, conns = nil, p.conns + p.next = 0 + return conns +} + +// Close closes the conn pool. +func (p *Pool) Close() { + for _, c := range p.takeConns() { + if err := c.Close(); err != nil { + log.Warn("failed to close clientConn", zap.String("target", c.Target()), zap.Error(err)) + } + } +} + +// Get tries to get an existing connection from the pool, or make a new one if the pool not full. +func (p *Pool) Get(ctx context.Context) (*grpc.ClientConn, error) { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.conns) < p.cap { + c, err := p.newConn(ctx) + if err != nil { + return nil, err + } + p.conns = append(p.conns, c) + return c, nil + } + + conn := p.conns[p.next] + p.next = (p.next + 1) % p.cap + return conn, nil +} + +// NewConnPool creates a new Pool by the specified conn factory function and capacity. +func NewConnPool(capacity int, newConn func(ctx context.Context) (*grpc.ClientConn, error)) *Pool { + return &Pool{ + cap: capacity, + conns: make([]*grpc.ClientConn, 0, capacity), + newConn: newConn, + + mu: sync.Mutex{}, + } +} + +type StoreManager struct { + pdClient pd.Client + grpcClis struct { + mu sync.Mutex + clis map[uint64]*grpc.ClientConn + } + keepalive keepalive.ClientParameters + tlsConf *tls.Config +} + +// NewStoreManager create a new manager for gRPC connections to stores. +func NewStoreManager(pdCli pd.Client, kl keepalive.ClientParameters, tlsConf *tls.Config) *StoreManager { + return &StoreManager{ + pdClient: pdCli, + grpcClis: struct { + mu sync.Mutex + clis map[uint64]*grpc.ClientConn + }{clis: make(map[uint64]*grpc.ClientConn)}, + keepalive: kl, + tlsConf: tlsConf, + } +} + +func (mgr *StoreManager) PDClient() pd.Client { + return mgr.pdClient +} + +func (mgr *StoreManager) getGrpcConnLocked(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { + failpoint.Inject("hint-get-backup-client", func(v failpoint.Value) { + log.Info("failpoint hint-get-backup-client injected, "+ + "process will notify the shell.", zap.Uint64("store", storeID)) + if sigFile, ok := v.(string); ok { + file, err := os.Create(sigFile) + if err != nil { + log.Warn("failed to create file for notifying, skipping notify", zap.Error(err)) + } + if file != nil { + file.Close() + } + } + time.Sleep(3 * time.Second) + }) + store, err := mgr.pdClient.GetStore(ctx, storeID) + if err != nil { + return nil, errors.Trace(err) + } + opt := grpc.WithInsecure() + if mgr.tlsConf != nil { + opt = grpc.WithTransportCredentials(credentials.NewTLS(mgr.tlsConf)) + } + ctx, cancel := context.WithTimeout(ctx, dialTimeout) + bfConf := backoff.DefaultConfig + bfConf.MaxDelay = time.Second * 3 + addr := store.GetPeerAddress() + if addr == "" { + addr = store.GetAddress() + } + conn, err := grpc.DialContext( + ctx, + addr, + opt, + grpc.WithBlock(), + grpc.WithConnectParams(grpc.ConnectParams{Backoff: bfConf}), + grpc.WithKeepaliveParams(mgr.keepalive), + ) + cancel() + if err != nil { + return nil, berrors.ErrFailedToConnect.Wrap(err).GenWithStack("failed to make connection to store %d", storeID) + } + return conn, nil +} + +func (mgr *StoreManager) WithConn(ctx context.Context, storeID uint64, f func(*grpc.ClientConn)) error { + if ctx.Err() != nil { + return errors.Trace(ctx.Err()) + } + + mgr.grpcClis.mu.Lock() + defer mgr.grpcClis.mu.Unlock() + + if conn, ok := mgr.grpcClis.clis[storeID]; ok { + // Find a cached backup client. + f(conn) + return nil + } + + conn, err := mgr.getGrpcConnLocked(ctx, storeID) + if err != nil { + return errors.Trace(err) + } + // Cache the conn. + mgr.grpcClis.clis[storeID] = conn + f(conn) + return nil +} + +// ResetBackupClient reset the connection for backup client. +func (mgr *StoreManager) ResetBackupClient(ctx context.Context, storeID uint64) (backuppb.BackupClient, error) { + if ctx.Err() != nil { + return nil, errors.Trace(ctx.Err()) + } + + mgr.grpcClis.mu.Lock() + defer mgr.grpcClis.mu.Unlock() + + if conn, ok := mgr.grpcClis.clis[storeID]; ok { + // Find a cached backup client. + log.Info("Reset backup client", zap.Uint64("storeID", storeID)) + err := conn.Close() + if err != nil { + log.Warn("close backup connection failed, ignore it", zap.Uint64("storeID", storeID)) + } + delete(mgr.grpcClis.clis, storeID) + } + var ( + conn *grpc.ClientConn + err error + ) + for retry := 0; retry < resetRetryTimes; retry++ { + conn, err = mgr.getGrpcConnLocked(ctx, storeID) + if err != nil { + log.Warn("failed to reset grpc connection, retry it", + zap.Int("retry time", retry), logutil.ShortError(err)) + time.Sleep(time.Duration(retry+3) * time.Second) + continue + } + mgr.grpcClis.clis[storeID] = conn + break + } + if err != nil { + return nil, errors.Trace(err) + } + return backuppb.NewBackupClient(conn), nil +} + +// Close closes all client in Mgr. +func (mgr *StoreManager) Close() { + if mgr == nil { + return + } + mgr.grpcClis.mu.Lock() + for _, cli := range mgr.grpcClis.clis { + err := cli.Close() + if err != nil { + log.Error("fail to close Mgr", zap.Error(err)) + } + } + mgr.grpcClis.mu.Unlock() +} + +func (mgr *StoreManager) TLSConfig() *tls.Config { + if mgr == nil { + return nil + } + return mgr.tlsConf +} diff --git a/br/pkg/utils/worker.go b/br/pkg/utils/worker.go index 773cfd41a64da..cf80770d0ae67 100644 --- a/br/pkg/utils/worker.go +++ b/br/pkg/utils/worker.go @@ -3,7 +3,10 @@ package utils import ( + "github.com/pingcap/errors" "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" "go.uber.org/zap" "golang.org/x/sync/errgroup" ) @@ -107,3 +110,23 @@ func (pool *WorkerPool) RecycleWorker(worker *Worker) { func (pool *WorkerPool) HasWorker() bool { return pool.IdleCount() > 0 } + +// PanicToErr recovers when the execution get panicked, and set the error provided by the arg. +// generally, this would be used with named return value and `defer`, like: +// +// func foo() (err error) { +// defer utils.PanicToErr(&err) +// return maybePanic() +// } +// +// Before using this, there are some hints for reducing resource leakage or bugs: +// - If any of clean work (by `defer`) relies on the error (say, when error happens, rollback some operations.), please +// place `defer this` AFTER that. +// - All resources allocated should be freed by the `defer` syntax, or when panicking, they may not be recycled. +func PanicToErr(err *error) { + item := recover() + if item != nil { + *err = errors.Annotatef(berrors.ErrUnknown, "panicked when executing, message: %v", item) + log.Warn("checkpoint advancer panicked, recovering", zap.StackSkip("stack", 1), logutil.ShortError(*err)) + } +} diff --git a/config/config.go b/config/config.go index 1070e6847f2ba..d5bca9f1c2692 100644 --- a/config/config.go +++ b/config/config.go @@ -32,6 +32,7 @@ import ( "github.com/BurntSushi/toml" "github.com/pingcap/errors" zaplog "github.com/pingcap/log" + logbackupconf "github.com/pingcap/tidb/br/pkg/streamhelper/config" "github.com/pingcap/tidb/parser/terror" typejson "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/logutil" @@ -256,8 +257,10 @@ type Config struct { // BallastObjectSize set the initial size of the ballast object, the unit is byte. BallastObjectSize int `toml:"ballast-object-size" json:"ballast-object-size"` // EnableGlobalKill indicates whether to enable global kill. - EnableGlobalKill bool `toml:"enable-global-kill" json:"enable-global-kill"` TrxSummary TrxSummary `toml:"transaction-summary" json:"transaction-summary"` + EnableGlobalKill bool `toml:"enable-global-kill" json:"enable-global-kill"` + // LogBackup controls the log backup related items. + LogBackup LogBackup `toml:"log-backup" json:"log-backup"` // The following items are deprecated. We need to keep them here temporarily // to support the upgrade process. They can be removed in future. @@ -416,6 +419,13 @@ func (b *AtomicBool) UnmarshalText(text []byte) error { return nil } +// LogBackup is the config for log backup service. +// For now, it includes the embed advancer. +type LogBackup struct { + Advancer logbackupconf.Config `toml:"advancer" json:"advancer"` + Enabled bool `toml:"enabled" json:"enabled"` +} + // Log is the log section of config. type Log struct { // Log level. @@ -942,6 +952,10 @@ var defaultConf = Config{ NewCollationsEnabledOnFirstBootstrap: true, EnableGlobalKill: true, TrxSummary: DefaultTrxSummary(), + LogBackup: LogBackup{ + Advancer: logbackupconf.Default(), + Enabled: false, + }, } var ( diff --git a/domain/domain.go b/domain/domain.go index 8b06ca3c736d0..0a83758aae4f4 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/bindinfo" + "github.com/pingcap/tidb/br/pkg/streamhelper" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" ddlutil "github.com/pingcap/tidb/ddl/util" @@ -92,6 +93,7 @@ type Domain struct { indexUsageSyncLease time.Duration dumpFileGcChecker *dumpFileGcChecker expiredTimeStamp4PC types.Time + logBackupAdvancer *streamhelper.AdvancerDaemon serverID uint64 serverIDSession *concurrency.Session @@ -889,10 +891,33 @@ func (do *Domain) Init(ddlLease time.Duration, sysExecutorFactory func(*Domain) do.wg.Add(1) go do.topologySyncerKeeper() } + err = do.initLogBackup(ctx, pdClient) + if err != nil { + return err + } return nil } +func (do *Domain) initLogBackup(ctx context.Context, pdClient pd.Client) error { + cfg := config.GetGlobalConfig() + if cfg.LogBackup.Enabled { + env, err := streamhelper.TiDBEnv(pdClient, do.etcdClient, cfg) + if err != nil { + return err + } + adv := streamhelper.NewCheckpointAdvancer(env) + adv.UpdateConfig(cfg.LogBackup.Advancer) + do.logBackupAdvancer = streamhelper.NewAdvancerDaemon(adv, streamhelper.OwnerManagerForLogBackup(ctx, do.etcdClient)) + loop, err := do.logBackupAdvancer.Begin(ctx) + if err != nil { + return err + } + do.wg.Run(loop) + } + return nil +} + type sessionPool struct { resources chan pools.Resource factory pools.Factory diff --git a/go.mod b/go.mod index 2d67c63904fc3..8cce36cca8da7 100644 --- a/go.mod +++ b/go.mod @@ -46,7 +46,7 @@ require ( github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c github.com/pingcap/failpoint v0.0.0-20220423142525-ae43b7f4e5c3 github.com/pingcap/fn v0.0.0-20200306044125-d5540d389059 - github.com/pingcap/kvproto v0.0.0-20220705053936-aa9c2d20cd2a + github.com/pingcap/kvproto v0.0.0-20220705090230-a5d4ffd2ba33 github.com/pingcap/log v1.1.0 github.com/pingcap/sysutil v0.0.0-20220114020952-ea68d2dbf5b4 github.com/pingcap/tidb/parser v0.0.0-20211011031125-9b13dc409c5e diff --git a/go.sum b/go.sum index f4e4fb8b40f4f..53f2f6f7ec9ed 100644 --- a/go.sum +++ b/go.sum @@ -665,8 +665,8 @@ github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989/go.mod h1:O17Xtb github.com/pingcap/kvproto v0.0.0-20191211054548-3c6b38ea5107/go.mod h1:WWLmULLO7l8IOcQG+t+ItJ3fEcrL5FxF0Wu+HrMy26w= github.com/pingcap/kvproto v0.0.0-20220302110454-c696585a961b/go.mod h1:IOdRDPLyda8GX2hE/jO7gqaCV/PNFh8BZQCQZXfIOqI= github.com/pingcap/kvproto v0.0.0-20220525022339-6aaebf466305/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= -github.com/pingcap/kvproto v0.0.0-20220705053936-aa9c2d20cd2a h1:nP2wmyw9JTRsk5rm+tZtfAso6c/1FvuaFNbXTaYz3FE= -github.com/pingcap/kvproto v0.0.0-20220705053936-aa9c2d20cd2a/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= +github.com/pingcap/kvproto v0.0.0-20220705090230-a5d4ffd2ba33 h1:VKMmvYhtG28j1sCCBdq4s+V9UOYqNgQ6CQviQwOgTeg= +github.com/pingcap/kvproto v0.0.0-20220705090230-a5d4ffd2ba33/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20200511115504-543df19646ad/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7/go.mod h1:8AanEdAHATuRurdGxZXBz0At+9avep+ub7U1AGYLIMM= diff --git a/metrics/log_backup.go b/metrics/log_backup.go new file mode 100644 index 0000000000000..b477f447c2dbb --- /dev/null +++ b/metrics/log_backup.go @@ -0,0 +1,51 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" +) + +// log backup metrics. +// see the `Help` field for details. +var ( + LastCheckpoint = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "tidb", + Subsystem: "log_backup", + Name: "last_checkpoint", + Help: "The last global checkpoint of log backup.", + }, []string{"task"}) + AdvancerOwner = prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "tidb", + Subsystem: "log_backup", + Name: "advancer_owner", + Help: "If the node is the owner of advancers, set this to `1`, otherwise `0`.", + ConstLabels: map[string]string{}, + }) + AdvancerTickDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "tidb", + Subsystem: "log_backup", + Name: "advancer_tick_duration_sec", + Help: "The time cost of each step during advancer ticking.", + Buckets: prometheus.ExponentialBuckets(0.01, 3.0, 8), + }, []string{"step"}) + GetCheckpointBatchSize = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "tidb", + Subsystem: "log_backup", + Name: "advancer_batch_size", + Help: "The batch size of scanning region or get region checkpoint.", + Buckets: prometheus.ExponentialBuckets(1, 2.0, 12), + }, []string{"type"}) +) diff --git a/metrics/metrics.go b/metrics/metrics.go index 19809bd9c85d2..4011e587cec71 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -191,6 +191,10 @@ func RegisterMetrics() { prometheus.MustRegister(StatsHealthyGauge) prometheus.MustRegister(TxnStatusEnteringCounter) prometheus.MustRegister(TxnDurationHistogram) + prometheus.MustRegister(LastCheckpoint) + prometheus.MustRegister(AdvancerOwner) + prometheus.MustRegister(AdvancerTickDuration) + prometheus.MustRegister(GetCheckpointBatchSize) tikvmetrics.InitMetrics(TiDB, TiKVClient) tikvmetrics.RegisterMetrics() diff --git a/server/driver_tidb.go b/server/driver_tidb.go index bd96059a30ca4..063dca0e8bf88 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -335,11 +335,11 @@ func (tc *TiDBContext) EncodeSessionStates(ctx context.Context, sctx sessionctx. // Bound params are sent by CMD_STMT_SEND_LONG_DATA, the proxy can wait for COM_STMT_EXECUTE. for _, boundParam := range stmt.BoundParams() { if boundParam != nil { - return session.ErrCannotMigrateSession.GenWithStackByArgs("prepared statements have bound params") + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("prepared statements have bound params") } } if rs := stmt.GetResultSet(); rs != nil && !rs.IsClosed() { - return session.ErrCannotMigrateSession.GenWithStackByArgs("prepared statements have open result sets") + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("prepared statements have open result sets") } preparedStmtInfo.ParamTypes = stmt.GetParamsType() } diff --git a/session/session.go b/session/session.go index 91c38e8322e43..bf5a7a9277caf 100644 --- a/session/session.go +++ b/session/session.go @@ -3367,23 +3367,23 @@ func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Conte valid := s.txn.Valid() s.txn.mu.Unlock() if valid { - return ErrCannotMigrateSession.GenWithStackByArgs("session has an active transaction") + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has an active transaction") } // Data in local temporary tables is hard to encode, so we do not support it. // Check temporary tables here to avoid circle dependency. if s.sessionVars.LocalTemporaryTables != nil { localTempTables := s.sessionVars.LocalTemporaryTables.(*infoschema.LocalTemporaryTables) if localTempTables.Count() > 0 { - return ErrCannotMigrateSession.GenWithStackByArgs("session has local temporary tables") + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has local temporary tables") } } // The advisory locks will be released when the session is closed. if len(s.advisoryLocks) > 0 { - return ErrCannotMigrateSession.GenWithStackByArgs("session has advisory locks") + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has advisory locks") } // The TableInfo stores session ID and server ID, so the session cannot be migrated. if len(s.lockedTables) > 0 { - return ErrCannotMigrateSession.GenWithStackByArgs("session has locked tables") + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has locked tables") } if err := s.sessionVars.EncodeSessionStates(ctx, sessionStates); err != nil { diff --git a/session/tidb.go b/session/tidb.go index d0530488b5b3f..12ee40da2d4be 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -378,6 +378,5 @@ func ResultSetToStringSlice(ctx context.Context, s Session, rs sqlexec.RecordSet // Session errors. var ( - ErrForUpdateCantRetry = dbterror.ClassSession.NewStd(errno.ErrForUpdateCantRetry) - ErrCannotMigrateSession = dbterror.ClassSession.NewStd(errno.ErrCannotMigrateSession) + ErrForUpdateCantRetry = dbterror.ClassSession.NewStd(errno.ErrForUpdateCantRetry) ) diff --git a/sessionctx/sessionstates/BUILD.bazel b/sessionctx/sessionstates/BUILD.bazel index 1cd0a6c172cc2..1aa44bc7fdb32 100644 --- a/sessionctx/sessionstates/BUILD.bazel +++ b/sessionctx/sessionstates/BUILD.bazel @@ -2,20 +2,33 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "sessionstates", - srcs = ["session_states.go"], + srcs = [ + "session_states.go", + "session_token.go", + ], importpath = "github.com/pingcap/tidb/sessionctx/sessionstates", visibility = ["//visibility:public"], deps = [ + "//errno", "//parser/types", "//sessionctx/stmtctx", "//types", + "//util/dbterror", + "//util/logutil", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@org_uber_go_zap//:zap", ], ) go_test( name = "sessionstates_test", timeout = "short", - srcs = ["session_states_test.go"], + srcs = [ + "session_states_test.go", + "session_token_test.go", + ], + embed = [":sessionstates"], deps = [ "//config", "//errno", @@ -25,8 +38,10 @@ go_test( "//sessionctx/variable", "//testkit", "//types", + "//util", "//util/sem", "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//require", ], ) diff --git a/sessionctx/sessionstates/session_states.go b/sessionctx/sessionstates/session_states.go index a9636e2f90014..36ea0b22455d7 100644 --- a/sessionctx/sessionstates/session_states.go +++ b/sessionctx/sessionstates/session_states.go @@ -17,14 +17,21 @@ package sessionstates import ( "time" + "github.com/pingcap/tidb/errno" ptypes "github.com/pingcap/tidb/parser/types" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/dbterror" ) // SessionStateType is the type of session states. type SessionStateType int +var ( + // ErrCannotMigrateSession indicates the session cannot be migrated. + ErrCannotMigrateSession = dbterror.ClassSession.NewStd(errno.ErrCannotMigrateSession) +) + // These enums represents the types of session state handlers. const ( // StatePrepareStmt represents prepared statements. diff --git a/sessionctx/sessionstates/session_token.go b/sessionctx/sessionstates/session_token.go new file mode 100644 index 0000000000000..975e627848969 --- /dev/null +++ b/sessionctx/sessionstates/session_token.go @@ -0,0 +1,338 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessionstates + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/sha512" + "crypto/tls" + "crypto/x509" + "encoding/json" + "strings" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/util/logutil" + "go.uber.org/zap" +) + +// Token-based authentication is used in session migration. We don't use typical authentication because the proxy +// cannot store the user passwords for security issues. +// +// The process of token-based authentication: +// 1. Before migrating the session, the proxy requires a token from server A. +// 2. Server A generates a token and signs it with a private key defined in the certificate. +// 3. The proxy authenticates with server B and sends the signed token as the password. +// 4. Server B checks the signature with the public key defined in the certificate and then verifies the token. +// +// The highlight is that the certificates on all the servers should be the same all the time. +// However, the certificates should be rotated periodically. Just in case of using different certificates to +// sign and check, a server should keep the old certificate for a while. A server will try both +// the 2 certificates to check the signature. +const ( + // A token needs a lifetime to avoid brute force attack. + tokenLifetime = time.Minute + // Reload the certificate periodically because it may be rotated. + loadCertInterval = 10 * time.Minute + // After a certificate is replaced, it's still valid for oldCertValidTime. + // oldCertValidTime must be a little longer than loadCertInterval, because the previous server may + // sign with the old cert but the new server checks with the new cert. + // - server A loads the old cert at 00:00:00. + // - the cert is rotated at 00:00:01 on all servers. + // - server B loads the new cert at 00:00:02. + // - server A signs token with the old cert at 00:10:00. + // - server B reloads the same new cert again at 00:10:01, and it has 3 certs now. + // - server B receives the token at 00:10:02, so the old cert should be valid for more than 10m after replacement. + oldCertValidTime = 15 * time.Minute +) + +// SessionToken represents the token used to authenticate with the new server. +type SessionToken struct { + Username string `json:"username"` + SignTime time.Time `json:"sign-time"` + ExpireTime time.Time `json:"expire-time"` + Signature []byte `json:"signature,omitempty"` +} + +// CreateSessionToken creates a token for the proxy. +func CreateSessionToken(username string) (*SessionToken, error) { + now := getNow() + token := &SessionToken{ + Username: username, + SignTime: now, + ExpireTime: now.Add(tokenLifetime), + } + tokenBytes, err := json.Marshal(token) + if err != nil { + return nil, errors.Trace(err) + } + if token.Signature, err = globalSigningCert.sign(tokenBytes); err != nil { + return nil, ErrCannotMigrateSession.GenWithStackByArgs(err.Error()) + } + return token, nil +} + +// ValidateSessionToken validates the token sent from the proxy. +func ValidateSessionToken(tokenBytes []byte, username string) (err error) { + var token SessionToken + if err = json.Unmarshal(tokenBytes, &token); err != nil { + return errors.Trace(err) + } + signature := token.Signature + // Clear the signature and marshal it again to get the original content. + token.Signature = nil + if tokenBytes, err = json.Marshal(token); err != nil { + return errors.Trace(err) + } + if err = globalSigningCert.checkSignature(tokenBytes, signature); err != nil { + return ErrCannotMigrateSession.GenWithStackByArgs(err.Error()) + } + now := getNow() + if now.After(token.ExpireTime) { + return ErrCannotMigrateSession.GenWithStackByArgs("token expired", token.ExpireTime.String()) + } + // An attacker may forge a very long lifetime to brute force, so we also need to check `SignTime`. + // However, we need to be tolerant of these problems: + // - The `tokenLifetime` may change between TiDB versions, so we can't check `token.SignTime.Add(tokenLifetime).Equal(token.ExpireTime)` + // - There may exist time bias between TiDB instances, so we can't check `now.After(token.SignTime)` + if token.SignTime.Add(tokenLifetime).Before(now) { + return ErrCannotMigrateSession.GenWithStackByArgs("token lifetime is too long", token.SignTime.String()) + } + if !strings.EqualFold(username, token.Username) { + return ErrCannotMigrateSession.GenWithStackByArgs("username does not match", username, token.Username) + } + return nil +} + +// SetKeyPath sets the path of key.pem and force load the certificate again. +func SetKeyPath(keyPath string) { + globalSigningCert.setKeyPath(keyPath) +} + +// SetCertPath sets the path of key.pem and force load the certificate again. +func SetCertPath(certPath string) { + globalSigningCert.setCertPath(certPath) +} + +// ReloadSigningCert is used to load the certificate periodically in a separate goroutine. +// It's impossible to know when the old certificate should expire without this goroutine: +// - If the certificate is rotated a minute ago, the old certificate should be still valid for a while. +// - If the certificate is rotated a month ago, the old certificate should expire for safety. +func ReloadSigningCert() { + globalSigningCert.lockAndLoad() +} + +var globalSigningCert signingCert + +// signingCert represents the parsed certificate used for token-based auth. +type signingCert struct { + sync.RWMutex + certPath string + keyPath string + // The cert file may happen to be rotated between signing and checking, so we keep the old cert for a while. + // certs contain all the certificates that are not expired yet. + certs []*certInfo +} + +type certInfo struct { + cert *x509.Certificate + privKey crypto.PrivateKey + expireTime time.Time +} + +// We cannot guarantee that the cert and key paths are set at the same time because they are set through system variables. +func (sc *signingCert) setCertPath(certPath string) { + sc.Lock() + // Just in case of repeatedly loading global variables, we check the path to avoid useless loading. + if certPath != sc.certPath { + sc.certPath = certPath + // It may fail expectedly because the key path is not set yet. + sc.checkAndLoadCert() + } + sc.Unlock() +} + +func (sc *signingCert) setKeyPath(keyPath string) { + sc.Lock() + if keyPath != sc.keyPath { + sc.keyPath = keyPath + // It may fail expectedly because the cert path is not set yet. + sc.checkAndLoadCert() + } + sc.Unlock() +} + +func (sc *signingCert) lockAndLoad() { + sc.Lock() + sc.checkAndLoadCert() + sc.Unlock() +} + +func (sc *signingCert) checkAndLoadCert() { + if len(sc.certPath) == 0 || len(sc.keyPath) == 0 { + return + } + if err := sc.loadCert(); err != nil { + logutil.BgLogger().Warn("loading signing cert failed", + zap.String("cert path", sc.certPath), + zap.String("key path", sc.keyPath), + zap.Error(err)) + } else { + logutil.BgLogger().Info("signing cert is loaded successfully", + zap.String("cert path", sc.certPath), + zap.String("key path", sc.keyPath)) + } +} + +// loadCert loads the cert and adds it into the cert list. +func (sc *signingCert) loadCert() error { + tlsCert, err := tls.LoadX509KeyPair(sc.certPath, sc.keyPath) + if err != nil { + return errors.Wrapf(err, "load x509 failed, cert path: %s, key path: %s", sc.certPath, sc.keyPath) + } + var cert *x509.Certificate + if tlsCert.Leaf != nil { + cert = tlsCert.Leaf + } else { + if cert, err = x509.ParseCertificate(tlsCert.Certificate[0]); err != nil { + return errors.Wrapf(err, "parse x509 cert failed, cert path: %s, key path: %s", sc.certPath, sc.keyPath) + } + } + + // Rotate certs. Ensure that the expireTime of certs is in descending order. + now := getNow() + newCerts := make([]*certInfo, 0, len(sc.certs)+1) + newCerts = append(newCerts, &certInfo{ + cert: cert, + privKey: tlsCert.PrivateKey, + expireTime: now.Add(loadCertInterval + oldCertValidTime), + }) + for i := 0; i < len(sc.certs); i++ { + // Discard the certs that are already expired. + if now.After(sc.certs[i].expireTime) { + break + } + newCerts = append(newCerts, sc.certs[i]) + } + sc.certs = newCerts + return nil +} + +// sign generates a signature with the content and the private key. +func (sc *signingCert) sign(content []byte) ([]byte, error) { + var ( + signer crypto.Signer + opts crypto.SignerOpts + ) + sc.RLock() + defer sc.RUnlock() + if len(sc.certs) == 0 { + return nil, errors.New("no certificate or key file to sign the data") + } + // Always sign the token with the latest cert. + certInfo := sc.certs[0] + switch key := certInfo.privKey.(type) { + case ed25519.PrivateKey: + signer = key + opts = crypto.Hash(0) + case *rsa.PrivateKey: + signer = key + var pssHash crypto.Hash + switch certInfo.cert.SignatureAlgorithm { + case x509.SHA256WithRSAPSS: + pssHash = crypto.SHA256 + case x509.SHA384WithRSAPSS: + pssHash = crypto.SHA384 + case x509.SHA512WithRSAPSS: + pssHash = crypto.SHA512 + } + if pssHash != 0 { + h := pssHash.New() + h.Write(content) + content = h.Sum(nil) + opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: pssHash} + break + } + switch certInfo.cert.SignatureAlgorithm { + case x509.SHA256WithRSA: + hashed := sha256.Sum256(content) + content = hashed[:] + opts = crypto.SHA256 + case x509.SHA384WithRSA: + hashed := sha512.Sum384(content) + content = hashed[:] + opts = crypto.SHA384 + case x509.SHA512WithRSA: + hashed := sha512.Sum512(content) + content = hashed[:] + opts = crypto.SHA512 + default: + return nil, errors.Errorf("not supported private key type '%s' for signing", certInfo.cert.SignatureAlgorithm.String()) + } + case *ecdsa.PrivateKey: + signer = key + default: + return nil, errors.Errorf("not supported private key type '%s' for signing", certInfo.cert.SignatureAlgorithm.String()) + } + return signer.Sign(rand.Reader, content, opts) +} + +// checkSignature checks the signature and the content. +func (sc *signingCert) checkSignature(content, signature []byte) error { + sc.RLock() + defer sc.RUnlock() + now := getNow() + var err error + for _, certInfo := range sc.certs { + // The expireTime is in descending order. So if the first one is expired, we skip the following. + if now.After(certInfo.expireTime) { + break + } + switch certInfo.privKey.(type) { + // ESDSA is special: `PrivateKey.Sign` doesn't match with `Certificate.CheckSignature`. + case *ecdsa.PrivateKey: + if !ecdsa.VerifyASN1(certInfo.cert.PublicKey.(*ecdsa.PublicKey), content, signature) { + err = errors.New("x509: ECDSA verification failure") + } + default: + err = certInfo.cert.CheckSignature(certInfo.cert.SignatureAlgorithm, content, signature) + } + if err == nil { + return nil + } + } + // no certs (possible) or all certs are expired (impossible) + if err == nil { + return errors.Errorf("no valid certificate to check the signature, cached certificates: %d", len(sc.certs)) + } + return err +} + +func getNow() time.Time { + now := time.Now() + failpoint.Inject("mockNowOffset", func(val failpoint.Value) { + if s := uint64(val.(int)); s != 0 { + now = now.Add(time.Duration(s)) + } + }) + return now +} diff --git a/sessionctx/sessionstates/session_token_test.go b/sessionctx/sessionstates/session_token_test.go new file mode 100644 index 0000000000000..9b56e0cde67e2 --- /dev/null +++ b/sessionctx/sessionstates/session_token_test.go @@ -0,0 +1,266 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessionstates + +import ( + "crypto/x509" + "encoding/json" + "fmt" + "path/filepath" + "testing" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/util" + "github.com/stretchr/testify/require" +) + +var ( + mockNowOffset = "github.com/pingcap/tidb/sessionctx/sessionstates/mockNowOffset" +) + +func TestSetCertAndKey(t *testing.T) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "test1_cert.pem") + keyPath := filepath.Join(tempDir, "test1_key.pem") + createRSACert(t, certPath, keyPath) + + // no cert and no key + _, err := CreateSessionToken("test_user") + require.ErrorContains(t, err, "no certificate or key file") + // no cert + SetKeyPath(keyPath) + _, err = CreateSessionToken("test_user") + require.ErrorContains(t, err, "no certificate or key file") + // no key + SetKeyPath("") + SetCertPath(certPath) + _, err = CreateSessionToken("test_user") + require.ErrorContains(t, err, "no certificate or key file") + // both configured + SetKeyPath(keyPath) + _, err = CreateSessionToken("test_user") + require.NoError(t, err) + // When the key and cert don't match, it will still use the old pair. + certPath2 := filepath.Join(tempDir, "test2_cert.pem") + keyPath2 := filepath.Join(tempDir, "test2_key.pem") + err = util.CreateCertificates(certPath2, keyPath2, 4096, x509.RSA, x509.UnknownSignatureAlgorithm) + require.NoError(t, err) + SetKeyPath(keyPath2) + _, err = CreateSessionToken("test_user") + require.NoError(t, err) +} + +func TestSignAlgo(t *testing.T) { + tests := []struct { + pubKeyAlgo x509.PublicKeyAlgorithm + signAlgos []x509.SignatureAlgorithm + keySizes []int + }{ + { + pubKeyAlgo: x509.RSA, + signAlgos: []x509.SignatureAlgorithm{ + x509.SHA256WithRSA, + x509.SHA384WithRSA, + x509.SHA512WithRSA, + x509.SHA256WithRSAPSS, + x509.SHA384WithRSAPSS, + x509.SHA512WithRSAPSS, + }, + keySizes: []int{ + 2048, + 4096, + }, + }, + { + pubKeyAlgo: x509.ECDSA, + signAlgos: []x509.SignatureAlgorithm{ + x509.ECDSAWithSHA256, + x509.ECDSAWithSHA384, + x509.ECDSAWithSHA512, + }, + keySizes: []int{ + 4096, + }, + }, + { + pubKeyAlgo: x509.Ed25519, + signAlgos: []x509.SignatureAlgorithm{ + x509.PureEd25519, + }, + keySizes: []int{ + 4096, + }, + }, + } + + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "test1_cert.pem") + keyPath := filepath.Join(tempDir, "test1_key.pem") + SetKeyPath(keyPath) + SetCertPath(certPath) + for _, test := range tests { + for _, signAlgo := range test.signAlgos { + for _, keySize := range test.keySizes { + msg := fmt.Sprintf("pubKeyAlgo: %s, signAlgo: %s, keySize: %d", test.pubKeyAlgo.String(), + signAlgo.String(), keySize) + err := util.CreateCertificates(certPath, keyPath, keySize, test.pubKeyAlgo, signAlgo) + require.NoError(t, err, msg) + ReloadSigningCert() + _, tokenBytes := createNewToken(t, "test_user") + err = ValidateSessionToken(tokenBytes, "test_user") + require.NoError(t, err, msg) + } + } + } +} + +func TestVerifyToken(t *testing.T) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "test1_cert.pem") + keyPath := filepath.Join(tempDir, "test1_key.pem") + createRSACert(t, certPath, keyPath) + SetKeyPath(keyPath) + SetCertPath(certPath) + + // check succeeds + token, tokenBytes := createNewToken(t, "test_user") + err := ValidateSessionToken(tokenBytes, "test_user") + require.NoError(t, err) + // the token expires + timeOffset := uint64(tokenLifetime + time.Minute) + require.NoError(t, failpoint.Enable(mockNowOffset, fmt.Sprintf(`return(%d)`, timeOffset))) + err = ValidateSessionToken(tokenBytes, "test_user") + require.NoError(t, failpoint.Disable(mockNowOffset)) + require.ErrorContains(t, err, "token expired") + // the current user is different with the token + err = ValidateSessionToken(tokenBytes, "another_user") + require.ErrorContains(t, err, "username does not match") + // forge the user name + token.Username = "another_user" + tokenBytes2, err := json.Marshal(token) + require.NoError(t, err) + err = ValidateSessionToken(tokenBytes2, "another_user") + require.ErrorContains(t, err, "verification error") + // forge the expire time + token.Username = "test_user" + token.ExpireTime = time.Now().Add(-time.Minute) + tokenBytes2, err = json.Marshal(token) + require.NoError(t, err) + err = ValidateSessionToken(tokenBytes2, "test_user") + require.ErrorContains(t, err, "verification error") +} + +func TestCertExpire(t *testing.T) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "test1_cert.pem") + keyPath := filepath.Join(tempDir, "test1_key.pem") + createRSACert(t, certPath, keyPath) + SetKeyPath(keyPath) + SetCertPath(certPath) + + _, tokenBytes := createNewToken(t, "test_user") + err := ValidateSessionToken(tokenBytes, "test_user") + require.NoError(t, err) + // replace the cert, but the old cert is still valid for a while + certPath2 := filepath.Join(tempDir, "test2_cert.pem") + keyPath2 := filepath.Join(tempDir, "test2_key.pem") + createRSACert(t, certPath2, keyPath2) + SetKeyPath(keyPath2) + SetCertPath(certPath2) + err = ValidateSessionToken(tokenBytes, "test_user") + require.NoError(t, err) + // the old cert expires and the original token is invalid + timeOffset := uint64(loadCertInterval) + require.NoError(t, failpoint.Enable(mockNowOffset, fmt.Sprintf(`return(%d)`, timeOffset))) + ReloadSigningCert() + timeOffset += uint64(oldCertValidTime + time.Minute) + require.NoError(t, failpoint.Enable(mockNowOffset, fmt.Sprintf(`return(%d)`, timeOffset))) + err = ValidateSessionToken(tokenBytes, "test_user") + require.ErrorContains(t, err, "verification error") + // the new cert is not rotated but is reloaded + _, tokenBytes = createNewToken(t, "test_user") + ReloadSigningCert() + err = ValidateSessionToken(tokenBytes, "test_user") + require.NoError(t, err) + // the cert is rotated but is still valid + createRSACert(t, certPath2, keyPath2) + timeOffset += uint64(loadCertInterval) + require.NoError(t, failpoint.Enable(mockNowOffset, fmt.Sprintf(`return(%d)`, timeOffset))) + ReloadSigningCert() + err = ValidateSessionToken(tokenBytes, "test_user") + require.ErrorContains(t, err, "token expired") + // after some time, it's not valid + timeOffset += uint64(oldCertValidTime + time.Minute) + require.NoError(t, failpoint.Enable(mockNowOffset, fmt.Sprintf(`return(%d)`, timeOffset))) + err = ValidateSessionToken(tokenBytes, "test_user") + require.NoError(t, failpoint.Disable(mockNowOffset)) + require.ErrorContains(t, err, "verification error") +} + +func TestLoadAndReadConcurrently(t *testing.T) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "test1_cert.pem") + keyPath := filepath.Join(tempDir, "test1_key.pem") + createRSACert(t, certPath, keyPath) + SetKeyPath(keyPath) + SetCertPath(certPath) + + deadline := time.Now().Add(5 * time.Second) + var wg util.WaitGroupWrapper + // the writer + wg.Run(func() { + for time.Now().Before(deadline) { + createRSACert(t, certPath, keyPath) + time.Sleep(time.Second) + } + }) + // the loader + for i := 0; i < 2; i++ { + wg.Run(func() { + for time.Now().Before(deadline) { + ReloadSigningCert() + time.Sleep(500 * time.Millisecond) + } + }) + } + // the reader + for i := 0; i < 3; i++ { + wg.Run(func() { + username := fmt.Sprintf("test_user_%d", i) + for time.Now().Before(deadline) { + _, tokenBytes := createNewToken(t, username) + time.Sleep(10 * time.Millisecond) + err := ValidateSessionToken(tokenBytes, username) + require.NoError(t, err) + time.Sleep(10 * time.Millisecond) + } + }) + } + wg.Wait() +} + +func createNewToken(t *testing.T, username string) (*SessionToken, []byte) { + token, err := CreateSessionToken(username) + require.NoError(t, err) + tokenBytes, err := json.Marshal(token) + require.NoError(t, err) + return token, tokenBytes +} + +func createRSACert(t *testing.T, certPath, keyPath string) { + err := util.CreateCertificates(certPath, keyPath, 4096, x509.RSA, x509.UnknownSignatureAlgorithm) + require.NoError(t, err) +} diff --git a/util/misc.go b/util/misc.go index 9434ee3f98106..498e7433d2fd0 100644 --- a/util/misc.go +++ b/util/misc.go @@ -16,6 +16,10 @@ package util import ( "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -615,12 +619,9 @@ func QueryStrForLog(query string) string { return query } -func createTLSCertificates(certpath string, keypath string, rsaKeySize int) error { - privkey, err := rsa.GenerateKey(rand.Reader, rsaKeySize) - if err != nil { - return err - } - +// CreateCertificates creates and writes a cert based on the params. +func CreateCertificates(certpath string, keypath string, rsaKeySize int, pubKeyAlgo x509.PublicKeyAlgorithm, + signAlgo x509.SignatureAlgorithm) error { certValidity := 90 * 24 * time.Hour // 90 days notBefore := time.Now() notAfter := notBefore.Add(certValidity) @@ -633,14 +634,29 @@ func createTLSCertificates(certpath string, keypath string, rsaKeySize int) erro Subject: pkix.Name{ CommonName: "TiDB_Server_Auto_Generated_Server_Certificate", }, - SerialNumber: big.NewInt(1), - NotBefore: notBefore, - NotAfter: notAfter, - DNSNames: []string{hostname}, + SerialNumber: big.NewInt(1), + NotBefore: notBefore, + NotAfter: notAfter, + DNSNames: []string{hostname}, + SignatureAlgorithm: signAlgo, + } + + var privKey crypto.Signer + switch pubKeyAlgo { + case x509.RSA: + privKey, err = rsa.GenerateKey(rand.Reader, rsaKeySize) + case x509.ECDSA: + privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + case x509.Ed25519: + _, privKey, err = ed25519.GenerateKey(rand.Reader) + default: + return errors.Errorf("unknown public key algorithm: %s", pubKeyAlgo.String()) + } + if err != nil { + return err } - // DER: Distinguished Encoding Rules, this is the ASN.1 encoding rule of the certificate. - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privkey.PublicKey, privkey) + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, privKey.Public(), privKey) if err != nil { return err } @@ -661,7 +677,7 @@ func createTLSCertificates(certpath string, keypath string, rsaKeySize int) erro return err } - privBytes, err := x509.MarshalPKCS8PrivateKey(privkey) + privBytes, err := x509.MarshalPKCS8PrivateKey(privKey) if err != nil { return err } @@ -678,3 +694,8 @@ func createTLSCertificates(certpath string, keypath string, rsaKeySize int) erro zap.Duration("validity", certValidity), zap.Int("rsaKeySize", rsaKeySize)) return nil } + +func createTLSCertificates(certpath string, keypath string, rsaKeySize int) error { + // use RSA and unspecified signature algorithm + return CreateCertificates(certpath, keypath, rsaKeySize, x509.RSA, x509.UnknownSignatureAlgorithm) +}