Skip to content

Commit

Permalink
Refactor some common utilities which will be used by tso mcs
Browse files Browse the repository at this point in the history
This change is split from "basic implement tso gPRC service tikv#5949" tikv#5949

Signed-off-by: Bin Shi <[email protected]>
  • Loading branch information
binshi-bing committed Feb 13, 2023
1 parent e6086c4 commit ee78f47
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 86 deletions.
4 changes: 2 additions & 2 deletions pkg/mcs/tso/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Server struct {

// TODO: Implement the following methods defined in bs.Server

// Name returns the unique etcd Name for this server in etcd cluster.
// Name returns the unique etcd Name for this server in etcd cluster
func (s *Server) Name() string {
return ""
}
Expand Down Expand Up @@ -71,7 +71,7 @@ func (s *Server) GetHTTPClient() *http.Client {
// CreateServerWrapper encapsulates the configuration/log/metrics initialization and create the server
func CreateServerWrapper(args []string) (context.Context, context.CancelFunc, bs.Server) {
cfg := tso.NewConfig()
err := cfg.Parse(os.Args[1:])
err := cfg.Parse(args)

if cfg.Version {
printVersionInfo()
Expand Down
60 changes: 60 additions & 0 deletions pkg/utils/etcdutil/etcdutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"crypto/tls"
"fmt"
"math/rand"
"net/http"
"net/url"
"testing"
Expand All @@ -28,6 +29,7 @@ import (
"github.com/pingcap/log"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/utils/tempurl"
"github.com/tikv/pd/pkg/utils/typeutil"
"go.etcd.io/etcd/clientv3"
"go.etcd.io/etcd/embed"
"go.etcd.io/etcd/etcdserver"
Expand Down Expand Up @@ -230,3 +232,61 @@ func CreateClients(tlsConfig *tls.Config, acUrls []url.URL) (*clientv3.Client, *
log.Info("create etcd v3 client", zap.Strings("endpoints", endpoints))
return client, httpClient, nil
}

// InitClusterID will create a cluster ID for the given key if it hasn't existed.
// This function assumes the cluster ID has already existed and always use a
// cheaper read to retrieve it; if it doesn't exist, invoke the more expensive
// operation InitOrGetClusterID().
func InitClusterID(c *clientv3.Client, key string) (clusterID uint64, err error) {
// Get any cluster key to parse the cluster ID.
resp, err := EtcdKVGet(c, key)
if err != nil {
return 0, err
}
// If no key exist, generate a random cluster ID.
if len(resp.Kvs) == 0 {
return InitOrGetClusterID(c, key)
}
return typeutil.BytesToUint64(resp.Kvs[0].Value)
}

// InitClusterID will create a cluster ID for the given key if it hasn't existed
// with a CAS operation.
func InitOrGetClusterID(c *clientv3.Client, key string) (uint64, error) {
ctx, cancel := context.WithTimeout(c.Ctx(), DefaultRequestTimeout)
defer cancel()

// Generate a random cluster ID.
ts := uint64(time.Now().Unix())
clusterID := (ts << 32) + uint64(rand.Uint32())
value := typeutil.Uint64ToBytes(clusterID)

// Multiple servers may try to init the cluster ID at the same time.
// Only one server can commit this transaction, then other servers
// can get the committed cluster ID.
resp, err := c.Txn(ctx).
If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)).
Then(clientv3.OpPut(key, string(value))).
Else(clientv3.OpGet(key)).
Commit()
if err != nil {
return 0, errs.ErrEtcdTxnInternal.Wrap(err).GenWithStackByCause()
}

// Txn commits ok, return the generated cluster ID.
if resp.Succeeded {
return clusterID, nil
}

// Otherwise, parse the committed cluster ID.
if len(resp.Responses) == 0 {
return 0, errs.ErrEtcdTxnConflict.FastGenByArgs()
}

response := resp.Responses[0].GetResponseRange()
if response == nil || len(response.Kvs) != 1 {
return 0, errs.ErrEtcdTxnConflict.FastGenByArgs()
}

return typeutil.BytesToUint64(response.Kvs[0].Value)
}
33 changes: 33 additions & 0 deletions pkg/utils/etcdutil/etcdutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,36 @@ func TestEtcdKVPutWithTTL(t *testing.T) {
re.NoError(err)
re.Equal(int64(0), resp.Count)
}

func TestInitClusterID(t *testing.T) {
t.Parallel()
re := require.New(t)
cfg := NewTestSingleConfig(t)
etcd, err := embed.StartEtcd(cfg)
defer func() {
etcd.Close()
}()
re.NoError(err)

ep := cfg.LCUrls[0].String()
client, err := clientv3.New(clientv3.Config{
Endpoints: []string{ep},
})
re.NoError(err)

<-etcd.Server.ReadyNotify()

pdClusterIDPath := "test/TestInitClusterID/pd/cluster_id"
// Get any cluster key to parse the cluster ID.
resp, err := EtcdKVGet(client, pdClusterIDPath)
re.NoError(err)
re.Equal(0, len(resp.Kvs))

clusterID, err := InitClusterID(client, pdClusterIDPath)
re.NoError(err)
re.NotEqual(0, clusterID)

clusterID1, err := InitClusterID(client, pdClusterIDPath)
re.NoError(err)
re.Equal(clusterID, clusterID1)
}
12 changes: 12 additions & 0 deletions pkg/utils/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,15 @@ func ResetForwardContext(ctx context.Context) context.Context {
md.Set(ForwardMetadataKey, "")
return metadata.NewOutgoingContext(ctx, md)
}

// GetForwardedHost returns the forwarded host in metadata.
func GetForwardedHost(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Debug("failed to get forwarding metadata")
}
if t, ok := md[ForwardMetadataKey]; ok {
return t[0]
}
return ""
}
26 changes: 7 additions & 19 deletions server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

Expand Down Expand Up @@ -74,7 +73,7 @@ func (s *GrpcServer) unaryMiddleware(ctx context.Context, header *pdpb.RequestHe
failpoint.Inject("customTimeout", func() {
time.Sleep(5 * time.Second)
})
forwardedHost := getForwardedHost(ctx)
forwardedHost := grpcutil.GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand Down Expand Up @@ -167,7 +166,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
}

streamCtx := stream.Context()
forwardedHost := getForwardedHost(streamCtx)
forwardedHost := grpcutil.GetForwardedHost(streamCtx)
if !s.isLocalRequest(forwardedHost) {
if errCh == nil {
doneCh = make(chan struct{})
Expand Down Expand Up @@ -766,7 +765,7 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error {
if err != nil {
return errors.WithStack(err)
}
forwardedHost := getForwardedHost(stream.Context())
forwardedHost := grpcutil.GetForwardedHost(stream.Context())
failpoint.Inject("grpcClientClosed", func() {
forwardedHost = s.GetMember().Member().GetClientUrls()[0]
})
Expand Down Expand Up @@ -861,7 +860,7 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error
return errors.WithStack(err)
}

forwardedHost := getForwardedHost(stream.Context())
forwardedHost := grpcutil.GetForwardedHost(stream.Context())
if !s.isLocalRequest(forwardedHost) {
if forwardStream == nil || lastForwardedHost != forwardedHost {
if cancel != nil {
Expand Down Expand Up @@ -1786,17 +1785,6 @@ func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string
return client.(*grpc.ClientConn), nil
}

func getForwardedHost(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Debug("failed to get forwarding metadata")
}
if t, ok := md[grpcutil.ForwardMetadataKey]; ok {
return t[0]
}
return ""
}

func (s *GrpcServer) isLocalRequest(forwardedHost string) bool {
failpoint.Inject("useForwardRequest", func() {
failpoint.Return(false)
Expand Down Expand Up @@ -2044,7 +2032,7 @@ func (s *GrpcServer) handleDamagedStore(stats *pdpb.StoreStats) {

// ReportMinResolvedTS implements gRPC PDServer.
func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.ReportMinResolvedTsRequest) (*pdpb.ReportMinResolvedTsResponse, error) {
forwardedHost := getForwardedHost(ctx)
forwardedHost := grpcutil.GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand Down Expand Up @@ -2078,7 +2066,7 @@ func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.Repo

// SetExternalTimestamp implements gRPC PDServer.
func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.SetExternalTimestampRequest) (*pdpb.SetExternalTimestampResponse, error) {
forwardedHost := getForwardedHost(ctx)
forwardedHost := grpcutil.GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand All @@ -2105,7 +2093,7 @@ func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.Set

// GetExternalTimestamp implements gRPC PDServer.
func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.GetExternalTimestampRequest) (*pdpb.GetExternalTimestampResponse, error) {
forwardedHost := getForwardedHost(ctx)
forwardedHost := grpcutil.GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand Down
18 changes: 1 addition & 17 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ func (s *Server) AddStartCallback(callbacks ...func()) {

func (s *Server) startServer(ctx context.Context) error {
var err error
if err = s.initClusterID(); err != nil {
if s.clusterID, err = etcdutil.InitClusterID(s.client, pdClusterIDPath); err != nil {
return err
}
log.Info("init cluster id", zap.Uint64("cluster-id", s.clusterID))
Expand Down Expand Up @@ -408,22 +408,6 @@ func (s *Server) startServer(ctx context.Context) error {
return nil
}

func (s *Server) initClusterID() error {
// Get any cluster key to parse the cluster ID.
resp, err := etcdutil.EtcdKVGet(s.client, pdClusterIDPath)
if err != nil {
return err
}

// If no key exist, generate a random cluster ID.
if len(resp.Kvs) == 0 {
s.clusterID, err = initOrGetClusterID(s.client, pdClusterIDPath)
return err
}
s.clusterID, err = typeutil.BytesToUint64(resp.Kvs[0].Value)
return err
}

// AddCloseCallback adds a callback in the Close phase.
func (s *Server) AddCloseCallback(callbacks ...func()) {
s.closeCallbacks = append(s.closeCallbacks, callbacks...)
Expand Down
48 changes: 0 additions & 48 deletions server/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,21 @@ package server
import (
"context"
"fmt"
"math/rand"
"net/http"
"strings"
"time"

"github.com/gorilla/mux"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/pdpb"
"github.com/pingcap/log"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/etcdutil"
"github.com/tikv/pd/pkg/utils/typeutil"
"github.com/tikv/pd/pkg/versioninfo"
"github.com/tikv/pd/server/config"
"github.com/urfave/negroni"
"go.etcd.io/etcd/clientv3"
"go.uber.org/zap"
)

const (
requestTimeout = etcdutil.DefaultRequestTimeout
)

// LogPDInfo prints the PD version information.
func LogPDInfo() {
log.Info("Welcome to Placement Driver (PD)")
Expand Down Expand Up @@ -88,45 +79,6 @@ func CheckPDVersion(opt *config.PersistOptions) {
}
}

func initOrGetClusterID(c *clientv3.Client, key string) (uint64, error) {
ctx, cancel := context.WithTimeout(c.Ctx(), requestTimeout)
defer cancel()

// Generate a random cluster ID.
ts := uint64(time.Now().Unix())
clusterID := (ts << 32) + uint64(rand.Uint32())
value := typeutil.Uint64ToBytes(clusterID)

// Multiple PDs may try to init the cluster ID at the same time.
// Only one PD can commit this transaction, then other PDs can get
// the committed cluster ID.
resp, err := c.Txn(ctx).
If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)).
Then(clientv3.OpPut(key, string(value))).
Else(clientv3.OpGet(key)).
Commit()
if err != nil {
return 0, errs.ErrEtcdTxnInternal.Wrap(err).GenWithStackByCause()
}

// Txn commits ok, return the generated cluster ID.
if resp.Succeeded {
return clusterID, nil
}

// Otherwise, parse the committed cluster ID.
if len(resp.Responses) == 0 {
return 0, errs.ErrEtcdTxnConflict.FastGenByArgs()
}

response := resp.Responses[0].GetResponseRange()
if response == nil || len(response.Kvs) != 1 {
return 0, errs.ErrEtcdTxnConflict.FastGenByArgs()
}

return typeutil.BytesToUint64(response.Kvs[0].Value)
}

func checkBootstrapRequest(clusterID uint64, req *pdpb.BootstrapRequest) error {
// TODO: do more check for request fields validation.

Expand Down

0 comments on commit ee78f47

Please sign in to comment.