Skip to content

Commit

Permalink
mcs: add lock for forward tso stream (#9095) (#9107)
Browse files Browse the repository at this point in the history
close #9091

Signed-off-by: ti-chi-bot <[email protected]>
Signed-off-by: lhy1024 <[email protected]>

Co-authored-by: lhy1024 <[email protected]>
Co-authored-by: lhy1024 <[email protected]>
  • Loading branch information
3 people authored Mar 3, 2025
1 parent 44e25d3 commit 4a40114
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 23 deletions.
14 changes: 14 additions & 0 deletions pkg/utils/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net/url"
"strings"

"github.com/pingcap/log"
"github.com/tikv/pd/pkg/errs"
"go.etcd.io/etcd/pkg/transport"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -172,3 +175,14 @@ func GetForwardedHost(ctx context.Context) string {
}
return ""
}

// NeedRebuildConnection checks if the error is a connection error.
func NeedRebuildConnection(err error) bool {
return err == io.EOF ||
strings.Contains(err.Error(), codes.Unavailable.String()) || // Unavailable indicates the service is currently unavailable. This is a most likely a transient condition.
strings.Contains(err.Error(), codes.DeadlineExceeded.String()) || // DeadlineExceeded means operation expired before completion.
strings.Contains(err.Error(), codes.Internal.String()) || // Internal errors.
strings.Contains(err.Error(), codes.Unknown.String()) || // Unknown error.
strings.Contains(err.Error(), codes.ResourceExhausted.String()) // ResourceExhausted is returned when either the client or the server has exhausted their resources.
// Besides, we don't need to rebuild the connection if the code is Canceled, which means the client cancelled the request.
}
88 changes: 70 additions & 18 deletions server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"io"
"path"
"strconv"
"strings"
"sync/atomic"
"time"

Expand Down Expand Up @@ -48,7 +49,9 @@ import (
)

const (
heartbeatSendTimeout = 5 * time.Second
heartbeatSendTimeout = 5 * time.Second
maxRetryTimesRequestTSOServer = 3
retryIntervalRequestTSOServer = 500 * time.Millisecond
)

// gRPC errors
Expand Down Expand Up @@ -1781,31 +1784,77 @@ func checkStream(streamCtx context.Context, cancel context.CancelFunc, done chan
}

func (s *GrpcServer) getGlobalTSOFromTSOServer(ctx context.Context) (pdpb.Timestamp, error) {
forwardedHost, ok := s.GetServicePrimaryAddr(ctx, utils.TSOServiceName)
if !ok || forwardedHost == "" {
return pdpb.Timestamp{}, ErrNotFoundTSOAddr
}
forwardStream, err := s.getTSOForwardStream(forwardedHost)
if err != nil {
return pdpb.Timestamp{}, err
}
forwardStream.Send(&tsopb.TsoRequest{
request := &tsopb.TsoRequest{
Header: &tsopb.RequestHeader{
ClusterId: s.clusterID,
KeyspaceId: utils.DefaultKeyspaceID,
KeyspaceGroupId: utils.DefaultKeyspaceGroupID,
},
Count: 1,
})
ts, err := forwardStream.Recv()
if err != nil {
log.Error("get global tso from tso server failed", zap.Error(err))
return pdpb.Timestamp{}, err
}
return *ts.GetTimestamp(), nil
var (
forwardedHost string
forwardStream *streamWrapper
ts *tsopb.TsoResponse
err error
ok bool
)
handleStreamError := func(err error) (needRetry bool) {
if strings.Contains(err.Error(), errs.NotLeaderErr) {
s.updateServicePrimaryAddr(utils.TSOServiceName)
log.Warn("force to load tso primary address due to error", zap.Error(err), zap.String("tso-addr", forwardedHost))
return true
}
if grpcutil.NeedRebuildConnection(err) {
s.tsoClientPool.Lock()
delete(s.tsoClientPool.clients, forwardedHost)
s.tsoClientPool.Unlock()
log.Warn("client connection removed due to error", zap.Error(err), zap.String("tso-addr", forwardedHost))
return true
}
return false
}
for i := 0; i < maxRetryTimesRequestTSOServer; i++ {
if i > 0 {
time.Sleep(retryIntervalRequestTSOServer)
}
forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName)
if !ok || forwardedHost == "" {
return pdpb.Timestamp{}, ErrNotFoundTSOAddr
}
forwardStream, err = s.getTSOForwardStream(forwardedHost)
if err != nil {
return pdpb.Timestamp{}, err
}
start := time.Now()
forwardStream.Lock()
err = forwardStream.Send(request)
if err != nil {
if needRetry := handleStreamError(err); needRetry {
forwardStream.Unlock()
continue
}
log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost))
forwardStream.Unlock()
return pdpb.Timestamp{}, err
}
ts, err = forwardStream.Recv()
forwardStream.Unlock()
forwardTsoDuration.Observe(time.Since(start).Seconds())
if err != nil {
if needRetry := handleStreamError(err); needRetry {
continue
}
log.Error("receive response from tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost))
return pdpb.Timestamp{}, err
}
return *ts.GetTimestamp(), nil
}
log.Error("get global tso from tso primary server failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost))
return pdpb.Timestamp{}, err
}

func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoClient, error) {
func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (*streamWrapper, error) {
s.tsoClientPool.RLock()
forwardStream, ok := s.tsoClientPool.clients[forwardedHost]
s.tsoClientPool.RUnlock()
Expand All @@ -1831,11 +1880,14 @@ func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoCli
done := make(chan struct{})
ctx, cancel := context.WithCancel(s.ctx)
go checkStream(ctx, cancel, done)
forwardStream, err = tsopb.NewTSOClient(client).Tso(ctx)
tsoClient, err := tsopb.NewTSOClient(client).Tso(ctx)
done <- struct{}{}
if err != nil {
return nil, err
}
forwardStream = &streamWrapper{
TSO_TsoClient: tsoClient,
}
s.tsoClientPool.clients[forwardedHost] = forwardStream
return forwardStream, nil
}
Expand Down
9 changes: 9 additions & 0 deletions server/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ var (
Name: "maxprocs",
Help: "The value of GOMAXPROCS.",
})
forwardTsoDuration = prometheus.NewHistogram(
prometheus.HistogramOpts{
Namespace: "pd",
Subsystem: "server",
Name: "forward_tso_duration_seconds",
Help: "Bucketed histogram of processing time (s) of handled forward tso requests.",
Buckets: prometheus.ExponentialBuckets(0.0005, 2, 13),
})
)

func init() {
Expand All @@ -170,4 +178,5 @@ func init() {
prometheus.MustRegister(serviceAuditHistogram)
prometheus.MustRegister(bucketReportInterval)
prometheus.MustRegister(serverMaxProcs)
prometheus.MustRegister(forwardTsoDuration)
}
16 changes: 11 additions & 5 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import (
"github.com/tikv/pd/pkg/utils/grpcutil"
"github.com/tikv/pd/pkg/utils/jsonutil"
"github.com/tikv/pd/pkg/utils/logutil"
"github.com/tikv/pd/pkg/utils/syncutil"
"github.com/tikv/pd/pkg/utils/tsoutil"
"github.com/tikv/pd/pkg/utils/typeutil"
"github.com/tikv/pd/pkg/versioninfo"
Expand Down Expand Up @@ -123,6 +124,11 @@ var (
etcdCommittedIndexGauge = etcdStateGauge.WithLabelValues("committedIndex")
)

type streamWrapper struct {
tsopb.TSO_TsoClient
syncutil.Mutex
}

// Server is the pd server. It implements bs.Server
// nolint
type Server struct {
Expand Down Expand Up @@ -199,8 +205,8 @@ type Server struct {
clientConns sync.Map

tsoClientPool struct {
sync.RWMutex
clients map[string]tsopb.TSO_TsoClient
syncutil.RWMutex
clients map[string]*streamWrapper
}

// tsoDispatcher is used to dispatch different TSO requests to
Expand Down Expand Up @@ -254,10 +260,10 @@ func CreateServer(ctx context.Context, cfg *config.Config, services []string, le
DiagnosticsServer: sysutil.NewDiagnosticsServer(cfg.Log.File.Filename),
mode: mode,
tsoClientPool: struct {
sync.RWMutex
clients map[string]tsopb.TSO_TsoClient
syncutil.RWMutex
clients map[string]*streamWrapper
}{
clients: make(map[string]tsopb.TSO_TsoClient),
clients: make(map[string]*streamWrapper),
},
}
s.handler = newHandler(s)
Expand Down
138 changes: 138 additions & 0 deletions tests/integrations/mcs/tso/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ import (
"net/http"
"strconv"
"strings"
"sync"
"testing"
"time"

"github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -319,6 +321,142 @@ func (suite *APIServerForwardTestSuite) checkAvailableTSO() {
suite.NoError(err)
}

func TestForwardTsoConcurrently(t *testing.T) {
re := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cluster, err := tests.NewTestAPICluster(ctx, 3)
re.NoError(err)
defer cluster.Destroy()

err = cluster.RunInitialServers()
re.NoError(err)

leaderName := cluster.WaitLeader()
pdLeader := cluster.GetServer(leaderName)
backendEndpoints := pdLeader.GetAddr()
re.NoError(pdLeader.BootstrapCluster())
leader := cluster.GetServer(cluster.WaitLeader())
rc := leader.GetRaftCluster()
for i := 0; i < 3; i++ {
region := &metapb.Region{
Id: uint64(i*4 + 1),
Peers: []*metapb.Peer{{Id: uint64(i*4 + 2), StoreId: uint64(i*4 + 3)}},
StartKey: []byte{byte(i)},
EndKey: []byte{byte(i + 1)},
}
rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0]))
}

re.NoError(failpoint.Enable("github.com/tikv/pd/client/usePDServiceMode", "return(true)"))
defer func() {
re.NoError(failpoint.Disable("github.com/tikv/pd/client/usePDServiceMode"))
}()

tc, err := mcs.NewTestTSOCluster(ctx, 2, backendEndpoints)
re.NoError(err)
defer tc.Destroy()
tc.WaitForDefaultPrimaryServing(re)

wg := sync.WaitGroup{}
for i := 0; i < 50; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
pdClient, err := pd.NewClientWithContext(
context.Background(),
[]string{backendEndpoints},
pd.SecurityOption{})
re.NoError(err)
re.NotNil(pdClient)
defer pdClient.Close()
for j := 0; j < 20; j++ {
testutil.Eventually(re, func() bool {
min, err := pdClient.UpdateServiceGCSafePoint(context.Background(), fmt.Sprintf("service-%d", i), 1000, 1)
return err == nil && min == 0
})
}
}(i)
}
wg.Wait()
}

func BenchmarkForwardTsoConcurrently(b *testing.B) {
re := require.New(b)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cluster, err := tests.NewTestAPICluster(ctx, 3)
re.NoError(err)
defer cluster.Destroy()

err = cluster.RunInitialServers()
re.NoError(err)

leaderName := cluster.WaitLeader()
pdLeader := cluster.GetServer(leaderName)
backendEndpoints := pdLeader.GetAddr()
re.NoError(pdLeader.BootstrapCluster())
leader := cluster.GetServer(cluster.WaitLeader())
rc := leader.GetRaftCluster()
for i := 0; i < 3; i++ {
region := &metapb.Region{
Id: uint64(i*4 + 1),
Peers: []*metapb.Peer{{Id: uint64(i*4 + 2), StoreId: uint64(i*4 + 3)}},
StartKey: []byte{byte(i)},
EndKey: []byte{byte(i + 1)},
}
rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0]))
}

re.NoError(failpoint.Enable("github.com/tikv/pd/client/usePDServiceMode", "return(true)"))
defer func() {
re.NoError(failpoint.Disable("github.com/tikv/pd/client/usePDServiceMode"))
}()

tc, err := mcs.NewTestTSOCluster(ctx, 1, backendEndpoints)
re.NoError(err)
defer tc.Destroy()
tc.WaitForDefaultPrimaryServing(re)

initClients := func(num int) []pd.Client {
var clients []pd.Client
for i := 0; i < num; i++ {
pdClient, err := pd.NewClientWithContext(context.Background(),
[]string{backendEndpoints}, pd.SecurityOption{}, pd.WithMaxErrorRetry(1))
re.NoError(err)
re.NotNil(pdClient)
clients = append(clients, pdClient)
}
return clients
}

concurrencyLevels := []int{1, 2, 5, 10, 20}
for _, clientsNum := range concurrencyLevels {
clients := initClients(clientsNum)
b.Run(fmt.Sprintf("clients_%d", clientsNum), func(b *testing.B) {
wg := sync.WaitGroup{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j, client := range clients {
wg.Add(1)
go func(j int, client pd.Client) {
defer wg.Done()
for k := 0; k < 1000; k++ {
min, err := client.UpdateServiceGCSafePoint(context.Background(), fmt.Sprintf("service-%d", j), 1000, 1)
re.NoError(err)
re.Equal(uint64(0), min)
}
}(j, client)
}
}
wg.Wait()
})
for _, c := range clients {
c.Close()
}
}
}

func TestAdvertiseAddr(t *testing.T) {
re := require.New(t)

Expand Down

0 comments on commit 4a40114

Please sign in to comment.