Skip to content

Commit

Permalink
mcs: add option to control redirect http request (tikv#7691)
Browse files Browse the repository at this point in the history
close tikv#7690

Signed-off-by: lhy1024 <[email protected]>
Signed-off-by: pingandb <[email protected]>
  • Loading branch information
lhy1024 authored and pingandb committed Jan 18, 2024
1 parent b1cdf6b commit d2c7fbb
Show file tree
Hide file tree
Showing 8 changed files with 338 additions and 99 deletions.
6 changes: 4 additions & 2 deletions pkg/utils/apiutil/apiutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ const (
XRealIPHeader = "X-Real-Ip"
// XCallerIDHeader is used to mark the caller ID.
XCallerIDHeader = "X-Caller-ID"
// ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service.
ForwardToMicroServiceHeader = "Forward-To-Micro-Service"
// XForbiddenForwardToMicroServiceHeader is used to indicate that forwarding the request to a microservice is explicitly disallowed.
XForbiddenForwardToMicroServiceHeader = "X-Forbidden-Forward-To-MicroService"
// XForwardedToMicroServiceHeader is used to signal that the request has already been forwarded to a microservice.
XForwardedToMicroServiceHeader = "X-Forwarded-To-MicroService"

chunkSize = 4096
)
Expand Down
11 changes: 5 additions & 6 deletions pkg/utils/apiutil/serverapi/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"net/url"
"strings"

"github.com/pingcap/failpoint"
"github.com/pingcap/log"
"github.com/tikv/pd/pkg/errs"
mcsutils "github.com/tikv/pd/pkg/mcs/utils"
Expand Down Expand Up @@ -119,6 +118,9 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri
if len(h.microserviceRedirectRules) == 0 {
return false, ""
}
if r.Header.Get(apiutil.XForbiddenForwardToMicroServiceHeader) == "true" {
return false, ""
}
// Remove trailing '/' from the URL path
// It will be helpful when matching the redirect rules "schedulers" or "schedulers/{name}"
r.URL.Path = strings.TrimRight(r.URL.Path, "/")
Expand Down Expand Up @@ -200,11 +202,8 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http
return
}
clientUrls = append(clientUrls, targetAddr)
// Add a header to the response, this is not a failure injection
// it is used for testing, to check whether the request is forwarded to the micro service
failpoint.Inject("checkHeader", func() {
w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true")
})
// Add a header to the response, it is used to mark whether the request has been forwarded to the micro service.
w.Header().Add(apiutil.XForwardedToMicroServiceHeader, "true")
} else {
leader := h.s.GetMember().GetLeader()
if leader == nil {
Expand Down
9 changes: 6 additions & 3 deletions server/api/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ func newConfHandler(svr *server.Server, rd *render.Render) *confHandler {
// @Router /config [get]
func (h *confHandler) GetConfig(w http.ResponseWriter, r *http.Request) {
cfg := h.svr.GetConfig()
if h.svr.IsServiceIndependent(utils.SchedulingServiceName) {
if h.svr.IsServiceIndependent(utils.SchedulingServiceName) &&
r.Header.Get(apiutil.XForbiddenForwardToMicroServiceHeader) != "true" {
schedulingServerConfig, err := h.GetSchedulingServerConfig()
if err != nil {
h.rd.JSON(w, http.StatusInternalServerError, err.Error())
Expand Down Expand Up @@ -313,7 +314,8 @@ func getConfigMap(cfg map[string]interface{}, key []string, value interface{}) m
// @Success 200 {object} sc.ScheduleConfig
// @Router /config/schedule [get]
func (h *confHandler) GetScheduleConfig(w http.ResponseWriter, r *http.Request) {
if h.svr.IsServiceIndependent(utils.SchedulingServiceName) {
if h.svr.IsServiceIndependent(utils.SchedulingServiceName) &&
r.Header.Get(apiutil.XForbiddenForwardToMicroServiceHeader) != "true" {
cfg, err := h.GetSchedulingServerConfig()
if err != nil {
h.rd.JSON(w, http.StatusInternalServerError, err.Error())
Expand Down Expand Up @@ -386,7 +388,8 @@ func (h *confHandler) SetScheduleConfig(w http.ResponseWriter, r *http.Request)
// @Success 200 {object} sc.ReplicationConfig
// @Router /config/replicate [get]
func (h *confHandler) GetReplicationConfig(w http.ResponseWriter, r *http.Request) {
if h.svr.IsServiceIndependent(utils.SchedulingServiceName) {
if h.svr.IsServiceIndependent(utils.SchedulingServiceName) &&
r.Header.Get(apiutil.XForbiddenForwardToMicroServiceHeader) != "true" {
cfg, err := h.GetSchedulingServerConfig()
if err != nil {
h.rd.JSON(w, http.StatusInternalServerError, err.Error())
Expand Down
139 changes: 75 additions & 64 deletions tests/integrations/mcs/scheduling/api_test.go

Large diffs are not rendered by default.

17 changes: 4 additions & 13 deletions tests/integrations/mcs/tso/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,6 @@ func (suite *tsoAPITestSuite) TestGetKeyspaceGroupMembers() {

func (suite *tsoAPITestSuite) TestForwardResetTS() {
re := suite.Require()
re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader", "return(true)"))
defer func() {
re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader"))
}()

primary := suite.tsoCluster.WaitForDefaultPrimaryServing(re)
re.NotNil(primary)
Expand All @@ -115,13 +111,13 @@ func (suite *tsoAPITestSuite) TestForwardResetTS() {
// Test reset ts
input := []byte(`{"tso":"121312", "force-use-larger":true}`)
err := testutil.CheckPostJSON(dialClient, url, input,
testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully"), testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"))
testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully"), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true"))
re.NoError(err)

// Test reset ts with invalid tso
input = []byte(`{}`)
err = testutil.CheckPostJSON(dialClient, url, input,
testutil.StatusNotOK(re), testutil.StringContain(re, "invalid tso value"), testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"))
testutil.StatusNotOK(re), testutil.StringContain(re, "invalid tso value"), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true"))
re.NoError(err)
}

Expand Down Expand Up @@ -205,11 +201,6 @@ func TestTSOServerStartFirst(t *testing.T) {

func TestForwardOnlyTSONoScheduling(t *testing.T) {
re := require.New(t)
re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader", "return(true)"))
defer func() {
re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader"))
}()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tc, err := tests.NewTestAPICluster(ctx, 1)
Expand All @@ -229,14 +220,14 @@ func TestForwardOnlyTSONoScheduling(t *testing.T) {
// Test /operators, it should not forward when there is no scheduling server.
var slice []string
err = testutil.ReadGetJSON(re, dialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice,
testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader))
testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader))
re.NoError(err)
re.Empty(slice)

// Test admin/reset-ts, it should forward to tso server.
input := []byte(`{"tso":"121312", "force-use-larger":true}`)
err = testutil.CheckPostJSON(dialClient, fmt.Sprintf("%s/%s", urlPrefix, "admin/reset-ts"), input,
testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully"), testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"))
testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully"), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true"))
re.NoError(err)

// If close tso server, it should try forward to tso server, but return error in api mode.
Expand Down
49 changes: 38 additions & 11 deletions tools/pd-ctl/pdctl/command/config_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ import (

"github.com/spf13/cobra"
"github.com/tikv/pd/pkg/schedule/placement"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/reflectutil"
"github.com/tikv/pd/server/config"
)

var (
const (
configPrefix = "pd/api/v1/config"
schedulePrefix = "pd/api/v1/config/schedule"
replicatePrefix = "pd/api/v1/config/replicate"
Expand All @@ -45,6 +46,8 @@ var (
replicationModePrefix = "pd/api/v1/config/replication-mode"
ruleBundlePrefix = "pd/api/v1/config/placement-rule"
pdServerPrefix = "pd/api/v1/config/pd-server"
// flagFromAPIServer has no influence for pd mode, but it is useful for us to debug in api mode.
flagFromAPIServer = "from_api_server"
)

// NewConfigCommand return a config subcommand of rootCmd
Expand Down Expand Up @@ -74,6 +77,7 @@ func NewShowConfigCommand() *cobra.Command {
sc.AddCommand(NewShowClusterVersionCommand())
sc.AddCommand(newShowReplicationModeCommand())
sc.AddCommand(NewShowServerConfigCommand())
sc.Flags().Bool(flagFromAPIServer, false, "read data from api server rather than micor service")
return sc
}

Expand All @@ -84,6 +88,7 @@ func NewShowAllConfigCommand() *cobra.Command {
Short: "show all config of PD",
Run: showAllConfigCommandFunc,
}
sc.Flags().Bool(flagFromAPIServer, false, "read data from api server rather than micor service")
return sc
}

Expand All @@ -94,6 +99,7 @@ func NewShowScheduleConfigCommand() *cobra.Command {
Short: "show schedule config of PD",
Run: showScheduleConfigCommandFunc,
}
sc.Flags().Bool(flagFromAPIServer, false, "read data from api server rather than micor service")
return sc
}

Expand All @@ -104,6 +110,7 @@ func NewShowReplicationConfigCommand() *cobra.Command {
Short: "show replication config of PD",
Run: showReplicationConfigCommandFunc,
}
sc.Flags().Bool(flagFromAPIServer, false, "read data from api server rather than micor service")
return sc
}

Expand Down Expand Up @@ -206,7 +213,8 @@ func NewDeleteLabelPropertyConfigCommand() *cobra.Command {
}

func showConfigCommandFunc(cmd *cobra.Command, args []string) {
allR, err := doRequest(cmd, configPrefix, http.MethodGet, http.Header{})
header := buildHeader(cmd)
allR, err := doRequest(cmd, configPrefix, http.MethodGet, header)
if err != nil {
cmd.Printf("Failed to get config: %s\n", err)
return
Expand Down Expand Up @@ -261,7 +269,8 @@ var hideConfig = []string{
}

func showScheduleConfigCommandFunc(cmd *cobra.Command, args []string) {
r, err := doRequest(cmd, schedulePrefix, http.MethodGet, http.Header{})
header := buildHeader(cmd)
r, err := doRequest(cmd, schedulePrefix, http.MethodGet, header)
if err != nil {
cmd.Printf("Failed to get config: %s\n", err)
return
Expand All @@ -270,7 +279,8 @@ func showScheduleConfigCommandFunc(cmd *cobra.Command, args []string) {
}

func showReplicationConfigCommandFunc(cmd *cobra.Command, args []string) {
r, err := doRequest(cmd, replicatePrefix, http.MethodGet, http.Header{})
header := buildHeader(cmd)
r, err := doRequest(cmd, replicatePrefix, http.MethodGet, header)
if err != nil {
cmd.Printf("Failed to get config: %s\n", err)
return
Expand All @@ -288,7 +298,8 @@ func showLabelPropertyConfigCommandFunc(cmd *cobra.Command, args []string) {
}

func showAllConfigCommandFunc(cmd *cobra.Command, args []string) {
r, err := doRequest(cmd, configPrefix, http.MethodGet, http.Header{})
header := buildHeader(cmd)
r, err := doRequest(cmd, configPrefix, http.MethodGet, header)
if err != nil {
cmd.Printf("Failed to get config: %s\n", err)
return
Expand Down Expand Up @@ -437,6 +448,7 @@ func NewPlacementRulesCommand() *cobra.Command {
show.Flags().String("id", "", "rule id")
show.Flags().String("region", "", "region id")
show.Flags().Bool("detail", false, "detailed match info for region")
show.Flags().Bool(flagFromAPIServer, false, "read data from api server rather than micor service")
load := &cobra.Command{
Use: "load",
Short: "load placement rules to a file",
Expand All @@ -446,6 +458,7 @@ func NewPlacementRulesCommand() *cobra.Command {
load.Flags().String("id", "", "rule id")
load.Flags().String("region", "", "region id")
load.Flags().String("out", "rules.json", "the filename contains rules")
load.Flags().Bool(flagFromAPIServer, false, "read data from api server rather than micor service")
save := &cobra.Command{
Use: "save",
Short: "save rules from file",
Expand All @@ -461,6 +474,7 @@ func NewPlacementRulesCommand() *cobra.Command {
Short: "show rule group configuration(s)",
Run: showRuleGroupFunc,
}
ruleGroupShow.Flags().Bool(flagFromAPIServer, false, "read data from api server rather than micor service")
ruleGroupSet := &cobra.Command{
Use: "set <id> <index> <override>",
Short: "update rule group configuration",
Expand All @@ -483,6 +497,7 @@ func NewPlacementRulesCommand() *cobra.Command {
Run: getRuleBundle,
}
ruleBundleGet.Flags().String("out", "", "the output file")
ruleBundleGet.Flags().Bool(flagFromAPIServer, false, "read data from api server rather than micor service")
ruleBundleSet := &cobra.Command{
Use: "set",
Short: "set rule group config and its rules from file",
Expand All @@ -501,6 +516,7 @@ func NewPlacementRulesCommand() *cobra.Command {
Run: loadRuleBundle,
}
ruleBundleLoad.Flags().String("out", "rules.json", "the output file")
ruleBundleLoad.Flags().Bool(flagFromAPIServer, false, "read data from api server rather than micor service")
ruleBundleSave := &cobra.Command{
Use: "save",
Short: "save all group configs and rules from file",
Expand Down Expand Up @@ -561,7 +577,8 @@ func getPlacementRulesFunc(cmd *cobra.Command, args []string) {
cmd.Println(`"region" should not be specified with "group" or "id" at the same time`)
return
}
res, err := doRequest(cmd, reqPath, http.MethodGet, http.Header{})
header := buildHeader(cmd)
res, err := doRequest(cmd, reqPath, http.MethodGet, header)
if err != nil {
cmd.Println(err)
return
Expand Down Expand Up @@ -629,8 +646,8 @@ func showRuleGroupFunc(cmd *cobra.Command, args []string) {
if len(args) > 0 {
reqPath = path.Join(ruleGroupPrefix, args[0])
}

res, err := doRequest(cmd, reqPath, http.MethodGet, http.Header{})
header := buildHeader(cmd)
res, err := doRequest(cmd, reqPath, http.MethodGet, header)
if err != nil {
cmd.Println(err)
return
Expand Down Expand Up @@ -671,8 +688,8 @@ func getRuleBundle(cmd *cobra.Command, args []string) {
}

reqPath := path.Join(ruleBundlePrefix, args[0])

res, err := doRequest(cmd, reqPath, http.MethodGet, http.Header{})
header := buildHeader(cmd)
res, err := doRequest(cmd, reqPath, http.MethodGet, header)
if err != nil {
cmd.Println(err)
return
Expand Down Expand Up @@ -747,7 +764,8 @@ func delRuleBundle(cmd *cobra.Command, args []string) {
}

func loadRuleBundle(cmd *cobra.Command, args []string) {
res, err := doRequest(cmd, ruleBundlePrefix, http.MethodGet, http.Header{})
header := buildHeader(cmd)
res, err := doRequest(cmd, ruleBundlePrefix, http.MethodGet, header)
if err != nil {
cmd.Println(err)
return
Expand Down Expand Up @@ -794,3 +812,12 @@ func saveRuleBundle(cmd *cobra.Command, args []string) {

cmd.Println(res)
}

func buildHeader(cmd *cobra.Command) http.Header {
header := http.Header{}
forbiddenRedirectToMicroService, err := cmd.Flags().GetBool(flagFromAPIServer)
if err == nil && forbiddenRedirectToMicroService {
header.Add(apiutil.XForbiddenForwardToMicroServiceHeader, "true")
}
return header
}
5 changes: 5 additions & 0 deletions tools/pd-ctl/pdctl/command/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ func dial(req *http.Request) (string, error) {
if err != nil {
return "", err
}
if req.Header.Get(apiutil.XForbiddenForwardToMicroServiceHeader) == "true" {
if resp.Header.Get(apiutil.XForwardedToMicroServiceHeader) == "true" {
return string(content), errors.Errorf("the request is forwarded to micro service unexpectedly")
}
}
return string(content), nil
}

Expand Down
Loading

0 comments on commit d2c7fbb

Please sign in to comment.