diff --git a/client/go.mod b/client/go.mod index 9eb066d0fcc..5c20dea65a0 100644 --- a/client/go.mod +++ b/client/go.mod @@ -10,6 +10,7 @@ require ( github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 github.com/pingcap/kvproto v0.0.0-20230727073445-53e1f8730c30 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 + github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.11.1 github.com/stretchr/testify v1.8.2 go.uber.org/goleak v1.1.11 @@ -25,7 +26,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect - github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.26.0 // indirect diff --git a/client/pd_service_discovery.go b/client/pd_service_discovery.go index 4499c9e17c0..16489278b19 100644 --- a/client/pd_service_discovery.go +++ b/client/pd_service_discovery.go @@ -29,9 +29,11 @@ import ( "github.com/pingcap/log" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/grpcutil" + "github.com/tikv/pd/client/retry" "github.com/tikv/pd/client/tlsutil" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" ) const ( @@ -39,6 +41,7 @@ const ( memberUpdateInterval = time.Minute serviceModeUpdateInterval = 3 * time.Second updateMemberTimeout = time.Second // Use a shorter timeout to recover faster from network isolation. + requestTimeout = 2 * time.Second ) type serviceType int @@ -61,7 +64,7 @@ type ServiceDiscovery interface { GetKeyspaceID() uint32 // GetKeyspaceGroupID returns the ID of the keyspace group GetKeyspaceGroupID() uint32 - // DiscoverServiceURLs discovers the microservice with the specified type and returns the server urls. + // DiscoverMicroservice discovers the microservice with the specified type and returns the server urls. DiscoverMicroservice(svcType serviceType) ([]string, error) // GetServiceURLs returns the URLs of the servers providing the service GetServiceURLs() []string @@ -95,6 +98,8 @@ type ServiceDiscovery interface { // in a quorum-based cluster or any primary/secondary in a primary/secondary configured cluster // is changed. AddServiceAddrsSwitchedCallback(callbacks ...func()) + // GetBackoffer returns the backoffer. + GetBackoffer() *retry.Backoffer } type updateKeyspaceIDFunc func() error @@ -153,6 +158,9 @@ type pdServiceDiscovery struct { tlsCfg *tlsutil.TLSConfig // Client option. option *option + + successReConnect chan struct{} + bo *retry.Backoffer } // newPDServiceDiscovery returns a new PD service discovery-based client. @@ -166,6 +174,7 @@ func newPDServiceDiscovery( ) *pdServiceDiscovery { pdsd := &pdServiceDiscovery{ checkMembershipCh: make(chan struct{}, 1), + successReConnect: make(chan struct{}, 1), ctx: ctx, cancel: cancel, wg: wg, @@ -174,6 +183,7 @@ func newPDServiceDiscovery( keyspaceID: keyspaceID, tlsCfg: tlsCfg, option: option, + bo: retry.NewBackoffer(ctx, maxRetryTimes), } pdsd.urls.Store(urls) return pdsd @@ -207,7 +217,7 @@ func (c *pdServiceDiscovery) Init() error { } c.wg.Add(2) - go c.updateMemberLoop() + go c.reconnectMemberLoop() go c.updateServiceModeLoop() c.isInitialized = true @@ -231,13 +241,17 @@ func (c *pdServiceDiscovery) initRetry(f func() error) error { return errors.WithStack(err) } -func (c *pdServiceDiscovery) updateMemberLoop() { +func (c *pdServiceDiscovery) reconnectMemberLoop() { defer c.wg.Done() ctx, cancel := context.WithCancel(c.ctx) defer cancel() ticker := time.NewTicker(memberUpdateInterval) defer ticker.Stop() + failpoint.Inject("acceleratedMemberUpdateInterval", func() { + ticker.Stop() + ticker = time.NewTicker(time.Millisecond * 100) + }) for { select { @@ -246,15 +260,98 @@ func (c *pdServiceDiscovery) updateMemberLoop() { case <-ticker.C: case <-c.checkMembershipCh: } + failpoint.Inject("skipUpdateMember", func() { failpoint.Continue() }) + if err := c.updateMember(); err != nil { - log.Error("[pd] failed to update member", zap.Strings("urls", c.GetServiceURLs()), errs.ZapError(err)) + log.Error("[pd] failed to update member", errs.ZapError(err)) + } else { + c.SuccessReconnect() } } } +func (c *pdServiceDiscovery) waitForReady() error { + ctx, cancel := context.WithCancel(c.ctx) + defer cancel() + + if e1 := c.waitForLeaderReady(); e1 != nil { + log.Error("[pd.waitForReady] failed to wait for leader ready", errs.ZapError(e1)) + return errors.WithStack(e1) + } else if e2 := c.loadMembers(); e2 != nil { + log.Error("[pd.waitForReady] failed to load members", errs.ZapError(e2)) + } else { + return nil + } + + deadline := time.Now().Add(requestTimeout) + failpoint.Inject("acceleratedRequestTimeout", func() { + deadline = time.Now().Add(500 * time.Millisecond) + }) + for { + select { + case <-c.successReConnect: + return nil + case <-time.After(time.Until(deadline)): + log.Error("[pd.waitForReady] timeout") + return errors.New("wait for ready timeout") + case <-ctx.Done(): + log.Info("[pd.waitForReady] exit") + return nil + } + } +} + +// waitForLeaderReady waits for the leader to be ready. +func (c *pdServiceDiscovery) waitForLeaderReady() error { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + for { + old, ok := c.clientConns.Load(c.getLeaderAddr()) + if !ok { + cancel() + return errors.New("no leader") + } + cc := old.(*grpc.ClientConn) + + s := cc.GetState() + if s == connectivity.Ready { + cancel() + return nil + } + if !cc.WaitForStateChange(ctx, s) { + cancel() + // ctx got timeout or canceled. + return ctx.Err() + } + } +} + +func (c *pdServiceDiscovery) loadMembers() error { + ctx, cancel := context.WithCancel(c.ctx) + defer cancel() + + members, err := c.getMembers(ctx, c.getLeaderAddr(), updateMemberTimeout) + if err != nil { + log.Error("[pd.loadMembers] failed to load members ", zap.String("url", c.getLeaderAddr()), errs.ZapError(err)) + return errors.WithStack(err) + } else if members.GetHeader() == nil || members.GetLeader() == nil || len(members.GetLeader().GetClientUrls()) == 0 { + err = errs.ErrClientGetLeader.FastGenByArgs("leader address don't exist") + log.Error("[pd.loadMembers] leader address don't exist. ", zap.String("url", c.getLeaderAddr()), errs.ZapError(err)) + return errors.WithStack(err) + } + + return nil +} + +func (c *pdServiceDiscovery) SuccessReconnect() { + select { + case c.successReConnect <- struct{}{}: + default: + } +} + func (c *pdServiceDiscovery) updateServiceModeLoop() { defer c.wg.Done() failpoint.Inject("skipUpdateServiceMode", func() { @@ -319,7 +416,7 @@ func (c *pdServiceDiscovery) GetKeyspaceGroupID() uint32 { return defaultKeySpaceGroupID } -// DiscoverServiceURLs discovers the microservice with the specified type and returns the server urls. +// DiscoverMicroservice discovers the microservice with the specified type and returns the server urls. func (c *pdServiceDiscovery) DiscoverMicroservice(svcType serviceType) (urls []string, err error) { switch svcType { case apiService: @@ -382,11 +479,23 @@ func (c *pdServiceDiscovery) GetBackupAddrs() []string { func (c *pdServiceDiscovery) ScheduleCheckMemberChanged() { select { case c.checkMembershipCh <- struct{}{}: + if err := c.waitForReady(); err != nil { + // If backoff times count is greater than 10, reset it. + if c.bo.GetBackoffTimeCnt(retry.BoMemberUpdate.String()) >= 10 { + c.bo.Reset() + } + e := c.bo.Backoff(retry.BoMemberUpdate, err) + if e != nil { + log.Error("[pd] wait for ready backoff failed", errs.ZapError(e)) + return + } + log.Error("[pd] wait for ready failed", errs.ZapError(err)) + } default: } } -// Immediately check if there is any membership change among the leader/followers in a +// CheckMemberChanged Immediately check if there is any membership change among the leader/followers in a // quorum-based cluster or among the primary/secondaries in a primary/secondary configured cluster. func (c *pdServiceDiscovery) CheckMemberChanged() error { return c.updateMember() @@ -669,3 +778,7 @@ func (c *pdServiceDiscovery) switchTSOAllocatorLeaders(allocatorMap map[string]* func (c *pdServiceDiscovery) GetOrCreateGRPCConn(addr string) (*grpc.ClientConn, error) { return grpcutil.GetOrCreateGRPCConn(c.ctx, &c.clientConns, addr, c.tlsCfg, c.option.gRPCDialOptions...) } + +func (c *pdServiceDiscovery) GetBackoffer() *retry.Backoffer { + return c.bo +} diff --git a/client/retry/backoff.go b/client/retry/backoff.go new file mode 100644 index 00000000000..c5350b7ef72 --- /dev/null +++ b/client/retry/backoff.go @@ -0,0 +1,166 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package retry + +import ( + "context" + "fmt" + "time" + + "github.com/opentracing/opentracing-go" + "github.com/pingcap/log" + "github.com/pkg/errors" + "go.uber.org/zap" +) + +// Backoffer is a utility for retrying queries. +type Backoffer struct { + ctx context.Context + + fn map[string]backoffFn + maxSleep int + totalSleep int + + errors []error + configs []*Config + backoffSleepMS map[string]int + backoffTimes map[string]int +} + +// NewBackoffer (Deprecated) creates a Backoffer with maximum sleep time(in ms). +func NewBackoffer(ctx context.Context, maxSleep int) *Backoffer { + return &Backoffer{ + ctx: ctx, + maxSleep: maxSleep, + } +} + +// Backoff sleeps a while base on the Config and records the error message. +// It returns a retryable error if total sleep time exceeds maxSleep. +func (b *Backoffer) Backoff(cfg *Config, err error) error { + if span := opentracing.SpanFromContext(b.ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan(fmt.Sprintf("pd.client.backoff.%s", cfg), opentracing.ChildOf(span.Context())) + defer span1.Finish() + opentracing.ContextWithSpan(b.ctx, span1) + } + return b.BackoffWithCfgAndMaxSleep(cfg, -1, err) +} + +// BackoffWithCfgAndMaxSleep sleeps a while base on the Config and records the error message +// and never sleep more than maxSleepMs for each sleep. +func (b *Backoffer) BackoffWithCfgAndMaxSleep(cfg *Config, maxSleepMs int, err error) error { + select { + case <-b.ctx.Done(): + return errors.WithStack(err) + default: + } + b.errors = append(b.errors, errors.Errorf("%s at %s", err.Error(), time.Now().Format(time.RFC3339Nano))) + b.configs = append(b.configs, cfg) + + // Lazy initialize. + if b.fn == nil { + b.fn = make(map[string]backoffFn) + } + f, ok := b.fn[cfg.name] + if !ok { + f = cfg.createBackoffFn() + b.fn[cfg.name] = f + } + realSleep := f(b.ctx, maxSleepMs) + + b.totalSleep += realSleep + if b.backoffSleepMS == nil { + b.backoffSleepMS = make(map[string]int) + } + b.backoffSleepMS[cfg.name] += realSleep + if b.backoffTimes == nil { + b.backoffTimes = make(map[string]int) + } + b.backoffTimes[cfg.name]++ + + log.Debug("retry later", + zap.Error(err), + zap.Int("totalSleep", b.totalSleep), + zap.Int("maxSleep", b.maxSleep), + zap.Stringer("type", cfg)) + return nil +} + +func (b *Backoffer) String() string { + if b.totalSleep == 0 { + return "" + } + return fmt.Sprintf(" backoff(%dms %v)", b.totalSleep, b.configs) +} + +// GetTotalSleep returns total sleep time. +func (b *Backoffer) GetTotalSleep() int { + return b.totalSleep +} + +// GetCtx returns the bound context. +func (b *Backoffer) GetCtx() context.Context { + return b.ctx +} + +// SetCtx sets the bound context to ctx. +func (b *Backoffer) SetCtx(ctx context.Context) { + b.ctx = ctx +} + +// GetBackoffTimes returns a map contains backoff time count by type. +func (b *Backoffer) GetBackoffTimes() map[string]int { + return b.backoffTimes +} + +// GetBackoffTimeCnt returns backoff time count by specific type. +func (b *Backoffer) GetBackoffTimeCnt(s string) int { + return b.backoffTimes[s] +} + +// GetTotalBackoffTimes returns the total backoff times of the backoffer. +func (b *Backoffer) GetTotalBackoffTimes() int { + total := 0 + for _, t := range b.backoffTimes { + total += t + } + return total +} + +// GetBackoffSleepMS returns a map contains backoff sleep time by type. +func (b *Backoffer) GetBackoffSleepMS() map[string]int { + return b.backoffSleepMS +} + +// ErrorsNum returns the number of errors. +func (b *Backoffer) ErrorsNum() int { + return len(b.errors) +} + +// Reset resets the sleep state of the backoffer, so that following backoff +// can sleep shorter. The reason why we don't create a new backoffer is that +// backoffer is similar to context, and it records some metrics that we +// want to record for an entire process which is composed of serveral stages. +func (b *Backoffer) Reset() { + b.fn = nil + b.totalSleep = 0 +} + +// ResetMaxSleep resets the sleep state and max sleep limit of the backoffer. +// It's used when switches to the next stage of the process. +func (b *Backoffer) ResetMaxSleep(maxSleep int) { + b.Reset() + b.maxSleep = maxSleep +} diff --git a/client/retry/backoff_test.go b/client/retry/backoff_test.go new file mode 100644 index 00000000000..2b9c943b864 --- /dev/null +++ b/client/retry/backoff_test.go @@ -0,0 +1,29 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package retry + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBackoffErrorType(t *testing.T) { + b := NewBackoffer(context.TODO(), 800) + err := b.Backoff(BoMemberUpdate, errors.New("no leader")) // 100 ms + assert.Nil(t, err) +} diff --git a/client/retry/config.go b/client/retry/config.go new file mode 100644 index 00000000000..478187f6f18 --- /dev/null +++ b/client/retry/config.go @@ -0,0 +1,141 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package retry + +import ( + "context" + "math" + "math/rand" + "time" + + "github.com/pingcap/log" + "go.uber.org/zap" +) + +// Config is the configuration of the Backoff function. +type Config struct { + name string + fnCfg *BackoffFnCfg + err error +} + +// Backoff Config variables +var ( + // BoMemberUpdate is for member change events. + BoMemberUpdate = NewConfig("memberUpdate", NewBackoffFnCfg(100, 2000, EqualJitter), nil) +) + +// backoffFn is the backoff function which compute the sleep time and do sleep. +type backoffFn func(ctx context.Context, maxSleepMs int) int + +func (c *Config) createBackoffFn() backoffFn { + return newBackoffFn(c.fnCfg.base, c.fnCfg.cap, c.fnCfg.jitter) +} + +// BackoffFnCfg is the configuration for the backoff func which implements exponential backoff with +// optional jitters. +// See http://www.awsarchitectureblog.com/2015/03/backoff.html +type BackoffFnCfg struct { + base int + cap int + jitter int +} + +// NewBackoffFnCfg creates the config for BackoffFn. +func NewBackoffFnCfg(base, cap, jitter int) *BackoffFnCfg { + return &BackoffFnCfg{ + base, + cap, + jitter, + } +} + +// NewConfig creates a new Config for the Backoff operation. +func NewConfig(name string, backoffFnCfg *BackoffFnCfg, err error) *Config { + return &Config{ + name: name, + fnCfg: backoffFnCfg, + err: err, + } +} + +func (c *Config) String() string { + return c.name +} + +// SetErrors sets a more detailed error instead of the default bo config. +func (c *Config) SetErrors(err error) { + c.err = err +} + +const ( + // NoJitter makes the backoff sequence strict exponential. + NoJitter = 1 + iota + // FullJitter applies random factors to strict exponential. + FullJitter + // EqualJitter is also randomized, but prevents very short sleeps. + EqualJitter + // DecorrJitter increases the maximum jitter based on the last random value. + DecorrJitter +) + +// newBackoffFn creates a backoff func which implements exponential backoff with +// optional jitters. +// See http://www.awsarchitectureblog.com/2015/03/backoff.html +func newBackoffFn(base, cap, jitter int) backoffFn { + if base < 2 { + // Top prevent panic in 'rand.Intn'. + base = 2 + } + attempts := 0 + lastSleep := base + return func(ctx context.Context, maxSleepMs int) int { + var sleep int + switch jitter { + case NoJitter: + sleep = expo(base, cap, attempts) + case FullJitter: + v := expo(base, cap, attempts) + sleep = rand.Intn(v) + case EqualJitter: + v := expo(base, cap, attempts) + sleep = v/2 + rand.Intn(v/2) + case DecorrJitter: + sleep = int(math.Min(float64(cap), float64(base+rand.Intn(lastSleep*3-base)))) + } + log.Debug("backoff", + zap.Int("base", base), + zap.Int("sleep", sleep), + zap.Int("attempts", attempts)) + + realSleep := sleep + // when set maxSleepMs >= 0 will force sleep maxSleepMs milliseconds. + if maxSleepMs >= 0 && realSleep > maxSleepMs { + realSleep = maxSleepMs + } + select { + case <-time.After(time.Duration(realSleep) * time.Millisecond): + attempts++ + lastSleep = sleep + return realSleep + case <-ctx.Done(): + return 0 + } + } +} + +func expo(base, cap, n int) int { + return int(math.Min(float64(cap), float64(base)*math.Pow(2.0, float64(n)))) +} diff --git a/client/tso_service_discovery.go b/client/tso_service_discovery.go index 2aeb49e1523..62484aef2a8 100644 --- a/client/tso_service_discovery.go +++ b/client/tso_service_discovery.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/log" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/grpcutil" + "github.com/tikv/pd/client/retry" "github.com/tikv/pd/client/tlsutil" "go.uber.org/zap" "google.golang.org/grpc" @@ -351,7 +352,7 @@ func (c *tsoServiceDiscovery) ScheduleCheckMemberChanged() { } } -// Immediately check if there is any membership change among the primary/secondaries in +// CheckMemberChanged Immediately check if there is any membership change among the primary/secondaries in // a primary/secondary configured cluster. func (c *tsoServiceDiscovery) CheckMemberChanged() error { c.apiSvcDiscovery.CheckMemberChanged() @@ -641,3 +642,7 @@ func (c *tsoServiceDiscovery) discoverWithLegacyPath() ([]string, error) { } return listenUrls, nil } + +func (c *tsoServiceDiscovery) GetBackoffer() *retry.Backoffer { + panic("unimplemented") +} diff --git a/pkg/utils/testutil/leak.go b/pkg/utils/testutil/leak.go index d1329aef0e6..1bc70855a3e 100644 --- a/pkg/utils/testutil/leak.go +++ b/pkg/utils/testutil/leak.go @@ -28,4 +28,5 @@ var LeakOptions = []goleak.Option{ goleak.IgnoreTopFunction("net/http.(*persistConn).writeLoop"), // natefinch/lumberjack#56, It's a goroutine leak bug. Another ignore option PR https://github.com/pingcap/tidb/pull/27405/ goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"), + goleak.IgnoreTopFunction("google.golang.org/grpc.(*ClientConn).WaitForStateChange"), } diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index 41e7e650261..10f2f03ff5b 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -36,6 +36,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" pd "github.com/tikv/pd/client" + "github.com/tikv/pd/client/retry" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/mock/mockid" @@ -96,6 +97,10 @@ func TestClientClusterIDCheck(t *testing.T) { func TestClientLeaderChange(t *testing.T) { re := require.New(t) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/acceleratedMemberUpdateInterval", `return(true)`)) + defer func() { + failpoint.Disable("github.com/tikv/pd/client/acceleratedMemberUpdateInterval") + }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 3) @@ -312,6 +317,10 @@ func TestTSOFollowerProxy(t *testing.T) { // TestUnavailableTimeAfterLeaderIsReady is used to test https://github.com/tikv/pd/issues/5207 func TestUnavailableTimeAfterLeaderIsReady(t *testing.T) { re := require.New(t) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/acceleratedMemberUpdateInterval", `return(true)`)) + defer func() { + failpoint.Disable("github.com/tikv/pd/client/acceleratedMemberUpdateInterval") + }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() cluster, err := tests.NewTestCluster(ctx, 3) @@ -375,6 +384,10 @@ func TestUnavailableTimeAfterLeaderIsReady(t *testing.T) { // TODO: migrate the Local/Global TSO tests to TSO integration test folder. func TestGlobalAndLocalTSO(t *testing.T) { re := require.New(t) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/acceleratedMemberUpdateInterval", `return(true)`)) + defer func() { + failpoint.Disable("github.com/tikv/pd/client/acceleratedMemberUpdateInterval") + }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() dcLocationConfig := map[string]string{ @@ -1506,3 +1519,54 @@ func TestClientWatchWithRevision(t *testing.T) { } } } + +func (suite *clientTestSuite) TestRetryMemberUpdate() { + re := suite.Require() + re.NoError(failpoint.Enable("github.com/tikv/pd/client/acceleratedRequestTimeout", `return(true)`)) + defer func() { + failpoint.Disable("github.com/tikv/pd/client/acceleratedRequestTimeout") + }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) + defer cluster.Destroy() + + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) + defer cli.Close() + innerCli, ok := cli.(interface{ GetServiceDiscovery() pd.ServiceDiscovery }) + re.True(ok) + + leader := cluster.GetLeader() + waitLeader(re, innerCli.GetServiceDiscovery(), cluster.GetServer(leader).GetConfig().ClientUrls) + memberID := cluster.GetServer(leader).GetLeader().GetMemberId() + + re.NoError(failpoint.Enable("github.com/tikv/pd/server/leaderLoopCheckAgain", fmt.Sprintf("return(\"%d\")", memberID))) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/exitCampaignLeader", fmt.Sprintf("return(\"%d\")", memberID))) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/timeoutWaitPDLeader", `return(true)`)) + + leader2 := waitLeaderChange(re, cluster, leader, innerCli.GetServiceDiscovery()) + re.NotEqual(leader, leader2) + + re.NoError(failpoint.Disable("github.com/tikv/pd/server/leaderLoopCheckAgain")) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/exitCampaignLeader")) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/timeoutWaitPDLeader")) + + bo := innerCli.GetServiceDiscovery().GetBackoffer() + retryTimes := bo.GetBackoffTimes()[retry.BoMemberUpdate.String()] + re.Greater(retryTimes, 0) +} + +func waitLeaderChange(re *require.Assertions, cluster *tests.TestCluster, old string, cli pd.ServiceDiscovery) string { + var leader string + testutil.Eventually(re, func() bool { + cli.ScheduleCheckMemberChanged() + leader = cluster.GetLeader() + if leader == old || leader == "" { + return false + } + return true + }) + return leader +}