From 8ec1dafa8b0f40b01b9e3a4c7c96c969970d04dd Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Thu, 27 Feb 2025 19:41:13 +0800 Subject: [PATCH] fix and pick #6341 #6279 #7327 Signed-off-by: lhy1024 --- pkg/utils/grpcutil/grpcutil.go | 14 + server/forward.go | 547 ---------------------- server/grpc_service.go | 88 +++- server/metrics.go | 21 +- server/server.go | 11 +- tests/integrations/mcs/tso/server_test.go | 201 +++----- 6 files changed, 161 insertions(+), 721 deletions(-) delete mode 100644 server/forward.go diff --git a/pkg/utils/grpcutil/grpcutil.go b/pkg/utils/grpcutil/grpcutil.go index 97d1719da52..0fc3a15c1ad 100644 --- a/pkg/utils/grpcutil/grpcutil.go +++ b/pkg/utils/grpcutil/grpcutil.go @@ -18,12 +18,15 @@ import ( "context" "crypto/tls" "crypto/x509" + "io" "net/url" + "strings" "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" "go.etcd.io/etcd/pkg/transport" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -172,3 +175,14 @@ func GetForwardedHost(ctx context.Context) string { } return "" } + +// NeedRebuildConnection checks if the error is a connection error. +func NeedRebuildConnection(err error) bool { + return err == io.EOF || + strings.Contains(err.Error(), codes.Unavailable.String()) || // Unavailable indicates the service is currently unavailable. This is a most likely a transient condition. + strings.Contains(err.Error(), codes.DeadlineExceeded.String()) || // DeadlineExceeded means operation expired before completion. + strings.Contains(err.Error(), codes.Internal.String()) || // Internal errors. + strings.Contains(err.Error(), codes.Unknown.String()) || // Unknown error. + strings.Contains(err.Error(), codes.ResourceExhausted.String()) // ResourceExhausted is returned when either the client or the server has exhausted their resources. + // Besides, we don't need to rebuild the connection if the code is Canceled, which means the client cancelled the request. +} diff --git a/server/forward.go b/server/forward.go deleted file mode 100644 index a7ce4e3c7d8..00000000000 --- a/server/forward.go +++ /dev/null @@ -1,547 +0,0 @@ -// Copyright 2023 TiKV Project Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package server - -import ( - "context" - "io" - "strings" - "time" - - "go.uber.org/zap" - "google.golang.org/grpc" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/pdpb" - "github.com/pingcap/kvproto/pkg/schedulingpb" - "github.com/pingcap/kvproto/pkg/tsopb" - "github.com/pingcap/log" - - "github.com/tikv/pd/pkg/errs" - "github.com/tikv/pd/pkg/mcs/utils/constant" - "github.com/tikv/pd/pkg/utils/grpcutil" - "github.com/tikv/pd/pkg/utils/keypath" - "github.com/tikv/pd/pkg/utils/logutil" - "github.com/tikv/pd/pkg/utils/tsoutil" - "github.com/tikv/pd/server/cluster" -) - -// forwardToTSOService forwards the TSO requests to the TSO service. -func (s *GrpcServer) forwardToTSOService(stream pdpb.PD_TsoServer) error { - var ( - server = &tsoServer{stream: stream} - forwarder = newTSOForwarder(server) - tsoStreamErr error - ) - defer func() { - s.concurrentTSOProxyStreamings.Add(-1) - forwarder.cancel() - if grpcutil.NeedRebuildConnection(tsoStreamErr) { - s.closeDelegateClient(forwarder.host) - } - }() - - maxConcurrentTSOProxyStreamings := int32(s.GetMaxConcurrentTSOProxyStreamings()) - if maxConcurrentTSOProxyStreamings >= 0 { - if newCount := s.concurrentTSOProxyStreamings.Add(1); newCount > maxConcurrentTSOProxyStreamings { - return errors.WithStack(errs.ErrMaxCountTSOProxyRoutinesExceeded) - } - } - - tsDeadlineCh := make(chan *tsoutil.TSDeadline, 1) - go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh) - - for { - select { - case <-s.ctx.Done(): - return errors.WithStack(s.ctx.Err()) - case <-stream.Context().Done(): - return stream.Context().Err() - default: - } - - request, err := server.recv(s.GetTSOProxyRecvFromClientTimeout()) - if err == io.EOF { - return nil - } - if err != nil { - return errors.WithStack(err) - } - if request.GetCount() == 0 { - err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") - return errs.ErrUnknown(err) - } - tsoStreamErr, err = s.handleTSOForwarding(stream.Context(), forwarder, request, tsDeadlineCh) - if tsoStreamErr != nil { - return tsoStreamErr - } - if err != nil { - return err - } - } -} - -type tsoForwarder struct { - // The original source that we need to send the response back to. - responser interface{ Send(*pdpb.TsoResponse) error } - // The context for the forwarding stream. - ctx context.Context - // The cancel function for the forwarding stream. - canceller context.CancelFunc - // The current forwarding stream. - stream tsopb.TSO_TsoClient - // The current host of the forwarding stream. - host string -} - -func newTSOForwarder(responser interface{ Send(*pdpb.TsoResponse) error }) *tsoForwarder { - return &tsoForwarder{ - responser: responser, - } -} - -func (f *tsoForwarder) cancel() { - if f != nil && f.canceller != nil { - f.canceller() - } -} - -// forwardTSORequest sends the TSO request with the current forward stream. -func (f *tsoForwarder) forwardTSORequest( - request *pdpb.TsoRequest, -) (*tsopb.TsoResponse, error) { - tsopbReq := &tsopb.TsoRequest{ - Header: &tsopb.RequestHeader{ - ClusterId: request.GetHeader().GetClusterId(), - SenderId: request.GetHeader().GetSenderId(), - KeyspaceId: constant.DefaultKeyspaceID, - KeyspaceGroupId: constant.DefaultKeyspaceGroupID, - }, - Count: request.GetCount(), - } - - failpoint.Inject("tsoProxySendToTSOTimeout", func() { - // block until watchDeadline routine cancels the context. - <-f.ctx.Done() - }) - - select { - case <-f.ctx.Done(): - return nil, f.ctx.Err() - default: - } - - if err := f.stream.Send(tsopbReq); err != nil { - return nil, err - } - - failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() { - // block until watchDeadline routine cancels the context. - <-f.ctx.Done() - }) - - select { - case <-f.ctx.Done(): - return nil, f.ctx.Err() - default: - } - - return f.stream.Recv() -} - -func (s *GrpcServer) handleTSOForwarding( - ctx context.Context, - forwarder *tsoForwarder, - request *pdpb.TsoRequest, - tsDeadlineCh chan<- *tsoutil.TSDeadline, -) (tsoStreamErr, sendErr error) { - // Get the latest TSO primary address. - targetHost, ok := s.GetServicePrimaryAddr(ctx, constant.TSOServiceName) - if !ok || len(targetHost) == 0 { - return errors.WithStack(errs.ErrNotFoundTSOAddr), nil - } - // Check if the forwarder is already built with the target host. - if forwarder.stream == nil || forwarder.host != targetHost { - // Cancel the old forwarder. - forwarder.cancel() - // Build a new forward stream. - clientConn, err := s.getDelegateClient(s.ctx, targetHost) - if err != nil { - return errors.WithStack(err), nil - } - forwarder.stream, forwarder.ctx, forwarder.canceller, err = createTSOForwardStream(ctx, clientConn) - if err != nil { - return errors.WithStack(err), nil - } - forwarder.host = targetHost - } - - // Forward the TSO request with the deadline. - tsopbResp, err := s.forwardTSORequestWithDeadLine(forwarder, request, tsDeadlineCh) - if err != nil { - return errors.WithStack(err), nil - } - - // The error types defined for tsopb and pdpb are different, so we need to convert them. - var pdpbErr *pdpb.Error - tsopbErr := tsopbResp.GetHeader().GetError() - if tsopbErr != nil { - if tsopbErr.Type == tsopb.ErrorType_OK { - pdpbErr = &pdpb.Error{ - Type: pdpb.ErrorType_OK, - Message: tsopbErr.GetMessage(), - } - } else { - // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. - pdpbErr = &pdpb.Error{ - Type: pdpb.ErrorType_UNKNOWN, - Message: tsopbErr.GetMessage(), - } - } - } - // Send the TSO response back to the original source. - sendErr = forwarder.responser.Send(&pdpb.TsoResponse{ - Header: &pdpb.ResponseHeader{ - ClusterId: tsopbResp.GetHeader().GetClusterId(), - Error: pdpbErr, - }, - Count: tsopbResp.GetCount(), - Timestamp: tsopbResp.GetTimestamp(), - }) - - return nil, errors.WithStack(sendErr) -} - -func (s *GrpcServer) forwardTSORequestWithDeadLine( - forwarder *tsoForwarder, - request *pdpb.TsoRequest, - tsDeadlineCh chan<- *tsoutil.TSDeadline, -) (*tsopb.TsoResponse, error) { - var ( - forwardCtx = forwarder.ctx - forwardCancel = forwarder.canceller - done = make(chan struct{}) - dl = tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, forwardCancel) - ) - select { - case tsDeadlineCh <- dl: - case <-forwardCtx.Done(): - return nil, forwardCtx.Err() - } - - start := time.Now() - resp, err := forwarder.forwardTSORequest(request) - close(done) - if err != nil { - if strings.Contains(err.Error(), errs.NotLeaderErr) { - s.tsoPrimaryWatcher.ForceLoad() - } - return nil, err - } - tsoProxyBatchSize.Observe(float64(request.GetCount())) - tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) - return resp, nil -} - -func createTSOForwardStream(ctx context.Context, client *grpc.ClientConn) (tsopb.TSO_TsoClient, context.Context, context.CancelFunc, error) { - done := make(chan struct{}) - forwardCtx, cancelForward := context.WithCancel(ctx) - go grpcutil.CheckStream(forwardCtx, cancelForward, done) - forwardStream, err := tsopb.NewTSOClient(client).Tso(forwardCtx) - done <- struct{}{} - return forwardStream, forwardCtx, cancelForward, err -} - -func (s *GrpcServer) createRegionHeartbeatForwardStream(client *grpc.ClientConn) (pdpb.PD_RegionHeartbeatClient, context.CancelFunc, error) { - done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go grpcutil.CheckStream(ctx, cancel, done) - forwardStream, err := pdpb.NewPDClient(client).RegionHeartbeat(ctx) - done <- struct{}{} - return forwardStream, cancel, err -} - -func createRegionHeartbeatSchedulingStream(ctx context.Context, client *grpc.ClientConn) (schedulingpb.Scheduling_RegionHeartbeatClient, context.Context, context.CancelFunc, error) { - done := make(chan struct{}) - forwardCtx, cancelForward := context.WithCancel(ctx) - go grpcutil.CheckStream(forwardCtx, cancelForward, done) - forwardStream, err := schedulingpb.NewSchedulingClient(client).RegionHeartbeat(forwardCtx) - done <- struct{}{} - return forwardStream, forwardCtx, cancelForward, err -} - -func forwardRegionHeartbeatToScheduling(rc *cluster.RaftCluster, forwardStream schedulingpb.Scheduling_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { - defer logutil.LogPanic() - defer close(errCh) - for { - resp, err := forwardStream.Recv() - if err == io.EOF { - errCh <- errors.WithStack(err) - return - } - if err != nil { - errCh <- errors.WithStack(err) - return - } - // TODO: find a better way to halt scheduling immediately. - if rc.IsSchedulingHalted() { - continue - } - // The error types defined for schedulingpb and pdpb are different, so we need to convert them. - var pdpbErr *pdpb.Error - schedulingpbErr := resp.GetHeader().GetError() - if schedulingpbErr != nil { - if schedulingpbErr.Type == schedulingpb.ErrorType_OK { - pdpbErr = &pdpb.Error{ - Type: pdpb.ErrorType_OK, - Message: schedulingpbErr.GetMessage(), - } - } else { - // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. - pdpbErr = &pdpb.Error{ - Type: pdpb.ErrorType_UNKNOWN, - Message: schedulingpbErr.GetMessage(), - } - } - } - response := &pdpb.RegionHeartbeatResponse{ - Header: &pdpb.ResponseHeader{ - ClusterId: resp.GetHeader().GetClusterId(), - Error: pdpbErr, - }, - ChangePeer: resp.GetChangePeer(), - TransferLeader: resp.GetTransferLeader(), - RegionId: resp.GetRegionId(), - RegionEpoch: resp.GetRegionEpoch(), - TargetPeer: resp.GetTargetPeer(), - Merge: resp.GetMerge(), - SplitRegion: resp.GetSplitRegion(), - ChangePeerV2: resp.GetChangePeerV2(), - SwitchWitnesses: resp.GetSwitchWitnesses(), - } - - if err := server.Send(response); err != nil { - errCh <- errors.WithStack(err) - return - } - } -} - -func forwardRegionHeartbeatClientToServer(forwardStream pdpb.PD_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { - defer logutil.LogPanic() - defer close(errCh) - for { - resp, err := forwardStream.Recv() - if err != nil { - errCh <- errors.WithStack(err) - return - } - if err := server.Send(resp); err != nil { - errCh <- errors.WithStack(err) - return - } - } -} - -func forwardReportBucketClientToServer(forwardStream pdpb.PD_ReportBucketsClient, server *bucketHeartbeatServer, errCh chan error) { - defer logutil.LogPanic() - defer close(errCh) - for { - resp, err := forwardStream.CloseAndRecv() - if err != nil { - errCh <- errors.WithStack(err) - return - } - if err := server.send(resp); err != nil { - errCh <- errors.WithStack(err) - return - } - } -} - -func (s *GrpcServer) createReportBucketsForwardStream(client *grpc.ClientConn) (pdpb.PD_ReportBucketsClient, context.CancelFunc, error) { - done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go grpcutil.CheckStream(ctx, cancel, done) - forwardStream, err := pdpb.NewPDClient(client).ReportBuckets(ctx) - done <- struct{}{} - return forwardStream, cancel, err -} - -func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string) (*grpc.ClientConn, error) { - client, ok := s.clientConns.Load(forwardedHost) - if ok { - // Mostly, the connection is already established, and return it directly. - return client.(*grpc.ClientConn), nil - } - - tlsConfig, err := s.GetTLSConfig().ToTLSConfig() - if err != nil { - return nil, err - } - ctxTimeout, cancel := context.WithTimeout(ctx, defaultGRPCDialTimeout) - defer cancel() - newConn, err := grpcutil.GetClientConn(ctxTimeout, forwardedHost, tlsConfig) - if err != nil { - return nil, err - } - conn, loaded := s.clientConns.LoadOrStore(forwardedHost, newConn) - if !loaded { - // Successfully stored the connection we created. - return newConn, nil - } - // Loaded a connection created/stored by another goroutine, so close the one we created - // and return the one we loaded. - newConn.Close() - return conn.(*grpc.ClientConn), nil -} - -func (s *GrpcServer) closeDelegateClient(forwardedHost string) { - client, ok := s.clientConns.LoadAndDelete(forwardedHost) - if !ok { - return - } - client.(*grpc.ClientConn).Close() - log.Debug("close delegate client connection", zap.String("forwarded-host", forwardedHost)) -} - -func (s *GrpcServer) isLocalRequest(host string) bool { - failpoint.Inject("useForwardRequest", func() { - failpoint.Return(false) - }) - if host == "" { - return true - } - memberAddrs := s.GetMember().Member().GetClientUrls() - for _, addr := range memberAddrs { - if addr == host { - return true - } - } - return false -} - -func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { - if !s.IsServiceIndependent(constant.TSOServiceName) { - return s.tsoAllocatorManager.HandleRequest(ctx, 1) - } - request := &tsopb.TsoRequest{ - Header: &tsopb.RequestHeader{ - ClusterId: keypath.ClusterID(), - KeyspaceId: constant.DefaultKeyspaceID, - KeyspaceGroupId: constant.DefaultKeyspaceGroupID, - }, - Count: 1, - } - var ( - forwardedHost string - forwardStream *streamWrapper - ts *tsopb.TsoResponse - err error - ok bool - ) - handleStreamError := func(err error) (needRetry bool) { - if strings.Contains(err.Error(), errs.NotLeaderErr) { - s.tsoPrimaryWatcher.ForceLoad() - log.Warn("force to load tso primary address due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return true - } - if grpcutil.NeedRebuildConnection(err) { - s.tsoClientPool.Lock() - delete(s.tsoClientPool.clients, forwardedHost) - s.tsoClientPool.Unlock() - log.Warn("client connection removed due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return true - } - return false - } - for i := range maxRetryTimesRequestTSOServer { - if i > 0 { - time.Sleep(retryIntervalRequestTSOServer) - } - forwardedHost, ok = s.GetServicePrimaryAddr(ctx, constant.TSOServiceName) - if !ok || forwardedHost == "" { - return pdpb.Timestamp{}, errs.ErrNotFoundTSOAddr - } - forwardStream, err = s.getTSOForwardStream(forwardedHost) - if err != nil { - return pdpb.Timestamp{}, err - } - start := time.Now() - forwardStream.Lock() - err = forwardStream.Send(request) - if err != nil { - if needRetry := handleStreamError(err); needRetry { - forwardStream.Unlock() - continue - } - log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) - forwardStream.Unlock() - return pdpb.Timestamp{}, err - } - ts, err = forwardStream.Recv() - forwardStream.Unlock() - forwardTsoDuration.Observe(time.Since(start).Seconds()) - if err != nil { - if needRetry := handleStreamError(err); needRetry { - continue - } - log.Error("receive response from tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return pdpb.Timestamp{}, err - } - return *ts.GetTimestamp(), nil - } - log.Error("get global tso from tso primary server failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return pdpb.Timestamp{}, err -} - -func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (*streamWrapper, error) { - s.tsoClientPool.RLock() - forwardStream, ok := s.tsoClientPool.clients[forwardedHost] - s.tsoClientPool.RUnlock() - if ok { - // This is the common case to return here - return forwardStream, nil - } - - s.tsoClientPool.Lock() - defer s.tsoClientPool.Unlock() - - // Double check after entering the critical section - forwardStream, ok = s.tsoClientPool.clients[forwardedHost] - if ok { - return forwardStream, nil - } - - // Now let's create the client connection and the forward stream - client, err := s.getDelegateClient(s.ctx, forwardedHost) - if err != nil { - return nil, err - } - done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go grpcutil.CheckStream(ctx, cancel, done) - tsoClient, err := tsopb.NewTSOClient(client).Tso(ctx) - done <- struct{}{} - if err != nil { - return nil, err - } - forwardStream = &streamWrapper{ - TSO_TsoClient: tsoClient, - } - s.tsoClientPool.clients[forwardedHost] = forwardStream - return forwardStream, nil -} diff --git a/server/grpc_service.go b/server/grpc_service.go index 7c587ac18b8..22cfbc885bf 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -20,6 +20,7 @@ import ( "io" "path" "strconv" + "strings" "sync/atomic" "time" @@ -48,7 +49,9 @@ import ( ) const ( - heartbeatSendTimeout = 5 * time.Second + heartbeatSendTimeout = 5 * time.Second + maxRetryTimesRequestTSOServer = 3 + retryIntervalRequestTSOServer = 500 * time.Millisecond ) // gRPC errors @@ -1781,31 +1784,77 @@ func checkStream(streamCtx context.Context, cancel context.CancelFunc, done chan } func (s *GrpcServer) getGlobalTSOFromTSOServer(ctx context.Context) (pdpb.Timestamp, error) { - forwardedHost, ok := s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) - if !ok || forwardedHost == "" { - return pdpb.Timestamp{}, ErrNotFoundTSOAddr - } - forwardStream, err := s.getTSOForwardStream(forwardedHost) - if err != nil { - return pdpb.Timestamp{}, err - } - forwardStream.Send(&tsopb.TsoRequest{ + request := &tsopb.TsoRequest{ Header: &tsopb.RequestHeader{ ClusterId: s.clusterID, KeyspaceId: utils.DefaultKeyspaceID, KeyspaceGroupId: utils.DefaultKeyspaceGroupID, }, Count: 1, - }) - ts, err := forwardStream.Recv() - if err != nil { - log.Error("get global tso from tso server failed", zap.Error(err)) - return pdpb.Timestamp{}, err } - return *ts.GetTimestamp(), nil + var ( + forwardedHost string + forwardStream *streamWrapper + ts *tsopb.TsoResponse + err error + ok bool + ) + handleStreamError := func(err error) (needRetry bool) { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.updateServicePrimaryAddr(utils.TSOServiceName) + log.Warn("force to load tso primary address due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + if grpcutil.NeedRebuildConnection(err) { + s.tsoClientPool.Lock() + delete(s.tsoClientPool.clients, forwardedHost) + s.tsoClientPool.Unlock() + log.Warn("client connection removed due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + return false + } + for i := 0; i < maxRetryTimesRequestTSOServer; i++ { + if i > 0 { + time.Sleep(retryIntervalRequestTSOServer) + } + forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) + if !ok || forwardedHost == "" { + return pdpb.Timestamp{}, ErrNotFoundTSOAddr + } + forwardStream, err = s.getTSOForwardStream(forwardedHost) + if err != nil { + return pdpb.Timestamp{}, err + } + start := time.Now() + forwardStream.Lock() + err = forwardStream.Send(request) + if err != nil { + if needRetry := handleStreamError(err); needRetry { + forwardStream.Unlock() + continue + } + log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + forwardStream.Unlock() + return pdpb.Timestamp{}, err + } + ts, err = forwardStream.Recv() + forwardStream.Unlock() + forwardTsoDuration.Observe(time.Since(start).Seconds()) + if err != nil { + if needRetry := handleStreamError(err); needRetry { + continue + } + log.Error("receive response from tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err + } + return *ts.GetTimestamp(), nil + } + log.Error("get global tso from tso primary server failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err } -func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoClient, error) { +func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (*streamWrapper, error) { s.tsoClientPool.RLock() forwardStream, ok := s.tsoClientPool.clients[forwardedHost] s.tsoClientPool.RUnlock() @@ -1831,11 +1880,14 @@ func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoCli done := make(chan struct{}) ctx, cancel := context.WithCancel(s.ctx) go checkStream(ctx, cancel, done) - forwardStream, err = tsopb.NewTSOClient(client).Tso(ctx) + tsoClient, err := tsopb.NewTSOClient(client).Tso(ctx) done <- struct{}{} if err != nil { return nil, err } + forwardStream = &streamWrapper{ + TSO_TsoClient: tsoClient, + } s.tsoClientPool.clients[forwardedHost] = forwardStream return forwardStream, nil } diff --git a/server/metrics.go b/server/metrics.go index 23866c01dae..6082cf4ea95 100644 --- a/server/metrics.go +++ b/server/metrics.go @@ -147,23 +147,10 @@ var ( serverMaxProcs = prometheus.NewGauge( prometheus.GaugeOpts{ Namespace: "pd", -<<<<<<< HEAD Subsystem: "service", Name: "maxprocs", Help: "The value of GOMAXPROCS.", -======= - Subsystem: "server", - Name: "api_concurrency", - Help: "Concurrency number of the api.", - }, []string{"kind", "api"}) - - forwardFailCounter = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: "pd", - Subsystem: "server", - Name: "forward_fail_total", - Help: "Counter of forward fail.", - }, []string{"request", "type"}) + }) forwardTsoDuration = prometheus.NewHistogram( prometheus.HistogramOpts{ Namespace: "pd", @@ -171,7 +158,6 @@ var ( Name: "forward_tso_duration_seconds", Help: "Bucketed histogram of processing time (s) of handled forward tso requests.", Buckets: prometheus.ExponentialBuckets(0.0005, 2, 13), ->>>>>>> 0c13897bf (mcs: add lock for forward tso stream (#9095)) }) ) @@ -191,11 +177,6 @@ func init() { prometheus.MustRegister(bucketReportLatency) prometheus.MustRegister(serviceAuditHistogram) prometheus.MustRegister(bucketReportInterval) -<<<<<<< HEAD prometheus.MustRegister(serverMaxProcs) -======= - prometheus.MustRegister(apiConcurrencyGauge) - prometheus.MustRegister(forwardFailCounter) prometheus.MustRegister(forwardTsoDuration) ->>>>>>> 0c13897bf (mcs: add lock for forward tso stream (#9095)) } diff --git a/server/server.go b/server/server.go index 532ede58ca8..0e11039ca14 100644 --- a/server/server.go +++ b/server/server.go @@ -69,6 +69,7 @@ import ( "github.com/tikv/pd/pkg/utils/grpcutil" "github.com/tikv/pd/pkg/utils/jsonutil" "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/pkg/utils/syncutil" "github.com/tikv/pd/pkg/utils/tsoutil" "github.com/tikv/pd/pkg/utils/typeutil" "github.com/tikv/pd/pkg/versioninfo" @@ -204,13 +205,8 @@ type Server struct { clientConns sync.Map tsoClientPool struct { -<<<<<<< HEAD - sync.RWMutex - clients map[string]tsopb.TSO_TsoClient -======= syncutil.RWMutex clients map[string]*streamWrapper ->>>>>>> 0c13897bf (mcs: add lock for forward tso stream (#9095)) } // tsoDispatcher is used to dispatch different TSO requests to @@ -264,13 +260,8 @@ func CreateServer(ctx context.Context, cfg *config.Config, services []string, le DiagnosticsServer: sysutil.NewDiagnosticsServer(cfg.Log.File.Filename), mode: mode, tsoClientPool: struct { -<<<<<<< HEAD - sync.RWMutex - clients map[string]tsopb.TSO_TsoClient -======= syncutil.RWMutex clients map[string]*streamWrapper ->>>>>>> 0c13897bf (mcs: add lock for forward tso stream (#9095)) }{ clients: make(map[string]*streamWrapper), }, diff --git a/tests/integrations/mcs/tso/server_test.go b/tests/integrations/mcs/tso/server_test.go index 896b1431946..916d5ea6edd 100644 --- a/tests/integrations/mcs/tso/server_test.go +++ b/tests/integrations/mcs/tso/server_test.go @@ -27,6 +27,7 @@ import ( "testing" "time" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -320,59 +321,108 @@ func (suite *APIServerForwardTestSuite) checkAvailableTSO() { suite.NoError(err) } -<<<<<<< HEAD -func TestAdvertiseAddr(t *testing.T) { -======= func TestForwardTsoConcurrently(t *testing.T) { re := require.New(t) - suite := NewPDServiceForward(re) - defer suite.ShutDown() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestAPICluster(ctx, 3) + re.NoError(err) + defer cluster.Destroy() + + err = cluster.RunInitialServers() + re.NoError(err) + + leaderName := cluster.WaitLeader() + pdLeader := cluster.GetServer(leaderName) + backendEndpoints := pdLeader.GetAddr() + re.NoError(pdLeader.BootstrapCluster()) + leader := cluster.GetServer(cluster.WaitLeader()) + rc := leader.GetRaftCluster() + for i := 0; i < 3; i++ { + region := &metapb.Region{ + Id: uint64(i*4 + 1), + Peers: []*metapb.Peer{{Id: uint64(i*4 + 2), StoreId: uint64(i*4 + 3)}}, + StartKey: []byte{byte(i)}, + EndKey: []byte{byte(i + 1)}, + } + rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0])) + } + + re.NoError(failpoint.Enable("github.com/tikv/pd/client/usePDServiceMode", "return(true)")) + defer func() { + re.NoError(failpoint.Disable("github.com/tikv/pd/client/usePDServiceMode")) + }() - tc, err := tests.NewTestTSOCluster(suite.ctx, 2, suite.backendEndpoints) + tc, err := mcs.NewTestTSOCluster(ctx, 2, backendEndpoints) re.NoError(err) defer tc.Destroy() tc.WaitForDefaultPrimaryServing(re) wg := sync.WaitGroup{} - for i := range 3 { + for i := 0; i < 50; i++ { wg.Add(1) - go func() { + go func(i int) { defer wg.Done() pdClient, err := pd.NewClientWithContext( context.Background(), - caller.TestComponent, - []string{suite.backendEndpoints}, + []string{backendEndpoints}, pd.SecurityOption{}) re.NoError(err) re.NotNil(pdClient) defer pdClient.Close() - for range 10 { + for j := 0; j < 20; j++ { testutil.Eventually(re, func() bool { min, err := pdClient.UpdateServiceGCSafePoint(context.Background(), fmt.Sprintf("service-%d", i), 1000, 1) return err == nil && min == 0 }) } - }() + }(i) } wg.Wait() } func BenchmarkForwardTsoConcurrently(b *testing.B) { re := require.New(b) - suite := NewPDServiceForward(re) - defer suite.ShutDown() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestAPICluster(ctx, 3) + re.NoError(err) + defer cluster.Destroy() + + err = cluster.RunInitialServers() + re.NoError(err) + + leaderName := cluster.WaitLeader() + pdLeader := cluster.GetServer(leaderName) + backendEndpoints := pdLeader.GetAddr() + re.NoError(pdLeader.BootstrapCluster()) + leader := cluster.GetServer(cluster.WaitLeader()) + rc := leader.GetRaftCluster() + for i := 0; i < 3; i++ { + region := &metapb.Region{ + Id: uint64(i*4 + 1), + Peers: []*metapb.Peer{{Id: uint64(i*4 + 2), StoreId: uint64(i*4 + 3)}}, + StartKey: []byte{byte(i)}, + EndKey: []byte{byte(i + 1)}, + } + rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0])) + } + + re.NoError(failpoint.Enable("github.com/tikv/pd/client/usePDServiceMode", "return(true)")) + defer func() { + re.NoError(failpoint.Disable("github.com/tikv/pd/client/usePDServiceMode")) + }() - tc, err := tests.NewTestTSOCluster(suite.ctx, 1, suite.backendEndpoints) + tc, err := mcs.NewTestTSOCluster(ctx, 1, backendEndpoints) re.NoError(err) defer tc.Destroy() tc.WaitForDefaultPrimaryServing(re) initClients := func(num int) []pd.Client { var clients []pd.Client - for range num { + for i := 0; i < num; i++ { pdClient, err := pd.NewClientWithContext(context.Background(), - caller.TestComponent, - []string{suite.backendEndpoints}, pd.SecurityOption{}, opt.WithMaxErrorRetry(1)) + []string{backendEndpoints}, pd.SecurityOption{}, pd.WithMaxErrorRetry(1)) re.NoError(err) re.NotNil(pdClient) clients = append(clients, pdClient) @@ -386,17 +436,17 @@ func BenchmarkForwardTsoConcurrently(b *testing.B) { b.Run(fmt.Sprintf("clients_%d", clientsNum), func(b *testing.B) { wg := sync.WaitGroup{} b.ResetTimer() - for range b.N { - for i, client := range clients { + for i := 0; i < b.N; i++ { + for j, client := range clients { wg.Add(1) - go func() { + go func(j int, client pd.Client) { defer wg.Done() - for range 1000 { - min, err := client.UpdateServiceGCSafePoint(context.Background(), fmt.Sprintf("service-%d", i), 1000, 1) + for k := 0; k < 1000; k++ { + min, err := client.UpdateServiceGCSafePoint(context.Background(), fmt.Sprintf("service-%d", j), 1000, 1) re.NoError(err) re.Equal(uint64(0), min) } - }() + }(j, client) } } wg.Wait() @@ -407,108 +457,7 @@ func BenchmarkForwardTsoConcurrently(b *testing.B) { } } -type CommonTestSuite struct { - suite.Suite - ctx context.Context - cancel context.CancelFunc - cluster *tests.TestCluster - tsoCluster *tests.TestTSOCluster - pdLeader *tests.TestServer - // tsoDefaultPrimaryServer is the primary server of the default keyspace group - tsoDefaultPrimaryServer *tso.Server - backendEndpoints string -} - -func TestCommonTestSuite(t *testing.T) { - suite.Run(t, new(CommonTestSuite)) -} - -func (suite *CommonTestSuite) SetupSuite() { - var err error - re := suite.Require() - suite.ctx, suite.cancel = context.WithCancel(context.Background()) - suite.cluster, err = tests.NewTestClusterWithKeyspaceGroup(suite.ctx, 1) - re.NoError(err) - - err = suite.cluster.RunInitialServers() - re.NoError(err) - - leaderName := suite.cluster.WaitLeader() - re.NotEmpty(leaderName) - suite.pdLeader = suite.cluster.GetServer(leaderName) - suite.backendEndpoints = suite.pdLeader.GetAddr() - re.NoError(suite.pdLeader.BootstrapCluster()) - - suite.tsoCluster, err = tests.NewTestTSOCluster(suite.ctx, 1, suite.backendEndpoints) - re.NoError(err) - suite.tsoCluster.WaitForDefaultPrimaryServing(re) - suite.tsoDefaultPrimaryServer = suite.tsoCluster.GetPrimaryServer(constant.DefaultKeyspaceID, constant.DefaultKeyspaceGroupID) -} - -func (suite *CommonTestSuite) TearDownSuite() { - re := suite.Require() - suite.tsoCluster.Destroy() - etcdClient := suite.pdLeader.GetEtcdClient() - endpoints, err := discovery.Discover(etcdClient, constant.TSOServiceName) - re.NoError(err) - if len(endpoints) != 0 { - endpoints, err = discovery.Discover(etcdClient, constant.TSOServiceName) - re.NoError(err) - re.Empty(endpoints) - } - suite.cluster.Destroy() - suite.cancel() -} - -func (suite *CommonTestSuite) TestAdvertiseAddr() { - re := suite.Require() - - conf := suite.tsoDefaultPrimaryServer.GetConfig() - re.Equal(conf.GetListenAddr(), conf.GetAdvertiseListenAddr()) -} - -func (suite *CommonTestSuite) TestBootstrapDefaultKeyspaceGroup() { - re := suite.Require() - - // check the default keyspace group and wait for alloc tso nodes for the default keyspace group - check := func() { - testutil.Eventually(re, func() bool { - resp, err := tests.TestDialClient.Get(suite.pdLeader.GetServer().GetConfig().AdvertiseClientUrls + "/pd/api/v2/tso/keyspace-groups") - re.NoError(err) - defer resp.Body.Close() - re.Equal(http.StatusOK, resp.StatusCode) - respString, err := io.ReadAll(resp.Body) - re.NoError(err) - var kgs []*endpoint.KeyspaceGroup - re.NoError(json.Unmarshal(respString, &kgs)) - re.Len(kgs, 1) - re.Equal(constant.DefaultKeyspaceGroupID, kgs[0].ID) - re.Equal(endpoint.Basic.String(), kgs[0].UserKind) - re.Empty(kgs[0].SplitState) - re.Empty(kgs[0].KeyspaceLookupTable) - return len(kgs[0].Members) == 1 - }) - } - check() - - s, err := suite.cluster.JoinWithKeyspaceGroup(suite.ctx) - re.NoError(err) - re.NoError(s.Run()) - - // transfer leader to the new server - suite.pdLeader.ResignLeader() - suite.pdLeader = suite.cluster.GetServer(suite.cluster.WaitLeader()) - check() - suite.pdLeader.ResignLeader() - suite.pdLeader = suite.cluster.GetServer(suite.cluster.WaitLeader()) -} - -// TestTSOServiceSwitch tests the behavior of TSO service switching when `EnableTSODynamicSwitching` is enabled. -// Initially, the TSO service should be provided by PD. After starting a TSO server, the service should switch to the TSO server. -// When the TSO server is stopped, the PD should resume providing the TSO service if `EnableTSODynamicSwitching` is enabled. -// If `EnableTSODynamicSwitching` is disabled, the PD should not provide TSO service after the TSO server is stopped. -func TestTSOServiceSwitch(t *testing.T) { ->>>>>>> 0c13897bf (mcs: add lock for forward tso stream (#9095)) +func TestAdvertiseAddr(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background())