From 94475933c5ed5ed08f19e70cf4ed3609d82cee2c Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 24 Feb 2025 15:58:45 +0800 Subject: [PATCH 1/3] Refactor TSO server forwarding function Signed-off-by: JmPotato --- server/forward.go | 123 ++++++++++++++++++++++++----------------- server/grpc_service.go | 23 +++----- 2 files changed, 81 insertions(+), 65 deletions(-) diff --git a/server/forward.go b/server/forward.go index 9a604410fc0..2c2af6f37a5 100644 --- a/server/forward.go +++ b/server/forward.go @@ -39,10 +39,12 @@ import ( "github.com/tikv/pd/server/cluster" ) +// forwardTSORequest sends the TSO request with the given forward stream. func forwardTSORequest( ctx context.Context, request *pdpb.TsoRequest, - forwardStream tsopb.TSO_TsoClient) (*tsopb.TsoResponse, error) { + forwardStream tsopb.TSO_TsoClient, +) (*tsopb.TsoResponse, error) { tsopbReq := &tsopb.TsoRequest{ Header: &tsopb.RequestHeader{ ClusterId: request.GetHeader().GetClusterId(), @@ -85,20 +87,15 @@ func forwardTSORequest( // forwardTSO forward the TSO requests to the TSO service. func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { var ( - server = &tsoServer{stream: stream} - forwardStream tsopb.TSO_TsoClient - forwardCtx context.Context - cancelForward context.CancelFunc - tsoStreamErr error - lastForwardedHost string + server = &tsoServer{stream: stream} + forwarder = newTSOForwarder(server) + tsoStreamErr error ) defer func() { s.concurrentTSOProxyStreamings.Add(-1) - if cancelForward != nil { - cancelForward() - } + forwarder.cancel() if grpcutil.NeedRebuildConnection(tsoStreamErr) { - s.closeDelegateClient(lastForwardedHost) + s.closeDelegateClient(forwarder.host) } }() @@ -132,7 +129,7 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") return errs.ErrUnknown(err) } - forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, server, request, tsDeadlineCh, lastForwardedHost, cancelForward) + tsoStreamErr, err = s.handleTSOForwarding(stream.Context(), forwarder, request, tsDeadlineCh) if tsoStreamErr != nil { return tsoStreamErr } @@ -142,38 +139,62 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { } } -func (s *GrpcServer) handleTSOForwarding(forwardCtx context.Context, forwardStream tsopb.TSO_TsoClient, stream pdpb.PD_TsoServer, server *tsoServer, - request *pdpb.TsoRequest, tsDeadlineCh chan<- *tsoutil.TSDeadline, lastForwardedHost string, cancelForward context.CancelFunc) ( - context.Context, - context.CancelFunc, - tsopb.TSO_TsoClient, - string, - error, // tso stream error - error, // send error -) { - forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), constant.TSOServiceName) - if !ok || len(forwardedHost) == 0 { - return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(errs.ErrNotFoundTSOAddr), nil - } - if forwardStream == nil || lastForwardedHost != forwardedHost { - if cancelForward != nil { - cancelForward() - } +type tsoForwarder struct { + // The original source that we need to send the response back to. + responser interface{ Send(*pdpb.TsoResponse) error } + // The context for the forwarding stream. + ctx context.Context + // The cancel function for the forwarding stream. + canceller context.CancelFunc + // The current forwarding stream. + stream tsopb.TSO_TsoClient + // The current host of the forwarding stream. + host string +} - clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) +func newTSOForwarder(responser interface{ Send(*pdpb.TsoResponse) error }) *tsoForwarder { + return &tsoForwarder{ + responser: responser, + } +} + +func (f *tsoForwarder) cancel() { + if f != nil && f.canceller != nil { + f.canceller() + } +} + +func (s *GrpcServer) handleTSOForwarding( + ctx context.Context, + forwarder *tsoForwarder, + request *pdpb.TsoRequest, + tsDeadlineCh chan<- *tsoutil.TSDeadline, +) (tsoStreamErr, sendErr error) { + // Get the latest TSO primary address. + targetHost, ok := s.GetServicePrimaryAddr(ctx, constant.TSOServiceName) + if !ok || len(targetHost) == 0 { + return errors.WithStack(errs.ErrNotFoundTSOAddr), nil + } + // Check if the forwarder is already built with the target host. + if forwarder.stream == nil || forwarder.host != targetHost { + // Cancel the old forwarder. + forwarder.cancel() + // Build a new forward stream. + clientConn, err := s.getDelegateClient(s.ctx, targetHost) if err != nil { - return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil + return errors.WithStack(err), nil } - forwardStream, forwardCtx, cancelForward, err = createTSOForwardStream(stream.Context(), clientConn) + forwarder.stream, forwarder.ctx, forwarder.canceller, err = createTSOForwardStream(ctx, clientConn) if err != nil { - return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil + return errors.WithStack(err), nil } - lastForwardedHost = forwardedHost + forwarder.host = targetHost } - tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) + // Forward the TSO request with the deadline. + tsopbResp, err := s.forwardTSORequestWithDeadLine(forwarder, request, tsDeadlineCh) if err != nil { - return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil + return errors.WithStack(err), nil } // The error types defined for tsopb and pdpb are different, so we need to convert them. @@ -193,31 +214,31 @@ func (s *GrpcServer) handleTSOForwarding(forwardCtx context.Context, forwardStre } } } - - response := &pdpb.TsoResponse{ + // Send the TSO response back to the original source. + sendErr = forwarder.responser.Send(&pdpb.TsoResponse{ Header: &pdpb.ResponseHeader{ ClusterId: tsopbResp.GetHeader().GetClusterId(), Error: pdpbErr, }, Count: tsopbResp.GetCount(), Timestamp: tsopbResp.GetTimestamp(), - } - if server != nil { - err = server.send(response) - } else { - err = stream.Send(response) - } - return forwardCtx, cancelForward, forwardStream, lastForwardedHost, nil, errors.WithStack(err) + }) + + return nil, errors.WithStack(sendErr) } func (s *GrpcServer) forwardTSORequestWithDeadLine( - forwardCtx context.Context, - cancelForward context.CancelFunc, - forwardStream tsopb.TSO_TsoClient, + forwarder *tsoForwarder, request *pdpb.TsoRequest, - tsDeadlineCh chan<- *tsoutil.TSDeadline) (*tsopb.TsoResponse, error) { - done := make(chan struct{}) - dl := tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, cancelForward) + tsDeadlineCh chan<- *tsoutil.TSDeadline, +) (*tsopb.TsoResponse, error) { + var ( + forwardCtx = forwarder.ctx + forwardCancel = forwarder.canceller + forwardStream = forwarder.stream + done = make(chan struct{}) + dl = tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, forwardCancel) + ) select { case tsDeadlineCh <- dl: case <-forwardCtx.Done(): diff --git a/server/grpc_service.go b/server/grpc_service.go index a9295ac6ab3..58c7828f5e2 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -96,7 +96,8 @@ type pdpbTSORequest struct { err error } -func (s *tsoServer) send(m *pdpb.TsoResponse) error { +// Send wraps Send() of PD_TsoServer. +func (s *tsoServer) Send(m *pdpb.TsoResponse) error { if atomic.LoadInt32(&s.closed) == 1 { return io.EOF } @@ -493,22 +494,16 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh) var ( - doneCh chan struct{} - errCh chan error - // The following are tso forward stream related variables. - forwardStream tsopb.TSO_TsoClient - cancelForward context.CancelFunc - forwardCtx context.Context - tsoStreamErr error - lastForwardedHost string + doneCh chan struct{} + errCh chan error + forwarder = newTSOForwarder(stream) + tsoStreamErr error ) defer func() { - if cancelForward != nil { - cancelForward() - } + forwarder.cancel() if grpcutil.NeedRebuildConnection(tsoStreamErr) { - s.closeDelegateClient(lastForwardedHost) + s.closeDelegateClient(forwarder.host) } }() @@ -559,7 +554,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") return errs.ErrUnknown(err) } - forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, nil, request, tsDeadlineCh, lastForwardedHost, cancelForward) + tsoStreamErr, err = s.handleTSOForwarding(stream.Context(), forwarder, request, tsDeadlineCh) if tsoStreamErr != nil { return tsoStreamErr } From 6c725cb64bae748244b3409672d39781c6fbfb45 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 24 Feb 2025 20:28:58 +0800 Subject: [PATCH 2/3] Make forwardTSORequest a method of tsoForwarder Signed-off-by: JmPotato --- server/forward.go | 91 +++++++++++++++++++++++------------------------ 1 file changed, 44 insertions(+), 47 deletions(-) diff --git a/server/forward.go b/server/forward.go index 2c2af6f37a5..de3fe024197 100644 --- a/server/forward.go +++ b/server/forward.go @@ -39,51 +39,6 @@ import ( "github.com/tikv/pd/server/cluster" ) -// forwardTSORequest sends the TSO request with the given forward stream. -func forwardTSORequest( - ctx context.Context, - request *pdpb.TsoRequest, - forwardStream tsopb.TSO_TsoClient, -) (*tsopb.TsoResponse, error) { - tsopbReq := &tsopb.TsoRequest{ - Header: &tsopb.RequestHeader{ - ClusterId: request.GetHeader().GetClusterId(), - SenderId: request.GetHeader().GetSenderId(), - KeyspaceId: constant.DefaultKeyspaceID, - KeyspaceGroupId: constant.DefaultKeyspaceGroupID, - }, - Count: request.GetCount(), - } - - failpoint.Inject("tsoProxySendToTSOTimeout", func() { - // block until watchDeadline routine cancels the context. - <-ctx.Done() - }) - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - if err := forwardStream.Send(tsopbReq); err != nil { - return nil, err - } - - failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() { - // block until watchDeadline routine cancels the context. - <-ctx.Done() - }) - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - return forwardStream.Recv() -} - // forwardTSO forward the TSO requests to the TSO service. func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { var ( @@ -164,6 +119,49 @@ func (f *tsoForwarder) cancel() { } } +// forwardTSORequest sends the TSO request with the current forward stream. +func (f *tsoForwarder) forwardTSORequest( + request *pdpb.TsoRequest, +) (*tsopb.TsoResponse, error) { + tsopbReq := &tsopb.TsoRequest{ + Header: &tsopb.RequestHeader{ + ClusterId: request.GetHeader().GetClusterId(), + SenderId: request.GetHeader().GetSenderId(), + KeyspaceId: constant.DefaultKeyspaceID, + KeyspaceGroupId: constant.DefaultKeyspaceGroupID, + }, + Count: request.GetCount(), + } + + failpoint.Inject("tsoProxySendToTSOTimeout", func() { + // block until watchDeadline routine cancels the context. + <-f.ctx.Done() + }) + + select { + case <-f.ctx.Done(): + return nil, f.ctx.Err() + default: + } + + if err := f.stream.Send(tsopbReq); err != nil { + return nil, err + } + + failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() { + // block until watchDeadline routine cancels the context. + <-f.ctx.Done() + }) + + select { + case <-f.ctx.Done(): + return nil, f.ctx.Err() + default: + } + + return f.stream.Recv() +} + func (s *GrpcServer) handleTSOForwarding( ctx context.Context, forwarder *tsoForwarder, @@ -235,7 +233,6 @@ func (s *GrpcServer) forwardTSORequestWithDeadLine( var ( forwardCtx = forwarder.ctx forwardCancel = forwarder.canceller - forwardStream = forwarder.stream done = make(chan struct{}) dl = tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, forwardCancel) ) @@ -246,7 +243,7 @@ func (s *GrpcServer) forwardTSORequestWithDeadLine( } start := time.Now() - resp, err := forwardTSORequest(forwardCtx, request, forwardStream) + resp, err := forwarder.forwardTSORequest(request) close(done) if err != nil { if strings.Contains(err.Error(), errs.NotLeaderErr) { From 09f5f8f21a8d9c33a1cbc60024008eb4794ac32d Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 25 Feb 2025 14:40:07 +0800 Subject: [PATCH 3/3] Rename forwardTSO to forwardToTSOService Signed-off-by: JmPotato --- server/forward.go | 4 ++-- server/grpc_service.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/forward.go b/server/forward.go index de3fe024197..b2d7a8be549 100644 --- a/server/forward.go +++ b/server/forward.go @@ -39,8 +39,8 @@ import ( "github.com/tikv/pd/server/cluster" ) -// forwardTSO forward the TSO requests to the TSO service. -func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { +// forwardToTSOService forwards the TSO requests to the TSO service. +func (s *GrpcServer) forwardToTSOService(stream pdpb.PD_TsoServer) error { var ( server = &tsoServer{stream: stream} forwarder = newTSOForwarder(server) diff --git a/server/grpc_service.go b/server/grpc_service.go index 58c7828f5e2..a311b8aedc3 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -487,7 +487,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { defer done() } if s.IsServiceIndependent(constant.TSOServiceName) { - return s.forwardTSO(stream) + return s.forwardToTSOService(stream) } tsDeadlineCh := make(chan *tsoutil.TSDeadline, 1)