Skip to content

Commit

Permalink
Refactor some common utilities which will be used by tso mcs
Browse files Browse the repository at this point in the history
This change is split from "basic implement tso gPRC service tikv#5949" tikv#5949

Signed-off-by: Bin Shi <[email protected]>
  • Loading branch information
binshi-bing committed Feb 13, 2023
1 parent e6086c4 commit 34a68cb
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
4 changes: 2 additions & 2 deletions pkg/mcs/tso/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Server struct {

// TODO: Implement the following methods defined in bs.Server

// Name returns the unique etcd Name for this server in etcd cluster.
// Name returns the unique etcd Name for this server in etcd cluster
func (s *Server) Name() string {
return ""
}
Expand Down Expand Up @@ -71,7 +71,7 @@ func (s *Server) GetHTTPClient() *http.Client {
// CreateServerWrapper encapsulates the configuration/log/metrics initialization and create the server
func CreateServerWrapper(args []string) (context.Context, context.CancelFunc, bs.Server) {
cfg := tso.NewConfig()
err := cfg.Parse(os.Args[1:])
err := cfg.Parse(args)

if cfg.Version {
printVersionInfo()
Expand Down
12 changes: 12 additions & 0 deletions pkg/utils/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,15 @@ func ResetForwardContext(ctx context.Context) context.Context {
md.Set(ForwardMetadataKey, "")
return metadata.NewOutgoingContext(ctx, md)
}

// GetForwardedHost returns the forwarded host in metadata.
func GetForwardedHost(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Debug("failed to get forwarding metadata")
}
if t, ok := md[ForwardMetadataKey]; ok {
return t[0]
}
return ""
}
25 changes: 7 additions & 18 deletions server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (s *GrpcServer) unaryMiddleware(ctx context.Context, header *pdpb.RequestHe
failpoint.Inject("customTimeout", func() {
time.Sleep(5 * time.Second)
})
forwardedHost := getForwardedHost(ctx)
forwardedHost := grpctutil.GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand Down Expand Up @@ -167,7 +167,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
}

streamCtx := stream.Context()
forwardedHost := getForwardedHost(streamCtx)
forwardedHost := GetForwardedHost(streamCtx)
if !s.isLocalRequest(forwardedHost) {
if errCh == nil {
doneCh = make(chan struct{})
Expand Down Expand Up @@ -766,7 +766,7 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error {
if err != nil {
return errors.WithStack(err)
}
forwardedHost := getForwardedHost(stream.Context())
forwardedHost := GetForwardedHost(stream.Context())
failpoint.Inject("grpcClientClosed", func() {
forwardedHost = s.GetMember().Member().GetClientUrls()[0]
})
Expand Down Expand Up @@ -861,7 +861,7 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error
return errors.WithStack(err)
}

forwardedHost := getForwardedHost(stream.Context())
forwardedHost := GetForwardedHost(stream.Context())
if !s.isLocalRequest(forwardedHost) {
if forwardStream == nil || lastForwardedHost != forwardedHost {
if cancel != nil {
Expand Down Expand Up @@ -1786,17 +1786,6 @@ func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string
return client.(*grpc.ClientConn), nil
}

func getForwardedHost(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Debug("failed to get forwarding metadata")
}
if t, ok := md[grpcutil.ForwardMetadataKey]; ok {
return t[0]
}
return ""
}

func (s *GrpcServer) isLocalRequest(forwardedHost string) bool {
failpoint.Inject("useForwardRequest", func() {
failpoint.Return(false)
Expand Down Expand Up @@ -2044,7 +2033,7 @@ func (s *GrpcServer) handleDamagedStore(stats *pdpb.StoreStats) {

// ReportMinResolvedTS implements gRPC PDServer.
func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.ReportMinResolvedTsRequest) (*pdpb.ReportMinResolvedTsResponse, error) {
forwardedHost := getForwardedHost(ctx)
forwardedHost := GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand Down Expand Up @@ -2078,7 +2067,7 @@ func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.Repo

// SetExternalTimestamp implements gRPC PDServer.
func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.SetExternalTimestampRequest) (*pdpb.SetExternalTimestampResponse, error) {
forwardedHost := getForwardedHost(ctx)
forwardedHost := GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand All @@ -2105,7 +2094,7 @@ func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.Set

// GetExternalTimestamp implements gRPC PDServer.
func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.GetExternalTimestampRequest) (*pdpb.GetExternalTimestampResponse, error) {
forwardedHost := getForwardedHost(ctx)
forwardedHost := GetForwardedHost(ctx)
if !s.isLocalRequest(forwardedHost) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
Expand Down

0 comments on commit 34a68cb

Please sign in to comment.