Skip to content

Commit

Permalink
Refactor TSO server forwarding function
Browse files Browse the repository at this point in the history
Signed-off-by: JmPotato <[email protected]>
  • Loading branch information
JmPotato committed Feb 24, 2025
1 parent 8428133 commit 5b6c547
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 64 deletions.
129 changes: 79 additions & 50 deletions server/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ import (
"github.com/tikv/pd/server/cluster"
)

// forwardTSORequest sends the TSO request with the given forward stream.
func forwardTSORequest(
ctx context.Context,
request *pdpb.TsoRequest,
forwardStream tsopb.TSO_TsoClient) (*tsopb.TsoResponse, error) {
forwardStream tsopb.TSO_TsoClient,
) (*tsopb.TsoResponse, error) {
tsopbReq := &tsopb.TsoRequest{
Header: &tsopb.RequestHeader{
ClusterId: request.GetHeader().GetClusterId(),
Expand Down Expand Up @@ -85,20 +87,15 @@ func forwardTSORequest(
// forwardTSO forward the TSO requests to the TSO service.
func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error {
var (
server = &tsoServer{stream: stream}
forwardStream tsopb.TSO_TsoClient
forwardCtx context.Context
cancelForward context.CancelFunc
tsoStreamErr error
lastForwardedHost string
server = &tsoServer{stream: stream}
forwarder = newTSOForwarder(server)
tsoStreamErr error
)
defer func() {
s.concurrentTSOProxyStreamings.Add(-1)
if cancelForward != nil {
cancelForward()
}
forwarder.cancel()
if grpcutil.NeedRebuildConnection(tsoStreamErr) {
s.closeDelegateClient(lastForwardedHost)
s.closeDelegateClient(forwarder.forwardHost)
}
}()

Expand Down Expand Up @@ -132,7 +129,7 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error {
err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive")
return errs.ErrUnknown(err)
}
forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, server, request, tsDeadlineCh, lastForwardedHost, cancelForward)
forwarder, tsoStreamErr, err = s.handleTSOForwarding(stream.Context(), forwarder, request, tsDeadlineCh)
if tsoStreamErr != nil {
return tsoStreamErr
}
Expand All @@ -142,38 +139,70 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error {
}
}

func (s *GrpcServer) handleTSOForwarding(forwardCtx context.Context, forwardStream tsopb.TSO_TsoClient, stream pdpb.PD_TsoServer, server *tsoServer,
request *pdpb.TsoRequest, tsDeadlineCh chan<- *tsoutil.TSDeadline, lastForwardedHost string, cancelForward context.CancelFunc) (
context.Context,
context.CancelFunc,
tsopb.TSO_TsoClient,
string,
error, // tso stream error
error, // send error
) {
forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), constant.TSOServiceName)
if !ok || len(forwardedHost) == 0 {
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(errs.ErrNotFoundTSOAddr), nil
type tsoForwarder struct {
// The original source that we need to send the response back to.
responser interface{ Send(*pdpb.TsoResponse) error }
// The context for the forwarding stream.
forwardCtx context.Context
// The cancel function for the forwarding stream.
forwardCancel context.CancelFunc
// The current forwarding stream.
forwardStream tsopb.TSO_TsoClient
// The current host of the forwarding stream.
forwardHost string
}

func newTSOForwarder(responser interface{ Send(*pdpb.TsoResponse) error }) *tsoForwarder {
return &tsoForwarder{
responser: responser,
}
if forwardStream == nil || lastForwardedHost != forwardedHost {
if cancelForward != nil {
cancelForward()
}
}

func (f *tsoForwarder) cancel() {
if f != nil && f.forwardCancel != nil {
f.forwardCancel()
}
}

clientConn, err := s.getDelegateClient(s.ctx, forwardedHost)
func (s *GrpcServer) handleTSOForwarding(
ctx context.Context,
curForwarder *tsoForwarder,
request *pdpb.TsoRequest,
tsDeadlineCh chan<- *tsoutil.TSDeadline,
) (
nextForwarder *tsoForwarder,
tsoStreamErr, sendErr error,
) {
// Get the latest TSO primary address.
targetHost, ok := s.GetServicePrimaryAddr(ctx, constant.TSOServiceName)
if !ok || len(targetHost) == 0 {
return curForwarder, errors.WithStack(errs.ErrNotFoundTSOAddr), nil
}
// Check if the forwarder is already built with the target host.
if curForwarder.forwardStream == nil || curForwarder.forwardHost != targetHost {
// Cancel the old forwarder.
curForwarder.cancel()
// Create a new forwarder with the target host.
nextForwarder = newTSOForwarder(curForwarder.responser)
// Build a new forward stream.
clientConn, err := s.getDelegateClient(s.ctx, targetHost)
if err != nil {
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil
return curForwarder, errors.WithStack(err), nil
}
forwardStream, forwardCtx, cancelForward, err = createTSOForwardStream(stream.Context(), clientConn)
nextForwarder.forwardStream, nextForwarder.forwardCtx, nextForwarder.forwardCancel, err = createTSOForwardStream(ctx, clientConn)
if err != nil {
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil
return curForwarder, errors.WithStack(err), nil
}
lastForwardedHost = forwardedHost
nextForwarder.forwardHost = targetHost
} else {
// Keep using the old forwarder.
nextForwarder = curForwarder
}

tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh)
// Forward the TSO request with the deadline.
tsopbResp, err := s.forwardTSORequestWithDeadLine(nextForwarder, request, tsDeadlineCh)
if err != nil {
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil
return nextForwarder, errors.WithStack(err), nil
}

// The error types defined for tsopb and pdpb are different, so we need to convert them.
Expand All @@ -193,31 +222,31 @@ func (s *GrpcServer) handleTSOForwarding(forwardCtx context.Context, forwardStre
}
}
}

response := &pdpb.TsoResponse{
// Send the TSO response back to the original source.
sendErr = curForwarder.responser.Send(&pdpb.TsoResponse{
Header: &pdpb.ResponseHeader{
ClusterId: tsopbResp.GetHeader().GetClusterId(),
Error: pdpbErr,
},
Count: tsopbResp.GetCount(),
Timestamp: tsopbResp.GetTimestamp(),
}
if server != nil {
err = server.send(response)
} else {
err = stream.Send(response)
}
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, nil, errors.WithStack(err)
})

return nextForwarder, nil, errors.WithStack(sendErr)
}

func (s *GrpcServer) forwardTSORequestWithDeadLine(
forwardCtx context.Context,
cancelForward context.CancelFunc,
forwardStream tsopb.TSO_TsoClient,
forwarder *tsoForwarder,
request *pdpb.TsoRequest,
tsDeadlineCh chan<- *tsoutil.TSDeadline) (*tsopb.TsoResponse, error) {
done := make(chan struct{})
dl := tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, cancelForward)
tsDeadlineCh chan<- *tsoutil.TSDeadline,
) (*tsopb.TsoResponse, error) {
var (
forwardCtx = forwarder.forwardCtx
forwardCancel = forwarder.forwardCancel
forwardStream = forwarder.forwardStream
done = make(chan struct{})
dl = tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, forwardCancel)
)
select {
case tsDeadlineCh <- dl:
case <-forwardCtx.Done():
Expand Down
23 changes: 9 additions & 14 deletions server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ type pdpbTSORequest struct {
err error
}

func (s *tsoServer) send(m *pdpb.TsoResponse) error {
// Send wraps Send() of PD_TsoServer.
func (s *tsoServer) Send(m *pdpb.TsoResponse) error {
if atomic.LoadInt32(&s.closed) == 1 {
return io.EOF
}
Expand Down Expand Up @@ -493,22 +494,16 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh)

var (
doneCh chan struct{}
errCh chan error
// The following are tso forward stream related variables.
forwardStream tsopb.TSO_TsoClient
cancelForward context.CancelFunc
forwardCtx context.Context
tsoStreamErr error
lastForwardedHost string
doneCh chan struct{}
errCh chan error
forwarder = newTSOForwarder(stream)
tsoStreamErr error
)

defer func() {
if cancelForward != nil {
cancelForward()
}
forwarder.cancel()
if grpcutil.NeedRebuildConnection(tsoStreamErr) {
s.closeDelegateClient(lastForwardedHost)
s.closeDelegateClient(forwarder.forwardHost)
}
}()

Expand Down Expand Up @@ -559,7 +554,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive")
return errs.ErrUnknown(err)
}
forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, nil, request, tsDeadlineCh, lastForwardedHost, cancelForward)
forwarder, tsoStreamErr, err = s.handleTSOForwarding(stream.Context(), forwarder, request, tsDeadlineCh)
if tsoStreamErr != nil {
return tsoStreamErr
}
Expand Down

0 comments on commit 5b6c547

Please sign in to comment.