diff --git a/server/forward.go b/server/forward.go index 9a604410fc01..2c2af6f37a55 100644 --- a/server/forward.go +++ b/server/forward.go @@ -39,10 +39,12 @@ import ( "github.com/tikv/pd/server/cluster" ) +// forwardTSORequest sends the TSO request with the given forward stream. func forwardTSORequest( ctx context.Context, request *pdpb.TsoRequest, - forwardStream tsopb.TSO_TsoClient) (*tsopb.TsoResponse, error) { + forwardStream tsopb.TSO_TsoClient, +) (*tsopb.TsoResponse, error) { tsopbReq := &tsopb.TsoRequest{ Header: &tsopb.RequestHeader{ ClusterId: request.GetHeader().GetClusterId(), @@ -85,20 +87,15 @@ func forwardTSORequest( // forwardTSO forward the TSO requests to the TSO service. func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { var ( - server = &tsoServer{stream: stream} - forwardStream tsopb.TSO_TsoClient - forwardCtx context.Context - cancelForward context.CancelFunc - tsoStreamErr error - lastForwardedHost string + server = &tsoServer{stream: stream} + forwarder = newTSOForwarder(server) + tsoStreamErr error ) defer func() { s.concurrentTSOProxyStreamings.Add(-1) - if cancelForward != nil { - cancelForward() - } + forwarder.cancel() if grpcutil.NeedRebuildConnection(tsoStreamErr) { - s.closeDelegateClient(lastForwardedHost) + s.closeDelegateClient(forwarder.host) } }() @@ -132,7 +129,7 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") return errs.ErrUnknown(err) } - forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, server, request, tsDeadlineCh, lastForwardedHost, cancelForward) + tsoStreamErr, err = s.handleTSOForwarding(stream.Context(), forwarder, request, tsDeadlineCh) if tsoStreamErr != nil { return tsoStreamErr } @@ -142,38 +139,62 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { } } -func (s *GrpcServer) handleTSOForwarding(forwardCtx context.Context, forwardStream tsopb.TSO_TsoClient, stream pdpb.PD_TsoServer, server *tsoServer, - request *pdpb.TsoRequest, tsDeadlineCh chan<- *tsoutil.TSDeadline, lastForwardedHost string, cancelForward context.CancelFunc) ( - context.Context, - context.CancelFunc, - tsopb.TSO_TsoClient, - string, - error, // tso stream error - error, // send error -) { - forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), constant.TSOServiceName) - if !ok || len(forwardedHost) == 0 { - return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(errs.ErrNotFoundTSOAddr), nil - } - if forwardStream == nil || lastForwardedHost != forwardedHost { - if cancelForward != nil { - cancelForward() - } +type tsoForwarder struct { + // The original source that we need to send the response back to. + responser interface{ Send(*pdpb.TsoResponse) error } + // The context for the forwarding stream. + ctx context.Context + // The cancel function for the forwarding stream. + canceller context.CancelFunc + // The current forwarding stream. + stream tsopb.TSO_TsoClient + // The current host of the forwarding stream. + host string +} - clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) +func newTSOForwarder(responser interface{ Send(*pdpb.TsoResponse) error }) *tsoForwarder { + return &tsoForwarder{ + responser: responser, + } +} + +func (f *tsoForwarder) cancel() { + if f != nil && f.canceller != nil { + f.canceller() + } +} + +func (s *GrpcServer) handleTSOForwarding( + ctx context.Context, + forwarder *tsoForwarder, + request *pdpb.TsoRequest, + tsDeadlineCh chan<- *tsoutil.TSDeadline, +) (tsoStreamErr, sendErr error) { + // Get the latest TSO primary address. + targetHost, ok := s.GetServicePrimaryAddr(ctx, constant.TSOServiceName) + if !ok || len(targetHost) == 0 { + return errors.WithStack(errs.ErrNotFoundTSOAddr), nil + } + // Check if the forwarder is already built with the target host. + if forwarder.stream == nil || forwarder.host != targetHost { + // Cancel the old forwarder. + forwarder.cancel() + // Build a new forward stream. + clientConn, err := s.getDelegateClient(s.ctx, targetHost) if err != nil { - return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil + return errors.WithStack(err), nil } - forwardStream, forwardCtx, cancelForward, err = createTSOForwardStream(stream.Context(), clientConn) + forwarder.stream, forwarder.ctx, forwarder.canceller, err = createTSOForwardStream(ctx, clientConn) if err != nil { - return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil + return errors.WithStack(err), nil } - lastForwardedHost = forwardedHost + forwarder.host = targetHost } - tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) + // Forward the TSO request with the deadline. + tsopbResp, err := s.forwardTSORequestWithDeadLine(forwarder, request, tsDeadlineCh) if err != nil { - return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil + return errors.WithStack(err), nil } // The error types defined for tsopb and pdpb are different, so we need to convert them. @@ -193,31 +214,31 @@ func (s *GrpcServer) handleTSOForwarding(forwardCtx context.Context, forwardStre } } } - - response := &pdpb.TsoResponse{ + // Send the TSO response back to the original source. + sendErr = forwarder.responser.Send(&pdpb.TsoResponse{ Header: &pdpb.ResponseHeader{ ClusterId: tsopbResp.GetHeader().GetClusterId(), Error: pdpbErr, }, Count: tsopbResp.GetCount(), Timestamp: tsopbResp.GetTimestamp(), - } - if server != nil { - err = server.send(response) - } else { - err = stream.Send(response) - } - return forwardCtx, cancelForward, forwardStream, lastForwardedHost, nil, errors.WithStack(err) + }) + + return nil, errors.WithStack(sendErr) } func (s *GrpcServer) forwardTSORequestWithDeadLine( - forwardCtx context.Context, - cancelForward context.CancelFunc, - forwardStream tsopb.TSO_TsoClient, + forwarder *tsoForwarder, request *pdpb.TsoRequest, - tsDeadlineCh chan<- *tsoutil.TSDeadline) (*tsopb.TsoResponse, error) { - done := make(chan struct{}) - dl := tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, cancelForward) + tsDeadlineCh chan<- *tsoutil.TSDeadline, +) (*tsopb.TsoResponse, error) { + var ( + forwardCtx = forwarder.ctx + forwardCancel = forwarder.canceller + forwardStream = forwarder.stream + done = make(chan struct{}) + dl = tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, forwardCancel) + ) select { case tsDeadlineCh <- dl: case <-forwardCtx.Done(): diff --git a/server/grpc_service.go b/server/grpc_service.go index c3698763fa2e..0f6f7c6edca1 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -96,7 +96,8 @@ type pdpbTSORequest struct { err error } -func (s *tsoServer) send(m *pdpb.TsoResponse) error { +// Send wraps Send() of PD_TsoServer. +func (s *tsoServer) Send(m *pdpb.TsoResponse) error { if atomic.LoadInt32(&s.closed) == 1 { return io.EOF } @@ -493,22 +494,16 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh) var ( - doneCh chan struct{} - errCh chan error - // The following are tso forward stream related variables. - forwardStream tsopb.TSO_TsoClient - cancelForward context.CancelFunc - forwardCtx context.Context - tsoStreamErr error - lastForwardedHost string + doneCh chan struct{} + errCh chan error + forwarder = newTSOForwarder(stream) + tsoStreamErr error ) defer func() { - if cancelForward != nil { - cancelForward() - } + forwarder.cancel() if grpcutil.NeedRebuildConnection(tsoStreamErr) { - s.closeDelegateClient(lastForwardedHost) + s.closeDelegateClient(forwarder.host) } }() @@ -559,7 +554,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") return errs.ErrUnknown(err) } - forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, nil, request, tsDeadlineCh, lastForwardedHost, cancelForward) + tsoStreamErr, err = s.handleTSOForwarding(stream.Context(), forwarder, request, tsDeadlineCh) if tsoStreamErr != nil { return tsoStreamErr }