Skip to content

Commit

Permalink
*: remove in-struct context (pingcap#452)
Browse files Browse the repository at this point in the history
Signed-off-by: Neil Shen <[email protected]>
  • Loading branch information
overvenus authored Sep 18, 2020
1 parent be23a9c commit 5af97a1
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 48 deletions.
10 changes: 2 additions & 8 deletions pkg/backup/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,6 @@ func (bc *Client) BackupRanges(
updateCh glue.Progress,
) ([]*kvproto.File, error) {
errCh := make(chan error)
ctx, cancel := context.WithCancel(ctx)
defer cancel()

// we collect all files in a single goroutine to avoid thread safety issues.
filesCh := make(chan []*kvproto.File, concurrency)
Expand Down Expand Up @@ -477,8 +475,6 @@ func (bc *Client) BackupRange(
zap.Stringer("EndKey", utils.WrapKey(endKey)),
zap.Uint64("RateLimit", req.RateLimit),
zap.Uint32("Concurrency", req.Concurrency))
ctx, cancel := context.WithCancel(ctx)
defer cancel()

var allStores []*metapb.Store
allStores, err = conn.GetAllTiKVStores(ctx, bc.mgr.GetPDClient(), conn.SkipTiFlash)
Expand All @@ -491,10 +487,10 @@ func (bc *Client) BackupRange(
req.EndKey = endKey
req.StorageBackend = bc.backend

push := newPushDown(ctx, bc.mgr, len(allStores))
push := newPushDown(bc.mgr, len(allStores))

var results rtree.RangeTree
results, err = push.pushBackup(req, allStores, updateCh)
results, err = push.pushBackup(ctx, req, allStores, updateCh)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -801,8 +797,6 @@ func SendBackup(
zap.Stringer("EndKey", utils.WrapKey(req.EndKey)),
zap.Uint64("storeID", storeID),
)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
bcli, err := client.Backup(ctx, &req)
if err != nil {
log.Error("fail to backup", zap.Uint64("StoreID", storeID))
Expand Down
10 changes: 4 additions & 6 deletions pkg/backup/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@ import (

// pushDown warps a backup task.
type pushDown struct {
ctx context.Context
mgr ClientMgr
respCh chan *backup.BackupResponse
errCh chan error
}

// newPushDown creates a push down backup.
func newPushDown(ctx context.Context, mgr ClientMgr, cap int) *pushDown {
log.Info("new backup client")
func newPushDown(mgr ClientMgr, cap int) *pushDown {
return &pushDown{
ctx: ctx,
mgr: mgr,
respCh: make(chan *backup.BackupResponse, cap),
errCh: make(chan error, cap),
Expand All @@ -37,6 +34,7 @@ func newPushDown(ctx context.Context, mgr ClientMgr, cap int) *pushDown {

// FullBackup make a full backup of a tikv cluster.
func (push *pushDown) pushBackup(
ctx context.Context,
req backup.BackupRequest,
stores []*metapb.Store,
updateCh glue.Progress,
Expand All @@ -50,7 +48,7 @@ func (push *pushDown) pushBackup(
log.Warn("skip store", zap.Uint64("StoreID", storeID), zap.Stringer("State", s.GetState()))
continue
}
client, err := push.mgr.GetBackupClient(push.ctx, storeID)
client, err := push.mgr.GetBackupClient(ctx, storeID)
if err != nil {
log.Error("fail to connect store", zap.Uint64("StoreID", storeID))
return res, errors.Trace(err)
Expand All @@ -59,7 +57,7 @@ func (push *pushDown) pushBackup(
go func() {
defer wg.Done()
err := SendBackup(
push.ctx, storeID, client, req,
ctx, storeID, client, req,
func(resp *backup.BackupResponse) error {
// Forward all responses (including error).
push.respCh <- resp
Expand Down
42 changes: 18 additions & 24 deletions pkg/restore/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ const defaultChecksumConcurrency = 64

// Client sends requests to restore files.
type Client struct {
ctx context.Context
cancel context.CancelFunc

pdClient pd.Client
toolClient SplitClient
fileImporter FileImporter
Expand Down Expand Up @@ -84,22 +81,17 @@ type Client struct {

// NewRestoreClient returns a new RestoreClient.
func NewRestoreClient(
ctx context.Context,
g glue.Glue,
pdClient pd.Client,
store kv.Storage,
tlsConf *tls.Config,
) (*Client, error) {
ctx, cancel := context.WithCancel(ctx)
db, err := NewDB(g, store)
if err != nil {
cancel()
return nil, errors.Trace(err)
}

return &Client{
ctx: ctx,
cancel: cancel,
pdClient: pdClient,
toolClient: NewSplitClient(pdClient, tlsConf),
db: db,
Expand Down Expand Up @@ -145,7 +137,6 @@ func (rc *Client) Close() {
if rc.db != nil {
rc.db.Close()
}
rc.cancel()
log.Info("Restore client closed")
}

Expand Down Expand Up @@ -258,22 +249,22 @@ func (rc *Client) GetTS(ctx context.Context) (uint64, error) {
}

// ResetTS resets the timestamp of PD to a bigger value.
func (rc *Client) ResetTS(pdAddrs []string) error {
func (rc *Client) ResetTS(ctx context.Context, pdAddrs []string) error {
restoreTS := rc.backupMeta.GetEndVersion()
log.Info("reset pd timestamp", zap.Uint64("ts", restoreTS))
i := 0
return utils.WithRetry(rc.ctx, func() error {
return utils.WithRetry(ctx, func() error {
idx := i % len(pdAddrs)
i++
return utils.ResetTS(pdAddrs[idx], restoreTS, rc.tlsConf)
}, newPDReqBackoffer())
}

// GetPlacementRules return the current placement rules.
func (rc *Client) GetPlacementRules(pdAddrs []string) ([]placement.Rule, error) {
func (rc *Client) GetPlacementRules(ctx context.Context, pdAddrs []string) ([]placement.Rule, error) {
var placementRules []placement.Rule
i := 0
errRetry := utils.WithRetry(rc.ctx, func() error {
errRetry := utils.WithRetry(ctx, func() error {
var err error
idx := i % len(pdAddrs)
i++
Expand Down Expand Up @@ -317,12 +308,12 @@ func (rc *Client) GetTableSchema(
}

// CreateDatabase creates a database.
func (rc *Client) CreateDatabase(db *model.DBInfo) error {
func (rc *Client) CreateDatabase(ctx context.Context, db *model.DBInfo) error {
if rc.IsSkipCreateSQL() {
log.Info("skip create database", zap.Stringer("database", db.Name))
return nil
}
return rc.db.CreateDatabase(rc.ctx, db)
return rc.db.CreateDatabase(ctx, db)
}

// CreateTables creates multiple tables, and returns their rewrite rules.
Expand Down Expand Up @@ -472,14 +463,14 @@ func (rc *Client) createTablesWithDBPool(ctx context.Context,
}

// ExecDDLs executes the queries of the ddl jobs.
func (rc *Client) ExecDDLs(ddlJobs []*model.Job) error {
func (rc *Client) ExecDDLs(ctx context.Context, ddlJobs []*model.Job) error {
// Sort the ddl jobs by schema version in ascending order.
sort.Slice(ddlJobs, func(i, j int) bool {
return ddlJobs[i].BinlogInfo.SchemaVersion < ddlJobs[j].BinlogInfo.SchemaVersion
})

for _, job := range ddlJobs {
err := rc.db.ExecDDL(rc.ctx, job)
err := rc.db.ExecDDL(ctx, job)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -491,14 +482,14 @@ func (rc *Client) ExecDDLs(ddlJobs []*model.Job) error {
return nil
}

func (rc *Client) setSpeedLimit() error {
func (rc *Client) setSpeedLimit(ctx context.Context) error {
if !rc.hasSpeedLimited && rc.rateLimit != 0 {
stores, err := conn.GetAllTiKVStores(rc.ctx, rc.pdClient, conn.SkipTiFlash)
stores, err := conn.GetAllTiKVStores(ctx, rc.pdClient, conn.SkipTiFlash)
if err != nil {
return err
}
for _, store := range stores {
err = rc.fileImporter.setDownloadSpeedLimit(rc.ctx, store.GetId())
err = rc.fileImporter.setDownloadSpeedLimit(ctx, store.GetId())
if err != nil {
return err
}
Expand All @@ -510,6 +501,7 @@ func (rc *Client) setSpeedLimit() error {

// RestoreFiles tries to restore the files.
func (rc *Client) RestoreFiles(
ctx context.Context,
files []*backup.File,
rewriteRules *RewriteRules,
updateCh glue.Progress,
Expand All @@ -527,8 +519,8 @@ func (rc *Client) RestoreFiles(
log.Debug("start to restore files",
zap.Int("files", len(files)),
)
eg, ectx := errgroup.WithContext(rc.ctx)
err = rc.setSpeedLimit()
eg, ectx := errgroup.WithContext(ctx)
err = rc.setSpeedLimit(ctx)
if err != nil {
return err
}
Expand All @@ -553,7 +545,9 @@ func (rc *Client) RestoreFiles(
}

// RestoreRaw tries to restore raw keys in the specified range.
func (rc *Client) RestoreRaw(startKey []byte, endKey []byte, files []*backup.File, updateCh glue.Progress) error {
func (rc *Client) RestoreRaw(
ctx context.Context, startKey []byte, endKey []byte, files []*backup.File, updateCh glue.Progress,
) error {
start := time.Now()
defer func() {
elapsed := time.Since(start)
Expand All @@ -563,7 +557,7 @@ func (rc *Client) RestoreRaw(startKey []byte, endKey []byte, files []*backup.Fil
zap.Duration("take", elapsed))
}()
errCh := make(chan error, len(files))
eg, ectx := errgroup.WithContext(rc.ctx)
eg, ectx := errgroup.WithContext(ctx)
defer close(errCh)

err := rc.fileImporter.SetRawRange(startKey, endKey)
Expand Down
5 changes: 2 additions & 3 deletions pkg/restore/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package restore_test

import (
"context"
"math"
"strconv"

Expand Down Expand Up @@ -40,7 +39,7 @@ func (s *testRestoreClientSuite) TestCreateTables(c *C) {
c.Assert(s.mock.Start(), IsNil)
defer s.mock.Stop()

client, err := restore.NewRestoreClient(context.Background(), gluetidb.New(), s.mock.PDClient, s.mock.Storage, nil)
client, err := restore.NewRestoreClient(gluetidb.New(), s.mock.PDClient, s.mock.Storage, nil)
c.Assert(err, IsNil)

info, err := s.mock.Domain.GetSnapshotInfoSchema(math.MaxInt64)
Expand Down Expand Up @@ -98,7 +97,7 @@ func (s *testRestoreClientSuite) TestIsOnline(c *C) {
c.Assert(s.mock.Start(), IsNil)
defer s.mock.Stop()

client, err := restore.NewRestoreClient(context.Background(), gluetidb.New(), s.mock.PDClient, s.mock.Storage, nil)
client, err := restore.NewRestoreClient(gluetidb.New(), s.mock.PDClient, s.mock.Storage, nil)
c.Assert(err, IsNil)

c.Assert(client.IsOnline(), IsFalse)
Expand Down
2 changes: 1 addition & 1 deletion pkg/restore/pipeline_items.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (b *tikvSender) RestoreBatch(ctx context.Context, ranges []rtree.Range, rew
files = append(files, fs.Files...)
}

if err := b.client.RestoreFiles(files, rewriteRules, b.updateCh); err != nil {
if err := b.client.RestoreFiles(ctx, files, rewriteRules, b.updateCh); err != nil {
return err
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/task/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf
}
defer mgr.Close()

client, err := restore.NewRestoreClient(ctx, g, mgr.GetPDClient(), mgr.GetTiKV(), mgr.GetTLSConfig())
client, err := restore.NewRestoreClient(g, mgr.GetPDClient(), mgr.GetTiKV(), mgr.GetTLSConfig())
if err != nil {
return err
}
Expand Down Expand Up @@ -158,7 +158,7 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf
defer restoreDBConfig()

// execute DDL first
err = client.ExecDDLs(ddlJobs)
err = client.ExecDDLs(ctx, ddlJobs)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -172,7 +172,7 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf
}

for _, db := range dbs {
err = client.CreateDatabase(db.Info)
err = client.CreateDatabase(ctx, db.Info)
if err != nil {
return err
}
Expand Down Expand Up @@ -226,7 +226,7 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf
// Do not reset timestamp if we are doing incremental restore, because
// we are not allowed to decrease timestamp.
if !client.IsIncremental() {
if err = client.ResetTS(cfg.PD); err != nil {
if err = client.ResetTS(ctx, cfg.PD); err != nil {
log.Error("reset pd TS failed", zap.Error(err))
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/task/restore_raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func RunRestoreRaw(c context.Context, g glue.Glue, cmdName string, cfg *RestoreR
}
defer mgr.Close()

client, err := restore.NewRestoreClient(ctx, g, mgr.GetPDClient(), mgr.GetTiKV(), mgr.GetTLSConfig())
client, err := restore.NewRestoreClient(g, mgr.GetPDClient(), mgr.GetTiKV(), mgr.GetTLSConfig())
if err != nil {
return err
}
Expand Down Expand Up @@ -116,7 +116,7 @@ func RunRestoreRaw(c context.Context, g glue.Glue, cmdName string, cfg *RestoreR
}
defer restorePostWork(ctx, client, restoreSchedulers)

err = client.RestoreRaw(cfg.StartKey, cfg.EndKey, files, updateCh)
err = client.RestoreRaw(ctx, cfg.StartKey, cfg.EndKey, files, updateCh)
if err != nil {
return errors.Trace(err)
}
Expand Down

0 comments on commit 5af97a1

Please sign in to comment.