Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: refactor TSO server forwarding function #9096

Merged
merged 4 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 114 additions & 96 deletions server/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,66 +39,18 @@
"github.com/tikv/pd/server/cluster"
)

func forwardTSORequest(
ctx context.Context,
request *pdpb.TsoRequest,
forwardStream tsopb.TSO_TsoClient) (*tsopb.TsoResponse, error) {
tsopbReq := &tsopb.TsoRequest{
Header: &tsopb.RequestHeader{
ClusterId: request.GetHeader().GetClusterId(),
SenderId: request.GetHeader().GetSenderId(),
KeyspaceId: constant.DefaultKeyspaceID,
KeyspaceGroupId: constant.DefaultKeyspaceGroupID,
},
Count: request.GetCount(),
}

failpoint.Inject("tsoProxySendToTSOTimeout", func() {
// block until watchDeadline routine cancels the context.
<-ctx.Done()
})

select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

if err := forwardStream.Send(tsopbReq); err != nil {
return nil, err
}

failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() {
// block until watchDeadline routine cancels the context.
<-ctx.Done()
})

select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

return forwardStream.Recv()
}

// forwardTSO forward the TSO requests to the TSO service.
func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error {
// forwardToTSOService forwards the TSO requests to the TSO service.
func (s *GrpcServer) forwardToTSOService(stream pdpb.PD_TsoServer) error {
var (
server = &tsoServer{stream: stream}
forwardStream tsopb.TSO_TsoClient
forwardCtx context.Context
cancelForward context.CancelFunc
tsoStreamErr error
lastForwardedHost string
server = &tsoServer{stream: stream}
forwarder = newTSOForwarder(server)
tsoStreamErr error
)
defer func() {
s.concurrentTSOProxyStreamings.Add(-1)
if cancelForward != nil {
cancelForward()
}
forwarder.cancel()
if grpcutil.NeedRebuildConnection(tsoStreamErr) {
s.closeDelegateClient(lastForwardedHost)
s.closeDelegateClient(forwarder.host)
}
}()

Expand Down Expand Up @@ -132,7 +84,7 @@
err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive")
return errs.ErrUnknown(err)
}
forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, server, request, tsDeadlineCh, lastForwardedHost, cancelForward)
tsoStreamErr, err = s.handleTSOForwarding(stream.Context(), forwarder, request, tsDeadlineCh)
if tsoStreamErr != nil {
return tsoStreamErr
}
Expand All @@ -142,38 +94,105 @@
}
}

func (s *GrpcServer) handleTSOForwarding(forwardCtx context.Context, forwardStream tsopb.TSO_TsoClient, stream pdpb.PD_TsoServer, server *tsoServer,
request *pdpb.TsoRequest, tsDeadlineCh chan<- *tsoutil.TSDeadline, lastForwardedHost string, cancelForward context.CancelFunc) (
context.Context,
context.CancelFunc,
tsopb.TSO_TsoClient,
string,
error, // tso stream error
error, // send error
) {
forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), constant.TSOServiceName)
if !ok || len(forwardedHost) == 0 {
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(errs.ErrNotFoundTSOAddr), nil
}
if forwardStream == nil || lastForwardedHost != forwardedHost {
if cancelForward != nil {
cancelForward()
}
type tsoForwarder struct {
// The original source that we need to send the response back to.
responser interface{ Send(*pdpb.TsoResponse) error }
// The context for the forwarding stream.
ctx context.Context
// The cancel function for the forwarding stream.
canceller context.CancelFunc
// The current forwarding stream.
stream tsopb.TSO_TsoClient
// The current host of the forwarding stream.
host string
}

func newTSOForwarder(responser interface{ Send(*pdpb.TsoResponse) error }) *tsoForwarder {
return &tsoForwarder{
responser: responser,
}
}

func (f *tsoForwarder) cancel() {
if f != nil && f.canceller != nil {
f.canceller()
}
}

clientConn, err := s.getDelegateClient(s.ctx, forwardedHost)
// forwardTSORequest sends the TSO request with the current forward stream.
func (f *tsoForwarder) forwardTSORequest(
request *pdpb.TsoRequest,
) (*tsopb.TsoResponse, error) {
tsopbReq := &tsopb.TsoRequest{
Header: &tsopb.RequestHeader{
ClusterId: request.GetHeader().GetClusterId(),
SenderId: request.GetHeader().GetSenderId(),
KeyspaceId: constant.DefaultKeyspaceID,
KeyspaceGroupId: constant.DefaultKeyspaceGroupID,
},
Count: request.GetCount(),
}

failpoint.Inject("tsoProxySendToTSOTimeout", func() {
// block until watchDeadline routine cancels the context.
<-f.ctx.Done()
})

select {
case <-f.ctx.Done():
return nil, f.ctx.Err()
default:
}

if err := f.stream.Send(tsopbReq); err != nil {
return nil, err
}

Check warning on line 149 in server/forward.go

View check run for this annotation

Codecov / codecov/patch

server/forward.go#L148-L149

Added lines #L148 - L149 were not covered by tests

failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() {
// block until watchDeadline routine cancels the context.
<-f.ctx.Done()
})

select {
case <-f.ctx.Done():
return nil, f.ctx.Err()
default:
}

return f.stream.Recv()
}

func (s *GrpcServer) handleTSOForwarding(
ctx context.Context,
forwarder *tsoForwarder,
request *pdpb.TsoRequest,
tsDeadlineCh chan<- *tsoutil.TSDeadline,
) (tsoStreamErr, sendErr error) {
// Get the latest TSO primary address.
targetHost, ok := s.GetServicePrimaryAddr(ctx, constant.TSOServiceName)
if !ok || len(targetHost) == 0 {
return errors.WithStack(errs.ErrNotFoundTSOAddr), nil
}
// Check if the forwarder is already built with the target host.
if forwarder.stream == nil || forwarder.host != targetHost {
// Cancel the old forwarder.
forwarder.cancel()
// Build a new forward stream.
clientConn, err := s.getDelegateClient(s.ctx, targetHost)
if err != nil {
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil
return errors.WithStack(err), nil

Check warning on line 183 in server/forward.go

View check run for this annotation

Codecov / codecov/patch

server/forward.go#L183

Added line #L183 was not covered by tests
}
forwardStream, forwardCtx, cancelForward, err = createTSOForwardStream(stream.Context(), clientConn)
forwarder.stream, forwarder.ctx, forwarder.canceller, err = createTSOForwardStream(ctx, clientConn)
if err != nil {
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil
return errors.WithStack(err), nil

Check warning on line 187 in server/forward.go

View check run for this annotation

Codecov / codecov/patch

server/forward.go#L187

Added line #L187 was not covered by tests
}
lastForwardedHost = forwardedHost
forwarder.host = targetHost
}

tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh)
// Forward the TSO request with the deadline.
tsopbResp, err := s.forwardTSORequestWithDeadLine(forwarder, request, tsDeadlineCh)
if err != nil {
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil
return errors.WithStack(err), nil
}

// The error types defined for tsopb and pdpb are different, so we need to convert them.
Expand All @@ -193,39 +212,38 @@
}
}
}

response := &pdpb.TsoResponse{
// Send the TSO response back to the original source.
sendErr = forwarder.responser.Send(&pdpb.TsoResponse{
Header: &pdpb.ResponseHeader{
ClusterId: tsopbResp.GetHeader().GetClusterId(),
Error: pdpbErr,
},
Count: tsopbResp.GetCount(),
Timestamp: tsopbResp.GetTimestamp(),
}
if server != nil {
err = server.send(response)
} else {
err = stream.Send(response)
}
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, nil, errors.WithStack(err)
})

return nil, errors.WithStack(sendErr)
}

func (s *GrpcServer) forwardTSORequestWithDeadLine(
forwardCtx context.Context,
cancelForward context.CancelFunc,
forwardStream tsopb.TSO_TsoClient,
forwarder *tsoForwarder,
request *pdpb.TsoRequest,
tsDeadlineCh chan<- *tsoutil.TSDeadline) (*tsopb.TsoResponse, error) {
done := make(chan struct{})
dl := tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, cancelForward)
tsDeadlineCh chan<- *tsoutil.TSDeadline,
) (*tsopb.TsoResponse, error) {
var (
forwardCtx = forwarder.ctx
forwardCancel = forwarder.canceller
done = make(chan struct{})
dl = tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, forwardCancel)
)
select {
case tsDeadlineCh <- dl:
case <-forwardCtx.Done():
return nil, forwardCtx.Err()
}

start := time.Now()
resp, err := forwardTSORequest(forwardCtx, request, forwardStream)
resp, err := forwarder.forwardTSORequest(request)
close(done)
if err != nil {
if strings.Contains(err.Error(), errs.NotLeaderErr) {
Expand Down
25 changes: 10 additions & 15 deletions server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@
err error
}

func (s *tsoServer) send(m *pdpb.TsoResponse) error {
// Send wraps Send() of PD_TsoServer.
func (s *tsoServer) Send(m *pdpb.TsoResponse) error {
if atomic.LoadInt32(&s.closed) == 1 {
return io.EOF
}
Expand Down Expand Up @@ -486,29 +487,23 @@
defer done()
}
if s.IsServiceIndependent(constant.TSOServiceName) {
return s.forwardTSO(stream)
return s.forwardToTSOService(stream)
}

tsDeadlineCh := make(chan *tsoutil.TSDeadline, 1)
go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh)

var (
doneCh chan struct{}
errCh chan error
// The following are tso forward stream related variables.
forwardStream tsopb.TSO_TsoClient
cancelForward context.CancelFunc
forwardCtx context.Context
tsoStreamErr error
lastForwardedHost string
doneCh chan struct{}
errCh chan error
forwarder = newTSOForwarder(stream)
tsoStreamErr error
)

defer func() {
if cancelForward != nil {
cancelForward()
}
forwarder.cancel()
if grpcutil.NeedRebuildConnection(tsoStreamErr) {
s.closeDelegateClient(lastForwardedHost)
s.closeDelegateClient(forwarder.host)

Check warning on line 506 in server/grpc_service.go

View check run for this annotation

Codecov / codecov/patch

server/grpc_service.go#L506

Added line #L506 was not covered by tests
}
}()

Expand Down Expand Up @@ -559,7 +554,7 @@
err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive")
return errs.ErrUnknown(err)
}
forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, nil, request, tsDeadlineCh, lastForwardedHost, cancelForward)
tsoStreamErr, err = s.handleTSOForwarding(stream.Context(), forwarder, request, tsDeadlineCh)

Check warning on line 557 in server/grpc_service.go

View check run for this annotation

Codecov / codecov/patch

server/grpc_service.go#L557

Added line #L557 was not covered by tests
if tsoStreamErr != nil {
return tsoStreamErr
}
Expand Down