From 4a401141e32a493ea0923721542fccfbf6d5a64d Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Mon, 3 Mar 2025 15:18:15 +0800 Subject: [PATCH] mcs: add lock for forward tso stream (#9095) (#9107) close tikv/pd#9091 Signed-off-by: ti-chi-bot Signed-off-by: lhy1024 Co-authored-by: lhy1024 Co-authored-by: lhy1024 --- pkg/utils/grpcutil/grpcutil.go | 14 +++ server/grpc_service.go | 88 +++++++++++--- server/metrics.go | 9 ++ server/server.go | 16 ++- tests/integrations/mcs/tso/server_test.go | 138 ++++++++++++++++++++++ 5 files changed, 242 insertions(+), 23 deletions(-) diff --git a/pkg/utils/grpcutil/grpcutil.go b/pkg/utils/grpcutil/grpcutil.go index 97d1719da52..0fc3a15c1ad 100644 --- a/pkg/utils/grpcutil/grpcutil.go +++ b/pkg/utils/grpcutil/grpcutil.go @@ -18,12 +18,15 @@ import ( "context" "crypto/tls" "crypto/x509" + "io" "net/url" + "strings" "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" "go.etcd.io/etcd/pkg/transport" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -172,3 +175,14 @@ func GetForwardedHost(ctx context.Context) string { } return "" } + +// NeedRebuildConnection checks if the error is a connection error. +func NeedRebuildConnection(err error) bool { + return err == io.EOF || + strings.Contains(err.Error(), codes.Unavailable.String()) || // Unavailable indicates the service is currently unavailable. This is a most likely a transient condition. + strings.Contains(err.Error(), codes.DeadlineExceeded.String()) || // DeadlineExceeded means operation expired before completion. + strings.Contains(err.Error(), codes.Internal.String()) || // Internal errors. + strings.Contains(err.Error(), codes.Unknown.String()) || // Unknown error. + strings.Contains(err.Error(), codes.ResourceExhausted.String()) // ResourceExhausted is returned when either the client or the server has exhausted their resources. + // Besides, we don't need to rebuild the connection if the code is Canceled, which means the client cancelled the request. +} diff --git a/server/grpc_service.go b/server/grpc_service.go index 7c587ac18b8..22cfbc885bf 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -20,6 +20,7 @@ import ( "io" "path" "strconv" + "strings" "sync/atomic" "time" @@ -48,7 +49,9 @@ import ( ) const ( - heartbeatSendTimeout = 5 * time.Second + heartbeatSendTimeout = 5 * time.Second + maxRetryTimesRequestTSOServer = 3 + retryIntervalRequestTSOServer = 500 * time.Millisecond ) // gRPC errors @@ -1781,31 +1784,77 @@ func checkStream(streamCtx context.Context, cancel context.CancelFunc, done chan } func (s *GrpcServer) getGlobalTSOFromTSOServer(ctx context.Context) (pdpb.Timestamp, error) { - forwardedHost, ok := s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) - if !ok || forwardedHost == "" { - return pdpb.Timestamp{}, ErrNotFoundTSOAddr - } - forwardStream, err := s.getTSOForwardStream(forwardedHost) - if err != nil { - return pdpb.Timestamp{}, err - } - forwardStream.Send(&tsopb.TsoRequest{ + request := &tsopb.TsoRequest{ Header: &tsopb.RequestHeader{ ClusterId: s.clusterID, KeyspaceId: utils.DefaultKeyspaceID, KeyspaceGroupId: utils.DefaultKeyspaceGroupID, }, Count: 1, - }) - ts, err := forwardStream.Recv() - if err != nil { - log.Error("get global tso from tso server failed", zap.Error(err)) - return pdpb.Timestamp{}, err } - return *ts.GetTimestamp(), nil + var ( + forwardedHost string + forwardStream *streamWrapper + ts *tsopb.TsoResponse + err error + ok bool + ) + handleStreamError := func(err error) (needRetry bool) { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.updateServicePrimaryAddr(utils.TSOServiceName) + log.Warn("force to load tso primary address due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + if grpcutil.NeedRebuildConnection(err) { + s.tsoClientPool.Lock() + delete(s.tsoClientPool.clients, forwardedHost) + s.tsoClientPool.Unlock() + log.Warn("client connection removed due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + return false + } + for i := 0; i < maxRetryTimesRequestTSOServer; i++ { + if i > 0 { + time.Sleep(retryIntervalRequestTSOServer) + } + forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) + if !ok || forwardedHost == "" { + return pdpb.Timestamp{}, ErrNotFoundTSOAddr + } + forwardStream, err = s.getTSOForwardStream(forwardedHost) + if err != nil { + return pdpb.Timestamp{}, err + } + start := time.Now() + forwardStream.Lock() + err = forwardStream.Send(request) + if err != nil { + if needRetry := handleStreamError(err); needRetry { + forwardStream.Unlock() + continue + } + log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + forwardStream.Unlock() + return pdpb.Timestamp{}, err + } + ts, err = forwardStream.Recv() + forwardStream.Unlock() + forwardTsoDuration.Observe(time.Since(start).Seconds()) + if err != nil { + if needRetry := handleStreamError(err); needRetry { + continue + } + log.Error("receive response from tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err + } + return *ts.GetTimestamp(), nil + } + log.Error("get global tso from tso primary server failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err } -func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoClient, error) { +func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (*streamWrapper, error) { s.tsoClientPool.RLock() forwardStream, ok := s.tsoClientPool.clients[forwardedHost] s.tsoClientPool.RUnlock() @@ -1831,11 +1880,14 @@ func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoCli done := make(chan struct{}) ctx, cancel := context.WithCancel(s.ctx) go checkStream(ctx, cancel, done) - forwardStream, err = tsopb.NewTSOClient(client).Tso(ctx) + tsoClient, err := tsopb.NewTSOClient(client).Tso(ctx) done <- struct{}{} if err != nil { return nil, err } + forwardStream = &streamWrapper{ + TSO_TsoClient: tsoClient, + } s.tsoClientPool.clients[forwardedHost] = forwardStream return forwardStream, nil } diff --git a/server/metrics.go b/server/metrics.go index 7eed1020186..6082cf4ea95 100644 --- a/server/metrics.go +++ b/server/metrics.go @@ -151,6 +151,14 @@ var ( Name: "maxprocs", Help: "The value of GOMAXPROCS.", }) + forwardTsoDuration = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Namespace: "pd", + Subsystem: "server", + Name: "forward_tso_duration_seconds", + Help: "Bucketed histogram of processing time (s) of handled forward tso requests.", + Buckets: prometheus.ExponentialBuckets(0.0005, 2, 13), + }) ) func init() { @@ -170,4 +178,5 @@ func init() { prometheus.MustRegister(serviceAuditHistogram) prometheus.MustRegister(bucketReportInterval) prometheus.MustRegister(serverMaxProcs) + prometheus.MustRegister(forwardTsoDuration) } diff --git a/server/server.go b/server/server.go index e21501bc61e..0e11039ca14 100644 --- a/server/server.go +++ b/server/server.go @@ -69,6 +69,7 @@ import ( "github.com/tikv/pd/pkg/utils/grpcutil" "github.com/tikv/pd/pkg/utils/jsonutil" "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/pkg/utils/syncutil" "github.com/tikv/pd/pkg/utils/tsoutil" "github.com/tikv/pd/pkg/utils/typeutil" "github.com/tikv/pd/pkg/versioninfo" @@ -123,6 +124,11 @@ var ( etcdCommittedIndexGauge = etcdStateGauge.WithLabelValues("committedIndex") ) +type streamWrapper struct { + tsopb.TSO_TsoClient + syncutil.Mutex +} + // Server is the pd server. It implements bs.Server // nolint type Server struct { @@ -199,8 +205,8 @@ type Server struct { clientConns sync.Map tsoClientPool struct { - sync.RWMutex - clients map[string]tsopb.TSO_TsoClient + syncutil.RWMutex + clients map[string]*streamWrapper } // tsoDispatcher is used to dispatch different TSO requests to @@ -254,10 +260,10 @@ func CreateServer(ctx context.Context, cfg *config.Config, services []string, le DiagnosticsServer: sysutil.NewDiagnosticsServer(cfg.Log.File.Filename), mode: mode, tsoClientPool: struct { - sync.RWMutex - clients map[string]tsopb.TSO_TsoClient + syncutil.RWMutex + clients map[string]*streamWrapper }{ - clients: make(map[string]tsopb.TSO_TsoClient), + clients: make(map[string]*streamWrapper), }, } s.handler = newHandler(s) diff --git a/tests/integrations/mcs/tso/server_test.go b/tests/integrations/mcs/tso/server_test.go index 5a3782446a6..916d5ea6edd 100644 --- a/tests/integrations/mcs/tso/server_test.go +++ b/tests/integrations/mcs/tso/server_test.go @@ -23,9 +23,11 @@ import ( "net/http" "strconv" "strings" + "sync" "testing" "time" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -319,6 +321,142 @@ func (suite *APIServerForwardTestSuite) checkAvailableTSO() { suite.NoError(err) } +func TestForwardTsoConcurrently(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestAPICluster(ctx, 3) + re.NoError(err) + defer cluster.Destroy() + + err = cluster.RunInitialServers() + re.NoError(err) + + leaderName := cluster.WaitLeader() + pdLeader := cluster.GetServer(leaderName) + backendEndpoints := pdLeader.GetAddr() + re.NoError(pdLeader.BootstrapCluster()) + leader := cluster.GetServer(cluster.WaitLeader()) + rc := leader.GetRaftCluster() + for i := 0; i < 3; i++ { + region := &metapb.Region{ + Id: uint64(i*4 + 1), + Peers: []*metapb.Peer{{Id: uint64(i*4 + 2), StoreId: uint64(i*4 + 3)}}, + StartKey: []byte{byte(i)}, + EndKey: []byte{byte(i + 1)}, + } + rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0])) + } + + re.NoError(failpoint.Enable("github.com/tikv/pd/client/usePDServiceMode", "return(true)")) + defer func() { + re.NoError(failpoint.Disable("github.com/tikv/pd/client/usePDServiceMode")) + }() + + tc, err := mcs.NewTestTSOCluster(ctx, 2, backendEndpoints) + re.NoError(err) + defer tc.Destroy() + tc.WaitForDefaultPrimaryServing(re) + + wg := sync.WaitGroup{} + for i := 0; i < 50; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + pdClient, err := pd.NewClientWithContext( + context.Background(), + []string{backendEndpoints}, + pd.SecurityOption{}) + re.NoError(err) + re.NotNil(pdClient) + defer pdClient.Close() + for j := 0; j < 20; j++ { + testutil.Eventually(re, func() bool { + min, err := pdClient.UpdateServiceGCSafePoint(context.Background(), fmt.Sprintf("service-%d", i), 1000, 1) + return err == nil && min == 0 + }) + } + }(i) + } + wg.Wait() +} + +func BenchmarkForwardTsoConcurrently(b *testing.B) { + re := require.New(b) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestAPICluster(ctx, 3) + re.NoError(err) + defer cluster.Destroy() + + err = cluster.RunInitialServers() + re.NoError(err) + + leaderName := cluster.WaitLeader() + pdLeader := cluster.GetServer(leaderName) + backendEndpoints := pdLeader.GetAddr() + re.NoError(pdLeader.BootstrapCluster()) + leader := cluster.GetServer(cluster.WaitLeader()) + rc := leader.GetRaftCluster() + for i := 0; i < 3; i++ { + region := &metapb.Region{ + Id: uint64(i*4 + 1), + Peers: []*metapb.Peer{{Id: uint64(i*4 + 2), StoreId: uint64(i*4 + 3)}}, + StartKey: []byte{byte(i)}, + EndKey: []byte{byte(i + 1)}, + } + rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0])) + } + + re.NoError(failpoint.Enable("github.com/tikv/pd/client/usePDServiceMode", "return(true)")) + defer func() { + re.NoError(failpoint.Disable("github.com/tikv/pd/client/usePDServiceMode")) + }() + + tc, err := mcs.NewTestTSOCluster(ctx, 1, backendEndpoints) + re.NoError(err) + defer tc.Destroy() + tc.WaitForDefaultPrimaryServing(re) + + initClients := func(num int) []pd.Client { + var clients []pd.Client + for i := 0; i < num; i++ { + pdClient, err := pd.NewClientWithContext(context.Background(), + []string{backendEndpoints}, pd.SecurityOption{}, pd.WithMaxErrorRetry(1)) + re.NoError(err) + re.NotNil(pdClient) + clients = append(clients, pdClient) + } + return clients + } + + concurrencyLevels := []int{1, 2, 5, 10, 20} + for _, clientsNum := range concurrencyLevels { + clients := initClients(clientsNum) + b.Run(fmt.Sprintf("clients_%d", clientsNum), func(b *testing.B) { + wg := sync.WaitGroup{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j, client := range clients { + wg.Add(1) + go func(j int, client pd.Client) { + defer wg.Done() + for k := 0; k < 1000; k++ { + min, err := client.UpdateServiceGCSafePoint(context.Background(), fmt.Sprintf("service-%d", j), 1000, 1) + re.NoError(err) + re.Equal(uint64(0), min) + } + }(j, client) + } + } + wg.Wait() + }) + for _, c := range clients { + c.Close() + } + } +} + func TestAdvertiseAddr(t *testing.T) { re := require.New(t)